use std::collections::HashMap;
use std::io::Write;
use crate::nn::checkpoint::{MAGIC, VERSION, HASH_LEN, write_tensor_data, io_err};
use crate::tensor::{Device, Result, Tensor};
use super::Graph;
pub struct ModelSnapshot {
pub params: HashMap<String, Tensor>,
pub buffers: HashMap<String, Tensor>,
pub metrics: HashMap<String, f64>,
pub epoch: usize,
}
impl ModelSnapshot {
pub fn save<W: Write>(&self, w: &mut W) -> Result<()> {
w.write_all(&MAGIC).map_err(io_err)?;
w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
w.write_all(&[0u8; HASH_LEN]).map_err(io_err)?;
let total = (self.params.len() + self.buffers.len()) as u32;
w.write_all(&total.to_le_bytes()).map_err(io_err)?;
for (name, t) in &self.params {
let name_bytes = name.as_bytes();
w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
w.write_all(name_bytes).map_err(io_err)?;
write_tensor_data(w, t)?;
}
for (name, t) in &self.buffers {
let name_bytes = name.as_bytes();
w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
w.write_all(name_bytes).map_err(io_err)?;
write_tensor_data(w, t)?;
}
Ok(())
}
pub fn save_file(&self, path: &str) -> Result<()> {
let f = std::fs::File::create(path).map_err(io_err)?;
if path.ends_with(".gz") {
let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
self.save(&mut w)?;
w.finish().map_err(io_err)?;
Ok(())
} else {
let mut w = std::io::BufWriter::new(f);
self.save(&mut w)
}
}
}
impl Graph {
pub fn snapshot_cpu(&self) -> Result<ModelSnapshot> {
let mut params = HashMap::new();
for (name, p) in self.named_parameters() {
let t = p.variable.data().to_device(Device::CPU)?.detach()?;
params.insert(name, t);
}
let mut buffers = HashMap::new();
for (name, b) in self.named_buffers() {
let t = b.get().to_device(Device::CPU)?;
buffers.insert(name, t);
}
let metrics: HashMap<String, f64> = self.latest_metrics().into_iter().collect();
let epoch = self.flush_count();
Ok(ModelSnapshot {
params,
buffers,
metrics,
epoch,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::autograd::Variable;
use crate::nn::{Linear, Module};
use crate::graph::FlowBuilder;
use crate::tensor::{test_device, Tensor, TensorOptions};
use crate::worker::CpuWorker;
fn build_test_graph() -> Result<Graph> {
let g = FlowBuilder::from(Linear::on_device(2, 3, test_device())?)
.tag("encoder")
.build()?;
Ok(g)
}
#[test]
fn snapshot_captures_params_on_cpu() {
let g = build_test_graph().unwrap();
let snap = g.snapshot_cpu().unwrap();
assert!(!snap.params.is_empty(), "should have params");
for (name, t) in &snap.params {
assert_eq!(t.device(), Device::CPU, "param {} should be on CPU", name);
}
}
#[test]
fn snapshot_captures_correct_names() {
let g = build_test_graph().unwrap();
let snap = g.snapshot_cpu().unwrap();
let names: Vec<&String> = snap.params.keys().collect();
assert!(names.iter().any(|n| n.contains("encoder")),
"param names should include tag prefix, got: {:?}", names);
}
#[test]
fn snapshot_captures_metrics_and_epoch() {
let g = build_test_graph().unwrap();
g.record_scalar("loss", 0.5);
g.flush(&[]);
g.record_scalar("loss", 0.3);
g.flush(&[]);
let snap = g.snapshot_cpu().unwrap();
assert_eq!(snap.epoch, 2);
assert!(snap.metrics.contains_key("loss"));
assert!((snap.metrics["loss"] - 0.3).abs() < 1e-6);
}
#[test]
fn snapshot_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ModelSnapshot>();
}
#[test]
fn snapshot_save_roundtrip() {
let g = build_test_graph().unwrap();
let snap = g.snapshot_cpu().unwrap();
let mut buf = Vec::new();
snap.save(&mut buf).unwrap();
let load_params: Vec<(String, crate::nn::Parameter)> = g.named_parameters();
let load_buffers: Vec<(String, crate::nn::Buffer)> = g.named_buffers();
let mut cursor = std::io::Cursor::new(&buf);
let report = crate::nn::load_checkpoint(
&mut cursor, &load_params, &load_buffers, None,
).unwrap();
assert_eq!(report.loaded.len(), snap.params.len() + snap.buffers.len());
assert!(report.missing.is_empty());
assert!(report.skipped.is_empty());
}
#[test]
fn snapshot_save_file_gz() {
let g = build_test_graph().unwrap();
let snap = g.snapshot_cpu().unwrap();
let dir = std::env::temp_dir();
let path = dir.join("test_snapshot.fdl.gz");
let path_str = path.to_str().unwrap();
snap.save_file(path_str).unwrap();
let meta = std::fs::metadata(path_str).unwrap();
assert!(meta.len() > 0);
let load_params: Vec<(String, crate::nn::Parameter)> = g.named_parameters();
let report = crate::nn::load_checkpoint_file(
path_str, &load_params, &[], None,
).unwrap();
assert_eq!(report.loaded.len(), snap.params.len());
std::fs::remove_file(path_str).ok();
}
#[test]
fn snapshot_in_cpu_worker() {
let g = build_test_graph().unwrap();
let x = Variable::new(
Tensor::randn(&[1, 2], TensorOptions {
dtype: crate::tensor::DType::Float32,
device: test_device(),
}).unwrap(),
false,
);
let _ = g.forward(&x).unwrap();
let snap = g.snapshot_cpu().unwrap();
let done = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
let done2 = done.clone();
let mut worker = CpuWorker::new();
worker.submit(move || {
assert!(!snap.params.is_empty());
for t in snap.params.values() {
assert_eq!(t.device(), Device::CPU);
}
done2.store(true, std::sync::atomic::Ordering::Release);
});
worker.finish();
assert!(done.load(std::sync::atomic::Ordering::Acquire));
}
}