use crate::{
ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
TensorSnapshot, map_indices_contiguous,
};
use alloc::collections::BTreeMap;
use alloc::format;
use alloc::string::{String, ToString};
use alloc::vec::Vec;
use burn_tensor::backend::Backend;
use core::fmt;
use std::path::PathBuf;
use super::reader::{PytorchError as ReaderError, PytorchReader};
#[derive(Debug)]
pub enum PytorchStoreError {
Reader(ReaderError),
Io(std::io::Error),
TensorNotFound(String),
ValidationFailed(String),
Other(String),
}
impl fmt::Display for PytorchStoreError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
Self::Other(msg) => write!(f, "{}", msg),
}
}
}
impl std::error::Error for PytorchStoreError {}
impl From<ReaderError> for PytorchStoreError {
fn from(e: ReaderError) -> Self {
PytorchStoreError::Reader(e)
}
}
impl From<std::io::Error> for PytorchStoreError {
fn from(e: std::io::Error) -> Self {
PytorchStoreError::Io(e)
}
}
pub struct PytorchStore {
pub(crate) path: PathBuf,
pub(crate) filter: PathFilter,
pub(crate) remapper: KeyRemapper,
pub(crate) validate: bool,
pub(crate) allow_partial: bool,
pub(crate) top_level_key: Option<String>,
pub(crate) skip_enum_variants: bool,
pub(crate) map_indices_contiguous: bool,
snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
}
impl PytorchStore {
pub fn from_file(path: impl Into<PathBuf>) -> Self {
Self {
path: path.into(),
filter: PathFilter::new(),
remapper: KeyRemapper::new(),
validate: true,
allow_partial: false,
top_level_key: None,
skip_enum_variants: true,
map_indices_contiguous: true,
snapshots_cache: None,
}
}
pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
self.top_level_key = Some(key.into());
self
}
pub fn filter(mut self, filter: PathFilter) -> Self {
self.filter = filter;
self
}
pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
self.filter = self.filter.with_regex(pattern);
self
}
pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
self.filter = self.filter.with_regexes(patterns);
self
}
pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
self.filter = self.filter.with_full_path(path);
self
}
pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.filter = self.filter.with_full_paths(paths);
self
}
pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
self.filter = self.filter.with_predicate(predicate);
self
}
pub fn with_predicates<I>(mut self, predicates: I) -> Self
where
I: IntoIterator<Item = fn(&str, &str) -> bool>,
{
self.filter = self.filter.with_predicates(predicates);
self
}
pub fn match_all(mut self) -> Self {
self.filter = self.filter.match_all();
self
}
pub fn remap(mut self, remapper: KeyRemapper) -> Self {
self.remapper = remapper;
self
}
pub fn with_key_remapping(
mut self,
from_pattern: impl AsRef<str>,
to_pattern: impl Into<String>,
) -> Self {
self.remapper = self
.remapper
.add_pattern(from_pattern, to_pattern)
.expect("Invalid regex pattern");
self
}
pub fn validate(mut self, validate: bool) -> Self {
self.validate = validate;
self
}
pub fn allow_partial(mut self, allow: bool) -> Self {
self.allow_partial = allow;
self
}
pub fn skip_enum_variants(mut self, skip: bool) -> Self {
self.skip_enum_variants = skip;
self
}
pub fn map_indices_contiguous(mut self, map: bool) -> Self {
self.map_indices_contiguous = map;
self
}
fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
if self.remapper.is_empty() {
return snapshots;
}
let (remapped, _) = self.remapper.remap(snapshots);
remapped
}
fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {
let reader = if let Some(ref key) = self.top_level_key {
PytorchReader::with_top_level_key(&self.path, key)?
} else {
PytorchReader::new(&self.path)?
};
Ok(reader)
}
}
impl ModuleStore for PytorchStore {
type Error = PytorchStoreError;
fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
_module: &M,
) -> Result<(), Self::Error> {
Err(PytorchStoreError::Other(
"Saving to PyTorch format is not yet supported. Use other formats for saving."
.to_string(),
))
}
fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
&mut self,
module: &mut M,
) -> Result<ApplyResult, Self::Error> {
let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
let filter_opt = if self.filter.is_empty() {
None
} else {
Some(self.filter.clone())
};
let result = module.apply(
snapshots,
filter_opt,
Some(Box::new(PyTorchToBurnAdapter)),
self.skip_enum_variants,
);
if self.validate && !result.errors.is_empty() {
return Err(PytorchStoreError::ValidationFailed(format!(
"Import errors:\n{}",
result
)));
}
if !self.allow_partial && !result.missing.is_empty() {
return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result)));
}
Ok(result)
}
fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap().get(name))
}
fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
self.ensure_snapshots_cache()?;
Ok(self.snapshots_cache.as_ref().unwrap())
}
fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
Ok(self.get_all_snapshots()?.keys().cloned().collect())
}
}
impl PytorchStore {
fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {
if self.snapshots_cache.is_some() {
return Ok(());
}
let reader = self.create_reader()?;
let mut snapshots: Vec<TensorSnapshot> = reader
.into_tensors()
.into_iter()
.map(|(key, mut snapshot)| {
let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
snapshot.path_stack = Some(path_parts);
snapshot.container_stack = None;
snapshot.tensor_id = None;
snapshot
})
.collect();
snapshots = self.apply_remapping(snapshots);
if self.map_indices_contiguous {
let (mapped, _) = map_indices_contiguous(snapshots);
snapshots = mapped;
}
let cache: BTreeMap<String, TensorSnapshot> =
snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
self.snapshots_cache = Some(cache);
Ok(())
}
}