Skip to main content

oxibonsai_model/
checkpoint.rs

1//! Model checkpoint format for saving and restoring training state.
2//!
3//! # Binary Format (version 1)
4//!
5//! ## Header
6//! ```text
7//! magic:        b"OXCK"   (4 bytes)
8//! version:      u32 LE    (= 1)
9//! flags:        u64 LE    (reserved, must be 0 on write; ignored on read)
10//! num_tensors:  u64 LE
11//! metadata_len: u32 LE
12//! metadata:     UTF-8 JSON string (metadata_len bytes)
13//! ```
14//!
15//! ## Per tensor
16//! ```text
17//! name_len:  u32 LE
18//! name:      UTF-8 (name_len bytes)
19//! ndim:      u32 LE
20//! shape:     [u64 LE; ndim]
21//! data_len:  u64 LE  (number of f32 elements)
22//! data:      [f32 LE; data_len]
23//! ```
24//!
25//! Metadata is serialised as a simple `{"key":"val",...}` JSON object
26//! without nesting; keys and values must not contain `"` or `\`.
27
28use std::collections::HashMap;
29use std::fs::File;
30use std::io::{BufReader, BufWriter, Read, Write};
31use std::path::Path;
32
33// ─────────────────────────────────────────────────────────────────────────────
34// Public types
35// ─────────────────────────────────────────────────────────────────────────────
36
37/// Checkpoint metadata key-value pairs.
38pub type CheckpointMetadata = HashMap<String, String>;
39
40/// A serialized model checkpoint containing metadata and named tensors.
41#[derive(Debug)]
42pub struct Checkpoint {
43    /// Format version (always 1 for new checkpoints).
44    pub version: u32,
45    /// Arbitrary key-value metadata (e.g. step, loss, lr).
46    pub metadata: CheckpointMetadata,
47    /// Ordered list of tensor entries.
48    pub tensors: Vec<CheckpointTensor>,
49}
50
51/// A single tensor entry in the checkpoint.
52#[derive(Debug, Clone)]
53pub struct CheckpointTensor {
54    /// Unique tensor name within the checkpoint (e.g. `"layer.0.weight"`).
55    pub name: String,
56    /// N-dimensional shape; product must equal `data.len()`.
57    pub shape: Vec<u64>,
58    /// Raw `f32` data in row-major order.
59    pub data: Vec<f32>,
60}
61
62// ─────────────────────────────────────────────────────────────────────────────
63// CheckpointTensor
64// ─────────────────────────────────────────────────────────────────────────────
65
66impl CheckpointTensor {
67    /// Construct a checkpoint tensor.
68    pub fn new(name: impl Into<String>, data: Vec<f32>, shape: Vec<u64>) -> Self {
69        Self {
70            name: name.into(),
71            shape,
72            data,
73        }
74    }
75
76    /// Total number of scalar elements: product of all shape dimensions.
77    pub fn element_count(&self) -> u64 {
78        if self.shape.is_empty() {
79            return 0;
80        }
81        self.shape.iter().product()
82    }
83
84    /// Size of the tensor data in bytes (`element_count * 4`).
85    pub fn size_bytes(&self) -> usize {
86        self.element_count() as usize * 4
87    }
88
89    /// Convert from a [`crate::model_merge::WeightTensor`].
90    ///
91    /// The `usize` shape dimensions are widened to `u64`.
92    pub fn from_weight_tensor(wt: &crate::model_merge::WeightTensor) -> Self {
93        Self {
94            name: wt.name.clone(),
95            shape: wt.shape.iter().map(|&d| d as u64).collect(),
96            data: wt.data.clone(),
97        }
98    }
99
100    /// Convert back to a [`crate::model_merge::WeightTensor`].
101    ///
102    /// The `u64` shape dimensions are narrowed to `usize`; values that do not
103    /// fit in `usize` are clamped to `usize::MAX` (a safeguard — real models
104    /// never have dimensions that large).
105    pub fn to_weight_tensor(&self) -> crate::model_merge::WeightTensor {
106        let shape: Vec<usize> = self
107            .shape
108            .iter()
109            .map(|&d| usize::try_from(d).unwrap_or(usize::MAX))
110            .collect();
111        crate::model_merge::WeightTensor::new(self.name.clone(), self.data.clone(), shape)
112    }
113}
114
115// ─────────────────────────────────────────────────────────────────────────────
116// Checkpoint
117// ─────────────────────────────────────────────────────────────────────────────
118
119impl Checkpoint {
120    /// Create an empty checkpoint (version 1, no metadata, no tensors).
121    pub fn new() -> Self {
122        Self {
123            version: 1,
124            metadata: CheckpointMetadata::new(),
125            tensors: Vec::new(),
126        }
127    }
128
129    /// Append a tensor to the checkpoint.
130    pub fn add_tensor(&mut self, tensor: CheckpointTensor) {
131        self.tensors.push(tensor);
132    }
133
134    /// Insert or replace a metadata key-value pair.
135    pub fn set_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
136        self.metadata.insert(key.into(), value.into());
137    }
138
139    /// Look up a metadata value by key.
140    pub fn get_metadata(&self, key: &str) -> Option<&str> {
141        self.metadata.get(key).map(|s| s.as_str())
142    }
143
144    /// Find a tensor by name (linear scan; checkpoints are small).
145    pub fn get_tensor(&self, name: &str) -> Option<&CheckpointTensor> {
146        self.tensors.iter().find(|t| t.name == name)
147    }
148
149    /// Total bytes occupied by all tensor data (`sum of size_bytes()`).
150    pub fn total_bytes(&self) -> usize {
151        self.tensors.iter().map(|t| t.size_bytes()).sum()
152    }
153
154    /// Total number of `f32` parameters across all tensors.
155    pub fn num_params(&self) -> u64 {
156        self.tensors.iter().map(|t| t.element_count()).sum()
157    }
158
159    // ── file I/O ──────────────────────────────────────────────────────────────
160
161    /// Save the checkpoint to `path`, creating or truncating the file.
162    pub fn save(&self, path: &Path) -> Result<(), CheckpointError> {
163        let file = File::create(path)?;
164        let mut writer = BufWriter::new(file);
165        self.write_to(&mut writer)
166    }
167
168    /// Load a checkpoint from `path`.
169    pub fn load(path: &Path) -> Result<Self, CheckpointError> {
170        let file = File::open(path)?;
171        let mut reader = BufReader::new(file);
172        Self::read_from(&mut reader)
173    }
174
175    // ── streaming I/O ─────────────────────────────────────────────────────────
176
177    /// Serialise the checkpoint into `writer`.
178    ///
179    /// The writer is NOT flushed; callers that need it (e.g. `BufWriter`) must
180    /// flush themselves, or use [`save`](Self::save) which wraps a `BufWriter`
181    /// and flushes on drop.
182    pub fn write_to<W: Write>(&self, writer: &mut W) -> Result<(), CheckpointError> {
183        // ── header ──
184        writer.write_all(b"OXCK")?;
185        write_u32_le(writer, 1u32)?; // version
186        write_u64_le(writer, 0u64)?; // flags (reserved)
187        write_u64_le(writer, self.tensors.len() as u64)?;
188
189        // metadata
190        let meta_str = serialize_metadata(&self.metadata);
191        let meta_bytes = meta_str.as_bytes();
192        write_u32_le(writer, meta_bytes.len() as u32)?;
193        writer.write_all(meta_bytes)?;
194
195        // ── tensors ──
196        for tensor in &self.tensors {
197            let name_bytes = tensor.name.as_bytes();
198            if name_bytes.len() > 65535 {
199                return Err(CheckpointError::NameTooLong(name_bytes.len()));
200            }
201            write_u32_le(writer, name_bytes.len() as u32)?;
202            writer.write_all(name_bytes)?;
203
204            write_u32_le(writer, tensor.shape.len() as u32)?;
205            for &dim in &tensor.shape {
206                write_u64_le(writer, dim)?;
207            }
208
209            write_u64_le(writer, tensor.data.len() as u64)?;
210            for &f in &tensor.data {
211                writer.write_all(&f.to_le_bytes())?;
212            }
213        }
214
215        Ok(())
216    }
217
218    /// Deserialise a checkpoint from `reader`.
219    pub fn read_from<R: Read>(reader: &mut R) -> Result<Self, CheckpointError> {
220        // ── magic ──
221        let mut magic = [0u8; 4];
222        read_exact(reader, &mut magic)?;
223        if &magic != b"OXCK" {
224            return Err(CheckpointError::InvalidMagic(magic.to_vec()));
225        }
226
227        // ── version ──
228        let version = read_u32_le(reader)?;
229        if version == 0 || version > 1 {
230            return Err(CheckpointError::UnsupportedVersion(version));
231        }
232
233        // ── flags (reserved) ──
234        let _flags = read_u64_le(reader)?;
235
236        // ── tensor count ──
237        let num_tensors = read_u64_le(reader)? as usize;
238
239        // ── metadata ──
240        let meta_len = read_u32_le(reader)? as usize;
241        let mut meta_bytes = vec![0u8; meta_len];
242        read_exact(reader, &mut meta_bytes)?;
243        let meta_str = std::str::from_utf8(&meta_bytes)
244            .map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
245        let metadata = deserialize_metadata(meta_str)?;
246
247        // ── tensors ──
248        let mut tensors = Vec::with_capacity(num_tensors);
249        for _ in 0..num_tensors {
250            // name
251            let name_len = read_u32_le(reader)? as usize;
252            let mut name_bytes = vec![0u8; name_len];
253            read_exact(reader, &mut name_bytes)?;
254            let name = String::from_utf8(name_bytes)
255                .map_err(|e| CheckpointError::MetadataParse(e.to_string()))?;
256
257            // shape
258            let ndim = read_u32_le(reader)? as usize;
259            let mut shape = Vec::with_capacity(ndim);
260            for _ in 0..ndim {
261                shape.push(read_u64_le(reader)?);
262            }
263
264            // data
265            let data_len = read_u64_le(reader)? as usize;
266            let mut data = Vec::with_capacity(data_len);
267            for _ in 0..data_len {
268                let mut buf = [0u8; 4];
269                read_exact(reader, &mut buf)?;
270                data.push(f32::from_le_bytes(buf));
271            }
272
273            tensors.push(CheckpointTensor { name, shape, data });
274        }
275
276        Ok(Self {
277            version,
278            metadata,
279            tensors,
280        })
281    }
282}
283
284impl Default for Checkpoint {
285    fn default() -> Self {
286        Self::new()
287    }
288}
289
290// ─────────────────────────────────────────────────────────────────────────────
291// Metadata serialization (no serde)
292// ─────────────────────────────────────────────────────────────────────────────
293
294/// Serialize `metadata` as `{"key1":"val1","key2":"val2"}`.
295///
296/// Keys and values must not contain `"` or `\`; if they do, those characters
297/// are escaped with `\` so the round-trip is still correct for typical
298/// training metadata (step numbers, loss strings, etc.).
299fn serialize_metadata(meta: &CheckpointMetadata) -> String {
300    // Deterministic order for reproducibility.
301    let mut pairs: Vec<(&String, &String)> = meta.iter().collect();
302    pairs.sort_by_key(|(k, _)| k.as_str());
303
304    let mut out = String::from('{');
305    for (i, (k, v)) in pairs.iter().enumerate() {
306        if i > 0 {
307            out.push(',');
308        }
309        out.push('"');
310        push_escaped(&mut out, k);
311        out.push_str("\":\"");
312        push_escaped(&mut out, v);
313        out.push('"');
314    }
315    out.push('}');
316    out
317}
318
319/// Escape `"` → `\"` and `\` → `\\` within a JSON string value.
320fn push_escaped(out: &mut String, s: &str) {
321    for ch in s.chars() {
322        match ch {
323            '"' => out.push_str("\\\""),
324            '\\' => out.push_str("\\\\"),
325            other => out.push(other),
326        }
327    }
328}
329
330/// Deserialize a simple `{"key":"val",...}` JSON object.
331///
332/// This is a purposely minimal state machine — it does not handle nested
333/// objects or arrays.  Its sole purpose is to decode the metadata written by
334/// [`serialize_metadata`].
335fn deserialize_metadata(s: &str) -> Result<CheckpointMetadata, CheckpointError> {
336    let s = s.trim();
337    if s.is_empty() {
338        return Ok(CheckpointMetadata::new());
339    }
340
341    // Allow both `{}` and plain empty strings as "no metadata".
342    if s == "{}" {
343        return Ok(CheckpointMetadata::new());
344    }
345
346    let bytes = s.as_bytes();
347    if bytes.first() != Some(&b'{') || bytes.last() != Some(&b'}') {
348        return Err(CheckpointError::MetadataParse(format!(
349            "expected JSON object, got: {s}"
350        )));
351    }
352
353    // Strip outer braces.
354    let inner = &s[1..s.len() - 1];
355    let mut map = CheckpointMetadata::new();
356
357    if inner.trim().is_empty() {
358        return Ok(map);
359    }
360
361    // Parse "key":"value" pairs separated by commas.
362    // We use a simple char-by-char scanner that handles `\"` escapes.
363    let chars: Vec<char> = inner.chars().collect();
364    let mut pos = 0usize;
365
366    loop {
367        // Skip optional whitespace and commas between pairs.
368        while pos < chars.len() && (chars[pos] == ',' || chars[pos].is_whitespace()) {
369            pos += 1;
370        }
371        if pos >= chars.len() {
372            break;
373        }
374
375        // Expect opening `"` of key.
376        if chars[pos] != '"' {
377            return Err(CheckpointError::MetadataParse(format!(
378                "expected '\"' at position {pos}, got '{}'",
379                chars[pos]
380            )));
381        }
382        pos += 1;
383
384        let (key, new_pos) = parse_json_string(&chars, pos)?;
385        pos = new_pos;
386
387        // Expect `:`
388        skip_ws(&chars, &mut pos);
389        if pos >= chars.len() || chars[pos] != ':' {
390            return Err(CheckpointError::MetadataParse(format!(
391                "expected ':' after key '{key}'"
392            )));
393        }
394        pos += 1;
395        skip_ws(&chars, &mut pos);
396
397        // Expect opening `"` of value.
398        if pos >= chars.len() || chars[pos] != '"' {
399            return Err(CheckpointError::MetadataParse(format!(
400                "expected '\"' for value of key '{key}'"
401            )));
402        }
403        pos += 1;
404
405        let (value, new_pos) = parse_json_string(&chars, pos)?;
406        pos = new_pos;
407
408        map.insert(key, value);
409    }
410
411    Ok(map)
412}
413
414/// Parse a JSON string body starting at `pos` (after the opening `"`).
415///
416/// Returns `(string, position_after_closing_quote)`.
417fn parse_json_string(chars: &[char], mut pos: usize) -> Result<(String, usize), CheckpointError> {
418    let mut s = String::new();
419    while pos < chars.len() {
420        match chars[pos] {
421            '"' => {
422                pos += 1; // consume closing quote
423                return Ok((s, pos));
424            }
425            '\\' => {
426                pos += 1;
427                if pos >= chars.len() {
428                    return Err(CheckpointError::MetadataParse(
429                        "unexpected end after backslash".into(),
430                    ));
431                }
432                match chars[pos] {
433                    '"' => s.push('"'),
434                    '\\' => s.push('\\'),
435                    'n' => s.push('\n'),
436                    'r' => s.push('\r'),
437                    't' => s.push('\t'),
438                    other => {
439                        return Err(CheckpointError::MetadataParse(format!(
440                            "unknown escape '\\{other}'"
441                        )))
442                    }
443                }
444                pos += 1;
445            }
446            ch => {
447                s.push(ch);
448                pos += 1;
449            }
450        }
451    }
452    Err(CheckpointError::MetadataParse("unterminated string".into()))
453}
454
455/// Advance `pos` past ASCII whitespace.
456fn skip_ws(chars: &[char], pos: &mut usize) {
457    while *pos < chars.len() && chars[*pos].is_whitespace() {
458        *pos += 1;
459    }
460}
461
462// ─────────────────────────────────────────────────────────────────────────────
463// Low-level I/O helpers
464// ─────────────────────────────────────────────────────────────────────────────
465
466fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<(), CheckpointError> {
467    w.write_all(&v.to_le_bytes())?;
468    Ok(())
469}
470
471fn write_u64_le<W: Write>(w: &mut W, v: u64) -> Result<(), CheckpointError> {
472    w.write_all(&v.to_le_bytes())?;
473    Ok(())
474}
475
476fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), CheckpointError> {
477    let expected = buf.len();
478    let mut total_read = 0usize;
479    while total_read < expected {
480        match r.read(&mut buf[total_read..]) {
481            Ok(0) => {
482                return Err(CheckpointError::TruncatedData {
483                    expected,
484                    got: total_read,
485                })
486            }
487            Ok(n) => total_read += n,
488            Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
489            Err(e) => return Err(CheckpointError::Io(e)),
490        }
491    }
492    Ok(())
493}
494
495fn read_u32_le<R: Read>(r: &mut R) -> Result<u32, CheckpointError> {
496    let mut buf = [0u8; 4];
497    read_exact(r, &mut buf)?;
498    Ok(u32::from_le_bytes(buf))
499}
500
501fn read_u64_le<R: Read>(r: &mut R) -> Result<u64, CheckpointError> {
502    let mut buf = [0u8; 8];
503    read_exact(r, &mut buf)?;
504    Ok(u64::from_le_bytes(buf))
505}
506
507// ─────────────────────────────────────────────────────────────────────────────
508// Error type
509// ─────────────────────────────────────────────────────────────────────────────
510
511/// Errors that can occur during checkpoint I/O.
512#[derive(Debug, thiserror::Error)]
513pub enum CheckpointError {
514    /// Wraps any [`std::io::Error`] from the underlying reader/writer.
515    #[error("I/O error: {0}")]
516    Io(#[from] std::io::Error),
517
518    /// The file does not begin with the expected `b"OXCK"` magic bytes.
519    #[error("invalid magic bytes: expected OXCK, got {0:?}")]
520    InvalidMagic(Vec<u8>),
521
522    /// The checkpoint was written with a version this library cannot read.
523    #[error("unsupported checkpoint version: {0}")]
524    UnsupportedVersion(u32),
525
526    /// The metadata block could not be parsed as a key-value JSON object.
527    #[error("metadata parse error: {0}")]
528    MetadataParse(String),
529
530    /// The byte stream ended before the expected number of bytes were read.
531    #[error("truncated data: expected {expected} bytes, got {got}")]
532    TruncatedData { expected: usize, got: usize },
533
534    /// A tensor name exceeds 65 535 bytes (the 16-bit length field limit).
535    #[error("tensor name too long: {0} bytes (max 65535)")]
536    NameTooLong(usize),
537}