use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
pub type CheckpointMetadata = HashMap<String, String>;
#[derive(Debug)]
pub struct Checkpoint {
pub version: u32,
pub metadata: CheckpointMetadata,
pub tensors: Vec<CheckpointTensor>,
}
#[derive(Debug, Clone)]
pub struct CheckpointTensor {
pub name: String,
pub shape: Vec<u64>,
pub data: Vec<f32>,
}
impl CheckpointTensor {
pub fn new(name: impl Into<String>, data: Vec<f32>, shape: Vec<u64>) -> Self {
Self {
name: name.into(),
shape,
data,
}
}
pub fn element_count(&self) -> u64 {
if self.shape.is_empty() {
return 0;
}
self.shape.iter().product()
}
pub fn size_bytes(&self) -> usize {
self.element_count() as usize * 4
}
pub fn from_weight_tensor(wt: &crate::model_merge::WeightTensor) -> Self {
Self {
name: wt.name.clone(),
shape: wt.shape.iter().map(|&d| d as u64).collect(),
data: wt.data.clone(),
}
}
pub fn to_weight_tensor(&self) -> crate::model_merge::WeightTensor {
let shape: Vec<usize> = self
.shape
.iter()
.map(|&d| usize::try_from(d).unwrap_or(usize::MAX))
.collect();
crate::model_merge::WeightTensor::new(self.name.clone(), self.data.clone(), shape)
}
}
impl Checkpoint {
pub fn new() -> Self {
Self {
version: 1,
metadata: CheckpointMetadata::new(),
tensors: Vec::new(),
}
}
pub fn add_tensor(&mut self, tensor: CheckpointTensor) {
self.tensors.push(tensor);
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn get_metadata(&self, key: &str) -> Option<&str> {
self.metadata.get(key).map(|s| s.as_str())
}
pub fn get_tensor(&self, name: &str) -> Option<&CheckpointTensor> {
self.tensors.iter().find(|t| t.name == name)
}
pub fn total_bytes(&self) -> usize {
self.tensors.iter().map(|t| t.size_bytes()).sum()
}
pub fn num_params(&self) -> u64 {
self.tensors.iter().map(|t| t.element_count()).sum()
}
pub fn save(&self, path: &Path) -> Result<(), CheckpointError> {
let file = File::create(path)?;
let mut writer = BufWriter::new(file);
self.write_to(&mut writer)
}
pub fn load(path: &Path) -> Result<Self, CheckpointError> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
Self::read_from(&mut reader)
}
pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<(), CheckpointError> {
writer.write_all(b"OXCK")?;
write_u32_le(writer, 1u32)?; write_u64_le(writer, 0u64)?; write_u64_le(writer, self.tensors.len() as u64)?;
let meta_str = serialize_metadata(&self.metadata);
let meta_bytes = meta_str.as_bytes();
write_u32_le(writer, meta_bytes.len() as u32)?;
writer.write_all(meta_bytes)?;
for tensor in &self.tensors {
let name_bytes = tensor.name.as_bytes();
if name_bytes.len() > 65535 {
return Err(CheckpointError::NameTooLong(name_bytes.len()));
}
write_u32_le(writer, name_bytes.len() as u32)?;
writer.write_all(name_bytes)?;
write_u32_le(writer, tensor.shape.len() as u32)?;
for &dim in &tensor.shape {
write_u64_le(writer, dim)?;
}
write_u64_le(writer, tensor.data.len() as u64)?;
for &f in &tensor.data {
writer.write_all(&f.to_le_bytes())?;
}
}
Ok(())
}
pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, CheckpointError> {
let mut magic = [0u8; 4];
read_exact(reader, &mut magic)?;
if &magic != b"OXCK" {
return Err(CheckpointError::InvalidMagic(magic.to_vec()));
}
let version = read_u32_le(reader)?;
if version == 0 || version > 1 {
return Err(CheckpointError::UnsupportedVersion(version));
}
let _flags = read_u64_le(reader)?;
let num_tensors = read_u64_le(reader)? as usize;
let meta_len = read_u32_le(reader)? as usize;
let mut meta_bytes = vec![0u8; meta_len];
read_exact(reader, &mut meta_bytes)?;
let meta_str = std::str::from_utf8(&meta_bytes)
.map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
let metadata = deserialize_metadata(meta_str)?;
let mut tensors = Vec::with_capacity(num_tensors);
for _ in 0..num_tensors {
let name_len = read_u32_le(reader)? as usize;
let mut name_bytes = vec![0u8; name_len];
read_exact(reader, &mut name_bytes)?;
let name = String::from_utf8(name_bytes)
.map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
let ndim = read_u32_le(reader)? as usize;
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
shape.push(read_u64_le(reader)?);
}
let data_len = read_u64_le(reader)? as usize;
let mut data = Vec::with_capacity(data_len);
for _ in 0..data_len {
let mut buf = [0u8; 4];
read_exact(reader, &mut buf)?;
data.push(f32::from_le_bytes(buf));
}
tensors.push(CheckpointTensor { name, shape, data });
}
Ok(Self {
version,
metadata,
tensors,
})
}
}
impl Default for Checkpoint {
fn default() -> Self {
Self::new()
}
}
fn serialize_metadata(meta: &CheckpointMetadata) -> String {
let mut pairs: Vec<(&String, &String)> = meta.iter().collect();
pairs.sort_by_key(|(k, _)| k.as_str());
let mut out = String::from('{');
for (i, (k, v)) in pairs.iter().enumerate() {
if i > 0 {
out.push(',');
}
out.push('"');
push_escaped(&mut out, k);
out.push_str("\":\"");
push_escaped(&mut out, v);
out.push('"');
}
out.push('}');
out
}
fn push_escaped(out: &mut String, s: &str) {
for ch in s.chars() {
match ch {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
other => out.push(other),
}
}
}
fn deserialize_metadata(s: &str) -> Result<CheckpointMetadata, CheckpointError> {
let s = s.trim();
if s.is_empty() {
return Ok(CheckpointMetadata::new());
}
if s == "{}" {
return Ok(CheckpointMetadata::new());
}
let bytes = s.as_bytes();
if bytes.first() != Some(&b'{') || bytes.last() != Some(&b'}') {
return Err(CheckpointError::MetadataParse(format!(
"expected JSON object, got: {s}"
)));
}
let inner = &s[1..s.len() - 1];
let mut map = CheckpointMetadata::new();
if inner.trim().is_empty() {
return Ok(map);
}
let chars: Vec<char> = inner.chars().collect();
let mut pos = 0usize;
loop {
while pos < chars.len() && (chars[pos] == ',' || chars[pos].is_whitespace()) {
pos += 1;
}
if pos >= chars.len() {
break;
}
if chars[pos] != '"' {
return Err(CheckpointError::MetadataParse(format!(
"expected '\"' at position {pos}, got '{}'",
chars[pos]
)));
}
pos += 1;
let (key, new_pos) = parse_json_string(&chars, pos)?;
pos = new_pos;
skip_ws(&chars, &mut pos);
if pos >= chars.len() || chars[pos] != ':' {
return Err(CheckpointError::MetadataParse(format!(
"expected ':' after key '{key}'"
)));
}
pos += 1;
skip_ws(&chars, &mut pos);
if pos >= chars.len() || chars[pos] != '"' {
return Err(CheckpointError::MetadataParse(format!(
"expected '\"' for value of key '{key}'"
)));
}
pos += 1;
let (value, new_pos) = parse_json_string(&chars, pos)?;
pos = new_pos;
map.insert(key, value);
}
Ok(map)
}
fn parse_json_string(chars: &[char], mut pos: usize) -> Result<(String, usize), CheckpointError> {
let mut s = String::new();
while pos < chars.len() {
match chars[pos] {
'"' => {
pos += 1; return Ok((s, pos));
}
'\\' => {
pos += 1;
if pos >= chars.len() {
return Err(CheckpointError::MetadataParse(
"unexpected end after backslash".into(),
));
}
match chars[pos] {
'"' => s.push('"'),
'\\' => s.push('\\'),
'n' => s.push('\n'),
'r' => s.push('\r'),
't' => s.push('\t'),
other => {
return Err(CheckpointError::MetadataParse(format!(
"unknown escape '\\{other}'"
)))
}
}
pos += 1;
}
ch => {
s.push(ch);
pos += 1;
}
}
}
Err(CheckpointError::MetadataParse("unterminated string".into()))
}
fn skip_ws(chars: &[char], pos: &mut usize) {
while *pos < chars.len() && chars[*pos].is_whitespace() {
*pos += 1;
}
}
fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<(), CheckpointError> {
w.write_all(&v.to_le_bytes())?;
Ok(())
}
fn write_u64_le<W: Write>(w: &mut W, v: u64) -> Result<(), CheckpointError> {
w.write_all(&v.to_le_bytes())?;
Ok(())
}
fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), CheckpointError> {
let expected = buf.len();
let mut total_read = 0usize;
while total_read < expected {
match r.read(&mut buf[total_read..]) {
Ok(0) => {
return Err(CheckpointError::TruncatedData {
expected,
got: total_read,
})
}
Ok(n) => total_read += n,
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(e) => return Err(CheckpointError::Io(e)),
}
}
Ok(())
}
fn read_u32_le<R: Read>(r: &mut R) -> Result<u32, CheckpointError> {
let mut buf = [0u8; 4];
read_exact(r, &mut buf)?;
Ok(u32::from_le_bytes(buf))
}
fn read_u64_le<R: Read>(r: &mut R) -> Result<u64, CheckpointError> {
let mut buf = [0u8; 8];
read_exact(r, &mut buf)?;
Ok(u64::from_le_bytes(buf))
}
#[derive(Debug, thiserror::Error)]
pub enum CheckpointError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("invalid magic bytes: expected OXCK, got {0:?}")]
InvalidMagic(Vec<u8>),
#[error("unsupported checkpoint version: {0}")]
UnsupportedVersion(u32),
#[error("metadata parse error: {0}")]
MetadataParse(String),
#[error("truncated data: expected {expected} bytes, got {got}")]
TruncatedData { expected: usize, got: usize },
#[error("tensor name too long: {0} bytes (max 65535)")]
NameTooLong(usize),
}