pub struct PytorchStore { /* private fields */ }Expand description
PyTorch store for file-based storage only.
This store allows loading models from PyTorch checkpoint files (.pt/.pth)
with automatic weight transformation using PyTorchToBurnAdapter.
Linear weights are automatically transposed and normalization parameters
are renamed (gamma -> weight, beta -> bias).
Note that saving to PyTorch format is not yet supported.
Implementations§
Source§impl PytorchStore
impl PytorchStore
Sourcepub fn with_top_level_key(self, key: impl Into<String>) -> Self
pub fn with_top_level_key(self, key: impl Into<String>) -> Self
Set a top-level key to extract tensors from.
PyTorch files often contain nested dictionaries. Use this to extract tensors from a specific top-level key like “state_dict” or “model_state_dict”.
§Example
let store = PytorchStore::from_file("checkpoint.pth")
.with_top_level_key("model_state_dict");Sourcepub fn filter(self, filter: PathFilter) -> Self
pub fn filter(self, filter: PathFilter) -> Self
Filter which tensors to load.
Sourcepub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self
pub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self
Add a regex pattern to filter tensors.
Multiple patterns can be added and they work with OR logic.
§Example
let store = PytorchStore::from_file("model.pth")
.with_regex(r"^encoder\..*") // Match all encoder tensors
.with_regex(r".*\.weight$"); // OR match any weight tensorsSourcepub fn with_regexes<I, S>(self, patterns: I) -> Self
pub fn with_regexes<I, S>(self, patterns: I) -> Self
Add multiple regex patterns to filter tensors.
Sourcepub fn with_full_path<S: Into<String>>(self, path: S) -> Self
pub fn with_full_path<S: Into<String>>(self, path: S) -> Self
Add an exact full path to match.
§Example
let store = PytorchStore::from_file("model.pth")
.with_full_path("encoder.layer1.weight")
.with_full_path("decoder.output.bias");Sourcepub fn with_full_paths<I, S>(self, paths: I) -> Self
pub fn with_full_paths<I, S>(self, paths: I) -> Self
Add multiple exact full paths to match.
Sourcepub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self
pub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self
Add a predicate function for custom filtering logic.
The predicate receives the tensor path and container path.
§Example
let store = PytorchStore::from_file("model.pth")
.with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));Sourcepub fn with_predicates<I>(self, predicates: I) -> Self
pub fn with_predicates<I>(self, predicates: I) -> Self
Add multiple predicate functions.
Sourcepub fn remap(self, remapper: KeyRemapper) -> Self
pub fn remap(self, remapper: KeyRemapper) -> Self
Remap tensor names during load.
Sourcepub fn with_key_remapping(
self,
from_pattern: impl AsRef<str>,
to_pattern: impl Into<String>,
) -> Self
pub fn with_key_remapping( self, from_pattern: impl AsRef<str>, to_pattern: impl Into<String>, ) -> Self
Add a regex pattern to remap tensor names during load.
§Example
let store = PytorchStore::from_file("model.pth")
.with_key_remapping(r"^encoder\.", "transformer.encoder.") // encoder.X -> transformer.encoder.X
.with_key_remapping(r"\.gamma$", ".weight"); // X.gamma -> X.weightSourcepub fn validate(self, validate: bool) -> Self
pub fn validate(self, validate: bool) -> Self
Set whether to validate tensors during loading (default: true).
Sourcepub fn allow_partial(self, allow: bool) -> Self
pub fn allow_partial(self, allow: bool) -> Self
Allow partial loading of tensors (continue even if some tensors are missing).