#[cfg(feature = "safetensors")]
use crate::tensor::Tensor;
#[cfg(feature = "safetensors")]
use memmap2::Mmap;
#[cfg(feature = "safetensors")]
use num_traits::Float;
#[cfg(feature = "safetensors")]
use safetensors::{tensor::Dtype, SafeTensors};
#[cfg(feature = "safetensors")]
use std::collections::HashMap;
#[cfg(feature = "safetensors")]
use std::fs::File;
#[cfg(feature = "safetensors")]
use std::path::Path;
#[cfg(feature = "safetensors")]
#[derive(Debug)]
pub enum SafetensorsError {
IoError(std::io::Error),
SafetensorsError(safetensors::SafeTensorError),
ConversionError(String),
}
#[cfg(feature = "safetensors")]
impl From<std::io::Error> for SafetensorsError {
fn from(error: std::io::Error) -> Self {
SafetensorsError::IoError(error)
}
}
#[cfg(feature = "safetensors")]
impl From<safetensors::SafeTensorError> for SafetensorsError {
fn from(error: safetensors::SafeTensorError) -> Self {
SafetensorsError::SafetensorsError(error)
}
}
#[cfg(feature = "safetensors")]
impl std::fmt::Display for SafetensorsError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SafetensorsError::IoError(e) => write!(f, "IO Error: {}", e),
SafetensorsError::SafetensorsError(e) => write!(f, "Safetensors Error: {}", e),
SafetensorsError::ConversionError(e) => write!(f, "Conversion Error: {}", e),
}
}
}
#[cfg(feature = "safetensors")]
impl std::error::Error for SafetensorsError {}
#[cfg(feature = "safetensors")]
pub struct SafetensorsLoader {
_mmap: Mmap, tensors: SafeTensors<'static>,
}
#[cfg(feature = "safetensors")]
impl SafetensorsLoader {
pub fn from_file<P: AsRef<Path>>(path: P) -> Result<Self, SafetensorsError> {
let file = File::open(path)?;
let mmap = unsafe { Mmap::map(&file)? };
let tensors = unsafe {
let data: &'static [u8] = std::mem::transmute(mmap.as_ref());
SafeTensors::deserialize(data)?
};
Ok(Self {
_mmap: mmap,
tensors,
})
}
pub fn tensor_names(&self) -> Vec<&str> {
self.tensors.names().iter().map(|s| s.as_str()).collect()
}
pub fn load_tensor<T: Float + 'static>(
&self,
name: &str,
) -> Result<Tensor<T>, SafetensorsError> {
let tensor = self.tensors.tensor(name)?;
let shape = tensor.shape().to_vec();
let data = match tensor.dtype() {
Dtype::F32 => {
let data_slice = tensor.data();
let f32_data: &[f32] = bytemuck::cast_slice(data_slice);
f32_data.iter().map(|&x| T::from(x).unwrap()).collect()
}
Dtype::F64 => {
let data_slice = tensor.data();
let f64_data: &[f64] = bytemuck::cast_slice(data_slice);
f64_data.iter().map(|&x| T::from(x).unwrap()).collect()
}
Dtype::I32 => {
let data_slice = tensor.data();
let i32_data: &[i32] = bytemuck::cast_slice(data_slice);
i32_data
.iter()
.map(|&x| T::from(x as f64).unwrap())
.collect()
}
Dtype::I64 => {
let data_slice = tensor.data();
let i64_data: &[i64] = bytemuck::cast_slice(data_slice);
i64_data
.iter()
.map(|&x| T::from(x as f64).unwrap())
.collect()
}
_ => {
return Err(SafetensorsError::ConversionError(format!(
"Unsupported dtype: {:?}",
tensor.dtype()
)))
}
};
Ok(Tensor::from_vec(data, shape))
}
pub fn load_all_tensors<T: Float + 'static>(
&self,
) -> Result<HashMap<String, Tensor<T>>, SafetensorsError> {
let mut tensors = HashMap::new();
for name in self.tensor_names() {
let tensor = self.load_tensor::<T>(name)?;
tensors.insert(name.to_string(), tensor);
}
Ok(tensors)
}
}
#[cfg(feature = "safetensors")]
pub struct SafetensorsSaver;
#[cfg(feature = "safetensors")]
impl SafetensorsSaver {
pub fn save_to_file<T: Float + 'static, P: AsRef<Path>>(
tensors: &HashMap<String, Tensor<T>>,
path: P,
) -> Result<(), SafetensorsError> {
use safetensors::tensor::{Dtype, TensorView};
let mut tensor_data: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::new();
for (name, tensor) in tensors {
let shape = tensor.shape().to_vec();
let data_f32: Vec<f32> = tensor.data.iter().map(|&x| x.to_f32().unwrap()).collect();
let bytes = bytemuck::cast_slice(&data_f32).to_vec();
tensor_data.push((name.clone(), shape, bytes));
}
let mut data_map = HashMap::new();
for (name, shape, bytes) in &tensor_data {
let dtype = Dtype::F32;
let tensor_view =
TensorView::new(dtype, shape.clone(), bytes.as_slice()).map_err(|e| {
SafetensorsError::ConversionError(format!("TensorView error: {}", e))
})?;
data_map.insert(name.clone(), tensor_view);
}
let safetensor_data = safetensors::serialize(&data_map, &None)?;
std::fs::write(path, safetensor_data)?;
Ok(())
}
}
#[cfg(test)]
#[cfg(feature = "safetensors")]
mod tests {
use super::*;
use tempfile::NamedTempFile;
#[test]
fn test_safetensors_save_load() {
let mut tensors = HashMap::new();
let tensor = Tensor::<f32>::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
tensors.insert("test_tensor".to_string(), tensor);
let temp_file = NamedTempFile::new().unwrap();
SafetensorsSaver::save_to_file(&tensors, temp_file.path()).unwrap();
let loader = SafetensorsLoader::from_file(temp_file.path()).unwrap();
let loaded_tensor: Tensor<f32> = loader.load_tensor("test_tensor").unwrap();
assert_eq!(loaded_tensor.shape(), &[2, 2]);
let expected_data = vec![1.0, 2.0, 3.0, 4.0];
assert_eq!(
loaded_tensor.data.iter().copied().collect::<Vec<_>>(),
expected_data
);
}
#[test]
fn test_load_all_tensors() {
let mut tensors = HashMap::new();
let tensor1 = Tensor::<f32>::from_vec(vec![1.0, 2.0], vec![2]);
let tensor2 = Tensor::<f32>::from_vec(vec![3.0, 4.0, 5.0], vec![3]);
tensors.insert("tensor1".to_string(), tensor1);
tensors.insert("tensor2".to_string(), tensor2);
let temp_file = NamedTempFile::new().unwrap();
SafetensorsSaver::save_to_file(&tensors, temp_file.path()).unwrap();
let loader = SafetensorsLoader::from_file(temp_file.path()).unwrap();
let loaded_tensors: HashMap<String, Tensor<f32>> = loader.load_all_tensors().unwrap();
assert_eq!(loaded_tensors.len(), 2);
assert!(loaded_tensors.contains_key("tensor1"));
assert!(loaded_tensors.contains_key("tensor2"));
}
}