use std::collections::HashMap;
use std::path::Path;
use crate::common::{
adapter::PyTorchAdapter,
tensor_snapshot::{TensorSnapshotWrapper, print_debug_info},
};
use burn::record::PrecisionSettings;
use burn::{
record::serde::{
data::{remap, unflatten},
de::Deserializer,
},
tensor::backend::Backend,
};
use burn_store::pytorch::PytorchReader;
use regex::Regex;
use serde::de::DeserializeOwned;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("Store error: {0}")]
Store(#[from] burn_store::pytorch::PytorchError),
#[error("Serde error: {0}")]
Serde(#[from] burn::record::serde::error::Error),
#[error("Other error: {0}")]
Other(String),
}
pub fn from_file<PS, D, B>(
path: &Path,
key_remap: Vec<(Regex, String)>,
top_level_key: Option<&str>,
debug: bool,
) -> Result<D, Error>
where
D: DeserializeOwned,
PS: PrecisionSettings,
B: Backend,
{
let reader = if let Some(key) = top_level_key {
PytorchReader::with_top_level_key(path, key)?
} else {
PytorchReader::new(path)?
};
let tensors: HashMap<String, TensorSnapshotWrapper> = reader
.into_tensors()
.into_iter()
.map(|(key, snapshot)| (key, TensorSnapshotWrapper(snapshot)))
.collect();
let (tensors, remapped_keys) = remap(tensors, key_remap);
if debug {
print_debug_info(&tensors, remapped_keys);
}
let nested_value = unflatten::<PS, _>(tensors)?;
let deserializer = Deserializer::<PyTorchAdapter<PS, B>>::new(nested_value, true);
let value = D::deserialize(deserializer)?;
Ok(value)
}
impl From<Error> for burn::record::RecorderError {
fn from(error: Error) -> Self {
burn::record::RecorderError::DeserializeError(error.to_string())
}
}