tokitai-operator 0.1.0

Verified DL kernel compiler: formally-checked GEMM, p-adic, sheaf, contract-carrying ops. Paper-artifact grade.
Documentation
//! Checkpoint save/load (gated on `rocm-hip`).
//!
//! Phase 2.7 of the 0.7B MoE training project. The binary
//! format is hand-rolled, little-endian, and intentionally
//! trivial so it can be parsed in a few dozen lines and
//! round-tripped in unit tests without dragging in a
//! serialization dependency.
//!
//! Binary layout (v1):
//!
//!   Header (16 bytes):
//!     [0..4]    magic: 4 bytes "TKP1"
//!     [4..8]    version: 4 bytes u32 LE
//!     [8..16]   reserved: 8 bytes
//!   Body:
//!     N tensors, each:
//!       [0..8]    data_len: 8 bytes u64 LE
//!       [data_len] raw tensor bytes
//!       [0..8]    shape_len: 8 bytes u64 LE
//!       [shape_len] raw shape bytes
//!
// Checkpoint save/load (Phase 2.7).
//
// The binary format is hand-rolled, little-endian, and intentionally
// trivial so it can be parsed in a few dozen lines and round-tripped
// in unit tests without dragging in a serialization dependency. The
// header carries a 4-byte magic + version for forward compatibility.
//
// Binary layout (v1):
//   Header (16 bytes):
//     [0..4]    magic: 4 bytes "TKP1"
//     [4..8]    version: u32 LE
//     [8..12]   step: u32 LE  (checkpoint-level step)
//     [12..16]  count: u32 LE  (number of parameters)
//   Config blob:
//     [16..20]  config_len: u32 LE
//     [20..]    config: UTF-8 bytes
//   Per-parameter (repeated `count` times):
//     name_len: u32 LE
//     name: UTF-8 bytes
//     weight_len: u32 LE
//     weight: fp32 LE (weight_len elements)
//     m_len: u32 LE
//     m: fp32 LE (m_len elements)
//     v_len: u32 LE
//     v: fp32 LE (v_len elements)
//     step: u32 LE (parameter-local step)
//     has_grad: u8 (0 or 1)
//     [if has_grad] grad_len: u32 LE
//     [if has_grad] grad: fp32 LE (grad_len elements)

use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;

use crate::model::parameter::Parameter;
use crate::{Error, Result};

const MAGIC: [u8; 4] = *b"TKP1";
const VERSION: u32 = 1;

/// Serialized form of a `Parameter`: name + fp32 weight/m/v vectors,
/// per-parameter step counter, and an optional gradient accumulator.
#[derive(Debug, Clone)]
pub struct ParameterSnapshot {
    pub name: String,
    pub weight: Vec<f32>,
    pub m: Vec<f32>,
    pub v: Vec<f32>,
    pub step: u32,
    pub grad: Option<Vec<f32>>,
}

/// A serializable checkpoint: a checkpoint-level step counter, an
/// arbitrary config string (e.g. JSON-encoded hyperparameters), and
/// an ordered list of parameter snapshots.
#[derive(Debug, Clone)]
pub struct Checkpoint {
    pub step: u32,
    pub params: Vec<ParameterSnapshot>,
    pub config: String,
}

/// Take a snapshot of a `Parameter`. The `name` is supplied by the
/// caller because the parameter itself is unaware of its position in
/// the model graph.
pub fn snapshot_parameter(param: &Parameter, name: &str) -> ParameterSnapshot {
    ParameterSnapshot {
        name: name.to_string(),
        weight: param.data.data.clone(),
        m: param.m.data.clone(),
        v: param.v.data.clone(),
        step: param.step,
        grad: None,
    }
}

/// Restore a `Parameter` from a `ParameterSnapshot`. The destination's
/// data vectors are replaced in their entirety so the destination ends
/// up with the exact length and contents of the snapshot.
pub fn restore_parameter(param: &mut Parameter, snap: &ParameterSnapshot) {
    param.data.data = snap.weight.clone();
    param.m.data = snap.m.clone();
    param.v.data = snap.v.clone();
    param.step = snap.step;
}

/// Return the parameter names in declaration order.
pub fn list_param_names(ckpt: &Checkpoint) -> Vec<&str> {
    ckpt.params.iter().map(|p| p.name.as_str()).collect()
}

/// Write a binary checkpoint to `path`. See the module-level comment
/// for the on-disk layout.
pub fn save_checkpoint(path: &Path, ckpt: &Checkpoint) -> Result<()> {
    let f = File::create(path).map_err(io_err("create"))?;
    let mut w = BufWriter::new(f);
    w.write_all(&MAGIC).map_err(io_err("write magic"))?;
    w.write_all(&VERSION.to_le_bytes())
        .map_err(io_err("write version"))?;
    w.write_all(&ckpt.step.to_le_bytes())
        .map_err(io_err("write step"))?;
    w.write_all(&(ckpt.params.len() as u32).to_le_bytes())
        .map_err(io_err("write count"))?;
    let cfg = ckpt.config.as_bytes();
    w.write_all(&(cfg.len() as u32).to_le_bytes())
        .map_err(io_err("write config_len"))?;
    w.write_all(cfg).map_err(io_err("write config"))?;
    for p in &ckpt.params {
        let nb = p.name.as_bytes();
        w.write_all(&(nb.len() as u32).to_le_bytes())
            .map_err(io_err("write name_len"))?;
        w.write_all(nb).map_err(io_err("write name"))?;
        write_f32_vec(&mut w, &p.weight)?;
        write_f32_vec(&mut w, &p.m)?;
        write_f32_vec(&mut w, &p.v)?;
        w.write_all(&p.step.to_le_bytes())
            .map_err(io_err("write param step"))?;
        match &p.grad {
            None => w.write_all(&[0u8]).map_err(io_err("write has_grad"))?,
            Some(g) => {
                w.write_all(&[1u8]).map_err(io_err("write has_grad"))?;
                write_f32_vec(&mut w, g)?;
            }
        }
    }
    w.flush().map_err(io_err("flush"))?;
    Ok(())
}

/// Read a binary checkpoint from `path`. The caller is responsible for
/// matching the loaded parameter names to live `Parameter`s and calling
/// [`restore_parameter`] for each one.
pub fn load_checkpoint(path: &Path) -> Result<Checkpoint> {
    let f = File::open(path).map_err(io_err("open"))?;
    let mut r = BufReader::new(f);
    let mut hdr = [0u8; 16];
    r.read_exact(&mut hdr).map_err(io_err("read header"))?;
    if &hdr[0..4] != &MAGIC {
        return Err(Error::backend(format!(
            "checkpoint: bad magic {:?}",
            String::from_utf8_lossy(&hdr[0..4])
        )));
    }
    let version = u32::from_le_bytes([hdr[4], hdr[5], hdr[6], hdr[7]]);
    if version != VERSION {
        return Err(Error::backend(format!(
            "checkpoint: unsupported version {version}"
        )));
    }
    let step = u32::from_le_bytes([hdr[8], hdr[9], hdr[10], hdr[11]]);
    let count = u32::from_le_bytes([hdr[12], hdr[13], hdr[14], hdr[15]]);
    let mut len_buf = [0u8; 4];
    r.read_exact(&mut len_buf)
        .map_err(io_err("read config_len"))?;
    let cfg_len = u32::from_le_bytes(len_buf) as usize;
    let mut cfg_bytes = vec![0u8; cfg_len];
    r.read_exact(&mut cfg_bytes)
        .map_err(io_err("read config"))?;
    let config = String::from_utf8(cfg_bytes)
        .map_err(|e| Error::backend(format!("checkpoint: config utf8: {e}")))?;
    let mut params = Vec::with_capacity(count as usize);
    for _ in 0..count {
        r.read_exact(&mut len_buf)
            .map_err(io_err("read name_len"))?;
        let name_len = u32::from_le_bytes(len_buf) as usize;
        let mut name_bytes = vec![0u8; name_len];
        r.read_exact(&mut name_bytes).map_err(io_err("read name"))?;
        let name = String::from_utf8(name_bytes)
            .map_err(|e| Error::backend(format!("checkpoint: param name utf8: {e}")))?;
        let weight = read_f32_vec(&mut r)?;
        let m = read_f32_vec(&mut r)?;
        let v = read_f32_vec(&mut r)?;
        let mut step_buf = [0u8; 4];
        r.read_exact(&mut step_buf)
            .map_err(io_err("read param step"))?;
        let pstep = u32::from_le_bytes(step_buf);
        let mut has_grad = [0u8; 1];
        r.read_exact(&mut has_grad)
            .map_err(io_err("read has_grad"))?;
        let grad = if has_grad[0] == 0 {
            None
        } else {
            Some(read_f32_vec(&mut r)?)
        };
        params.push(ParameterSnapshot {
            name,
            weight,
            m,
            v,
            step: pstep,
            grad,
        });
    }
    Ok(Checkpoint {
        step,
        params,
        config,
    })
}

/// Write a human-readable JSON summary of the checkpoint. The JSON is
/// hand-rolled (no `serde_json` at runtime) and is intended for
/// inspection / diffing, not for round-tripping — use
/// [`save_checkpoint`] / [`load_checkpoint`] for that.
pub fn save_json_summary(path: &Path, ckpt: &Checkpoint) -> Result<()> {
    let mut s = String::new();
    s.push_str("{\n");
    s.push_str(&format!("  \"step\": {},\n", ckpt.step));
    s.push_str(&format!("  \"param_count\": {},\n", ckpt.params.len()));
    s.push_str(&format!("  \"config\": {},\n", json_str(&ckpt.config)));
    let mut total_bytes: usize = ckpt.config.len();
    s.push_str("  \"params\": [\n");
    for (i, p) in ckpt.params.iter().enumerate() {
        let w = p.weight.len() * 4;
        let m = p.m.len() * 4;
        let v = p.v.len() * 4;
        let g = p.grad.as_ref().map_or(0, |g| g.len() * 4);
        total_bytes += w + m + v + g + p.name.len();
        s.push_str("    {");
        s.push_str(&format!("\"name\": {}, ", json_str(&p.name)));
        s.push_str(&format!("\"step\": {}, ", p.step));
        s.push_str(&format!("\"weight_numel\": {}, ", p.weight.len()));
        s.push_str(&format!("\"weight_bytes\": {}, ", w));
        s.push_str(&format!("\"m_numel\": {}, ", p.m.len()));
        s.push_str(&format!("\"m_bytes\": {}, ", m));
        s.push_str(&format!("\"v_numel\": {}, ", p.v.len()));
        s.push_str(&format!("\"v_bytes\": {}, ", v));
        s.push_str(&format!("\"has_grad\": {}, ", p.grad.is_some()));
        s.push_str(&format!("\"grad_bytes\": {}", g));
        s.push('}');
        if i + 1 < ckpt.params.len() {
            s.push(',');
        }
        s.push('\n');
    }
    s.push_str("  ],\n");
    s.push_str(&format!("  \"total_bytes\": {}\n", total_bytes));
    s.push_str("}\n");
    let mut f = File::create(path).map_err(io_err("create json"))?;
    f.write_all(s.as_bytes()).map_err(io_err("write json"))?;
    Ok(())
}

fn write_f32_vec<W: Write>(w: &mut W, v: &[f32]) -> Result<()> {
    w.write_all(&(v.len() as u32).to_le_bytes())
        .map_err(io_err("write vec_len"))?;
    for &f in v {
        w.write_all(&f.to_le_bytes()).map_err(io_err("write f32"))?;
    }
    Ok(())
}

fn read_f32_vec<R: Read>(r: &mut R) -> Result<Vec<f32>> {
    let mut len_buf = [0u8; 4];
    r.read_exact(&mut len_buf).map_err(io_err("read vec_len"))?;
    let len = u32::from_le_bytes(len_buf) as usize;
    let mut f_buf = [0u8; 4];
    let mut out = Vec::with_capacity(len);
    for _ in 0..len {
        r.read_exact(&mut f_buf).map_err(io_err("read f32"))?;
        out.push(f32::from_le_bytes(f_buf));
    }
    Ok(out)
}

fn json_str(s: &str) -> String {
    let mut o = String::with_capacity(s.len() + 2);
    o.push('"');
    for c in s.chars() {
        match c {
            '"' => o.push_str("\\\""),
            '\\' => o.push_str("\\\\"),
            '\n' => o.push_str("\\n"),
            '\r' => o.push_str("\\r"),
            '\t' => o.push_str("\\t"),
            c if (c as u32) < 0x20 => o.push_str(&format!("\\u{:04x}", c as u32)),
            c => o.push(c),
        }
    }
    o.push('"');
    o
}

fn io_err(op: &'static str) -> impl FnOnce(std::io::Error) -> Error {
    move |e| Error::backend(format!("checkpoint {op}: {e}"))
}