use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use crate::model::parameter::Parameter;
use crate::{Error, Result};
const MAGIC: [u8; 4] = *b"TKP1";
const VERSION: u32 = 1;
#[derive(Debug, Clone)]
pub struct ParameterSnapshot {
pub name: String,
pub weight: Vec<f32>,
pub m: Vec<f32>,
pub v: Vec<f32>,
pub step: u32,
pub grad: Option<Vec<f32>>,
}
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub step: u32,
pub params: Vec<ParameterSnapshot>,
pub config: String,
}
pub fn snapshot_parameter(param: &Parameter, name: &str) -> ParameterSnapshot {
ParameterSnapshot {
name: name.to_string(),
weight: param.data.data.clone(),
m: param.m.data.clone(),
v: param.v.data.clone(),
step: param.step,
grad: None,
}
}
pub fn restore_parameter(param: &mut Parameter, snap: &ParameterSnapshot) {
param.data.data = snap.weight.clone();
param.m.data = snap.m.clone();
param.v.data = snap.v.clone();
param.step = snap.step;
}
pub fn list_param_names(ckpt: &Checkpoint) -> Vec<&str> {
ckpt.params.iter().map(|p| p.name.as_str()).collect()
}
pub fn save_checkpoint(path: &Path, ckpt: &Checkpoint) -> Result<()> {
let f = File::create(path).map_err(io_err("create"))?;
let mut w = BufWriter::new(f);
w.write_all(&MAGIC).map_err(io_err("write magic"))?;
w.write_all(&VERSION.to_le_bytes())
.map_err(io_err("write version"))?;
w.write_all(&ckpt.step.to_le_bytes())
.map_err(io_err("write step"))?;
w.write_all(&(ckpt.params.len() as u32).to_le_bytes())
.map_err(io_err("write count"))?;
let cfg = ckpt.config.as_bytes();
w.write_all(&(cfg.len() as u32).to_le_bytes())
.map_err(io_err("write config_len"))?;
w.write_all(cfg).map_err(io_err("write config"))?;
for p in &ckpt.params {
let nb = p.name.as_bytes();
w.write_all(&(nb.len() as u32).to_le_bytes())
.map_err(io_err("write name_len"))?;
w.write_all(nb).map_err(io_err("write name"))?;
write_f32_vec(&mut w, &p.weight)?;
write_f32_vec(&mut w, &p.m)?;
write_f32_vec(&mut w, &p.v)?;
w.write_all(&p.step.to_le_bytes())
.map_err(io_err("write param step"))?;
match &p.grad {
None => w.write_all(&[0u8]).map_err(io_err("write has_grad"))?,
Some(g) => {
w.write_all(&[1u8]).map_err(io_err("write has_grad"))?;
write_f32_vec(&mut w, g)?;
}
}
}
w.flush().map_err(io_err("flush"))?;
Ok(())
}
pub fn load_checkpoint(path: &Path) -> Result<Checkpoint> {
let f = File::open(path).map_err(io_err("open"))?;
let mut r = BufReader::new(f);
let mut hdr = [0u8; 16];
r.read_exact(&mut hdr).map_err(io_err("read header"))?;
if &hdr[0..4] != &MAGIC {
return Err(Error::backend(format!(
"checkpoint: bad magic {:?}",
String::from_utf8_lossy(&hdr[0..4])
)));
}
let version = u32::from_le_bytes([hdr[4], hdr[5], hdr[6], hdr[7]]);
if version != VERSION {
return Err(Error::backend(format!(
"checkpoint: unsupported version {version}"
)));
}
let step = u32::from_le_bytes([hdr[8], hdr[9], hdr[10], hdr[11]]);
let count = u32::from_le_bytes([hdr[12], hdr[13], hdr[14], hdr[15]]);
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf)
.map_err(io_err("read config_len"))?;
let cfg_len = u32::from_le_bytes(len_buf) as usize;
let mut cfg_bytes = vec![0u8; cfg_len];
r.read_exact(&mut cfg_bytes)
.map_err(io_err("read config"))?;
let config = String::from_utf8(cfg_bytes)
.map_err(|e| Error::backend(format!("checkpoint: config utf8: {e}")))?;
let mut params = Vec::with_capacity(count as usize);
for _ in 0..count {
r.read_exact(&mut len_buf)
.map_err(io_err("read name_len"))?;
let name_len = u32::from_le_bytes(len_buf) as usize;
let mut name_bytes = vec![0u8; name_len];
r.read_exact(&mut name_bytes).map_err(io_err("read name"))?;
let name = String::from_utf8(name_bytes)
.map_err(|e| Error::backend(format!("checkpoint: param name utf8: {e}")))?;
let weight = read_f32_vec(&mut r)?;
let m = read_f32_vec(&mut r)?;
let v = read_f32_vec(&mut r)?;
let mut step_buf = [0u8; 4];
r.read_exact(&mut step_buf)
.map_err(io_err("read param step"))?;
let pstep = u32::from_le_bytes(step_buf);
let mut has_grad = [0u8; 1];
r.read_exact(&mut has_grad)
.map_err(io_err("read has_grad"))?;
let grad = if has_grad[0] == 0 {
None
} else {
Some(read_f32_vec(&mut r)?)
};
params.push(ParameterSnapshot {
name,
weight,
m,
v,
step: pstep,
grad,
});
}
Ok(Checkpoint {
step,
params,
config,
})
}
pub fn save_json_summary(path: &Path, ckpt: &Checkpoint) -> Result<()> {
let mut s = String::new();
s.push_str("{\n");
s.push_str(&format!(" \"step\": {},\n", ckpt.step));
s.push_str(&format!(" \"param_count\": {},\n", ckpt.params.len()));
s.push_str(&format!(" \"config\": {},\n", json_str(&ckpt.config)));
let mut total_bytes: usize = ckpt.config.len();
s.push_str(" \"params\": [\n");
for (i, p) in ckpt.params.iter().enumerate() {
let w = p.weight.len() * 4;
let m = p.m.len() * 4;
let v = p.v.len() * 4;
let g = p.grad.as_ref().map_or(0, |g| g.len() * 4);
total_bytes += w + m + v + g + p.name.len();
s.push_str(" {");
s.push_str(&format!("\"name\": {}, ", json_str(&p.name)));
s.push_str(&format!("\"step\": {}, ", p.step));
s.push_str(&format!("\"weight_numel\": {}, ", p.weight.len()));
s.push_str(&format!("\"weight_bytes\": {}, ", w));
s.push_str(&format!("\"m_numel\": {}, ", p.m.len()));
s.push_str(&format!("\"m_bytes\": {}, ", m));
s.push_str(&format!("\"v_numel\": {}, ", p.v.len()));
s.push_str(&format!("\"v_bytes\": {}, ", v));
s.push_str(&format!("\"has_grad\": {}, ", p.grad.is_some()));
s.push_str(&format!("\"grad_bytes\": {}", g));
s.push('}');
if i + 1 < ckpt.params.len() {
s.push(',');
}
s.push('\n');
}
s.push_str(" ],\n");
s.push_str(&format!(" \"total_bytes\": {}\n", total_bytes));
s.push_str("}\n");
let mut f = File::create(path).map_err(io_err("create json"))?;
f.write_all(s.as_bytes()).map_err(io_err("write json"))?;
Ok(())
}
fn write_f32_vec<W: Write>(w: &mut W, v: &[f32]) -> Result<()> {
w.write_all(&(v.len() as u32).to_le_bytes())
.map_err(io_err("write vec_len"))?;
for &f in v {
w.write_all(&f.to_le_bytes()).map_err(io_err("write f32"))?;
}
Ok(())
}
fn read_f32_vec<R: Read>(r: &mut R) -> Result<Vec<f32>> {
let mut len_buf = [0u8; 4];
r.read_exact(&mut len_buf).map_err(io_err("read vec_len"))?;
let len = u32::from_le_bytes(len_buf) as usize;
let mut f_buf = [0u8; 4];
let mut out = Vec::with_capacity(len);
for _ in 0..len {
r.read_exact(&mut f_buf).map_err(io_err("read f32"))?;
out.push(f32::from_le_bytes(f_buf));
}
Ok(out)
}
fn json_str(s: &str) -> String {
let mut o = String::with_capacity(s.len() + 2);
o.push('"');
for c in s.chars() {
match c {
'"' => o.push_str("\\\""),
'\\' => o.push_str("\\\\"),
'\n' => o.push_str("\\n"),
'\r' => o.push_str("\\r"),
'\t' => o.push_str("\\t"),
c if (c as u32) < 0x20 => o.push_str(&format!("\\u{:04x}", c as u32)),
c => o.push(c),
}
}
o.push('"');
o
}
fn io_err(op: &'static str) -> impl FnOnce(std::io::Error) -> Error {
move |e| Error::backend(format!("checkpoint {op}: {e}"))
}