use std::collections::HashMap;
use std::fs;
use std::path::Path;
use safetensors::tensor::SafeTensors;
use safetensors::Dtype;
use crate::autograd::AutogradError;
use crate::tensor::Tensor;
pub fn save_safetensors(
dict: &HashMap<String, Tensor>,
path: &Path,
) -> Result<(), AutogradError> {
let mut data_map: Vec<(String, Vec<u8>, Vec<usize>)> = Vec::with_capacity(dict.len());
for (name, tensor) in dict {
let guard = tensor.storage.data();
let bytes: &[u8] = bytemuck::cast_slice(&*guard);
data_map.push((
name.clone(),
bytes.to_vec(),
tensor.shape().to_vec(),
));
}
let views: Vec<(&str, safetensors::tensor::TensorView<'_>)> = data_map
.iter()
.map(|(name, bytes, shape)| {
(
name.as_str(),
safetensors::tensor::TensorView::new(Dtype::F32, shape.clone(), bytes)
.expect("invalid tensor view"),
)
})
.collect();
let serialized = safetensors::tensor::serialize(views, &None).map_err(|e| {
AutogradError::StateError {
key: String::new(),
message: format!("safetensors serialize error: {}", e),
}
})?;
fs::write(path, &serialized).map_err(|e| AutogradError::StateError {
key: String::new(),
message: format!("IO write error: {}", e),
})?;
Ok(())
}
pub fn load_safetensors(
path: &Path,
) -> Result<HashMap<String, Tensor>, AutogradError> {
let bytes = fs::read(path).map_err(|e| AutogradError::StateError {
key: String::new(),
message: format!("IO read error: {}", e),
})?;
let tensors = SafeTensors::deserialize(&bytes).map_err(|e| {
AutogradError::StateError {
key: String::new(),
message: format!("safetensors parse error: {}", e),
}
})?;
let mut dict = HashMap::new();
for (name, view) in tensors.tensors() {
if view.dtype() != Dtype::F32 {
return Err(AutogradError::StateError {
key: name.to_string(),
message: format!("unsupported dtype {:?}, expected F32", view.dtype()),
});
}
let data_bytes = view.data();
let shape: Vec<usize> = view.shape().to_vec();
if data_bytes.len() % 4 != 0 {
return Err(AutogradError::StateError {
key: name.to_string(),
message: format!(
"data length {} is not a multiple of 4 (F32 size)",
data_bytes.len(),
),
});
}
let floats: Vec<f32> = data_bytes
.chunks_exact(4)
.map(|chunk| {
let arr: [u8; 4] = chunk.try_into().expect("chunk must be 4 bytes");
f32::from_le_bytes(arr)
})
.collect();
dict.insert(name.to_string(), Tensor::new(floats, shape));
}
Ok(dict)
}