use safetensors;
use ndarray;
use std::{borrow::Cow, mem::size_of};
use std::error::Error;
pub struct TensorViewWithDataBuffer{
buf: Vec<u8>,
dtype: safetensors::Dtype,
shape: Vec<usize>,
}
impl TensorViewWithDataBuffer {
pub fn to_tensor_view<'data>(&'data self) -> safetensors::tensor::TensorView<'data> {
safetensors::tensor::TensorView::new(
self.dtype,
self.shape.clone(),
self.buf.as_ref()
).unwrap()
}
pub fn new<A, S, D>(array: &ndarray::ArrayBase<S, D>) -> TensorViewWithDataBuffer
where
A: CommonSupportedElement,
S: ndarray::Data<Elem = A>,
D:ndarray::Dimension
{
let shape = Vec::from(array.shape());
let one_dim_array = array.to_shape(
((array.len(),), ndarray::Order::RowMajor)
).unwrap();
let v = one_dim_array.to_vec();
let mut buf: Vec<u8> = Vec::with_capacity(size_of::<A>() * v.len());
for value in v {
value.extend_byte_vec(&mut buf);
}
TensorViewWithDataBuffer {
dtype: A::safetensors_dtype(),
shape,
buf
}
}
}
impl<'data> safetensors::View for TensorViewWithDataBuffer {
fn dtype(&self) -> safetensors::Dtype {
self.dtype
}
fn shape(&self) -> &[usize] {
&self.shape
}
fn data(&self)-> Cow<'_, [u8]> {
Cow::Borrowed(self.buf.as_ref())
}
fn data_len(&self) -> usize {
self.buf.len()
}
}
pub trait CommonSupportedElement: Clone + ndarray::NdFloat {
fn extend_byte_vec(&self, v: &mut Vec<u8>);
fn safetensors_dtype() -> safetensors::Dtype;
fn from_bytes(bytes: &[u8]) -> Self;
}
impl CommonSupportedElement for f32 {
fn extend_byte_vec(&self, v: &mut Vec<u8>) {
v.extend_from_slice(&self.to_le_bytes());
}
fn safetensors_dtype() -> safetensors::Dtype {
safetensors::Dtype::F32
}
fn from_bytes(bytes: &[u8]) -> Self {
let bytes_fixed: [u8; 4] = [bytes[0], bytes[1], bytes[2], bytes[3]];
f32::from_le_bytes(bytes_fixed)
}
}
impl CommonSupportedElement for f64 {
fn extend_byte_vec(&self, v: &mut Vec<u8>) {
v.extend_from_slice(&self.to_le_bytes());
}
fn safetensors_dtype() -> safetensors::Dtype {
safetensors::Dtype::F64
}
fn from_bytes(bytes: &[u8]) -> Self {
let bytes_fixed: [u8; 8] = [bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7]];
f64::from_le_bytes(bytes_fixed)
}
}
#[derive(Debug, Clone, Copy)]
pub struct TypeMismatchedError {
expected_type: safetensors::Dtype,
actual_type: safetensors::Dtype
}
impl Error for TypeMismatchedError {}
impl std::fmt::Display for TypeMismatchedError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Expected data type to be {:?}, but found {:?}", self.expected_type, self.actual_type)
}
}
pub fn parse_tensor_view_data<A>(view: &safetensors::tensor::TensorView) -> Result<ndarray::ArrayBase<ndarray::OwnedRepr<A>, ndarray::Dim<ndarray::IxDynImpl>>, TypeMismatchedError>
where
A: CommonSupportedElement,
{
if A::safetensors_dtype() != view.dtype() {
return Err(TypeMismatchedError{
expected_type: A::safetensors_dtype(),
actual_type: view.dtype()
});
}
let dtype_size = size_of::<A>();
let data = view.data();
let shape = Vec::from(view.shape());
let mut values: Vec<A> = Vec::with_capacity(data.len() / dtype_size);
for idx in (0..data.len()).step_by(dtype_size) {
values.push(A::from_bytes(&data[idx..(idx+dtype_size)]))
}
let array = ndarray::Array::from_vec(values);
Ok(array.into_shape_with_order(shape).unwrap())
}
pub fn parse_tensors<A>(tensors: &safetensors::SafeTensors) -> Result<Vec<(String, ndarray::ArrayBase<ndarray::OwnedRepr<A>, ndarray::Dim<ndarray::IxDynImpl>>)>, TypeMismatchedError>
where
A: CommonSupportedElement,
{
let mut arrays = Vec::with_capacity(tensors.len());
for (name, tensor) in tensors.iter() {
let array = parse_tensor_view_data::<A>(&tensor)?;
arrays.push((String::from(name), array));
}
Ok(arrays)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
pub fn test_serialize_and_deserialize_f32(){
let x = ndarray::array![[1.0f32, -1.0f32], [2.0f32, -2.0f32]];
let y = ndarray::array![3.14f32, 2.718f32];
let data = vec![("x", TensorViewWithDataBuffer::new(&x)), ("y", TensorViewWithDataBuffer::new(&y))];
let serialized_data = safetensors::serialize(data, &None).unwrap();
let deserialized_data = parse_tensors::<f32>(&safetensors::SafeTensors::deserialize(&serialized_data).unwrap()).unwrap();
let mut data_map = HashMap::new();
deserialized_data.iter().for_each(|(name, array)| { data_map.insert(name.clone(), array); });
let d_x = data_map.get("x").unwrap().to_shape([2, 2]).unwrap().to_owned();
let d_y = data_map.get("y").unwrap().to_shape([2]).unwrap().to_owned();
assert_eq!(x, d_x);
assert_eq!(y, d_y);
}
}