#![allow(unused_variables)]
use anyhow::{anyhow, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CheckpointFormat {
PyTorch,
TensorFlow,
JAX,
Trustformers,
SafeTensors,
ONNX,
}
impl CheckpointFormat {
pub fn extension(&self) -> &'static str {
match self {
Self::PyTorch => ".pt",
Self::TensorFlow => ".ckpt",
Self::JAX => ".jax",
Self::Trustformers => ".trust",
Self::SafeTensors => ".safetensors",
Self::ONNX => ".onnx",
}
}
pub fn from_path(path: &Path) -> Option<Self> {
let ext = path.extension()?.to_str()?;
match ext {
"pt" | "pth" | "bin" => Some(Self::PyTorch),
"ckpt" | "h5" | "pb" => Some(Self::TensorFlow),
"jax" | "msgpack" => Some(Self::JAX),
"trust" => Some(Self::Trustformers),
"safetensors" => Some(Self::SafeTensors),
"onnx" => Some(Self::ONNX),
_ => None,
}
}
}
pub trait Checkpoint: Send + Sync {
fn format(&self) -> CheckpointFormat;
fn weight_names(&self) -> Vec<String>;
fn get_weight(&self, name: &str) -> Result<WeightTensor>;
fn set_weight(&mut self, name: &str, weight: WeightTensor) -> Result<()>;
fn metadata(&self) -> &HashMap<String, String>;
fn save(&self, path: &Path) -> Result<()>;
fn load(path: &Path) -> Result<Self>
where
Self: Sized;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, oxicode::Encode, oxicode::Decode)]
pub struct WeightTensor {
pub data: Vec<f32>,
pub shape: Vec<usize>,
pub dtype: DataType,
}
#[derive(
Debug,
Clone,
Copy,
PartialEq,
Eq,
serde::Serialize,
serde::Deserialize,
oxicode::Encode,
oxicode::Decode,
)]
pub enum DataType {
Float32,
Float16,
BFloat16,
Int8,
Int4,
}
impl WeightTensor {
pub fn new(data: Vec<f32>, shape: Vec<usize>) -> Self {
Self {
data,
shape,
dtype: DataType::Float32,
}
}
pub fn num_elements(&self) -> usize {
self.shape.iter().product()
}
pub fn reshape(&mut self, new_shape: Vec<usize>) -> Result<()> {
let new_size: usize = new_shape.iter().product();
if new_size != self.num_elements() {
return Err(anyhow!(
"Cannot reshape tensor from {:?} to {:?}",
self.shape,
new_shape
));
}
self.shape = new_shape;
Ok(())
}
pub fn transpose(&mut self, dims: &[usize]) -> Result<()> {
if dims.len() != self.shape.len() {
return Err(anyhow!("Transpose dimensions mismatch"));
}
let mut new_shape = vec![0; self.shape.len()];
for (i, &dim) in dims.iter().enumerate() {
if dim >= self.shape.len() {
return Err(anyhow!("Transpose dimension {} out of bounds", dim));
}
new_shape[i] = self.shape[dim];
}
let mut old_strides = vec![1; self.shape.len()];
for i in (0..self.shape.len() - 1).rev() {
old_strides[i] = old_strides[i + 1] * self.shape[i + 1];
}
let mut new_strides = vec![1; new_shape.len()];
for i in (0..new_shape.len() - 1).rev() {
new_strides[i] = new_strides[i + 1] * new_shape[i + 1];
}
let mut new_data = vec![0.0; self.data.len()];
let total_elements = self.data.len();
for i in 0..total_elements {
let mut old_indices = vec![0; self.shape.len()];
let mut temp = i;
for (j, &stride) in old_strides.iter().enumerate() {
old_indices[j] = temp / stride;
temp %= stride;
}
let mut new_indices = vec![0; new_shape.len()];
for (j, &dim) in dims.iter().enumerate() {
new_indices[j] = old_indices[dim];
}
let mut new_idx = 0;
for (j, &idx) in new_indices.iter().enumerate() {
new_idx += idx * new_strides[j];
}
new_data[new_idx] = self.data[i];
}
self.data = new_data;
self.shape = new_shape;
Ok(())
}
}
pub struct PyTorchCheckpoint {
weights: HashMap<String, WeightTensor>,
metadata: HashMap<String, String>,
}
impl Default for PyTorchCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl PyTorchCheckpoint {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
metadata: HashMap::new(),
}
}
}
impl Checkpoint for PyTorchCheckpoint {
fn format(&self) -> CheckpointFormat {
CheckpointFormat::PyTorch
}
fn weight_names(&self) -> Vec<String> {
self.weights.keys().cloned().collect()
}
fn get_weight(&self, name: &str) -> Result<WeightTensor> {
self.weights
.get(name)
.cloned()
.ok_or_else(|| anyhow!("Weight '{}' not found", name))
}
fn set_weight(&mut self, name: &str, weight: WeightTensor) -> Result<()> {
self.weights.insert(name.to_string(), weight);
Ok(())
}
fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
fn save(&self, path: &Path) -> Result<()> {
let checkpoint_data = serde_json::json!({
"weights": self.weights.iter().map(|(key, tensor)| {
(key.clone(), serde_json::json!({
"data": tensor.data,
"shape": tensor.shape,
"dtype": format!("{:?}", tensor.dtype)
}))
}).collect::<HashMap<String, serde_json::Value>>(),
"metadata": self.metadata,
"format": "pytorch_compatible"
});
std::fs::write(path, serde_json::to_string_pretty(&checkpoint_data)?)?;
Ok(())
}
fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let checkpoint_data: serde_json::Value = serde_json::from_str(&content)?;
let mut checkpoint = Self::new();
if let Some(weights) = checkpoint_data.get("weights").and_then(|w| w.as_object()) {
for (key, tensor_data) in weights {
if let (Some(data), Some(shape), Some(dtype)) = (
tensor_data.get("data").and_then(|d| d.as_array()),
tensor_data.get("shape").and_then(|s| s.as_array()),
tensor_data.get("dtype").and_then(|d| d.as_str()),
) {
let data: Vec<f32> =
data.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
let shape: Vec<usize> =
shape.iter().filter_map(|v| v.as_u64().map(|u| u as usize)).collect();
checkpoint.weights.insert(key.clone(), WeightTensor::new(data, shape));
}
}
}
if let Some(metadata) = checkpoint_data.get("metadata").and_then(|m| m.as_object()) {
for (key, value) in metadata {
if let Some(value_str) = value.as_str() {
checkpoint.metadata.insert(key.clone(), value_str.to_string());
}
}
}
Ok(checkpoint)
}
}
pub struct TensorFlowCheckpoint {
weights: HashMap<String, WeightTensor>,
metadata: HashMap<String, String>,
}
impl Default for TensorFlowCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl TensorFlowCheckpoint {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
metadata: HashMap::new(),
}
}
}
impl Checkpoint for TensorFlowCheckpoint {
fn format(&self) -> CheckpointFormat {
CheckpointFormat::TensorFlow
}
fn weight_names(&self) -> Vec<String> {
self.weights.keys().cloned().collect()
}
fn get_weight(&self, name: &str) -> Result<WeightTensor> {
self.weights
.get(name)
.cloned()
.ok_or_else(|| anyhow!("Weight '{}' not found", name))
}
fn set_weight(&mut self, name: &str, weight: WeightTensor) -> Result<()> {
self.weights.insert(name.to_string(), weight);
Ok(())
}
fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
fn save(&self, path: &Path) -> Result<()> {
let checkpoint_data = serde_json::json!({
"weights": self.weights.iter().map(|(key, tensor)| {
(key.clone(), serde_json::json!({
"data": tensor.data,
"shape": tensor.shape,
"dtype": format!("{:?}", tensor.dtype)
}))
}).collect::<HashMap<String, serde_json::Value>>(),
"metadata": self.metadata,
"format": "tensorflow_compatible"
});
std::fs::write(path, serde_json::to_string_pretty(&checkpoint_data)?)?;
Ok(())
}
fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let checkpoint_data: serde_json::Value = serde_json::from_str(&content)?;
let mut checkpoint = Self::new();
if let Some(weights) = checkpoint_data.get("weights").and_then(|w| w.as_object()) {
for (key, tensor_data) in weights {
if let (Some(data), Some(shape), Some(dtype)) = (
tensor_data.get("data").and_then(|d| d.as_array()),
tensor_data.get("shape").and_then(|s| s.as_array()),
tensor_data.get("dtype").and_then(|d| d.as_str()),
) {
let data: Vec<f32> =
data.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
let shape: Vec<usize> =
shape.iter().filter_map(|v| v.as_u64().map(|u| u as usize)).collect();
checkpoint.weights.insert(key.clone(), WeightTensor::new(data, shape));
}
}
}
if let Some(metadata) = checkpoint_data.get("metadata").and_then(|m| m.as_object()) {
for (key, value) in metadata {
if let Some(value_str) = value.as_str() {
checkpoint.metadata.insert(key.clone(), value_str.to_string());
}
}
}
Ok(checkpoint)
}
}
pub struct JaxCheckpoint {
weights: HashMap<String, WeightTensor>,
metadata: HashMap<String, String>,
}
impl Default for JaxCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl JaxCheckpoint {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
metadata: HashMap::new(),
}
}
}
impl Checkpoint for JaxCheckpoint {
fn format(&self) -> CheckpointFormat {
CheckpointFormat::JAX
}
fn weight_names(&self) -> Vec<String> {
self.weights.keys().cloned().collect()
}
fn get_weight(&self, name: &str) -> Result<WeightTensor> {
self.weights
.get(name)
.cloned()
.ok_or_else(|| anyhow!("Weight '{}' not found", name))
}
fn set_weight(&mut self, name: &str, weight: WeightTensor) -> Result<()> {
self.weights.insert(name.to_string(), weight);
Ok(())
}
fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
fn save(&self, path: &Path) -> Result<()> {
let checkpoint_data = serde_json::json!({
"weights": self.weights.iter().map(|(key, tensor)| {
(key.clone(), serde_json::json!({
"data": tensor.data,
"shape": tensor.shape,
"dtype": format!("{:?}", tensor.dtype)
}))
}).collect::<HashMap<String, serde_json::Value>>(),
"metadata": self.metadata,
"format": "jax_compatible"
});
std::fs::write(path, serde_json::to_string_pretty(&checkpoint_data)?)?;
Ok(())
}
fn load(path: &Path) -> Result<Self> {
let content = std::fs::read_to_string(path)?;
let checkpoint_data: serde_json::Value = serde_json::from_str(&content)?;
let mut checkpoint = Self::new();
if let Some(weights) = checkpoint_data.get("weights").and_then(|w| w.as_object()) {
for (key, tensor_data) in weights {
if let (Some(data), Some(shape), Some(dtype)) = (
tensor_data.get("data").and_then(|d| d.as_array()),
tensor_data.get("shape").and_then(|s| s.as_array()),
tensor_data.get("dtype").and_then(|d| d.as_str()),
) {
let data: Vec<f32> =
data.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
let shape: Vec<usize> =
shape.iter().filter_map(|v| v.as_u64().map(|u| u as usize)).collect();
checkpoint.weights.insert(key.clone(), WeightTensor::new(data, shape));
}
}
}
if let Some(metadata) = checkpoint_data.get("metadata").and_then(|m| m.as_object()) {
for (key, value) in metadata {
if let Some(value_str) = value.as_str() {
checkpoint.metadata.insert(key.clone(), value_str.to_string());
}
}
}
Ok(checkpoint)
}
}
pub struct TrustformersCheckpoint {
weights: HashMap<String, WeightTensor>,
metadata: HashMap<String, String>,
version: String,
}
impl Default for TrustformersCheckpoint {
fn default() -> Self {
Self::new()
}
}
impl TrustformersCheckpoint {
pub fn new() -> Self {
Self {
weights: HashMap::new(),
metadata: HashMap::new(),
version: env!("CARGO_PKG_VERSION").to_string(),
}
}
}
impl Checkpoint for TrustformersCheckpoint {
fn format(&self) -> CheckpointFormat {
CheckpointFormat::Trustformers
}
fn weight_names(&self) -> Vec<String> {
self.weights.keys().cloned().collect()
}
fn get_weight(&self, name: &str) -> Result<WeightTensor> {
self.weights
.get(name)
.cloned()
.ok_or_else(|| anyhow!("Weight '{}' not found", name))
}
fn set_weight(&mut self, name: &str, weight: WeightTensor) -> Result<()> {
self.weights.insert(name.to_string(), weight);
Ok(())
}
fn metadata(&self) -> &HashMap<String, String> {
&self.metadata
}
fn save(&self, path: &Path) -> Result<()> {
let data = oxicode::serde::encode_to_vec(
&(&self.weights, &self.metadata, &self.version),
oxicode::config::standard(),
)?;
std::fs::write(path, data)?;
Ok(())
}
fn load(path: &Path) -> Result<Self> {
let data = std::fs::read(path)?;
type CheckpointData = (
HashMap<String, WeightTensor>,
HashMap<String, String>,
String,
);
let ((weights, metadata, version), _): (CheckpointData, usize) =
oxicode::serde::decode_from_slice(&data, oxicode::config::standard())?;
Ok(Self {
weights,
metadata,
version,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_detection() {
assert_eq!(
CheckpointFormat::from_path(Path::new("model.pt")),
Some(CheckpointFormat::PyTorch)
);
assert_eq!(
CheckpointFormat::from_path(Path::new("model.ckpt")),
Some(CheckpointFormat::TensorFlow)
);
assert_eq!(
CheckpointFormat::from_path(Path::new("model.jax")),
Some(CheckpointFormat::JAX)
);
}
#[test]
fn test_weight_tensor() {
let mut tensor = WeightTensor::new(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
assert_eq!(tensor.num_elements(), 4);
assert!(tensor.reshape(vec![4]).is_ok());
assert_eq!(tensor.shape, vec![4]);
assert!(tensor.reshape(vec![2, 3]).is_err());
}
}