#![allow(unused_variables)]
use anyhow::{anyhow, Result};
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
use crate::checkpoint::formats::{
Checkpoint, CheckpointFormat, JaxCheckpoint, PyTorchCheckpoint, TensorFlowCheckpoint,
TrustformersCheckpoint,
};
pub fn detect_format(path: &Path) -> Result<CheckpointFormat> {
if let Some(format) = CheckpointFormat::from_path(path) {
return Ok(format);
}
let mut file = BufReader::new(File::open(path)?);
let mut magic_bytes = [0u8; 16];
file.read_exact(&mut magic_bytes)?;
if magic_bytes.starts_with(b"\x80\x02\x8a") {
Ok(CheckpointFormat::PyTorch)
} else if magic_bytes.starts_with(b"\x89HDF") {
Ok(CheckpointFormat::TensorFlow)
} else if magic_bytes.starts_with(b"\x82\xa4") {
Ok(CheckpointFormat::JAX)
} else if magic_bytes.starts_with(b"TRUST") {
Ok(CheckpointFormat::Trustformers)
} else if magic_bytes.starts_with(b"{\"") {
Ok(CheckpointFormat::SafeTensors)
} else {
Err(anyhow!(
"Unable to detect checkpoint format from file content"
))
}
}
pub fn load_checkpoint(path: &Path) -> Result<Box<dyn Checkpoint>> {
let format = detect_format(path)?;
match format {
CheckpointFormat::PyTorch => Ok(Box::new(PyTorchCheckpoint::load(path)?)),
CheckpointFormat::TensorFlow => Ok(Box::new(TensorFlowCheckpoint::load(path)?)),
CheckpointFormat::JAX => Ok(Box::new(JaxCheckpoint::load(path)?)),
CheckpointFormat::Trustformers => Ok(Box::new(TrustformersCheckpoint::load(path)?)),
_ => Err(anyhow!("Unsupported checkpoint format: {:?}", format)),
}
}
pub fn save_checkpoint(
checkpoint: &dyn Checkpoint,
path: &Path,
format: CheckpointFormat,
) -> Result<()> {
if checkpoint.format() != format {
return Err(anyhow!(
"Checkpoint format mismatch: {:?} != {:?}",
checkpoint.format(),
format
));
}
checkpoint.save(path)
}
pub fn get_checkpoint_info(path: &Path) -> Result<CheckpointInfo> {
let format = detect_format(path)?;
let file_size = std::fs::metadata(path)?.len();
Ok(CheckpointInfo {
format,
file_size_bytes: file_size,
weight_count: None,
metadata: Default::default(),
})
}
#[derive(Debug)]
pub struct CheckpointInfo {
pub format: CheckpointFormat,
pub file_size_bytes: u64,
pub weight_count: Option<usize>,
pub metadata: std::collections::HashMap<String, String>,
}
pub fn validate_checkpoint(path: &Path) -> Result<bool> {
let checkpoint = load_checkpoint(path)?;
for name in checkpoint.weight_names() {
checkpoint.get_weight(&name)?;
}
Ok(true)
}
pub fn compare_checkpoints(path1: &Path, path2: &Path) -> Result<CheckpointComparison> {
let checkpoint1 = load_checkpoint(path1)?;
let checkpoint2 = load_checkpoint(path2)?;
let names1: std::collections::HashSet<_> = checkpoint1.weight_names().into_iter().collect();
let names2: std::collections::HashSet<_> = checkpoint2.weight_names().into_iter().collect();
let common_weights: Vec<_> = names1.intersection(&names2).cloned().collect();
let only_in_first: Vec<_> = names1.difference(&names2).cloned().collect();
let only_in_second: Vec<_> = names2.difference(&names1).cloned().collect();
let mut shape_mismatches = Vec::new();
for name in &common_weights {
let weight1 = checkpoint1.get_weight(name)?;
let weight2 = checkpoint2.get_weight(name)?;
if weight1.shape != weight2.shape {
shape_mismatches.push(ShapeMismatch {
weight_name: name.clone(),
shape1: weight1.shape,
shape2: weight2.shape,
});
}
}
Ok(CheckpointComparison {
format1: checkpoint1.format(),
format2: checkpoint2.format(),
common_weights,
only_in_first,
only_in_second,
shape_mismatches,
})
}
#[derive(Debug)]
pub struct CheckpointComparison {
pub format1: CheckpointFormat,
pub format2: CheckpointFormat,
pub common_weights: Vec<String>,
pub only_in_first: Vec<String>,
pub only_in_second: Vec<String>,
pub shape_mismatches: Vec<ShapeMismatch>,
}
#[derive(Debug)]
pub struct ShapeMismatch {
pub weight_name: String,
pub shape1: Vec<usize>,
pub shape2: Vec<usize>,
}
impl CheckpointComparison {
pub fn is_compatible(&self) -> bool {
self.only_in_first.is_empty()
&& self.only_in_second.is_empty()
&& self.shape_mismatches.is_empty()
}
pub fn summary(&self) -> String {
format!(
"Checkpoint Comparison:\n\
- Formats: {:?} vs {:?}\n\
- Common weights: {}\n\
- Only in first: {}\n\
- Only in second: {}\n\
- Shape mismatches: {}",
self.format1,
self.format2,
self.common_weights.len(),
self.only_in_first.len(),
self.only_in_second.len(),
self.shape_mismatches.len()
)
}
}
pub fn merge_checkpoints(
paths: &[&Path],
output_path: &Path,
format: CheckpointFormat,
) -> Result<()> {
if paths.is_empty() {
return Err(anyhow!("No checkpoints to merge"));
}
let mut merged = load_checkpoint(paths[0])?;
for path in &paths[1..] {
let checkpoint = load_checkpoint(path)?;
for name in checkpoint.weight_names() {
let weight = checkpoint.get_weight(&name)?;
merged.set_weight(&name, weight)?;
}
}
merged.save(output_path)
}
pub fn shard_checkpoint(
path: &Path,
output_dir: &Path,
max_shard_size_mb: usize,
) -> Result<Vec<String>> {
let checkpoint = load_checkpoint(path)?;
let weight_names = checkpoint.weight_names();
let mut shards = Vec::new();
let mut current_shard = TrustformersCheckpoint::new();
let mut current_size = 0usize;
let max_size = max_shard_size_mb * 1024 * 1024;
for name in weight_names {
let weight = checkpoint.get_weight(&name)?;
let weight_size = weight.data.len() * std::mem::size_of::<f32>();
if current_size + weight_size > max_size && current_size > 0 {
let shard_path = output_dir.join(format!("shard_{}.trust", shards.len()));
current_shard.save(&shard_path)?;
shards.push(shard_path.to_string_lossy().to_string());
current_shard = TrustformersCheckpoint::new();
current_size = 0;
}
current_shard.set_weight(&name, weight)?;
current_size += weight_size;
}
if current_size > 0 {
let shard_path = output_dir.join(format!("shard_{}.trust", shards.len()));
current_shard.save(&shard_path)?;
shards.push(shard_path.to_string_lossy().to_string());
}
Ok(shards)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_detection_by_extension() {
assert_eq!(
CheckpointFormat::from_path(Path::new("model.pt")),
Some(CheckpointFormat::PyTorch)
);
assert_eq!(
CheckpointFormat::from_path(Path::new("model.ckpt")),
Some(CheckpointFormat::TensorFlow)
);
}
#[test]
fn test_checkpoint_comparison() {
let comparison = CheckpointComparison {
format1: CheckpointFormat::PyTorch,
format2: CheckpointFormat::TensorFlow,
common_weights: vec!["weight1".to_string(), "weight2".to_string()],
only_in_first: vec!["extra1".to_string()],
only_in_second: vec!["extra2".to_string()],
shape_mismatches: vec![ShapeMismatch {
weight_name: "weight1".to_string(),
shape1: vec![512, 768],
shape2: vec![768, 512],
}],
};
assert!(!comparison.is_compatible());
let summary = comparison.summary();
assert!(summary.contains("Common weights: 2"));
assert!(summary.contains("Shape mismatches: 1"));
}
}