use crate::transformer::{QCTConfig, QCT};
use std::io::{Read, Write};
const MAGIC: &[u8; 4] = b"QTSR";
const VERSION: u32 = 1;
pub fn model_digest(model: &QCT) -> [u8; 32] {
let params = model.all_params();
let mut hasher = blake3::Hasher::new();
for &p in ¶ms {
hasher.update(&p.to_le_bytes());
}
*hasher.finalize().as_bytes()
}
pub fn save_checkpoint(model: &QCT, path: &std::path::Path) -> std::io::Result<()> {
let mut f = std::fs::File::create(path)?;
let params = model.all_params();
let digest = model_digest(model);
f.write_all(MAGIC)?;
f.write_all(&VERSION.to_le_bytes())?;
f.write_all(&(model.config.dim as u32).to_le_bytes())?;
f.write_all(&(model.config.num_blocks as u32).to_le_bytes())?;
f.write_all(&(model.config.vocab_size as u32).to_le_bytes())?;
f.write_all(&(params.len() as u32).to_le_bytes())?;
for &p in ¶ms {
f.write_all(&p.to_le_bytes())?;
}
f.write_all(&digest)?;
Ok(())
}
pub fn load_checkpoint(path: &std::path::Path) -> std::io::Result<(QCTConfig, Vec<f32>)> {
let mut f = std::fs::File::open(path)?;
let mut magic = [0u8; 4];
f.read_exact(&mut magic)?;
if &magic != MAGIC {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"not a .quantumtensor checkpoint (expected magic QTSR)",
));
}
let mut buf4 = [0u8; 4];
f.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
if version != VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unsupported version {version}"),
));
}
f.read_exact(&mut buf4)?;
let dim = u32::from_le_bytes(buf4) as usize;
f.read_exact(&mut buf4)?;
let blocks = u32::from_le_bytes(buf4) as usize;
f.read_exact(&mut buf4)?;
let vocab = u32::from_le_bytes(buf4) as usize;
f.read_exact(&mut buf4)?;
let param_count = u32::from_le_bytes(buf4) as usize;
let mut params = Vec::with_capacity(param_count);
for _ in 0..param_count {
f.read_exact(&mut buf4)?;
params.push(f32::from_le_bytes(buf4));
}
let mut stored_digest = [0u8; 32];
f.read_exact(&mut stored_digest)?;
let mut hasher = blake3::Hasher::new();
for &p in ¶ms {
hasher.update(&p.to_le_bytes());
}
let computed = *hasher.finalize().as_bytes();
if stored_digest != computed {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"BLAKE3 digest mismatch — checkpoint corrupted or tampered",
));
}
let config = QCTConfig {
vocab_size: vocab,
dim,
num_blocks: blocks,
seed: 0,
};
Ok((config, params))
}
pub fn save_checkpoint_to_vec(model: &QCT) -> Vec<u8> {
let params = model.all_params();
let digest = model_digest(model);
let param_bytes = params.len() * 4;
let total = 4 + 4 + 4 + 4 + 4 + 4 + param_bytes + 32;
let mut buf = Vec::with_capacity(total);
buf.extend_from_slice(MAGIC);
buf.extend_from_slice(&VERSION.to_le_bytes());
buf.extend_from_slice(&(model.config.dim as u32).to_le_bytes());
buf.extend_from_slice(&(model.config.num_blocks as u32).to_le_bytes());
buf.extend_from_slice(&(model.config.vocab_size as u32).to_le_bytes());
buf.extend_from_slice(&(params.len() as u32).to_le_bytes());
for &p in ¶ms {
buf.extend_from_slice(&p.to_le_bytes());
}
buf.extend_from_slice(&digest);
buf
}
pub fn load_checkpoint_from_bytes(data: &[u8]) -> std::io::Result<(QCTConfig, Vec<f32>)> {
let mut cursor = std::io::Cursor::new(data);
let mut magic = [0u8; 4];
cursor.read_exact(&mut magic)?;
if &magic != MAGIC {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"not a .quantumtensor checkpoint (expected magic QTSR)",
));
}
let mut buf4 = [0u8; 4];
cursor.read_exact(&mut buf4)?;
let version = u32::from_le_bytes(buf4);
if version != VERSION {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("unsupported version {version}"),
));
}
cursor.read_exact(&mut buf4)?;
let dim = u32::from_le_bytes(buf4) as usize;
cursor.read_exact(&mut buf4)?;
let blocks = u32::from_le_bytes(buf4) as usize;
cursor.read_exact(&mut buf4)?;
let vocab = u32::from_le_bytes(buf4) as usize;
cursor.read_exact(&mut buf4)?;
let param_count = u32::from_le_bytes(buf4) as usize;
let mut params = Vec::with_capacity(param_count);
for _ in 0..param_count {
cursor.read_exact(&mut buf4)?;
params.push(f32::from_le_bytes(buf4));
}
let mut stored_digest = [0u8; 32];
cursor.read_exact(&mut stored_digest)?;
let mut hasher = blake3::Hasher::new();
for &p in ¶ms {
hasher.update(&p.to_le_bytes());
}
let computed = *hasher.finalize().as_bytes();
if stored_digest != computed {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"BLAKE3 digest mismatch — checkpoint corrupted or tampered",
));
}
let config = QCTConfig {
vocab_size: vocab,
dim,
num_blocks: blocks,
seed: 0,
};
Ok((config, params))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn checkpoint_roundtrip() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let model = QCT::new(config.clone());
let original_params = model.all_params();
let original_digest = model_digest(&model);
let dir = std::env::temp_dir();
let path = dir.join("qsr_test_checkpoint.quantumtensor");
save_checkpoint(&model, &path).expect("save failed");
let (loaded_config, loaded_params) = load_checkpoint(&path).expect("load failed");
assert_eq!(loaded_config.dim, config.dim);
assert_eq!(loaded_config.num_blocks, config.num_blocks);
assert_eq!(loaded_config.vocab_size, config.vocab_size);
assert_eq!(loaded_params.len(), original_params.len());
for (a, b) in original_params.iter().zip(loaded_params.iter()) {
assert!((a - b).abs() < 1e-7, "param mismatch");
}
let mut model2 = QCT::new(loaded_config);
model2.set_all_params(&loaded_params);
assert_eq!(model_digest(&model2), original_digest);
let _ = std::fs::remove_file(&path);
}
#[test]
fn digest_deterministic() {
let model = QCT::new(QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
});
assert_eq!(model_digest(&model), model_digest(&model));
}
#[test]
fn bytes_roundtrip() {
let config = QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
};
let model = QCT::new(config.clone());
let original_params = model.all_params();
let bytes = save_checkpoint_to_vec(&model);
assert_eq!(&bytes[..4], b"QTSR", "should start with QTSR magic");
let (loaded_config, loaded_params) = load_checkpoint_from_bytes(&bytes).expect("bytes roundtrip failed");
assert_eq!(loaded_config.dim, config.dim);
assert_eq!(loaded_config.num_blocks, config.num_blocks);
assert_eq!(loaded_config.vocab_size, config.vocab_size);
assert_eq!(loaded_params.len(), original_params.len());
for (a, b) in original_params.iter().zip(loaded_params.iter()) {
assert!((a - b).abs() < 1e-7, "param mismatch in bytes roundtrip");
}
}
#[test]
fn bytes_corrupt_detected() {
let model = QCT::new(QCTConfig {
vocab_size: 10,
dim: 4,
num_blocks: 1,
seed: 42,
});
let mut bytes = save_checkpoint_to_vec(&model);
if bytes.len() > 30 {
bytes[30] ^= 0xFF;
}
assert!(
load_checkpoint_from_bytes(&bytes).is_err(),
"corrupted bytes should fail"
);
}
}