burn_store/pytorch/
store.rs

1//! PyTorch store implementation for saving and loading models in PyTorch format.
2
3use crate::{
4    ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
5    TensorSnapshot,
6};
7
8use alloc::format;
9use alloc::string::{String, ToString};
10use alloc::vec::Vec;
11use burn_tensor::backend::Backend;
12use core::fmt;
13use std::path::PathBuf;
14
15use super::reader::{PytorchError as ReaderError, PytorchReader};
16
17/// Errors that can occur during PyTorch operations.
18#[derive(Debug)]
19pub enum PytorchStoreError {
20    /// Reader error.
21    Reader(ReaderError),
22
23    /// I/O error.
24    Io(std::io::Error),
25
26    /// Tensor not found.
27    TensorNotFound(String),
28
29    /// Validation failed.
30    ValidationFailed(String),
31
32    /// Other error.
33    Other(String),
34}
35
36impl fmt::Display for PytorchStoreError {
37    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
38        match self {
39            Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
40            Self::Io(e) => write!(f, "I/O error: {}", e),
41            Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
42            Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
43            Self::Other(msg) => write!(f, "{}", msg),
44        }
45    }
46}
47
48impl std::error::Error for PytorchStoreError {}
49
50impl From<ReaderError> for PytorchStoreError {
51    fn from(e: ReaderError) -> Self {
52        PytorchStoreError::Reader(e)
53    }
54}
55
56impl From<std::io::Error> for PytorchStoreError {
57    fn from(e: std::io::Error) -> Self {
58        PytorchStoreError::Io(e)
59    }
60}
61
62/// PyTorch store for file-based storage only.
63///
64/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)
65/// with automatic weight transformation using `PyTorchToBurnAdapter`.
66/// Linear weights are automatically transposed and normalization parameters
67/// are renamed (gamma -> weight, beta -> bias).
68///
69/// Note that saving to PyTorch format is not yet supported.
70pub struct PytorchStore {
71    pub(crate) path: PathBuf,
72    pub(crate) filter: PathFilter,
73    pub(crate) remapper: KeyRemapper,
74    pub(crate) validate: bool,
75    pub(crate) allow_partial: bool,
76    pub(crate) top_level_key: Option<String>,
77}
78
79impl PytorchStore {
80    /// Create a store for loading from a PyTorch file.
81    ///
82    /// # Arguments
83    /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)
84    ///
85    /// # Example
86    /// ```rust,no_run
87    /// use burn_store::PytorchStore;
88    ///
89    /// let store = PytorchStore::from_file("model.pth");
90    /// ```
91    pub fn from_file(path: impl Into<PathBuf>) -> Self {
92        Self {
93            path: path.into(),
94            filter: PathFilter::new(),
95            remapper: KeyRemapper::new(),
96            validate: true,
97            allow_partial: false,
98            top_level_key: None,
99        }
100    }
101
102    /// Set a top-level key to extract tensors from.
103    ///
104    /// PyTorch files often contain nested dictionaries. Use this to extract
105    /// tensors from a specific top-level key like "state_dict" or "model_state_dict".
106    ///
107    /// # Example
108    /// ```rust,no_run
109    /// # use burn_store::PytorchStore;
110    /// let store = PytorchStore::from_file("checkpoint.pth")
111    ///     .with_top_level_key("model_state_dict");
112    /// ```
113    pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
114        self.top_level_key = Some(key.into());
115        self
116    }
117
118    /// Filter which tensors to load.
119    pub fn filter(mut self, filter: PathFilter) -> Self {
120        self.filter = filter;
121        self
122    }
123
124    /// Add a regex pattern to filter tensors.
125    ///
126    /// Multiple patterns can be added and they work with OR logic.
127    ///
128    /// # Example
129    /// ```rust,no_run
130    /// # use burn_store::PytorchStore;
131    /// let store = PytorchStore::from_file("model.pth")
132    ///     .with_regex(r"^encoder\..*")  // Match all encoder tensors
133    ///     .with_regex(r".*\.weight$");   // OR match any weight tensors
134    /// ```
135    pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
136        self.filter = self.filter.with_regex(pattern);
137        self
138    }
139
140    /// Add multiple regex patterns to filter tensors.
141    pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
142    where
143        I: IntoIterator<Item = S>,
144        S: AsRef<str>,
145    {
146        self.filter = self.filter.with_regexes(patterns);
147        self
148    }
149
150    /// Add an exact full path to match.
151    ///
152    /// # Example
153    /// ```rust,no_run
154    /// # use burn_store::PytorchStore;
155    /// let store = PytorchStore::from_file("model.pth")
156    ///     .with_full_path("encoder.layer1.weight")
157    ///     .with_full_path("decoder.output.bias");
158    /// ```
159    pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
160        self.filter = self.filter.with_full_path(path);
161        self
162    }
163
164    /// Add multiple exact full paths to match.
165    pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
166    where
167        I: IntoIterator<Item = S>,
168        S: Into<String>,
169    {
170        self.filter = self.filter.with_full_paths(paths);
171        self
172    }
173
174    /// Add a predicate function for custom filtering logic.
175    ///
176    /// The predicate receives the tensor path and container path.
177    ///
178    /// # Example
179    /// ```rust,no_run
180    /// # use burn_store::PytorchStore;
181    /// let store = PytorchStore::from_file("model.pth")
182    ///     .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
183    /// ```
184    pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
185        self.filter = self.filter.with_predicate(predicate);
186        self
187    }
188
189    /// Add multiple predicate functions.
190    pub fn with_predicates<I>(mut self, predicates: I) -> Self
191    where
192        I: IntoIterator<Item = fn(&str, &str) -> bool>,
193    {
194        self.filter = self.filter.with_predicates(predicates);
195        self
196    }
197
198    /// Set the filter to match all paths (disables filtering).
199    pub fn match_all(mut self) -> Self {
200        self.filter = self.filter.match_all();
201        self
202    }
203
204    /// Remap tensor names during load.
205    pub fn remap(mut self, remapper: KeyRemapper) -> Self {
206        self.remapper = remapper;
207        self
208    }
209
210    /// Add a regex pattern to remap tensor names during load.
211    ///
212    /// # Example
213    /// ```rust,no_run
214    /// # use burn_store::PytorchStore;
215    /// let store = PytorchStore::from_file("model.pth")
216    ///     .with_key_remapping(r"^encoder\.", "transformer.encoder.")  // encoder.X -> transformer.encoder.X
217    ///     .with_key_remapping(r"\.gamma$", ".weight");               // X.gamma -> X.weight
218    /// ```
219    pub fn with_key_remapping(
220        mut self,
221        from_pattern: impl AsRef<str>,
222        to_pattern: impl Into<String>,
223    ) -> Self {
224        self.remapper = self
225            .remapper
226            .add_pattern(from_pattern, to_pattern)
227            .expect("Invalid regex pattern");
228        self
229    }
230
231    /// Set whether to validate tensors during loading (default: true).
232    pub fn validate(mut self, validate: bool) -> Self {
233        self.validate = validate;
234        self
235    }
236
237    /// Allow partial loading of tensors (continue even if some tensors are missing).
238    pub fn allow_partial(mut self, allow: bool) -> Self {
239        self.allow_partial = allow;
240        self
241    }
242
243    /// Apply filter to tensor snapshots.
244    fn apply_filter(&self, mut snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
245        if self.filter.is_empty() {
246            return snapshots;
247        }
248
249        snapshots.retain(|snapshot| {
250            let path = snapshot.full_path();
251            self.filter.matches(&path)
252        });
253
254        snapshots
255    }
256
257    /// Apply remapping to tensor snapshots.
258    fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
259        if self.remapper.is_empty() {
260            return snapshots;
261        }
262
263        let (remapped, _) = self.remapper.remap(snapshots);
264        remapped
265    }
266}
267
268impl ModuleStore for PytorchStore {
269    type Error = PytorchStoreError;
270
271    fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
272        &mut self,
273        _module: &M,
274    ) -> Result<(), Self::Error> {
275        // Saving to PyTorch format is not yet supported
276        Err(PytorchStoreError::Other(
277            "Saving to PyTorch format is not yet supported. Use other formats for saving."
278                .to_string(),
279        ))
280    }
281
282    fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
283        &mut self,
284        module: &mut M,
285    ) -> Result<ApplyResult, Self::Error> {
286        // Load tensors from PyTorch file
287        let reader = if let Some(ref key) = self.top_level_key {
288            PytorchReader::with_top_level_key(&self.path, key)?
289        } else {
290            PytorchReader::new(&self.path)?
291        };
292
293        // Convert to tensor snapshots
294        let mut snapshots: Vec<TensorSnapshot> = reader
295            .into_tensors()
296            .into_iter()
297            .map(|(key, mut snapshot)| {
298                // Parse the key into path parts (split by '.')
299                let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
300
301                // Set the path stack from the key
302                // Note: container_stack should NOT be set here - it will be managed by the module during apply
303                snapshot.path_stack = Some(path_parts);
304                snapshot.container_stack = None;
305                snapshot.tensor_id = None;
306
307                snapshot
308            })
309            .collect();
310
311        // Apply filtering
312        snapshots = self.apply_filter(snapshots);
313
314        // Apply remapping
315        snapshots = self.apply_remapping(snapshots);
316
317        // Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)
318        // This adapter handles:
319        // - Transposing linear weights from PyTorch format to Burn format
320        // - Renaming normalization parameters (gamma -> weight, beta -> bias)
321        let result = module.apply(snapshots, None, Some(Box::new(PyTorchToBurnAdapter)));
322
323        // Validate if needed
324        if self.validate && !result.errors.is_empty() {
325            return Err(PytorchStoreError::ValidationFailed(format!(
326                "Import errors: {:?}",
327                result.errors
328            )));
329        }
330
331        if !self.allow_partial && !result.missing.is_empty() {
332            return Err(PytorchStoreError::TensorNotFound(format!(
333                "Missing tensors: {:?}",
334                result.missing
335            )));
336        }
337
338        Ok(result)
339    }
340}