pub struct PytorchReader { /* private fields */ }Expand description
PyTorch checkpoint reader
This is the main interface for reading PyTorch checkpoint files (.pt/.pth). It supports multiple PyTorch formats including modern ZIP-based format (1.6+), legacy format (0.1.10-1.5), and simple pickle files.
§Example
// Load a checkpoint file
let reader = PytorchReader::new("model.pt")?;
// Get tensor names
let keys = reader.keys();
// Access a specific tensor
if let Some(tensor) = reader.get("conv1.weight") {
let data = tensor.to_data(); // Materializes the tensor
}
// Check file metadata
println!("Format: {:?}", reader.metadata().format_type);
println!("Tensor count: {}", reader.metadata().tensor_count);Implementations§
Source§impl PytorchReader
impl PytorchReader
Sourcepub fn with_top_level_key<P: AsRef<Path>>(
path: P,
key: &str,
) -> Result<Self, PytorchError>
pub fn with_top_level_key<P: AsRef<Path>>( path: P, key: &str, ) -> Result<Self, PytorchError>
Load a PyTorch checkpoint with a specific top-level key
Many PyTorch checkpoints store the model weights under a specific key like “state_dict”, “model”, or “model_state_dict”.
§Arguments
path- Path to the PyTorch filekey- Top-level key to extract (e.g., “state_dict”)
§Example
let reader = PytorchReader::with_top_level_key("checkpoint.pt", "state_dict")?;Sourcepub fn from_reader<R: Read>(
reader: R,
top_level_key: Option<&str>,
) -> Result<Self, PytorchError>
pub fn from_reader<R: Read>( reader: R, top_level_key: Option<&str>, ) -> Result<Self, PytorchError>
Load from a reader
This method is useful when loading from non-file sources like memory buffers. Note: Metadata detection is limited when loading from a reader.
§Arguments
reader- Any type implementingReadtop_level_key- Optional key to extract
Sourcepub fn get(&self, name: &str) -> Option<&TensorSnapshot>
pub fn get(&self, name: &str) -> Option<&TensorSnapshot>
Get a tensor by name
Sourcepub fn tensors(&self) -> &HashMap<String, TensorSnapshot>
pub fn tensors(&self) -> &HashMap<String, TensorSnapshot>
Get all tensors
Sourcepub fn into_tensors(self) -> HashMap<String, TensorSnapshot>
pub fn into_tensors(self) -> HashMap<String, TensorSnapshot>
Take ownership of all tensors
Sourcepub fn metadata(&self) -> &PytorchMetadata
pub fn metadata(&self) -> &PytorchMetadata
Get metadata about the loaded file
Provides information about the file format, version, endianness, etc.
Sourcepub fn read_pickle_data<P: AsRef<Path>>(
path: P,
top_level_key: Option<&str>,
) -> Result<PickleValue, PytorchError>
pub fn read_pickle_data<P: AsRef<Path>>( path: P, top_level_key: Option<&str>, ) -> Result<PickleValue, PytorchError>
Read raw pickle data from a PyTorch file
This is useful for extracting configuration or metadata that isn’t tensor data. Returns a simplified JSON-like structure that can be easily converted to other formats.
§Arguments
path- Path to the PyTorch filetop_level_key- Optional key to extract from the top-level dictionary
§Returns
A PickleValue representing the pickle data structure
Sourcepub fn load_config<D, P>(
path: P,
top_level_key: Option<&str>,
) -> Result<D, PytorchError>
pub fn load_config<D, P>( path: P, top_level_key: Option<&str>, ) -> Result<D, PytorchError>
Load and deserialize configuration data from a PyTorch file
This method reads configuration or metadata stored in PyTorch checkpoint files and deserializes it into the specified type. It’s particularly useful for extracting model configurations that might be saved alongside model weights.
§Arguments
path- Path to the PyTorch file (.pt or .pth)top_level_key- Optional key to extract specific data within the pickle file. IfNone, the entire content is deserialized.
§Type Parameters
D- The target type to deserialize into. Must implementDeserializeOwned.
§Returns
A Result containing the deserialized configuration data, or an Error if
reading or deserialization fails.
§Example
#[derive(Debug, Deserialize)]
struct ModelConfig {
hidden_size: usize,
num_layers: usize,
}
let config: ModelConfig = PytorchReader::load_config("model.pth", Some("config"))?;