use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
use anyhow::{bail, Context, Result};
use serde::{Deserialize, Serialize};
use crate::safetensors_support::SafetensorsWriter;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PyTorchCheckpoint {
pub state_dict: StateDict,
pub optimizer_state: Option<OptimizerState>,
pub epoch: Option<usize>,
pub loss_history: Option<Vec<f32>>,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateDict {
pub tensors: HashMap<String, TensorData>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorData {
pub shape: Vec<usize>,
pub dtype: String,
pub data: Vec<u8>,
pub requires_grad: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerState {
pub optimizer_type: String,
pub param_state: HashMap<String, ParamState>,
pub hyperparameters: HashMap<String, f64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParamState {
pub momentum: Option<Vec<u8>>,
pub velocity: Option<Vec<u8>>,
pub step: Option<usize>,
pub custom: HashMap<String, Vec<u8>>,
}
#[derive(Debug, Clone)]
pub struct CheckpointMetadata {
pub total_parameters: usize,
pub layer_names: Vec<String>,
pub total_size_bytes: usize,
pub dtypes: HashMap<String, usize>,
pub has_optimizer_state: bool,
pub epoch: Option<usize>,
}
impl PyTorchCheckpoint {
#[allow(dead_code)]
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path.as_ref()).context("Failed to open checkpoint file")?;
let mut reader = BufReader::new(file);
let mut bytes = Vec::new();
reader
.read_to_end(&mut bytes)
.context("Failed to read checkpoint file")?;
Self::from_pickle_bytes(&bytes)
}
fn from_pickle_bytes(bytes: &[u8]) -> Result<Self> {
let value: serde_pickle::Value = serde_pickle::from_slice(bytes, Default::default())
.context("Failed to deserialize pickle data")?;
Self::parse_pickle_value(value)
}
fn parse_pickle_value(value: serde_pickle::Value) -> Result<Self> {
use serde_pickle::{HashableValue, Value};
let dict = match value {
Value::Dict(d) => d,
_ => bail!("Expected dictionary at root of checkpoint"),
};
let mut state_dict_tensors = HashMap::new();
let mut optimizer_state = None;
let mut epoch = None;
let mut loss_history = None;
let mut metadata = HashMap::new();
let has_state_dict_key = dict.iter().any(|(k, _)| {
matches!(k, HashableValue::String(ref s) if s == "state_dict" || s == "model_state_dict")
});
for (key, val) in &dict {
let key_str = match key {
HashableValue::String(s) => s.clone(),
HashableValue::Bytes(b) => String::from_utf8_lossy(b).to_string(),
_ => continue,
};
match key_str.as_str() {
"state_dict" | "model_state_dict" => {
if let Value::Dict(sd) = val {
state_dict_tensors = Self::parse_state_dict(sd.clone())?;
}
}
"optimizer_state_dict" | "optimizer" => {
optimizer_state = Self::parse_optimizer_state(val.clone()).ok();
}
"epoch" => {
if let Value::I64(e) = val {
epoch = Some(*e as usize);
}
}
"loss_history" => {
loss_history = Self::parse_loss_history(val.clone()).ok();
}
_ => {
if let Value::String(s) = val {
metadata.insert(key_str, s.clone());
}
}
}
}
if state_dict_tensors.is_empty() && !has_state_dict_key {
state_dict_tensors = Self::parse_state_dict(dict)?;
}
Ok(PyTorchCheckpoint {
state_dict: StateDict {
tensors: state_dict_tensors,
},
optimizer_state,
epoch,
loss_history,
metadata,
})
}
fn parse_state_dict(
dict: std::collections::BTreeMap<serde_pickle::HashableValue, serde_pickle::Value>,
) -> Result<HashMap<String, TensorData>> {
use serde_pickle::HashableValue;
let mut tensors = HashMap::new();
for (key, val) in dict {
let key_str = match key {
HashableValue::String(s) => s,
HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
_ => continue,
};
if let Ok(tensor_data) = Self::parse_tensor_value(val) {
tensors.insert(key_str, tensor_data);
}
}
Ok(tensors)
}
fn parse_tensor_value(value: serde_pickle::Value) -> Result<TensorData> {
use serde_pickle::{HashableValue, Value};
match value {
Value::Dict(d) => {
let mut shape = Vec::new();
let mut data = Vec::new();
let mut dtype = "float32".to_string();
let mut requires_grad = false;
for (k, v) in d {
let key = match k {
HashableValue::String(s) => s,
HashableValue::Bytes(b) => String::from_utf8_lossy(&b).to_string(),
_ => continue,
};
match key.as_str() {
"shape" | "size" => {
if let Value::List(list) = v {
shape = list
.into_iter()
.filter_map(|v| match v {
Value::I64(i) => Some(i as usize),
_ => None,
})
.collect();
}
}
"data" | "storage" => {
if let Value::Bytes(b) = v {
data = b;
}
}
"dtype" => {
if let Value::String(s) = v {
dtype = s;
}
}
"requires_grad" => {
if let Value::Bool(b) = v {
requires_grad = b;
}
}
_ => {}
}
}
if !shape.is_empty() && !data.is_empty() {
Ok(TensorData {
shape,
dtype,
data,
requires_grad,
})
} else {
bail!("Incomplete tensor data")
}
}
Value::Bytes(data) => {
Ok(TensorData {
shape: vec![data.len() / 4],
dtype: "float32".to_string(),
data,
requires_grad: false,
})
}
_ => bail!("Unsupported tensor value type"),
}
}
#[allow(dead_code)]
fn parse_optimizer_state(_value: serde_pickle::Value) -> Result<OptimizerState> {
Ok(OptimizerState {
optimizer_type: "Unknown".to_string(),
param_state: HashMap::new(),
hyperparameters: HashMap::new(),
})
}
#[allow(dead_code)]
fn parse_loss_history(value: serde_pickle::Value) -> Result<Vec<f32>> {
use serde_pickle::Value;
match value {
Value::List(list) => {
let losses = list
.into_iter()
.filter_map(|v| match v {
Value::F64(f) => Some(f as f32),
_ => None,
})
.collect();
Ok(losses)
}
_ => bail!("Expected list for loss history"),
}
}
pub fn metadata(&self) -> CheckpointMetadata {
let mut total_parameters = 0;
let mut layer_names = Vec::new();
let mut total_size_bytes = 0;
let mut dtypes = HashMap::new();
for (name, tensor) in &self.state_dict.tensors {
layer_names.push(name.clone());
let num_elements: usize = tensor.shape.iter().product();
total_parameters += num_elements;
total_size_bytes += tensor.data.len();
*dtypes.entry(tensor.dtype.clone()).or_insert(0) += 1;
}
CheckpointMetadata {
total_parameters,
layer_names,
total_size_bytes,
dtypes,
has_optimizer_state: self.optimizer_state.is_some(),
epoch: self.epoch,
}
}
pub fn state_dict(&self) -> &StateDict {
&self.state_dict
}
pub fn to_safetensors(&self) -> Result<Vec<u8>> {
let mut writer = SafetensorsWriter::new();
for (name, tensor) in &self.state_dict.tensors {
let shape = tensor.shape.clone();
match tensor.dtype.as_str() {
"float32" | "Float" => {
if tensor.data.len() % 4 != 0 {
bail!("Invalid float32 data length for tensor {}", name);
}
let float_data: Vec<f32> = tensor
.data
.chunks_exact(4)
.map(|chunk| {
let bytes: [u8; 4] = chunk.try_into().unwrap();
f32::from_le_bytes(bytes)
})
.collect();
writer.add_f32(name, shape, &float_data);
}
"float64" | "Double" => {
if tensor.data.len() % 8 != 0 {
bail!("Invalid float64 data length for tensor {}", name);
}
let float_data: Vec<f64> = tensor
.data
.chunks_exact(8)
.map(|chunk| {
let bytes: [u8; 8] = chunk.try_into().unwrap();
f64::from_le_bytes(bytes)
})
.collect();
writer.add_f64(name, shape, &float_data);
}
_ => {
bail!("Unsupported dtype: {}", tensor.dtype);
}
}
}
writer
.serialize()
.context("Failed to serialize to safetensors")
}
#[allow(dead_code)]
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let bytes = self.to_pickle_bytes()?;
std::fs::write(path, bytes).context("Failed to write checkpoint file")?;
Ok(())
}
fn to_pickle_bytes(&self) -> Result<Vec<u8>> {
use serde_pickle::ser;
#[derive(Serialize)]
struct CheckpointSer {
state_dict: HashMap<String, TensorSer>,
#[serde(skip_serializing_if = "Option::is_none")]
epoch: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
loss_history: Option<Vec<f32>>,
metadata: HashMap<String, String>,
}
#[derive(Serialize)]
struct TensorSer {
shape: Vec<usize>,
dtype: String,
data_len: usize,
}
let state_dict_ser: HashMap<String, TensorSer> = self
.state_dict
.tensors
.iter()
.map(|(name, tensor)| {
(
name.clone(),
TensorSer {
shape: tensor.shape.clone(),
dtype: tensor.dtype.clone(),
data_len: tensor.data.len(),
},
)
})
.collect();
let checkpoint_ser = CheckpointSer {
state_dict: state_dict_ser,
epoch: self.epoch,
loss_history: self.loss_history.clone(),
metadata: self.metadata.clone(),
};
ser::to_vec(&checkpoint_ser, Default::default()).context("Failed to serialize to pickle")
}
#[allow(dead_code)]
fn tensor_to_pickle_value(_tensor: &TensorData) -> HashMap<String, String> {
HashMap::new()
}
pub fn new() -> Self {
PyTorchCheckpoint {
state_dict: StateDict {
tensors: HashMap::new(),
},
optimizer_state: None,
epoch: None,
loss_history: None,
metadata: HashMap::new(),
}
}
pub fn add_tensor(&mut self, name: String, tensor: TensorData) {
self.state_dict.tensors.insert(name, tensor);
}
pub fn set_epoch(&mut self, epoch: usize) {
self.epoch = Some(epoch);
}
pub fn add_metadata(&mut self, key: String, value: String) {
self.metadata.insert(key, value);
}
}
impl Default for PyTorchCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl StateDict {
pub fn get(&self, name: &str) -> Option<&TensorData> {
self.tensors.get(name)
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &TensorData)> {
self.tensors.iter()
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
}
impl TensorData {
pub fn from_f32(shape: Vec<usize>, data: &[f32]) -> Self {
let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
TensorData {
shape,
dtype: "float32".to_string(),
data: bytes,
requires_grad: false,
}
}
pub fn from_f64(shape: Vec<usize>, data: &[f64]) -> Self {
let bytes: Vec<u8> = data.iter().flat_map(|&f| f.to_le_bytes()).collect();
TensorData {
shape,
dtype: "float64".to_string(),
data: bytes,
requires_grad: false,
}
}
pub fn as_f32(&self) -> Result<Vec<f32>> {
if self.dtype != "float32" && self.dtype != "Float" {
bail!("Expected float32 dtype, got {}", self.dtype);
}
if !self.data.len().is_multiple_of(4) {
bail!("Invalid data length for float32");
}
Ok(self
.data
.chunks_exact(4)
.map(|chunk| {
let bytes: [u8; 4] = chunk.try_into().unwrap();
f32::from_le_bytes(bytes)
})
.collect())
}
pub fn as_f64(&self) -> Result<Vec<f64>> {
if self.dtype != "float64" && self.dtype != "Double" {
bail!("Expected float64 dtype, got {}", self.dtype);
}
if !self.data.len().is_multiple_of(8) {
bail!("Invalid data length for float64");
}
Ok(self
.data
.chunks_exact(8)
.map(|chunk| {
let bytes: [u8; 8] = chunk.try_into().unwrap();
f64::from_le_bytes(bytes)
})
.collect())
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let mut checkpoint = PyTorchCheckpoint::new();
let tensor = TensorData::from_f32(vec![2, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
checkpoint.add_tensor("layer1.weight".to_string(), tensor);
checkpoint.set_epoch(10);
checkpoint.add_metadata("model_type".to_string(), "CNN".to_string());
assert_eq!(checkpoint.state_dict().len(), 1);
assert_eq!(checkpoint.epoch, Some(10));
assert_eq!(checkpoint.metadata.get("model_type").unwrap(), "CNN");
}
#[test]
fn test_tensor_data_f32() {
let data = vec![1.0f32, 2.0, 3.0, 4.0];
let tensor = TensorData::from_f32(vec![2, 2], &data);
assert_eq!(tensor.shape, vec![2, 2]);
assert_eq!(tensor.dtype, "float32");
assert_eq!(tensor.num_elements(), 4);
let recovered = tensor.as_f32().unwrap();
assert_eq!(recovered, data);
}
#[test]
fn test_tensor_data_f64() {
let data = vec![1.0f64, 2.0, 3.0, 4.0];
let tensor = TensorData::from_f64(vec![2, 2], &data);
assert_eq!(tensor.shape, vec![2, 2]);
assert_eq!(tensor.dtype, "float64");
let recovered = tensor.as_f64().unwrap();
assert_eq!(recovered, data);
}
#[test]
fn test_metadata_extraction() {
let mut checkpoint = PyTorchCheckpoint::new();
checkpoint.add_tensor(
"layer1.weight".to_string(),
TensorData::from_f32(vec![10, 10], &vec![0.0; 100]),
);
checkpoint.add_tensor(
"layer1.bias".to_string(),
TensorData::from_f32(vec![10], &[0.0; 10]),
);
checkpoint.add_tensor(
"layer2.weight".to_string(),
TensorData::from_f64(vec![5, 10], &vec![0.0; 50]),
);
let metadata = checkpoint.metadata();
assert_eq!(metadata.total_parameters, 160);
assert_eq!(metadata.layer_names.len(), 3);
assert_eq!(metadata.dtypes.get("float32"), Some(&2));
assert_eq!(metadata.dtypes.get("float64"), Some(&1));
}
#[test]
fn test_state_dict_access() {
let mut checkpoint = PyTorchCheckpoint::new();
let tensor = TensorData::from_f32(vec![3], &[1.0, 2.0, 3.0]);
checkpoint.add_tensor("test".to_string(), tensor);
let state_dict = checkpoint.state_dict();
assert_eq!(state_dict.len(), 1);
assert!(!state_dict.is_empty());
let retrieved = state_dict.get("test").unwrap();
assert_eq!(retrieved.shape, vec![3]);
}
#[test]
fn test_checkpoint_serialization() -> Result<()> {
let mut checkpoint = PyTorchCheckpoint::new();
checkpoint.add_tensor(
"weight".to_string(),
TensorData::from_f32(vec![2, 2], &[1.0, 2.0, 3.0, 4.0]),
);
checkpoint.set_epoch(5);
checkpoint.add_metadata("arch".to_string(), "ResNet".to_string());
let bytes = checkpoint.to_pickle_bytes()?;
assert!(!bytes.is_empty());
Ok(())
}
#[test]
fn test_to_safetensors() -> Result<()> {
let mut checkpoint = PyTorchCheckpoint::new();
checkpoint.add_tensor(
"layer1.weight".to_string(),
TensorData::from_f32(vec![3, 3], &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]),
);
checkpoint.add_tensor(
"layer1.bias".to_string(),
TensorData::from_f32(vec![3], &[0.1, 0.2, 0.3]),
);
let safetensors_bytes = checkpoint.to_safetensors()?;
assert!(!safetensors_bytes.is_empty());
Ok(())
}
}