PytorchStore

Struct PytorchStore 

Source
pub struct PytorchStore { /* private fields */ }
Expand description

PyTorch store for file-based storage only.

This store allows loading models from PyTorch checkpoint files (.pt/.pth) with automatic weight transformation using PyTorchToBurnAdapter. Linear weights are automatically transposed and normalization parameters are renamed (gamma -> weight, beta -> bias).

Note that saving to PyTorch format is not yet supported.

Implementations§

Source§

impl PytorchStore

Source

pub fn from_file(path: impl Into<PathBuf>) -> Self

Create a store for loading from a PyTorch file.

§Arguments
  • path - Path to the PyTorch checkpoint file (.pt or .pth)
§Example
use burn_store::PytorchStore;

let store = PytorchStore::from_file("model.pth");
Source

pub fn with_top_level_key(self, key: impl Into<String>) -> Self

Set a top-level key to extract tensors from.

PyTorch files often contain nested dictionaries. Use this to extract tensors from a specific top-level key like “state_dict” or “model_state_dict”.

§Example
let store = PytorchStore::from_file("checkpoint.pth")
    .with_top_level_key("model_state_dict");
Source

pub fn filter(self, filter: PathFilter) -> Self

Filter which tensors to load.

Source

pub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self

Add a regex pattern to filter tensors.

Multiple patterns can be added and they work with OR logic.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_regex(r"^encoder\..*")  // Match all encoder tensors
    .with_regex(r".*\.weight$");   // OR match any weight tensors
Source

pub fn with_regexes<I, S>(self, patterns: I) -> Self
where I: IntoIterator<Item = S>, S: AsRef<str>,

Add multiple regex patterns to filter tensors.

Source

pub fn with_full_path<S: Into<String>>(self, path: S) -> Self

Add an exact full path to match.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_full_path("encoder.layer1.weight")
    .with_full_path("decoder.output.bias");
Source

pub fn with_full_paths<I, S>(self, paths: I) -> Self
where I: IntoIterator<Item = S>, S: Into<String>,

Add multiple exact full paths to match.

Source

pub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self

Add a predicate function for custom filtering logic.

The predicate receives the tensor path and container path.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
Source

pub fn with_predicates<I>(self, predicates: I) -> Self
where I: IntoIterator<Item = fn(&str, &str) -> bool>,

Add multiple predicate functions.

Source

pub fn match_all(self) -> Self

Set the filter to match all paths (disables filtering).

Source

pub fn remap(self, remapper: KeyRemapper) -> Self

Remap tensor names during load.

Source

pub fn with_key_remapping( self, from_pattern: impl AsRef<str>, to_pattern: impl Into<String>, ) -> Self

Add a regex pattern to remap tensor names during load.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_key_remapping(r"^encoder\.", "transformer.encoder.")  // encoder.X -> transformer.encoder.X
    .with_key_remapping(r"\.gamma$", ".weight");               // X.gamma -> X.weight
Source

pub fn validate(self, validate: bool) -> Self

Set whether to validate tensors during loading (default: true).

Source

pub fn allow_partial(self, allow: bool) -> Self

Allow partial loading of tensors (continue even if some tensors are missing).

Source

pub fn skip_enum_variants(self, skip: bool) -> Self

Skip enum variant names when matching tensor paths (default: true).

When enabled, tensor paths from PyTorch that don’t include enum variants can be matched against Burn module paths that do include them. For example, PyTorch path “feature.weight” can match Burn path “feature.BaseConv.weight”.

This defaults to true for PytorchStore since PyTorch models never include enum variant names in their parameter paths.

§Example
// Disable enum variant skipping (not typical)
let store = PytorchStore::from_file("model.pth")
    .skip_enum_variants(false);
Source

pub fn map_indices_contiguous(self, map: bool) -> Self

Enable or disable automatic contiguous mapping of layer indices (default: true).

When enabled, non-contiguous numeric indices in tensor paths are renumbered to be contiguous. This is useful when loading PyTorch models that have gaps in layer numbering, such as when using nn.Sequential with mixed layer types (e.g., Conv2d layers at indices 0, 2, 4 with ReLU layers at 1, 3, 5).

§Example

With index mapping enabled (default):

  • fc.0.weightfc.0.weight
  • fc.2.weightfc.1.weight (gap filled)
  • fc.4.weightfc.2.weight (gap filled)
§Arguments
  • map - true to enable contiguous index mapping, false to disable
§Example
// Disable contiguous index mapping if your model already has contiguous indices
let store = PytorchStore::from_file("model.pth")
    .map_indices_contiguous(false);

Trait Implementations§

Source§

impl ModuleStore for PytorchStore

Source§

type Error = PytorchStoreError

The error type that can be returned during storage operations. Read more
Source§

fn collect_from<B: Backend, M: ModuleSnapshot<B>>( &mut self, _module: &M, ) -> Result<(), Self::Error>

Collect tensor data from a module and store it to storage. Read more
Source§

fn apply_to<B: Backend, M: ModuleSnapshot<B>>( &mut self, module: &mut M, ) -> Result<ApplyResult, Self::Error>

Load stored tensor data and apply it to a module. Read more
Source§

fn get_snapshot( &mut self, name: &str, ) -> Result<Option<&TensorSnapshot>, Self::Error>

Get a single tensor snapshot by name. Read more
Source§

fn get_all_snapshots( &mut self, ) -> Result<&BTreeMap<String, TensorSnapshot>, Self::Error>

Get all tensor snapshots from storage as an ordered map. Read more
Source§

fn keys(&mut self) -> Result<Vec<String>, Self::Error>

Get all tensor names/keys in storage. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V