Expand description
PyTorch model checkpoint support for ipfrs-tensorlogic.
This module provides functionality to load and work with PyTorch model checkpoints (.pt/.pth files). PyTorch checkpoints are Python pickle files containing state_dict structures with model weights and optionally optimizer state.
§Safety and Security
Python pickle format can execute arbitrary code during deserialization. This module provides a safe subset of pickle deserialization focused on tensor data structures. For maximum security, consider converting PyTorch checkpoints to Safetensors format.
§Example
use ipfrs_tensorlogic::pytorch_checkpoint::{PyTorchCheckpoint, CheckpointMetadata};
use std::path::Path;
// Load a PyTorch checkpoint
let checkpoint = PyTorchCheckpoint::load(Path::new("model.pt"))?;
// Extract metadata
let metadata = checkpoint.metadata();
println!("Model has {} parameters", metadata.total_parameters);
println!("Layers: {:?}", metadata.layer_names);
// Get state dict
let state_dict = checkpoint.state_dict();
for (key, tensor_info) in &state_dict.tensors {
println!("{}: {:?}", key, tensor_info.shape);
}Structs§
- Checkpoint
Metadata - Checkpoint metadata for quick inspection.
- Optimizer
State - Optimizer state containing parameter state and hyperparameters.
- Param
State - Per-parameter optimizer state (momentum, velocity, etc.).
- PyTorch
Checkpoint - PyTorch checkpoint structure.
- State
Dict - Model state dictionary containing named tensors.
- Tensor
Data - Tensor data with shape and values.