use std::{collections::HashMap, path::Path};
use burn::{
record::{
PrecisionSettings,
serde::{
adapter::DefaultAdapter,
data::{remap, unflatten},
de::Deserializer,
},
},
tensor::backend::Backend,
};
use candle_core::{Device, safetensors};
use regex::Regex;
use serde::de::DeserializeOwned;
use super::super::common::adapter::PyTorchAdapter;
use super::recorder::AdapterType;
use crate::common::candle::{CandleTensor, Error, print_debug_info};
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
debug: bool,
adapter_type: AdapterType,
) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
let tensors: HashMap<String, CandleTensor> = safetensors::load(path, &Device::Cpu)?
.into_iter()
.map(|(key, tensor)| (key, CandleTensor(tensor)))
.collect();
let (tensors, remapped_keys) = remap(tensors, key_remap);
if debug {
print_debug_info(&tensors, remapped_keys);
}
let nested_value = unflatten::<PS, _>(tensors)?;
let value = match adapter_type {
AdapterType::PyTorch => D::deserialize(Deserializer::<PyTorchAdapter<PS, B>>::new(
nested_value,
true, ))?,
AdapterType::NoAdapter => {
D::deserialize(Deserializer::<DefaultAdapter>::new(nested_value, true))?
}
};
Ok(value)
}