1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
//! PyTorch format support for burn-store.
//!
//! This module provides comprehensive support for loading PyTorch model files (.pth, .pt)
//! into Burn, with automatic weight transformation and flexible configuration options.
//!
//! ## Features
//!
//! - **Direct .pth/.pt file loading**: Load PyTorch checkpoint and state dict files
//! - **Automatic weight transformation**: `PyTorchToBurnAdapter` is applied by default:
//! - Linear layer weights are automatically transposed
//! - Normalization parameters are renamed (gamma → weight, beta → bias)
//! - Conv2d weights maintain their format
//! - **Flexible filtering**: Load only specific layers or parameters
//! - **Key remapping**: Rename tensors during loading to match your model structure
//! - **Partial loading**: Continue even when some tensors are missing
//!
//! ## Example
//!
//! ```rust,ignore
//! use burn_store::PytorchStore;
//!
//! // Load a PyTorch model (PyTorchToBurnAdapter is applied automatically)
//! let mut store = PytorchStore::from_file("model.pth")
//! .with_top_level_key("state_dict") // Access nested state dict
//! .with_regex(r"^encoder\..*") // Only load encoder layers
//! .with_key_remapping(r"^fc\.", "linear.") // Rename fc -> linear
//! .allow_partial(true); // Skip missing tensors
//!
//! let mut model = MyModel::new(&device);
//! let result = model.load_from(&mut store)?;
//!
//! println!("Loaded {} tensors", result.applied.len());
//! if !result.missing.is_empty() {
//! println!("Missing tensors: {:?}", result.missing);
//! }
//! ```
// Main public interface
pub use ;
pub use ;