PytorchStore

Struct PytorchStore 

Source
pub struct PytorchStore { /* private fields */ }
Expand description

PyTorch store for file-based storage only.

This store allows loading models from PyTorch checkpoint files (.pt/.pth) with automatic weight transformation using PyTorchToBurnAdapter. Linear weights are automatically transposed and normalization parameters are renamed (gamma -> weight, beta -> bias).

Note that saving to PyTorch format is not yet supported.

Implementations§

Source§

impl PytorchStore

Source

pub fn from_file(path: impl Into<PathBuf>) -> Self

Create a store for loading from a PyTorch file.

§Arguments
  • path - Path to the PyTorch checkpoint file (.pt or .pth)
§Example
use burn_store::PytorchStore;

let store = PytorchStore::from_file("model.pth");
Source

pub fn with_top_level_key(self, key: impl Into<String>) -> Self

Set a top-level key to extract tensors from.

PyTorch files often contain nested dictionaries. Use this to extract tensors from a specific top-level key like “state_dict” or “model_state_dict”.

§Example
let store = PytorchStore::from_file("checkpoint.pth")
    .with_top_level_key("model_state_dict");
Source

pub fn filter(self, filter: PathFilter) -> Self

Filter which tensors to load.

Source

pub fn with_regex<S: AsRef<str>>(self, pattern: S) -> Self

Add a regex pattern to filter tensors.

Multiple patterns can be added and they work with OR logic.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_regex(r"^encoder\..*")  // Match all encoder tensors
    .with_regex(r".*\.weight$");   // OR match any weight tensors
Source

pub fn with_regexes<I, S>(self, patterns: I) -> Self
where I: IntoIterator<Item = S>, S: AsRef<str>,

Add multiple regex patterns to filter tensors.

Source

pub fn with_full_path<S: Into<String>>(self, path: S) -> Self

Add an exact full path to match.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_full_path("encoder.layer1.weight")
    .with_full_path("decoder.output.bias");
Source

pub fn with_full_paths<I, S>(self, paths: I) -> Self
where I: IntoIterator<Item = S>, S: Into<String>,

Add multiple exact full paths to match.

Source

pub fn with_predicate(self, predicate: fn(&str, &str) -> bool) -> Self

Add a predicate function for custom filtering logic.

The predicate receives the tensor path and container path.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_predicate(|path, _| path.starts_with("encoder.") || path.ends_with(".bias"));
Source

pub fn with_predicates<I>(self, predicates: I) -> Self
where I: IntoIterator<Item = fn(&str, &str) -> bool>,

Add multiple predicate functions.

Source

pub fn match_all(self) -> Self

Set the filter to match all paths (disables filtering).

Source

pub fn remap(self, remapper: KeyRemapper) -> Self

Remap tensor names during load.

Source

pub fn with_key_remapping( self, from_pattern: impl AsRef<str>, to_pattern: impl Into<String>, ) -> Self

Add a regex pattern to remap tensor names during load.

§Example
let store = PytorchStore::from_file("model.pth")
    .with_key_remapping(r"^encoder\.", "transformer.encoder.")  // encoder.X -> transformer.encoder.X
    .with_key_remapping(r"\.gamma$", ".weight");               // X.gamma -> X.weight
Source

pub fn validate(self, validate: bool) -> Self

Set whether to validate tensors during loading (default: true).

Source

pub fn allow_partial(self, allow: bool) -> Self

Allow partial loading of tensors (continue even if some tensors are missing).

Trait Implementations§

Source§

impl ModuleStore for PytorchStore

Source§

type Error = PytorchStoreError

The error type that can be returned during storage operations. Read more
Source§

fn collect_from<B: Backend, M: ModuleSnapshot<B>>( &mut self, _module: &M, ) -> Result<(), Self::Error>

Collect tensor data from a module and store it to storage. Read more
Source§

fn apply_to<B: Backend, M: ModuleSnapshot<B>>( &mut self, module: &mut M, ) -> Result<ApplyResult, Self::Error>

Load stored tensor data and apply it to a module. Read more

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V