use std::collections::HashMap;
use crate::{Device, Tensor, TensorError};
pub fn save(file_path: &str, tensors: HashMap<String, Tensor>) -> Result<(), TensorError> {
let candle_tensor = tensors
.iter()
.map(|(name, t)| Ok((name, t.to_candle_tensor()?)))
.collect::<Result<HashMap<_, _>, TensorError>>()?;
candle_core::safetensors::save(&candle_tensor, file_path)?;
Ok(())
}
pub fn load(file_path: &str, device: &Device) -> Result<HashMap<String, Tensor>, TensorError> {
let candle_tensor = candle_core::safetensors::load(file_path, device)?;
let tensors = candle_tensor
.iter()
.map(|(name, tensor)| {
let name = name.to_string();
let tensor = Tensor::from_candle_tensor(tensor.clone(), device, false)?;
Ok((name, tensor))
})
.collect::<Result<HashMap<_, _>, TensorError>>()?;
Ok(tensors)
}