use cjc_runtime::Value;
use crate::{snap, snap_v2, restore_v2};
pub const MAGIC: [u8; 4] = [0x43, 0x4A, 0x43, 0x53];
pub const VERSION_V1: u32 = 1;
pub const VERSION_V2: u32 = 2;
pub const VERSION: u32 = VERSION_V1;
const HEADER_SIZE: usize = 48;
pub fn snap_save(value: &Value, path: &str) -> Result<(), String> {
let blob = snap(value);
write_snap_file(&blob.content_hash, &blob.data, VERSION_V1, path)
}
pub fn snap_save_v2(value: &Value, path: &str) -> Result<(), String> {
let blob = snap_v2(value);
write_snap_file(&blob.content_hash, &blob.data, VERSION_V2, path)
}
fn write_snap_file(
content_hash: &[u8; 32],
data: &[u8],
version: u32,
path: &str,
) -> Result<(), String> {
let data_len = data.len() as u64;
let mut file_bytes = Vec::with_capacity(HEADER_SIZE + data.len());
file_bytes.extend_from_slice(&MAGIC);
file_bytes.extend_from_slice(&version.to_le_bytes());
file_bytes.extend_from_slice(content_hash);
file_bytes.extend_from_slice(&data_len.to_le_bytes());
file_bytes.extend_from_slice(data);
std::fs::write(path, &file_bytes)
.map_err(|e| format!("snap_save: {}", e))
}
pub fn snap_load(path: &str) -> Result<Value, String> {
let file_bytes = std::fs::read(path)
.map_err(|e| format!("snap_load: {}", e))?;
if file_bytes.len() < HEADER_SIZE {
return Err(format!(
"snap_load: file too small ({} bytes, need at least {})",
file_bytes.len(),
HEADER_SIZE
));
}
if file_bytes[0..4] != MAGIC {
return Err(format!(
"snap_load: invalid magic bytes {:02x}{:02x}{:02x}{:02x} (expected CJCS)",
file_bytes[0], file_bytes[1], file_bytes[2], file_bytes[3]
));
}
let version = u32::from_le_bytes(file_bytes[4..8].try_into().unwrap());
if version != VERSION_V1 && version != VERSION_V2 {
return Err(format!(
"snap_load: unsupported version {} (expected {} or {})",
version, VERSION_V1, VERSION_V2
));
}
let mut content_hash = [0u8; 32];
content_hash.copy_from_slice(&file_bytes[8..40]);
let data_len = u64::from_le_bytes(file_bytes[40..48].try_into().unwrap()) as usize;
if file_bytes.len() < HEADER_SIZE + data_len {
return Err(format!(
"snap_load: truncated file (header says {} data bytes, file has {})",
data_len,
file_bytes.len() - HEADER_SIZE
));
}
let data = file_bytes[HEADER_SIZE..HEADER_SIZE + data_len].to_vec();
let blob = crate::SnapBlob { content_hash, data };
restore_v2(&blob).map_err(|e| format!("snap_load: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
use cjc_runtime::{Tensor, SparseCsr};
use std::rc::Rc;
fn test_file(name: &str) -> String {
format!("__test_persist_{}.snap", name)
}
fn cleanup(path: &str) {
let _ = std::fs::remove_file(path);
}
#[test]
fn test_save_load_int() {
let path = test_file("int");
snap_save(&Value::Int(42), &path).unwrap();
let loaded = snap_load(&path).unwrap();
assert!(matches!(loaded, Value::Int(42)));
cleanup(&path);
}
#[test]
fn test_save_load_string() {
let path = test_file("string");
snap_save(&Value::String(Rc::new("hello CJC".into())), &path).unwrap();
let loaded = snap_load(&path).unwrap();
match loaded {
Value::String(s) => assert_eq!(s.as_str(), "hello CJC"),
_ => panic!("expected String"),
}
cleanup(&path);
}
#[test]
fn test_save_load_tensor() {
let path = test_file("tensor");
let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
snap_save(&Value::Tensor(t), &path).unwrap();
let loaded = snap_load(&path).unwrap();
match loaded {
Value::Tensor(t) => {
assert_eq!(t.shape(), &[2, 3]);
assert_eq!(t.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
_ => panic!("expected Tensor"),
}
cleanup(&path);
}
#[test]
fn test_save_load_v2_int() {
let path = test_file("v2_int");
snap_save_v2(&Value::Int(99), &path).unwrap();
let loaded = snap_load(&path).unwrap();
assert!(matches!(loaded, Value::Int(99)));
cleanup(&path);
}
#[test]
fn test_save_load_v2_tensor() {
let path = test_file("v2_tensor");
let t = Tensor::from_vec(vec![1.0, 2.0, 3.0], &[3]).unwrap();
snap_save_v2(&Value::Tensor(t), &path).unwrap();
let loaded = snap_load(&path).unwrap();
match loaded {
Value::Tensor(t) => {
assert_eq!(t.shape(), &[3]);
assert_eq!(t.to_vec(), vec![1.0, 2.0, 3.0]);
}
_ => panic!("expected Tensor"),
}
cleanup(&path);
}
#[test]
fn test_save_load_v2_sparse() {
let path = test_file("v2_sparse");
let sparse = SparseCsr {
nrows: 2,
ncols: 3,
row_offsets: vec![0, 1, 3],
col_indices: vec![0, 1, 2],
values: vec![1.0, 2.0, 3.0],
};
snap_save_v2(&Value::SparseTensor(sparse), &path).unwrap();
let loaded = snap_load(&path).unwrap();
match loaded {
Value::SparseTensor(s) => {
assert_eq!(s.nrows, 2);
assert_eq!(s.ncols, 3);
assert_eq!(s.values, vec![1.0, 2.0, 3.0]);
}
_ => panic!("expected SparseTensor"),
}
cleanup(&path);
}
#[test]
fn test_bad_magic() {
let path = test_file("bad_magic");
let mut bytes = vec![0u8; 48];
bytes[0..4].copy_from_slice(b"XXXX");
std::fs::write(&path, &bytes).unwrap();
let result = snap_load(&path);
assert!(result.is_err());
assert!(result.unwrap_err().contains("invalid magic"));
cleanup(&path);
}
#[test]
fn test_truncated_file() {
let path = test_file("truncated");
std::fs::write(&path, b"CJC").unwrap();
let result = snap_load(&path);
assert!(result.is_err());
assert!(result.unwrap_err().contains("too small"));
cleanup(&path);
}
#[test]
fn test_bad_version() {
let path = test_file("bad_version");
let mut bytes = Vec::new();
bytes.extend_from_slice(&MAGIC);
bytes.extend_from_slice(&99u32.to_le_bytes());
bytes.extend_from_slice(&[0u8; 40]);
std::fs::write(&path, &bytes).unwrap();
let result = snap_load(&path);
assert!(result.is_err());
assert!(result.unwrap_err().contains("unsupported version"));
cleanup(&path);
}
#[test]
fn test_missing_file() {
let result = snap_load("__nonexistent_file_12345.snap");
assert!(result.is_err());
}
#[test]
fn test_roundtrip_array() {
let path = test_file("array");
let val = Value::Array(Rc::new(vec![
Value::Int(1),
Value::Float(2.5),
Value::Bool(true),
]));
snap_save(&val, &path).unwrap();
let loaded = snap_load(&path).unwrap();
assert!(matches!(loaded, Value::Array(_)));
cleanup(&path);
}
}