Skip to main content

PytorchReader

Struct PytorchReader 

Source
pub struct PytorchReader { /* private fields */ }
Expand description

PyTorch checkpoint reader

This is the main interface for reading PyTorch checkpoint files (.pt/.pth). It supports multiple PyTorch formats including modern ZIP-based format (1.6+), legacy format (0.1.10-1.5), and simple pickle files.

§Example

// Load a checkpoint file
let reader = PytorchReader::new("model.pt")?;

// Get tensor names
let keys = reader.keys();

// Access a specific tensor
if let Some(tensor) = reader.get("conv1.weight") {
    let data = tensor.to_data(); // Materializes the tensor
}

// Check file metadata
println!("Format: {:?}", reader.metadata().format_type);
println!("Tensor count: {}", reader.metadata().tensor_count);

Implementations§

Source§

impl PytorchReader

Source

pub fn new<P: AsRef<Path>>(path: P) -> Result<Self, PytorchError>

Load a PyTorch checkpoint file

§Arguments
  • path - Path to the PyTorch file (.pt or .pth)
§Returns

A PytorchReader with lazy-loaded tensors and metadata

Source

pub fn with_top_level_key<P: AsRef<Path>>( path: P, key: &str, ) -> Result<Self, PytorchError>

Load a PyTorch checkpoint with a specific top-level key

Many PyTorch checkpoints store the model weights under a specific key like “state_dict”, “model”, or “model_state_dict”.

§Arguments
  • path - Path to the PyTorch file
  • key - Top-level key to extract (e.g., “state_dict”)
§Example
let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?;
Source

pub fn from_reader<R: Read>( reader: R, top_level_key: Option<&str>, ) -> Result<Self, PytorchError>

Load from a reader

This method is useful when loading from non-file sources like memory buffers. Note: Metadata detection is limited when loading from a reader.

§Arguments
  • reader - Any type implementing Read
  • top_level_key - Optional key to extract
Source

pub fn keys(&self) -> Vec<String>

Get all tensor names

Source

pub fn get(&self, name: &str) -> Option<&TensorSnapshot>

Get a tensor by name

Source

pub fn tensors(&self) -> &HashMap<String, TensorSnapshot>

Get all tensors

Source

pub fn into_tensors(self) -> HashMap<String, TensorSnapshot>

Take ownership of all tensors

Source

pub fn metadata(&self) -> &PytorchMetadata

Get metadata about the loaded file

Provides information about the file format, version, endianness, etc.

Source

pub fn len(&self) -> usize

Get the number of tensors in the file

Source

pub fn is_empty(&self) -> bool

Check if the file contains no tensors

Source

pub fn read_pickle_data<P: AsRef<Path>>( path: P, top_level_key: Option<&str>, ) -> Result<PickleValue, PytorchError>

Read raw pickle data from a PyTorch file

This is useful for extracting configuration or metadata that isn’t tensor data. Returns a simplified JSON-like structure that can be easily converted to other formats.

§Arguments
  • path - Path to the PyTorch file
  • top_level_key - Optional key to extract from the top-level dictionary
§Returns

A PickleValue representing the pickle data structure

Source

pub fn load_config<D, P>( path: P, top_level_key: Option<&str>, ) -> Result<D, PytorchError>
where D: DeserializeOwned, P: AsRef<Path>,

Load and deserialize configuration data from a PyTorch file

This method reads configuration or metadata stored in PyTorch checkpoint files and deserializes it into the specified type. It’s particularly useful for extracting model configurations that might be saved alongside model weights.

§Arguments
  • path - Path to the PyTorch file (.pt or .pth)
  • top_level_key - Optional key to extract specific data within the pickle file. If None, the entire content is deserialized.
§Type Parameters
  • D - The target type to deserialize into. Must implement DeserializeOwned.
§Returns

A Result containing the deserialized configuration data, or an Error if reading or deserialization fails.

§Example
#[derive(Debug, Deserialize)]
struct ModelConfig {
    hidden_size: usize,
    num_layers: usize,
}

let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?;

Auto Trait Implementations§

Blanket Implementations§

Source§

impl<T> Any for T
where T: 'static + ?Sized,

Source§

fn type_id(&self) -> TypeId

Gets the TypeId of self. Read more
Source§

impl<T> Borrow<T> for T
where T: ?Sized,

Source§

fn borrow(&self) -> &T

Immutably borrows from an owned value. Read more
Source§

impl<T> BorrowMut<T> for T
where T: ?Sized,

Source§

fn borrow_mut(&mut self) -> &mut T

Mutably borrows from an owned value. Read more
Source§

impl<T> From<T> for T

Source§

fn from(t: T) -> T

Returns the argument unchanged.

Source§

impl<T, U> Into<U> for T
where U: From<T>,

Source§

fn into(self) -> U

Calls U::from(self).

That is, this conversion is whatever the implementation of From<T> for U chooses to do.

Source§

impl<T> Same for T

Source§

type Output = T

Should always be Self
Source§

impl<T, U> TryFrom<U> for T
where U: Into<T>,

Source§

type Error = Infallible

The type returned in the event of a conversion error.
Source§

fn try_from(value: U) -> Result<T, <T as TryFrom<U>>::Error>

Performs the conversion.
Source§

impl<T, U> TryInto<U> for T
where U: TryFrom<T>,

Source§

type Error = <U as TryFrom<T>>::Error

The type returned in the event of a conversion error.
Source§

fn try_into(self) -> Result<U, <U as TryFrom<T>>::Error>

Performs the conversion.
Source§

impl<V, T> VZip<V> for T
where V: MultiLane<T>,

Source§

fn vzip(self) -> V