burn_store/pytorch/
store.rs

1//! PyTorch store implementation for saving and loading models in PyTorch format.
2
3use crate::{
4    ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
5    TensorSnapshot, map_indices_contiguous,
6};
7
8use alloc::collections::BTreeMap;
9
10use alloc::format;
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13use burn_tensor::backend::Backend;
14use core::fmt;
15use std::path::PathBuf;
16
17use super::reader::{PytorchError as ReaderError, PytorchReader};
18
19/// Errors that can occur during PyTorch operations.
20#[derive(Debug)]
21pub enum PytorchStoreError {
22    /// Reader error.
23    Reader(ReaderError),
24
25    /// I/O error.
26    Io(std::io::Error),
27
28    /// Tensor not found.
29    TensorNotFound(String),
30
31    /// Validation failed.
32    ValidationFailed(String),
33
34    /// Other error.
35    Other(String),
36}
37
38impl fmt::Display for PytorchStoreError {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
42            Self::Io(e) => write!(f, "I/O error: {}", e),
43            Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
44            Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
45            Self::Other(msg) => write!(f, "{}", msg),
46        }
47    }
48}
49
50impl std::error::Error for PytorchStoreError {}
51
52impl From<ReaderError> for PytorchStoreError {
53    fn from(e: ReaderError) -> Self {
54        PytorchStoreError::Reader(e)
55    }
56}
57
58impl From<std::io::Error> for PytorchStoreError {
59    fn from(e: std::io::Error) -> Self {
60        PytorchStoreError::Io(e)
61    }
62}
63
64/// PyTorch store for file-based storage only.
65///
66/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)
67/// with automatic weight transformation using `PyTorchToBurnAdapter`.
68/// Linear weights are automatically transposed and normalization parameters
69/// are renamed (gamma -> weight, beta -> bias).
70///
71/// Note that saving to PyTorch format is not yet supported.
72pub struct PytorchStore {
73    pub(crate) path: PathBuf,
74    pub(crate) filter: PathFilter,
75    pub(crate) remapper: KeyRemapper,
76    pub(crate) validate: bool,
77    pub(crate) allow_partial: bool,
78    pub(crate) top_level_key: Option<String>,
79    pub(crate) skip_enum_variants: bool,
80    /// Enable contiguous mapping of layer indices (default: true)
81    pub(crate) map_indices_contiguous: bool,
82    /// Cached tensor snapshots (parsed once, reused)
83    snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
84}
85
86impl PytorchStore {
87    /// Create a store for loading from a PyTorch file.
88    ///
89    /// # Arguments
90    /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)
91    ///
92    /// # Example
93    /// ```rust,no_run
94    /// use burn_store::PytorchStore;
95    ///
96    /// let store = PytorchStore::from_file("model.pth");
97    /// ```
98    pub fn from_file(path: impl Into<PathBuf>) -> Self {
99        Self {
100            path: path.into(),
101            filter: PathFilter::new(),
102            remapper: KeyRemapper::new(),
103            validate: true,
104            allow_partial: false,
105            top_level_key: None,
106            // PyTorch models never include enum variant names in paths
107            skip_enum_variants: true,
108            // Enable contiguous index mapping by default for PyTorch files
109            // This handles nn.Sequential models with gaps in layer indices
110            map_indices_contiguous: true,
111            snapshots_cache: None,
112        }
113    }
114
115    /// Set a top-level key to extract tensors from.
116    ///
117    /// PyTorch files often contain nested dictionaries. Use this to extract
118    /// tensors from a specific top-level key like "state_dict" or "model_state_dict".
119    ///
120    /// # Example
121    /// ```rust,no_run
122    /// # use burn_store::PytorchStore;
123    /// let store = PytorchStore::from_file("checkpoint.pth")
124    ///     .with_top_level_key("model_state_dict");
125    /// ```
126    pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
127        self.top_level_key = Some(key.into());
128        self
129    }
130
131    /// Filter which tensors to load.
132    pub fn filter(mut self, filter: PathFilter) -> Self {
133        self.filter = filter;
134        self
135    }
136
137    /// Add a regex pattern to filter tensors.
138    ///
139    /// Multiple patterns can be added and they work with OR logic.
140    ///
141    /// # Example
142    /// ```rust,no_run
143    /// # use burn_store::PytorchStore;
144    /// let store = PytorchStore::from_file("model.pth")
145    ///     .with_regex(r"^encoder\..*")  // Match all encoder tensors
146    ///     .with_regex(r".*\.weight$");   // OR match any weight tensors
147    /// ```
148    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
149        self.filter = self.filter.with_regex(pattern);
150        self
151    }
152
153    /// Add multiple regex patterns to filter tensors.
154    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
155    where
156        I: IntoIterator<Item = S>,
157        S: AsRef<str>,
158    {
159        self.filter = self.filter.with_regexes(patterns);
160        self
161    }
162
163    /// Add an exact full path to match.
164    ///
165    /// # Example
166    /// ```rust,no_run
167    /// # use burn_store::PytorchStore;
168    /// let store = PytorchStore::from_file("model.pth")
169    ///     .with_full_path("encoder.layer1.weight")
170    ///     .with_full_path("decoder.output.bias");
171    /// ```
172    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
173        self.filter = self.filter.with_full_path(path);
174        self
175    }
176
177    /// Add multiple exact full paths to match.
178    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
179    where
180        I: IntoIterator<Item = S>,
181        S: Into<String>,
182    {
183        self.filter = self.filter.with_full_paths(paths);
184        self
185    }
186
187    /// Add a predicate function for custom filtering logic.
188    ///
189    /// The predicate receives the tensor path and container path.
190    ///
191    /// # Example
192    /// ```rust,no_run
193    /// # use burn_store::PytorchStore;
194    /// let store = PytorchStore::from_file("model.pth")
195    ///     .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
196    /// ```
197    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
198        self.filter = self.filter.with_predicate(predicate);
199        self
200    }
201
202    /// Add multiple predicate functions.
203    pub fn with_predicates<I>(mut self, predicates: I) -> Self
204    where
205        I: IntoIterator<Item = fn(&str, &str) -> bool>,
206    {
207        self.filter = self.filter.with_predicates(predicates);
208        self
209    }
210
211    /// Set the filter to match all paths (disables filtering).
212    pub fn match_all(mut self) -> Self {
213        self.filter = self.filter.match_all();
214        self
215    }
216
217    /// Remap tensor names during load.
218    pub fn remap(mut self, remapper: KeyRemapper) -> Self {
219        self.remapper = remapper;
220        self
221    }
222
223    /// Add a regex pattern to remap tensor names during load.
224    ///
225    /// # Example
226    /// ```rust,no_run
227    /// # use burn_store::PytorchStore;
228    /// let store = PytorchStore::from_file("model.pth")
229    ///     .with_key_remapping(r"^encoder\.", "transformer.encoder.")  // encoder.X -> transformer.encoder.X
230    ///     .with_key_remapping(r"\.gamma$", ".weight");               // X.gamma -> X.weight
231    /// ```
232    pub fn with_key_remapping(
233        mut self,
234        from_pattern: impl AsRef<str>,
235        to_pattern: impl Into<String>,
236    ) -> Self {
237        self.remapper = self
238            .remapper
239            .add_pattern(from_pattern, to_pattern)
240            .expect("Invalid regex pattern");
241        self
242    }
243
244    /// Set whether to validate tensors during loading (default: true).
245    pub fn validate(mut self, validate: bool) -> Self {
246        self.validate = validate;
247        self
248    }
249
250    /// Allow partial loading of tensors (continue even if some tensors are missing).
251    pub fn allow_partial(mut self, allow: bool) -> Self {
252        self.allow_partial = allow;
253        self
254    }
255
256    /// Skip enum variant names when matching tensor paths (default: true).
257    ///
258    /// When enabled, tensor paths from PyTorch that don't include enum variants
259    /// can be matched against Burn module paths that do include them.
260    /// For example, PyTorch path "feature.weight" can match Burn path "feature.BaseConv.weight".
261    ///
262    /// This defaults to `true` for PytorchStore since PyTorch models never include
263    /// enum variant names in their parameter paths.
264    ///
265    /// # Example
266    /// ```rust,no_run
267    /// # use burn_store::PytorchStore;
268    /// // Disable enum variant skipping (not typical)
269    /// let store = PytorchStore::from_file("model.pth")
270    ///     .skip_enum_variants(false);
271    /// ```
272    pub fn skip_enum_variants(mut self, skip: bool) -> Self {
273        self.skip_enum_variants = skip;
274        self
275    }
276
277    /// Enable or disable automatic contiguous mapping of layer indices (default: true).
278    ///
279    /// When enabled, non-contiguous numeric indices in tensor paths are renumbered
280    /// to be contiguous. This is useful when loading PyTorch models that have gaps
281    /// in layer numbering, such as when using `nn.Sequential` with mixed layer types
282    /// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).
283    ///
284    /// # Example
285    ///
286    /// With index mapping enabled (default):
287    /// - `fc.0.weight` → `fc.0.weight`
288    /// - `fc.2.weight` → `fc.1.weight` (gap filled)
289    /// - `fc.4.weight` → `fc.2.weight` (gap filled)
290    ///
291    /// # Arguments
292    ///
293    /// * `map` - `true` to enable contiguous index mapping, `false` to disable
294    ///
295    /// # Example
296    /// ```rust,no_run
297    /// # use burn_store::PytorchStore;
298    /// // Disable contiguous index mapping if your model already has contiguous indices
299    /// let store = PytorchStore::from_file("model.pth")
300    ///     .map_indices_contiguous(false);
301    /// ```
302    pub fn map_indices_contiguous(mut self, map: bool) -> Self {
303        self.map_indices_contiguous = map;
304        self
305    }
306
307    /// Apply remapping to tensor snapshots.
308    fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
309        if self.remapper.is_empty() {
310            return snapshots;
311        }
312
313        let (remapped, _) = self.remapper.remap(snapshots);
314        remapped
315    }
316
317    /// Create a PytorchReader for the configured path and options.
318    fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {
319        let reader = if let Some(ref key) = self.top_level_key {
320            PytorchReader::with_top_level_key(&self.path, key)?
321        } else {
322            PytorchReader::new(&self.path)?
323        };
324        Ok(reader)
325    }
326}
327
328impl ModuleStore for PytorchStore {
329    type Error = PytorchStoreError;
330
331    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
332        &mut self,
333        _module: &M,
334    ) -> Result<(), Self::Error> {
335        // Saving to PyTorch format is not yet supported
336        Err(PytorchStoreError::Other(
337            "Saving to PyTorch format is not yet supported. Use other formats for saving."
338                .to_string(),
339        ))
340    }
341
342    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
343        &mut self,
344        module: &mut M,
345    ) -> Result<ApplyResult, Self::Error> {
346        // Get snapshots from cache
347        let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
348
349        // Get filter (convert to Option for apply)
350        let filter_opt = if self.filter.is_empty() {
351            None
352        } else {
353            Some(self.filter.clone())
354        };
355
356        // Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)
357        // This adapter handles:
358        // - Transposing linear weights from PyTorch format to Burn format
359        // - Renaming normalization parameters (gamma -> weight, beta -> bias)
360        // Filter is applied here during apply, not during cache population
361        let result = module.apply(
362            snapshots,
363            filter_opt,
364            Some(Box::new(PyTorchToBurnAdapter)),
365            self.skip_enum_variants,
366        );
367
368        // Validate if needed
369        if self.validate && !result.errors.is_empty() {
370            return Err(PytorchStoreError::ValidationFailed(format!(
371                "Import errors:\n{}",
372                result
373            )));
374        }
375
376        if !self.allow_partial && !result.missing.is_empty() {
377            return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result)));
378        }
379
380        Ok(result)
381    }
382
383    fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
384        self.ensure_snapshots_cache()?;
385        Ok(self.snapshots_cache.as_ref().unwrap().get(name))
386    }
387
388    fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
389        self.ensure_snapshots_cache()?;
390        Ok(self.snapshots_cache.as_ref().unwrap())
391    }
392
393    fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
394        // Always use the cache to ensure remapping is applied consistently
395        Ok(self.get_all_snapshots()?.keys().cloned().collect())
396    }
397}
398
399impl PytorchStore {
400    /// Ensure the snapshots cache is populated
401    fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {
402        if self.snapshots_cache.is_some() {
403            return Ok(());
404        }
405
406        let reader = self.create_reader()?;
407
408        // Convert to tensor snapshots
409        let mut snapshots: Vec<TensorSnapshot> = reader
410            .into_tensors()
411            .into_iter()
412            .map(|(key, mut snapshot)| {
413                // Parse the key into path parts (split by '.')
414                let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
415
416                // Set the path stack from the key
417                snapshot.path_stack = Some(path_parts);
418                snapshot.container_stack = None;
419                snapshot.tensor_id = None;
420
421                snapshot
422            })
423            .collect();
424
425        // Apply remapping (but NOT filtering - that's done at apply time)
426        snapshots = self.apply_remapping(snapshots);
427
428        // Apply contiguous index mapping if enabled
429        // This must be done after remapping so that remapped paths are mapped
430        if self.map_indices_contiguous {
431            let (mapped, _) = map_indices_contiguous(snapshots);
432            snapshots = mapped;
433        }
434
435        // Build cache as BTreeMap
436        let cache: BTreeMap<String, TensorSnapshot> =
437            snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
438
439        self.snapshots_cache = Some(cache);
440        Ok(())
441    }
442}