burn_store/pytorch/store.rs
1//! PyTorch store implementation for saving and loading models in PyTorch format.
2
3use crate::{
4 ApplyResult, KeyRemapper, ModuleSnapshot, ModuleStore, PathFilter, PyTorchToBurnAdapter,
5 TensorSnapshot, map_indices_contiguous,
6};
7
8use alloc::collections::BTreeMap;
9
10use alloc::format;
11use alloc::string::{String, ToString};
12use alloc::vec::Vec;
13use burn_tensor::backend::Backend;
14use core::fmt;
15use std::path::PathBuf;
16
17use super::reader::{PytorchError as ReaderError, PytorchReader};
18
19/// Errors that can occur during PyTorch operations.
20#[derive(Debug)]
21pub enum PytorchStoreError {
22 /// Reader error.
23 Reader(ReaderError),
24
25 /// I/O error.
26 Io(std::io::Error),
27
28 /// Tensor not found.
29 TensorNotFound(String),
30
31 /// Validation failed.
32 ValidationFailed(String),
33
34 /// Other error.
35 Other(String),
36}
37
38impl fmt::Display for PytorchStoreError {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::Reader(e) => write!(f, "PyTorch reader error: {}", e),
42 Self::Io(e) => write!(f, "I/O error: {}", e),
43 Self::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
44 Self::ValidationFailed(msg) => write!(f, "Validation failed: {}", msg),
45 Self::Other(msg) => write!(f, "{}", msg),
46 }
47 }
48}
49
50impl std::error::Error for PytorchStoreError {}
51
52impl From<ReaderError> for PytorchStoreError {
53 fn from(e: ReaderError) -> Self {
54 PytorchStoreError::Reader(e)
55 }
56}
57
58impl From<std::io::Error> for PytorchStoreError {
59 fn from(e: std::io::Error) -> Self {
60 PytorchStoreError::Io(e)
61 }
62}
63
64/// PyTorch store for file-based storage only.
65///
66/// This store allows loading models from PyTorch checkpoint files (.pt/.pth)
67/// with automatic weight transformation using `PyTorchToBurnAdapter`.
68/// Linear weights are automatically transposed and normalization parameters
69/// are renamed (gamma -> weight, beta -> bias).
70///
71/// Note that saving to PyTorch format is not yet supported.
72pub struct PytorchStore {
73 pub(crate) path: PathBuf,
74 pub(crate) filter: PathFilter,
75 pub(crate) remapper: KeyRemapper,
76 pub(crate) validate: bool,
77 pub(crate) allow_partial: bool,
78 pub(crate) top_level_key: Option<String>,
79 pub(crate) skip_enum_variants: bool,
80 /// Enable contiguous mapping of layer indices (default: true)
81 pub(crate) map_indices_contiguous: bool,
82 /// Cached tensor snapshots (parsed once, reused)
83 snapshots_cache: Option<BTreeMap<String, TensorSnapshot>>,
84}
85
86impl PytorchStore {
87 /// Create a store for loading from a PyTorch file.
88 ///
89 /// # Arguments
90 /// * `path` - Path to the PyTorch checkpoint file (.pt or .pth)
91 ///
92 /// # Example
93 /// ```rust,no_run
94 /// use burn_store::PytorchStore;
95 ///
96 /// let store = PytorchStore::from_file("model.pth");
97 /// ```
98 pub fn from_file(path: impl Into<PathBuf>) -> Self {
99 Self {
100 path: path.into(),
101 filter: PathFilter::new(),
102 remapper: KeyRemapper::new(),
103 validate: true,
104 allow_partial: false,
105 top_level_key: None,
106 // PyTorch models never include enum variant names in paths
107 skip_enum_variants: true,
108 // Enable contiguous index mapping by default for PyTorch files
109 // This handles nn.Sequential models with gaps in layer indices
110 map_indices_contiguous: true,
111 snapshots_cache: None,
112 }
113 }
114
115 /// Set a top-level key to extract tensors from.
116 ///
117 /// PyTorch files often contain nested dictionaries. Use this to extract
118 /// tensors from a specific top-level key like "state_dict" or "model_state_dict".
119 ///
120 /// # Example
121 /// ```rust,no_run
122 /// # use burn_store::PytorchStore;
123 /// let store = PytorchStore::from_file("checkpoint.pth")
124 /// .with_top_level_key("model_state_dict");
125 /// ```
126 pub fn with_top_level_key(mut self, key: impl Into<String>) -> Self {
127 self.top_level_key = Some(key.into());
128 self
129 }
130
131 /// Filter which tensors to load.
132 pub fn filter(mut self, filter: PathFilter) -> Self {
133 self.filter = filter;
134 self
135 }
136
137 /// Add a regex pattern to filter tensors.
138 ///
139 /// Multiple patterns can be added and they work with OR logic.
140 ///
141 /// # Example
142 /// ```rust,no_run
143 /// # use burn_store::PytorchStore;
144 /// let store = PytorchStore::from_file("model.pth")
145 /// .with_regex(r"^encoder\..*") // Match all encoder tensors
146 /// .with_regex(r".*\.weight$"); // OR match any weight tensors
147 /// ```
148 pub fn with_regex<S: AsRef<str>>(mut self, pattern: S) -> Self {
149 self.filter = self.filter.with_regex(pattern);
150 self
151 }
152
153 /// Add multiple regex patterns to filter tensors.
154 pub fn with_regexes<I, S>(mut self, patterns: I) -> Self
155 where
156 I: IntoIterator<Item = S>,
157 S: AsRef<str>,
158 {
159 self.filter = self.filter.with_regexes(patterns);
160 self
161 }
162
163 /// Add an exact full path to match.
164 ///
165 /// # Example
166 /// ```rust,no_run
167 /// # use burn_store::PytorchStore;
168 /// let store = PytorchStore::from_file("model.pth")
169 /// .with_full_path("encoder.layer1.weight")
170 /// .with_full_path("decoder.output.bias");
171 /// ```
172 pub fn with_full_path<S: Into<String>>(mut self, path: S) -> Self {
173 self.filter = self.filter.with_full_path(path);
174 self
175 }
176
177 /// Add multiple exact full paths to match.
178 pub fn with_full_paths<I, S>(mut self, paths: I) -> Self
179 where
180 I: IntoIterator<Item = S>,
181 S: Into<String>,
182 {
183 self.filter = self.filter.with_full_paths(paths);
184 self
185 }
186
187 /// Add a predicate function for custom filtering logic.
188 ///
189 /// The predicate receives the tensor path and container path.
190 ///
191 /// # Example
192 /// ```rust,no_run
193 /// # use burn_store::PytorchStore;
194 /// let store = PytorchStore::from_file("model.pth")
195 /// .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
196 /// ```
197 pub fn with_predicate(mut self, predicate: fn(&str, &str) -> bool) -> Self {
198 self.filter = self.filter.with_predicate(predicate);
199 self
200 }
201
202 /// Add multiple predicate functions.
203 pub fn with_predicates<I>(mut self, predicates: I) -> Self
204 where
205 I: IntoIterator<Item = fn(&str, &str) -> bool>,
206 {
207 self.filter = self.filter.with_predicates(predicates);
208 self
209 }
210
211 /// Set the filter to match all paths (disables filtering).
212 pub fn match_all(mut self) -> Self {
213 self.filter = self.filter.match_all();
214 self
215 }
216
217 /// Remap tensor names during load.
218 pub fn remap(mut self, remapper: KeyRemapper) -> Self {
219 self.remapper = remapper;
220 self
221 }
222
223 /// Add a regex pattern to remap tensor names during load.
224 ///
225 /// # Example
226 /// ```rust,no_run
227 /// # use burn_store::PytorchStore;
228 /// let store = PytorchStore::from_file("model.pth")
229 /// .with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X
230 /// .with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weight
231 /// ```
232 pub fn with_key_remapping(
233 mut self,
234 from_pattern: impl AsRef<str>,
235 to_pattern: impl Into<String>,
236 ) -> Self {
237 self.remapper = self
238 .remapper
239 .add_pattern(from_pattern, to_pattern)
240 .expect("Invalid regex pattern");
241 self
242 }
243
244 /// Set whether to validate tensors during loading (default: true).
245 pub fn validate(mut self, validate: bool) -> Self {
246 self.validate = validate;
247 self
248 }
249
250 /// Allow partial loading of tensors (continue even if some tensors are missing).
251 pub fn allow_partial(mut self, allow: bool) -> Self {
252 self.allow_partial = allow;
253 self
254 }
255
256 /// Skip enum variant names when matching tensor paths (default: true).
257 ///
258 /// When enabled, tensor paths from PyTorch that don't include enum variants
259 /// can be matched against Burn module paths that do include them.
260 /// For example, PyTorch path "feature.weight" can match Burn path "feature.BaseConv.weight".
261 ///
262 /// This defaults to `true` for PytorchStore since PyTorch models never include
263 /// enum variant names in their parameter paths.
264 ///
265 /// # Example
266 /// ```rust,no_run
267 /// # use burn_store::PytorchStore;
268 /// // Disable enum variant skipping (not typical)
269 /// let store = PytorchStore::from_file("model.pth")
270 /// .skip_enum_variants(false);
271 /// ```
272 pub fn skip_enum_variants(mut self, skip: bool) -> Self {
273 self.skip_enum_variants = skip;
274 self
275 }
276
277 /// Enable or disable automatic contiguous mapping of layer indices (default: true).
278 ///
279 /// When enabled, non-contiguous numeric indices in tensor paths are renumbered
280 /// to be contiguous. This is useful when loading PyTorch models that have gaps
281 /// in layer numbering, such as when using `nn.Sequential` with mixed layer types
282 /// (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).
283 ///
284 /// # Example
285 ///
286 /// With index mapping enabled (default):
287 /// - `fc.0.weight` → `fc.0.weight`
288 /// - `fc.2.weight` → `fc.1.weight` (gap filled)
289 /// - `fc.4.weight` → `fc.2.weight` (gap filled)
290 ///
291 /// # Arguments
292 ///
293 /// * `map` - `true` to enable contiguous index mapping, `false` to disable
294 ///
295 /// # Example
296 /// ```rust,no_run
297 /// # use burn_store::PytorchStore;
298 /// // Disable contiguous index mapping if your model already has contiguous indices
299 /// let store = PytorchStore::from_file("model.pth")
300 /// .map_indices_contiguous(false);
301 /// ```
302 pub fn map_indices_contiguous(mut self, map: bool) -> Self {
303 self.map_indices_contiguous = map;
304 self
305 }
306
307 /// Apply remapping to tensor snapshots.
308 fn apply_remapping(&self, snapshots: Vec<TensorSnapshot>) -> Vec<TensorSnapshot> {
309 if self.remapper.is_empty() {
310 return snapshots;
311 }
312
313 let (remapped, _) = self.remapper.remap(snapshots);
314 remapped
315 }
316
317 /// Create a PytorchReader for the configured path and options.
318 fn create_reader(&self) -> Result<PytorchReader, PytorchStoreError> {
319 let reader = if let Some(ref key) = self.top_level_key {
320 PytorchReader::with_top_level_key(&self.path, key)?
321 } else {
322 PytorchReader::new(&self.path)?
323 };
324 Ok(reader)
325 }
326}
327
328impl ModuleStore for PytorchStore {
329 type Error = PytorchStoreError;
330
331 fn collect_from<B: Backend, M: ModuleSnapshot<B>>(
332 &mut self,
333 _module: &M,
334 ) -> Result<(), Self::Error> {
335 // Saving to PyTorch format is not yet supported
336 Err(PytorchStoreError::Other(
337 "Saving to PyTorch format is not yet supported. Use other formats for saving."
338 .to_string(),
339 ))
340 }
341
342 fn apply_to<B: Backend, M: ModuleSnapshot<B>>(
343 &mut self,
344 module: &mut M,
345 ) -> Result<ApplyResult, Self::Error> {
346 // Get snapshots from cache
347 let snapshots: Vec<TensorSnapshot> = self.get_all_snapshots()?.values().cloned().collect();
348
349 // Get filter (convert to Option for apply)
350 let filter_opt = if self.filter.is_empty() {
351 None
352 } else {
353 Some(self.filter.clone())
354 };
355
356 // Apply to module with PyTorchToBurnAdapter (always used for PyTorch files)
357 // This adapter handles:
358 // - Transposing linear weights from PyTorch format to Burn format
359 // - Renaming normalization parameters (gamma -> weight, beta -> bias)
360 // Filter is applied here during apply, not during cache population
361 let result = module.apply(
362 snapshots,
363 filter_opt,
364 Some(Box::new(PyTorchToBurnAdapter)),
365 self.skip_enum_variants,
366 );
367
368 // Validate if needed
369 if self.validate && !result.errors.is_empty() {
370 return Err(PytorchStoreError::ValidationFailed(format!(
371 "Import errors:\n{}",
372 result
373 )));
374 }
375
376 if !self.allow_partial && !result.missing.is_empty() {
377 return Err(PytorchStoreError::TensorNotFound(format!("\n{}", result)));
378 }
379
380 Ok(result)
381 }
382
383 fn get_snapshot(&mut self, name: &str) -> Result<Option<&TensorSnapshot>, Self::Error> {
384 self.ensure_snapshots_cache()?;
385 Ok(self.snapshots_cache.as_ref().unwrap().get(name))
386 }
387
388 fn get_all_snapshots(&mut self) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error> {
389 self.ensure_snapshots_cache()?;
390 Ok(self.snapshots_cache.as_ref().unwrap())
391 }
392
393 fn keys(&mut self) -> Result<Vec<String>, Self::Error> {
394 // Always use the cache to ensure remapping is applied consistently
395 Ok(self.get_all_snapshots()?.keys().cloned().collect())
396 }
397}
398
399impl PytorchStore {
400 /// Ensure the snapshots cache is populated
401 fn ensure_snapshots_cache(&mut self) -> Result<(), PytorchStoreError> {
402 if self.snapshots_cache.is_some() {
403 return Ok(());
404 }
405
406 let reader = self.create_reader()?;
407
408 // Convert to tensor snapshots
409 let mut snapshots: Vec<TensorSnapshot> = reader
410 .into_tensors()
411 .into_iter()
412 .map(|(key, mut snapshot)| {
413 // Parse the key into path parts (split by '.')
414 let path_parts: Vec<String> = key.split('.').map(|s| s.to_string()).collect();
415
416 // Set the path stack from the key
417 snapshot.path_stack = Some(path_parts);
418 snapshot.container_stack = None;
419 snapshot.tensor_id = None;
420
421 snapshot
422 })
423 .collect();
424
425 // Apply remapping (but NOT filtering - that's done at apply time)
426 snapshots = self.apply_remapping(snapshots);
427
428 // Apply contiguous index mapping if enabled
429 // This must be done after remapping so that remapped paths are mapped
430 if self.map_indices_contiguous {
431 let (mapped, _) = map_indices_contiguous(snapshots);
432 snapshots = mapped;
433 }
434
435 // Build cache as BTreeMap
436 let cache: BTreeMap<String, TensorSnapshot> =
437 snapshots.into_iter().map(|s| (s.full_path(), s)).collect();
438
439 self.snapshots_cache = Some(cache);
440 Ok(())
441 }
442}