Module pytorch_checkpoint

Module pytorch_checkpoint 

Source
Expand description

PyTorch model checkpoint support for ipfrs-tensorlogic.

This module provides functionality to load and work with PyTorch model checkpoints (.pt/.pth files). PyTorch checkpoints are Python pickle files containing state_dict structures with model weights and optionally optimizer state.

§Safety and Security

Python pickle format can execute arbitrary code during deserialization. This module provides a safe subset of pickle deserialization focused on tensor data structures. For maximum security, consider converting PyTorch checkpoints to Safetensors format.

§Example

use ipfrs_tensorlogic::pytorch_checkpoint::{PyTorchCheckpoint, CheckpointMetadata};
use std::path::Path;

// Load a PyTorch checkpoint
let checkpoint = PyTorchCheckpoint::load(Path::new("model.pt"))?;

// Extract metadata
let metadata = checkpoint.metadata();
println!("Model has {} parameters", metadata.total_parameters);
println!("Layers: {:?}", metadata.layer_names);

// Get state dict
let state_dict = checkpoint.state_dict();
for (key, tensor_info) in &state_dict.tensors {
    println!("{}: {:?}", key, tensor_info.shape);
}

Structs§

CheckpointMetadata
Checkpoint metadata for quick inspection.
OptimizerState
Optimizer state containing parameter state and hyperparameters.
ParamState
Per-parameter optimizer state (momentum, velocity, etc.).
PyTorchCheckpoint
PyTorch checkpoint structure.
StateDict
Model state dictionary containing named tensors.
TensorData
Tensor data with shape and values.