pub struct PytorchStore { /* private fields */ }Expand description
PyTorch store for file-based storage only.
This store allows loading models from PyTorch checkpoint files (.pt/.pth)
with automatic weight transformation using PyTorchToBurnAdapter.
Linear weights are automatically transposed and normalization parameters
are renamed (gamma -> weight, beta -> bias).
Note that saving to PyTorch format is not yet supported.
Implementations§
Source§impl PytorchStore
impl PytorchStore
Sourcepub fn with_top_level_key(self, key: impl Into<String>) -> Self
pub fn with_top_level_key(self, key: impl Into<String>) -> Self
Set a top-level key to extract tensors from.
PyTorch files often contain nested dictionaries. Use this to extract tensors from a specific top-level key like “state_dict” or “model_state_dict”.
§Example
let store = PytorchStore::from_file("checkpoint.pth")
.with_top_level_key("model_state_dict");Sourcepub fn filter(self, filter: PathFilter) -> Self
pub fn filter(self, filter: PathFilter) -> Self
Filter which tensors to load.
Sourcepub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self
pub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self
Add a regex pattern to filter tensors.
Multiple patterns can be added and they work with OR logic.
§Example
let store = PytorchStore::from_file("model.pth")
.with_regex(r"^encoder\..*") // Match all encoder tensors
.with_regex(r".*\.weight$"); // OR match any weight tensorsSourcepub fn with_regexes<I, S>(self, patterns: I) -> Self
pub fn with_regexes<I, S>(self, patterns: I) -> Self
Add multiple regex patterns to filter tensors.
Sourcepub fn with_full_path<S: Into<String>>(self, path: S) -> Self
pub fn with_full_path<S: Into<String>>(self, path: S) -> Self
Add an exact full path to match.
§Example
let store = PytorchStore::from_file("model.pth")
.with_full_path("encoder.layer1.weight")
.with_full_path("decoder.output.bias");Sourcepub fn with_full_paths<I, S>(self, paths: I) -> Self
pub fn with_full_paths<I, S>(self, paths: I) -> Self
Add multiple exact full paths to match.
Sourcepub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self
pub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self
Add a predicate function for custom filtering logic.
The predicate receives the tensor path and container path.
§Example
let store = PytorchStore::from_file("model.pth")
.with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));Sourcepub fn with_predicates<I>(self, predicates: I) -> Self
pub fn with_predicates<I>(self, predicates: I) -> Self
Add multiple predicate functions.
Sourcepub fn remap(self, remapper: KeyRemapper) -> Self
pub fn remap(self, remapper: KeyRemapper) -> Self
Remap tensor names during load.
Sourcepub fn with_key_remapping(
self,
from_pattern: impl AsRef<str>,
to_pattern: impl Into<String>,
) -> Self
pub fn with_key_remapping( self, from_pattern: impl AsRef<str>, to_pattern: impl Into<String>, ) -> Self
Add a regex pattern to remap tensor names during load.
§Example
let store = PytorchStore::from_file("model.pth")
.with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X
.with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weightSourcepub fn validate(self, validate: bool) -> Self
pub fn validate(self, validate: bool) -> Self
Set whether to validate tensors during loading (default: true).
Sourcepub fn allow_partial(self, allow: bool) -> Self
pub fn allow_partial(self, allow: bool) -> Self
Allow partial loading of tensors (continue even if some tensors are missing).
Sourcepub fn skip_enum_variants(self, skip: bool) -> Self
pub fn skip_enum_variants(self, skip: bool) -> Self
Skip enum variant names when matching tensor paths (default: true).
When enabled, tensor paths from PyTorch that don’t include enum variants can be matched against Burn module paths that do include them. For example, PyTorch path “feature.weight” can match Burn path “feature.BaseConv.weight”.
This defaults to true for PytorchStore since PyTorch models never include
enum variant names in their parameter paths.
§Example
// Disable enum variant skipping (not typical)
let store = PytorchStore::from_file("model.pth")
.skip_enum_variants(false);Sourcepub fn map_indices_contiguous(self, map: bool) -> Self
pub fn map_indices_contiguous(self, map: bool) -> Self
Enable or disable automatic contiguous mapping of layer indices (default: true).
When enabled, non-contiguous numeric indices in tensor paths are renumbered
to be contiguous. This is useful when loading PyTorch models that have gaps
in layer numbering, such as when using nn.Sequential with mixed layer types
(e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).
§Example
With index mapping enabled (default):
fc.0.weight→fc.0.weightfc.2.weight→fc.1.weight(gap filled)fc.4.weight→fc.2.weight(gap filled)
§Arguments
map-trueto enable contiguous index mapping,falseto disable
§Example
// Disable contiguous index mapping if your model already has contiguous indices
let store = PytorchStore::from_file("model.pth")
.map_indices_contiguous(false);