Skip to main content

flodl/nn/
checkpoint.rs

1use std::io::{Read, Write};
2
3use crate::tensor::{Device, DType, Result, Tensor, TensorError};
4
5use super::buffer::Buffer;
6use super::parameter::Parameter;
7
8/// Magic bytes for `.fdl` checkpoint files.
9pub(crate) const MAGIC: [u8; 4] = *b"FDLC";
10/// Current checkpoint format version.
11/// v1 = flodl 0.1.x naming, v2 = flodl 0.2.0+ naming (identical binary layout).
12pub(crate) const VERSION: u32 = 2;
13/// Maximum checkpoint version we can read.
14const MAX_VERSION: u32 = 2;
15/// Size of the structural hash field in the checkpoint header.
16pub(crate) const HASH_LEN: usize = 32;
17
18/// Report from a checkpoint load: what was loaded, skipped, or missing.
19#[derive(Debug, Clone)]
20pub struct LoadReport {
21    /// Entries matched by name and loaded successfully.
22    pub loaded: Vec<String>,
23    /// Checkpoint entries with no matching model parameter or buffer (ignored).
24    pub skipped: Vec<String>,
25    /// Model parameters/buffers with no matching checkpoint entry (kept at init values).
26    pub missing: Vec<String>,
27}
28
29/// Save parameters and buffers to a binary checkpoint.
30///
31/// Both params and buffers are stored as named tensors in the same flat list.
32/// The format is: `MAGIC(4) | VERSION(u32=1) | hash(32 bytes) | num_entries(u32) | entries...`
33///
34/// Pass `structural_hash` from `Graph::structural_hash()` to embed architecture
35/// identity. Pass `None` to write 32 zero bytes (hash validation skipped on load).
36pub fn save_checkpoint<W: Write>(
37    w: &mut W,
38    params: &[(String, Parameter)],
39    buffers: &[(String, Buffer)],
40    structural_hash: Option<&str>,
41) -> Result<()> {
42    w.write_all(&MAGIC).map_err(io_err)?;
43    w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
44
45    // Write 32-byte hash (or zeros)
46    let hash_bytes = match structural_hash {
47        Some(hex) => hex_to_bytes(hex)?,
48        None => [0u8; HASH_LEN],
49    };
50    w.write_all(&hash_bytes).map_err(io_err)?;
51
52    let total = (params.len() + buffers.len()) as u32;
53    w.write_all(&total.to_le_bytes()).map_err(io_err)?;
54
55    for (name, p) in params {
56        let name_bytes = name.as_bytes();
57        w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
58        w.write_all(name_bytes).map_err(io_err)?;
59        write_tensor_data(w, &p.variable.data())?;
60    }
61
62    for (name, b) in buffers {
63        let name_bytes = name.as_bytes();
64        w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
65        w.write_all(name_bytes).map_err(io_err)?;
66        write_tensor_data(w, &b.get())?;
67    }
68
69    Ok(())
70}
71
72/// Load a checkpoint, matching entries by qualified name against both
73/// parameters and buffers.
74///
75/// Returns a `LoadReport` describing what was matched, skipped, and missing.
76/// Shape mismatches on a matched name are errors (not silent skips).
77///
78/// Pass `structural_hash` from `Graph::structural_hash()` to validate that the
79/// checkpoint was saved from the same architecture. Pass `None` to skip validation.
80/// If both the file hash and expected hash are non-zero and they differ, returns an error.
81pub fn load_checkpoint<R: Read>(
82    r: &mut R,
83    params: &[(String, Parameter)],
84    buffers: &[(String, Buffer)],
85    structural_hash: Option<&str>,
86) -> Result<LoadReport> {
87    let mut magic = [0u8; 4];
88    r.read_exact(&mut magic).map_err(io_err)?;
89    if magic != MAGIC {
90        return Err(TensorError::new(
91            "invalid checkpoint: bad magic (expected .fdl checkpoint)"
92        ));
93    }
94
95    let version = read_u32(r)?;
96    if version == 0 || version > MAX_VERSION {
97        return Err(TensorError::new(&format!(
98            "unsupported checkpoint version {} (this build supports 1..={})",
99            version, MAX_VERSION,
100        )));
101    }
102
103    // Read and validate structural hash
104    let mut file_hash = [0u8; HASH_LEN];
105    r.read_exact(&mut file_hash).map_err(io_err)?;
106
107    let file_nonzero = file_hash.iter().any(|&b| b != 0);
108    if let Some(expected_hex) = structural_hash {
109        let expected = hex_to_bytes(expected_hex)?;
110        let expected_nonzero = expected.iter().any(|&b| b != 0);
111        if file_nonzero && expected_nonzero && file_hash != expected {
112            return Err(TensorError::new(&format!(
113                "checkpoint architecture mismatch: file={} model={}",
114                bytes_to_hex(&file_hash),
115                expected_hex,
116            )));
117        }
118    }
119
120    let count = read_u32(r)? as usize;
121
122    // Read all checkpoint entries into a map
123    let mut ckpt: std::collections::HashMap<String, (Vec<i64>, DType, Vec<u8>)> =
124        std::collections::HashMap::with_capacity(count);
125
126    for _ in 0..count {
127        let name_len = read_u32(r)? as usize;
128        let mut name_bytes = vec![0u8; name_len];
129        r.read_exact(&mut name_bytes).map_err(io_err)?;
130        let name = String::from_utf8_lossy(&name_bytes).into_owned();
131
132        let ndim = read_u32(r)? as usize;
133        let mut shape = vec![0i64; ndim];
134        for s in &mut shape { *s = read_i64(r)?; }
135        let mut tag = [0u8; 1];
136        r.read_exact(&mut tag).map_err(io_err)?;
137        let dtype = dtype_from_tag(tag[0])?;
138        let byte_count = read_u64(r)? as usize;
139        let mut raw = vec![0u8; byte_count];
140        r.read_exact(&mut raw).map_err(io_err)?;
141        ckpt.insert(name, (shape, dtype, raw));
142    }
143
144    let mut loaded = Vec::new();
145    let mut missing = Vec::new();
146
147    // Match parameters
148    for (name, p) in params {
149        if let Some((shape, dtype, raw)) = ckpt.remove(name) {
150            let model_shape = p.variable.shape();
151            if shape != model_shape {
152                return Err(TensorError::new(&format!(
153                    "parameter {:?}: shape mismatch: checkpoint={:?} model={:?}",
154                    name, shape, model_shape
155                )));
156            }
157            let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
158            let model_dtype = p.variable.data().dtype();
159            let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
160            let dev = p.variable.data().device();
161            if dev != Device::CPU {
162                p.variable.set_data(t.to_device(dev)?);
163            } else {
164                p.variable.set_data(t);
165            }
166            loaded.push(name.clone());
167        } else {
168            missing.push(name.clone());
169        }
170    }
171
172    // Match buffers
173    for (name, b) in buffers {
174        if let Some((shape, dtype, raw)) = ckpt.remove(name) {
175            let model_shape = b.shape();
176            if shape != model_shape {
177                return Err(TensorError::new(&format!(
178                    "buffer {:?}: shape mismatch: checkpoint={:?} model={:?}",
179                    name, shape, model_shape
180                )));
181            }
182            let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
183            let model_dtype = b.get().dtype();
184            let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
185            let dev = b.device();
186            if dev != Device::CPU {
187                b.set(t.to_device(dev)?);
188            } else {
189                b.set(t);
190            }
191            loaded.push(name.clone());
192        } else {
193            missing.push(name.clone());
194        }
195    }
196
197    let skipped: Vec<String> = ckpt.into_keys().collect();
198
199    Ok(LoadReport { loaded, skipped, missing })
200}
201
202/// Save checkpoint to a file path. Uses gzip compression if path ends with `.gz`.
203pub fn save_checkpoint_file(
204    path: &str,
205    params: &[(String, Parameter)],
206    buffers: &[(String, Buffer)],
207    structural_hash: Option<&str>,
208) -> Result<()> {
209    let f = std::fs::File::create(path).map_err(io_err)?;
210    if path.ends_with(".gz") {
211        let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
212        save_checkpoint(&mut w, params, buffers, structural_hash)?;
213        w.finish().map_err(io_err)?;
214        Ok(())
215    } else {
216        let mut w = std::io::BufWriter::new(f);
217        save_checkpoint(&mut w, params, buffers, structural_hash)
218    }
219}
220
221/// Load checkpoint from a file path. Detects gzip from `.gz` extension.
222pub fn load_checkpoint_file(
223    path: &str,
224    params: &[(String, Parameter)],
225    buffers: &[(String, Buffer)],
226    structural_hash: Option<&str>,
227) -> Result<LoadReport> {
228    let f = std::fs::File::open(path).map_err(io_err)?;
229    if path.ends_with(".gz") {
230        let mut r = flate2::read::GzDecoder::new(f);
231        load_checkpoint(&mut r, params, buffers, structural_hash)
232    } else {
233        let mut r = std::io::BufReader::new(f);
234        load_checkpoint(&mut r, params, buffers, structural_hash)
235    }
236}
237
238/// Peek at the version number of a checkpoint file without reading the full contents.
239///
240/// Read just the parameter and buffer names from a `.fdl` checkpoint
241/// without loading any tensor data.
242///
243/// Useful when a caller needs to introspect the checkpoint's shape — for
244/// example, to detect optional sub-modules (a pooler, a task head, …)
245/// before constructing the matching graph. Reading is bounded by the
246/// header's entry count, so a malformed file errors at parse rather
247/// than allocating unbounded memory.
248///
249/// Detects gzip from a `.gz` extension. The structural-hash field is
250/// read but not validated — pair this with `load_checkpoint_file` once
251/// the matching graph is built if you need hash validation.
252pub fn checkpoint_keys(path: &str) -> Result<Vec<String>> {
253    let f = std::fs::File::open(path).map_err(io_err)?;
254    let mut r: Box<dyn Read> = if path.ends_with(".gz") {
255        Box::new(flate2::read::GzDecoder::new(f))
256    } else {
257        Box::new(std::io::BufReader::new(f))
258    };
259
260    let mut magic = [0u8; 4];
261    r.read_exact(&mut magic).map_err(io_err)?;
262    if magic != MAGIC {
263        return Err(TensorError::new(
264            "invalid checkpoint: bad magic (expected .fdl checkpoint)",
265        ));
266    }
267    let version = read_u32(&mut r)?;
268    if version == 0 || version > MAX_VERSION {
269        return Err(TensorError::new(&format!(
270            "unsupported checkpoint version {} (this build supports 1..={})",
271            version, MAX_VERSION,
272        )));
273    }
274    // Skip the 32-byte structural hash.
275    let mut _hash = [0u8; HASH_LEN];
276    r.read_exact(&mut _hash).map_err(io_err)?;
277
278    let count = read_u32(&mut r)? as usize;
279    let mut keys = Vec::with_capacity(count);
280    for _ in 0..count {
281        let name_len = read_u32(&mut r)? as usize;
282        let mut name_bytes = vec![0u8; name_len];
283        r.read_exact(&mut name_bytes).map_err(io_err)?;
284        keys.push(String::from_utf8_lossy(&name_bytes).into_owned());
285        // Skip ndim, shape, dtype tag, byte_count + raw payload.
286        let ndim = read_u32(&mut r)? as usize;
287        for _ in 0..ndim {
288            let _ = read_i64(&mut r)?;
289        }
290        let mut tag = [0u8; 1];
291        r.read_exact(&mut tag).map_err(io_err)?;
292        let byte_count = read_u64(&mut r)? as usize;
293        // Skip payload.
294        std::io::copy(&mut r.by_ref().take(byte_count as u64), &mut std::io::sink())
295            .map_err(io_err)?;
296    }
297    Ok(keys)
298}
299
300/// Returns the version field (1 for flodl 0.1.x, 2 for flodl 0.2.0+).
301/// Useful to decide whether a checkpoint needs migration before loading.
302pub fn checkpoint_version(path: &str) -> Result<u32> {
303    let f = std::fs::File::open(path).map_err(io_err)?;
304    let mut r: Box<dyn Read> = if path.ends_with(".gz") {
305        Box::new(flate2::read::GzDecoder::new(f))
306    } else {
307        Box::new(std::io::BufReader::new(f))
308    };
309    let mut magic = [0u8; 4];
310    r.read_exact(&mut magic).map_err(io_err)?;
311    if magic != MAGIC {
312        return Err(TensorError::new(
313            "invalid checkpoint: bad magic (expected .fdl checkpoint)"
314        ));
315    }
316    read_u32(&mut r)
317}
318
319// --- Tensor state helpers for optimizer save/load ---
320
321/// Write an optional tensor (for optimizer buffers that may not be initialized).
322/// Uses native dtype — same format as v2 parameters.
323pub(crate) fn write_tensor_state<W: Write>(w: &mut W, t: Option<&Tensor>) -> Result<()> {
324    match t {
325        None => {
326            w.write_all(&[0u8]).map_err(io_err)?;
327        }
328        Some(t) => {
329            w.write_all(&[1u8]).map_err(io_err)?;
330            write_tensor_data(w, t)?;
331        }
332    }
333    Ok(())
334}
335
336/// Read an optional tensor (returns None if the tensor was nil when saved).
337pub(crate) fn read_tensor_state<R: Read>(r: &mut R, device: Device) -> Result<Option<Tensor>> {
338    let mut present = [0u8; 1];
339    r.read_exact(&mut present).map_err(io_err)?;
340    if present[0] == 0 {
341        return Ok(None);
342    }
343
344    let t = read_tensor_data(r)?;
345    if device != Device::CPU {
346        Ok(Some(t.to_device(device)?))
347    } else {
348        Ok(Some(t))
349    }
350}
351
352// --- Internal: dtype-aware tensor serialization ---
353
354/// DType tag byte for checkpoint format.
355fn dtype_tag(dtype: DType) -> u8 {
356    match dtype {
357        DType::Float16  => 1,
358        DType::BFloat16 => 2,
359        DType::Float32  => 3,
360        DType::Float64  => 4,
361        DType::Int32    => 5,
362        DType::Int64    => 6,
363    }
364}
365
366fn dtype_from_tag(tag: u8) -> Result<DType> {
367    match tag {
368        1 => Ok(DType::Float16),
369        2 => Ok(DType::BFloat16),
370        3 => Ok(DType::Float32),
371        4 => Ok(DType::Float64),
372        5 => Ok(DType::Int32),
373        6 => Ok(DType::Int64),
374        _ => Err(TensorError::new(&format!("unknown dtype tag: {}", tag))),
375    }
376}
377
378/// Write tensor data in native dtype: shape + dtype tag + raw bytes.
379pub(crate) fn write_tensor_data<W: Write>(w: &mut W, t: &Tensor) -> Result<()> {
380    let shape = t.shape();
381    w.write_all(&(shape.len() as u32).to_le_bytes()).map_err(io_err)?;
382    for &s in &shape {
383        w.write_all(&s.to_le_bytes()).map_err(io_err)?;
384    }
385
386    let dtype = t.dtype();
387    w.write_all(&[dtype_tag(dtype)]).map_err(io_err)?;
388
389    let numel = t.numel() as usize;
390    let elem_size = dtype.element_size();
391    let byte_count = numel * elem_size;
392
393    // Copy raw bytes from tensor (handles any dtype)
394    let raw = copy_raw_bytes(t, byte_count)?;
395    w.write_all(&(byte_count as u64).to_le_bytes()).map_err(io_err)?;
396    w.write_all(&raw).map_err(io_err)?;
397
398    Ok(())
399}
400
401/// Read tensor data written by write_tensor_data.
402fn read_tensor_data<R: Read>(r: &mut R) -> Result<Tensor> {
403    let ndim = read_u32(r)? as usize;
404    let mut shape = vec![0i64; ndim];
405    for s in &mut shape {
406        *s = read_i64(r)?;
407    }
408
409    let mut tag = [0u8; 1];
410    r.read_exact(&mut tag).map_err(io_err)?;
411    let dtype = dtype_from_tag(tag[0])?;
412
413    let byte_count = read_u64(r)? as usize;
414    let mut raw = vec![0u8; byte_count];
415    r.read_exact(&mut raw).map_err(io_err)?;
416
417    tensor_from_raw_bytes(&raw, &shape, dtype)
418}
419
420/// Copy raw bytes from a tensor (any dtype). Moves to CPU if needed.
421fn copy_raw_bytes(t: &Tensor, byte_count: usize) -> Result<Vec<u8>> {
422    let mut buf = vec![0u8; byte_count];
423    let err = unsafe {
424        flodl_sys::flodl_copy_data(
425            t.raw(),
426            buf.as_mut_ptr() as *mut std::ffi::c_void,
427            byte_count as i64,
428        )
429    };
430    check_err_raw(err)?;
431    Ok(buf)
432}
433
434/// Construct a tensor from raw bytes + shape + dtype.
435fn tensor_from_raw_bytes(raw: &[u8], shape: &[i64], dtype: DType) -> Result<Tensor> {
436    // Route through the typed constructors to get a proper owned tensor
437    match dtype {
438        DType::Float32 => {
439            let data: Vec<f32> = raw.chunks_exact(4)
440                .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
441                .collect();
442            Tensor::from_f32(&data, shape, Device::CPU)
443        }
444        DType::Float64 => {
445            let data: Vec<f64> = raw.chunks_exact(8)
446                .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
447                .collect();
448            Tensor::from_f64(&data, shape, Device::CPU)
449        }
450        DType::Int64 => {
451            let data: Vec<i64> = raw.chunks_exact(8)
452                .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
453                .collect();
454            Tensor::from_i64(&data, shape, Device::CPU)
455        }
456        DType::Float16 | DType::BFloat16 | DType::Int32 => {
457            // For f16/bf16/i32: load raw bytes via from_blob directly.
458            let mut shape_v = shape.to_vec();
459            let mut handle: flodl_sys::FlodlTensor = std::ptr::null_mut();
460            let (dev_type, dev_idx) = crate::tensor::Device::CPU.to_ffi();
461            let err = unsafe {
462                flodl_sys::flodl_from_blob(
463                    raw.as_ptr() as *mut std::ffi::c_void,
464                    shape_v.as_mut_ptr(),
465                    shape_v.len() as i32,
466                    dtype as i32,
467                    dev_type, dev_idx,
468                    &mut handle,
469                )
470            };
471            check_err_raw(err)?;
472            debug_assert!(!handle.is_null());
473            // Safety: from_blob clones the data in the shim, so handle is independent
474            Ok(unsafe { Tensor::from_raw_handle(handle) })
475        }
476    }
477}
478
479// --- Checkpoint migration ---
480
481/// Report from a checkpoint migration.
482#[derive(Debug, Clone)]
483pub struct MigrateReport {
484    /// Entries that kept their original name (exact match in old and new model).
485    pub unchanged: Vec<String>,
486    /// Entries remapped by shape+dtype matching: `(old_name, new_name)`.
487    pub remapped: Vec<(String, String)>,
488    /// Checkpoint entries with no matching model parameter/buffer (not migrated).
489    pub dropped: Vec<String>,
490    /// Model parameters/buffers with no matching checkpoint entry (will use init values).
491    pub missing: Vec<String>,
492}
493
494impl MigrateReport {
495    /// True if every checkpoint entry was matched (nothing dropped or missing).
496    pub fn is_complete(&self) -> bool {
497        self.dropped.is_empty() && self.missing.is_empty()
498    }
499}
500
501impl std::fmt::Display for MigrateReport {
502    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503        if !self.unchanged.is_empty() {
504            writeln!(f, "unchanged ({}):", self.unchanged.len())?;
505            for name in &self.unchanged { writeln!(f, "  {}", name)?; }
506        }
507        if !self.remapped.is_empty() {
508            writeln!(f, "remapped ({}):", self.remapped.len())?;
509            for (old, new) in &self.remapped { writeln!(f, "  {} -> {}", old, new)?; }
510        }
511        if !self.dropped.is_empty() {
512            writeln!(f, "dropped ({}):", self.dropped.len())?;
513            for name in &self.dropped { writeln!(f, "  {}", name)?; }
514        }
515        if !self.missing.is_empty() {
516            writeln!(f, "missing ({}):", self.missing.len())?;
517            for name in &self.missing { writeln!(f, "  {}", name)?; }
518        }
519        Ok(())
520    }
521}
522
523/// Raw checkpoint entry for migration (not loaded into a live Tensor).
524struct RawEntry {
525    name: String,
526    shape: Vec<i64>,
527    dtype: DType,
528    raw: Vec<u8>,
529}
530
531/// Read checkpoint header and all raw entries without constructing tensors.
532fn read_raw_checkpoint<R: Read>(r: &mut R) -> Result<Vec<RawEntry>> {
533    let mut magic = [0u8; 4];
534    r.read_exact(&mut magic).map_err(io_err)?;
535    if magic != MAGIC {
536        return Err(TensorError::new(
537            "invalid checkpoint: bad magic (expected .fdl checkpoint)"
538        ));
539    }
540    let version = read_u32(r)?;
541    if version == 0 || version > MAX_VERSION {
542        return Err(TensorError::new(&format!(
543            "unsupported checkpoint version {} (this build supports 1..={})",
544            version, MAX_VERSION,
545        )));
546    }
547    // Skip structural hash
548    let mut _hash = [0u8; HASH_LEN];
549    r.read_exact(&mut _hash).map_err(io_err)?;
550
551    let count = read_u32(r)? as usize;
552    let mut entries = Vec::with_capacity(count);
553
554    for _ in 0..count {
555        let name_len = read_u32(r)? as usize;
556        let mut name_bytes = vec![0u8; name_len];
557        r.read_exact(&mut name_bytes).map_err(io_err)?;
558        let name = String::from_utf8_lossy(&name_bytes).into_owned();
559
560        let ndim = read_u32(r)? as usize;
561        let mut shape = vec![0i64; ndim];
562        for s in &mut shape { *s = read_i64(r)?; }
563        let mut tag = [0u8; 1];
564        r.read_exact(&mut tag).map_err(io_err)?;
565        let dtype = dtype_from_tag(tag[0])?;
566        let byte_count = read_u64(r)? as usize;
567        let mut raw = vec![0u8; byte_count];
568        r.read_exact(&mut raw).map_err(io_err)?;
569
570        entries.push(RawEntry { name, shape, dtype, raw });
571    }
572
573    Ok(entries)
574}
575
576/// Write a single raw entry (name + tensor data) into a checkpoint stream.
577fn write_raw_entry<W: Write>(w: &mut W, name: &str, e: &RawEntry) -> Result<()> {
578    let name_bytes = name.as_bytes();
579    w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
580    w.write_all(name_bytes).map_err(io_err)?;
581    w.write_all(&(e.shape.len() as u32).to_le_bytes()).map_err(io_err)?;
582    for &s in &e.shape {
583        w.write_all(&s.to_le_bytes()).map_err(io_err)?;
584    }
585    w.write_all(&[dtype_tag(e.dtype)]).map_err(io_err)?;
586    w.write_all(&(e.raw.len() as u64).to_le_bytes()).map_err(io_err)?;
587    w.write_all(&e.raw).map_err(io_err)?;
588    Ok(())
589}
590
591/// Migrate a checkpoint to match a model's current parameter and buffer naming.
592///
593/// Reads the source checkpoint and matches each entry against the model's
594/// `named_parameters` and `named_buffers`:
595///
596/// 1. **Exact name match** — entries whose name and shape match a model target
597///    are passed through unchanged.
598/// 2. **Shape+dtype match** — remaining entries are matched to remaining model
599///    targets by shape and dtype, in checkpoint order. This handles the common
600///    case where only tag/node prefixes changed between versions.
601///
602/// The migrated checkpoint is written with a zeroed structural hash so it can
603/// be loaded without architecture validation.
604///
605/// # Example
606///
607/// ```ignore
608/// let graph = FlowBuilder::from(input)
609///     .through(encoder).tag("encoder")
610///     .build()?;
611///
612/// let report = migrate_checkpoint(
613///     &mut src_reader,
614///     &mut dst_writer,
615///     &graph.named_parameters(),
616///     &graph.named_buffers(),
617/// )?;
618/// println!("{}", report);
619/// ```
620pub fn migrate_checkpoint<R: Read, W: Write>(
621    r: &mut R,
622    w: &mut W,
623    params: &[(String, Parameter)],
624    buffers: &[(String, Buffer)],
625) -> Result<MigrateReport> {
626    let entries = read_raw_checkpoint(r)?;
627
628    // Build model expectations in order: params then buffers
629    let mut targets: Vec<(String, Vec<i64>, DType)> = Vec::with_capacity(
630        params.len() + buffers.len()
631    );
632    for (name, p) in params {
633        targets.push((name.clone(), p.variable.shape(), p.variable.data().dtype()));
634    }
635    for (name, b) in buffers {
636        targets.push((name.clone(), b.shape(), b.get().dtype()));
637    }
638
639    let mut unchanged = Vec::new();
640    let mut remapped = Vec::new();
641    let mut missing = Vec::new();
642    let mut used = vec![false; entries.len()];
643
644    // output: (new_name, checkpoint_index) in model order
645    let mut output: Vec<(String, usize)> = Vec::new();
646
647    // Index checkpoint entries by name for O(1) exact lookup
648    let name_index: std::collections::HashMap<&str, usize> =
649        entries.iter().enumerate().map(|(i, e)| (e.name.as_str(), i)).collect();
650
651    // Indices of model targets not yet matched
652    let mut unmatched: Vec<usize> = Vec::new();
653
654    // Pass 1: exact name + shape match
655    for (mi, (name, shape, _)) in targets.iter().enumerate() {
656        if let Some(&ci) = name_index.get(name.as_str()) {
657            if !used[ci] && entries[ci].shape == *shape {
658                unchanged.push(name.clone());
659                used[ci] = true;
660                output.push((name.clone(), ci));
661                continue;
662            }
663        }
664        unmatched.push(mi);
665    }
666
667    // Pass 2: shape+dtype matching in checkpoint order
668    for &mi in &unmatched {
669        let (name, shape, dtype) = &targets[mi];
670
671        let found = entries.iter().enumerate()
672            .find(|(ci, e)| !used[*ci] && e.shape == *shape && e.dtype == *dtype)
673            .map(|(ci, _)| ci);
674
675        if let Some(ci) = found {
676            remapped.push((entries[ci].name.clone(), name.clone()));
677            used[ci] = true;
678            output.push((name.clone(), ci));
679        } else {
680            missing.push(name.clone());
681        }
682    }
683
684    let dropped: Vec<String> = entries.iter().enumerate()
685        .filter(|(i, _)| !used[*i])
686        .map(|(_, e)| e.name.clone())
687        .collect();
688
689    // Write migrated checkpoint with zeroed structural hash
690    w.write_all(&MAGIC).map_err(io_err)?;
691    w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
692    w.write_all(&[0u8; HASH_LEN]).map_err(io_err)?;
693    w.write_all(&(output.len() as u32).to_le_bytes()).map_err(io_err)?;
694
695    for (name, ci) in &output {
696        write_raw_entry(w, name, &entries[*ci])?;
697    }
698
699    Ok(MigrateReport { unchanged, remapped, dropped, missing })
700}
701
702/// Migrate a checkpoint file. Detects gzip from `.gz` extension on both paths.
703///
704/// Source and destination must be different paths.
705pub fn migrate_checkpoint_file(
706    src: &str,
707    dst: &str,
708    params: &[(String, Parameter)],
709    buffers: &[(String, Buffer)],
710) -> Result<MigrateReport> {
711    let sf = std::fs::File::open(src).map_err(io_err)?;
712    let df = std::fs::File::create(dst).map_err(io_err)?;
713
714    match (src.ends_with(".gz"), dst.ends_with(".gz")) {
715        (true, true) => {
716            let mut r = flate2::read::GzDecoder::new(sf);
717            let mut w = flate2::write::GzEncoder::new(df, flate2::Compression::default());
718            let report = migrate_checkpoint(&mut r, &mut w, params, buffers)?;
719            w.finish().map_err(io_err)?;
720            Ok(report)
721        }
722        (true, false) => {
723            let mut r = flate2::read::GzDecoder::new(sf);
724            let mut w = std::io::BufWriter::new(df);
725            migrate_checkpoint(&mut r, &mut w, params, buffers)
726        }
727        (false, true) => {
728            let mut r = std::io::BufReader::new(sf);
729            let mut w = flate2::write::GzEncoder::new(df, flate2::Compression::default());
730            let report = migrate_checkpoint(&mut r, &mut w, params, buffers)?;
731            w.finish().map_err(io_err)?;
732            Ok(report)
733        }
734        (false, false) => {
735            let mut r = std::io::BufReader::new(sf);
736            let mut w = std::io::BufWriter::new(df);
737            migrate_checkpoint(&mut r, &mut w, params, buffers)
738        }
739    }
740}
741
742// --- Shared helpers ---
743
744pub(crate) fn io_err(e: impl std::fmt::Display) -> TensorError {
745    TensorError::new(&format!("io: {}", e))
746}
747
748fn check_err_raw(err: *mut i8) -> Result<()> {
749    if err.is_null() {
750        Ok(())
751    } else {
752        let msg = unsafe { std::ffi::CStr::from_ptr(err) }
753            .to_string_lossy()
754            .into_owned();
755        unsafe { flodl_sys::flodl_free_string(err) };
756        Err(TensorError::new(&msg))
757    }
758}
759
760fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
761    let mut buf = [0u8; 4];
762    r.read_exact(&mut buf).map_err(io_err)?;
763    Ok(u32::from_le_bytes(buf))
764}
765
766fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
767    let mut buf = [0u8; 8];
768    r.read_exact(&mut buf).map_err(io_err)?;
769    Ok(u64::from_le_bytes(buf))
770}
771
772fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
773    let mut buf = [0u8; 8];
774    r.read_exact(&mut buf).map_err(io_err)?;
775    Ok(i64::from_le_bytes(buf))
776}
777
778// Pub(crate) helpers for optimizer state serialization
779pub(crate) fn read_f64_le<R: Read>(r: &mut R) -> Result<f64> {
780    let mut buf = [0u8; 8];
781    r.read_exact(&mut buf).map_err(io_err)?;
782    Ok(f64::from_le_bytes(buf))
783}
784pub(crate) fn write_f64_le<W: Write>(w: &mut W, v: f64) -> Result<()> {
785    w.write_all(&v.to_le_bytes()).map_err(io_err)?;
786    Ok(())
787}
788pub(crate) fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<()> {
789    w.write_all(&v.to_le_bytes()).map_err(io_err)?;
790    Ok(())
791}
792pub(crate) fn write_i64_le<W: Write>(w: &mut W, v: i64) -> Result<()> {
793    w.write_all(&v.to_le_bytes()).map_err(io_err)?;
794    Ok(())
795}
796pub(crate) fn read_u32_le<R: Read>(r: &mut R) -> Result<u32> {
797    read_u32(r)
798}
799pub(crate) fn read_i64_le<R: Read>(r: &mut R) -> Result<i64> {
800    read_i64(r)
801}
802
803/// Decode a hex string to a 32-byte array.
804fn hex_to_bytes(hex: &str) -> Result<[u8; HASH_LEN]> {
805    if hex.len() != HASH_LEN * 2 {
806        return Err(TensorError::new(&format!(
807            "expected {} hex chars, got {}",
808            HASH_LEN * 2,
809            hex.len()
810        )));
811    }
812    let mut out = [0u8; HASH_LEN];
813    for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
814        let hi = hex_nibble(chunk[0])?;
815        let lo = hex_nibble(chunk[1])?;
816        out[i] = (hi << 4) | lo;
817    }
818    Ok(out)
819}
820
821fn hex_nibble(b: u8) -> Result<u8> {
822    match b {
823        b'0'..=b'9' => Ok(b - b'0'),
824        b'a'..=b'f' => Ok(b - b'a' + 10),
825        b'A'..=b'F' => Ok(b - b'A' + 10),
826        _ => Err(TensorError::new(&format!("invalid hex byte: {}", b))),
827    }
828}
829
830/// Encode a byte slice as a lowercase hex string.
831fn bytes_to_hex(bytes: &[u8]) -> String {
832    let mut s = String::with_capacity(bytes.len() * 2);
833    for &b in bytes {
834        use std::fmt::Write;
835        let _ = write!(s, "{:02x}", b);
836    }
837    s
838}
839
840#[cfg(test)]
841mod tests {
842    use super::*;
843    use crate::tensor::TensorOptions;
844
845    fn make_named_params(sizes: &[(i64, i64)]) -> Vec<(String, Parameter)> {
846        sizes.iter().enumerate().map(|(i, &(rows, cols))| {
847            let t = Tensor::randn(&[rows, cols], TensorOptions {
848                dtype: DType::Float32,
849                device: crate::tensor::test_device(),
850            }).unwrap();
851            let name = format!("layer_{}/weight", i);
852            (name.clone(), Parameter::new(t, "weight"))
853        }).collect()
854    }
855
856    fn make_named_buffers(sizes: &[i64]) -> Vec<(String, Buffer)> {
857        sizes.iter().enumerate().map(|(i, &features)| {
858            let t = Tensor::randn(&[features], TensorOptions {
859                dtype: DType::Float32,
860                device: crate::tensor::test_device(),
861            }).unwrap();
862            let name = format!("bn_{}/running_mean", i);
863            (name.clone(), Buffer::new(t, "running_mean"))
864        }).collect()
865    }
866
867    #[test]
868    fn test_named_roundtrip() {
869        let params = make_named_params(&[(4, 8), (8, 2)]);
870
871        let mut buf = Vec::new();
872        save_checkpoint(&mut buf, &params, &[], None).unwrap();
873
874        let load_params = make_named_params(&[(4, 8), (8, 2)]);
875        let mut cursor = std::io::Cursor::new(&buf);
876        let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
877
878        assert_eq!(report.loaded.len(), 2);
879        assert!(report.skipped.is_empty());
880        assert!(report.missing.is_empty());
881
882        for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
883            let src_data = src.variable.data().to_f32_vec().unwrap();
884            let dst_data = dst.variable.data().to_f32_vec().unwrap();
885            assert_eq!(src_data, dst_data);
886        }
887    }
888
889    #[test]
890    fn test_buffer_roundtrip() {
891        let params = make_named_params(&[(4, 8)]);
892        let buffers = make_named_buffers(&[8]);
893
894        let mut buf = Vec::new();
895        save_checkpoint(&mut buf, &params, &buffers, None).unwrap();
896
897        // Fresh model with same structure
898        let load_params = make_named_params(&[(4, 8)]);
899        let load_buffers = make_named_buffers(&[8]);
900        let mut cursor = std::io::Cursor::new(&buf);
901        let report = load_checkpoint(&mut cursor, &load_params, &load_buffers, None).unwrap();
902
903        assert_eq!(report.loaded.len(), 2); // 1 param + 1 buffer
904        assert!(report.skipped.is_empty());
905        assert!(report.missing.is_empty());
906
907        // Verify buffer data matches
908        let src_data = buffers[0].1.get().to_f32_vec().unwrap();
909        let dst_data = load_buffers[0].1.get().to_f32_vec().unwrap();
910        assert_eq!(src_data, dst_data);
911    }
912
913    #[test]
914    fn test_named_partial_load() {
915        let params_3 = make_named_params(&[(4, 8), (8, 4), (4, 2)]);
916
917        let mut buf = Vec::new();
918        save_checkpoint(&mut buf, &params_3, &[], None).unwrap();
919
920        let mut params_4 = make_named_params(&[(4, 8), (8, 4), (4, 2), (2, 1)]);
921        params_4[3].0 = "extra/weight".to_string();
922
923        let before_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
924
925        let mut cursor = std::io::Cursor::new(&buf);
926        let report = load_checkpoint(&mut cursor, &params_4, &[], None).unwrap();
927
928        assert_eq!(report.loaded.len(), 3);
929        assert_eq!(report.missing.len(), 1);
930        assert_eq!(report.missing[0], "extra/weight");
931        assert!(report.skipped.is_empty());
932
933        let after_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
934        assert_eq!(before_extra, after_extra);
935    }
936
937    #[test]
938    fn test_named_skipped_checkpoint_params() {
939        let params = make_named_params(&[(4, 8), (8, 2)]);
940
941        let mut buf = Vec::new();
942        save_checkpoint(&mut buf, &params, &[], None).unwrap();
943
944        let model = vec![params[0].clone()];
945        let mut cursor = std::io::Cursor::new(&buf);
946        let report = load_checkpoint(&mut cursor, &model, &[], None).unwrap();
947
948        assert_eq!(report.loaded.len(), 1);
949        assert_eq!(report.skipped.len(), 1);
950        assert!(report.missing.is_empty());
951    }
952
953    #[test]
954    fn test_named_shape_mismatch_error() {
955        let params = make_named_params(&[(4, 8)]);
956
957        let mut buf = Vec::new();
958        save_checkpoint(&mut buf, &params, &[], None).unwrap();
959
960        let wrong_shape = vec![(
961            "layer_0/weight".to_string(),
962            Parameter::new(
963                Tensor::randn(&[4, 4], TensorOptions {
964                    dtype: DType::Float32,
965                    device: crate::tensor::test_device(),
966                }).unwrap(),
967                "weight",
968            ),
969        )];
970        let mut cursor = std::io::Cursor::new(&buf);
971        let result = load_checkpoint(&mut cursor, &wrong_shape, &[], None);
972        assert!(result.is_err(), "shape mismatch should be an error");
973        let err_msg = format!("{}", result.unwrap_err());
974        assert!(err_msg.contains("shape mismatch"), "error should mention shape: {}", err_msg);
975    }
976
977    #[test]
978    fn test_buffer_shape_mismatch_error() {
979        let buffers = make_named_buffers(&[8]);
980
981        let mut buf = Vec::new();
982        save_checkpoint(&mut buf, &[], &buffers, None).unwrap();
983
984        let wrong_buffers = vec![(
985            "bn_0/running_mean".to_string(),
986            Buffer::new(
987                Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(),
988                "running_mean",
989            ),
990        )];
991        let mut cursor = std::io::Cursor::new(&buf);
992        let result = load_checkpoint(&mut cursor, &[], &wrong_buffers, None);
993        assert!(result.is_err());
994        assert!(format!("{}", result.unwrap_err()).contains("shape mismatch"));
995    }
996
997    #[test]
998    fn test_compressed_roundtrip() {
999        let params = make_named_params(&[(16, 32), (32, 8)]);
1000        let buffers = make_named_buffers(&[32]);
1001
1002        let dir = std::env::temp_dir();
1003        let gz_path = dir.join("test_ckpt_v2.fdl.gz");
1004        let plain_path = dir.join("test_ckpt_v2.fdl");
1005        let gz = gz_path.to_str().unwrap();
1006        let plain = plain_path.to_str().unwrap();
1007
1008        save_checkpoint_file(gz, &params, &buffers, None).unwrap();
1009        save_checkpoint_file(plain, &params, &buffers, None).unwrap();
1010
1011        // Compressed should be smaller
1012        let gz_size = std::fs::metadata(gz).unwrap().len();
1013        let plain_size = std::fs::metadata(plain).unwrap().len();
1014        assert!(gz_size < plain_size, "gz={} should be < plain={}", gz_size, plain_size);
1015
1016        // Load from compressed and verify
1017        let load_params = make_named_params(&[(16, 32), (32, 8)]);
1018        let load_buffers = make_named_buffers(&[32]);
1019        let report = load_checkpoint_file(gz, &load_params, &load_buffers, None).unwrap();
1020        assert_eq!(report.loaded.len(), 3); // 2 params + 1 buffer
1021
1022        for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
1023            assert_eq!(src.variable.data().to_f32_vec().unwrap(),
1024                       dst.variable.data().to_f32_vec().unwrap());
1025        }
1026
1027        let src_buf = buffers[0].1.get().to_f32_vec().unwrap();
1028        let dst_buf = load_buffers[0].1.get().to_f32_vec().unwrap();
1029        assert_eq!(src_buf, dst_buf);
1030
1031        std::fs::remove_file(gz).ok();
1032        std::fs::remove_file(plain).ok();
1033    }
1034
1035    #[test]
1036    fn test_hash_roundtrip() {
1037        let params = make_named_params(&[(4, 8)]);
1038        // Use a known 64-char hex hash
1039        let hash = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
1040
1041        let mut buf = Vec::new();
1042        save_checkpoint(&mut buf, &params, &[], Some(hash)).unwrap();
1043
1044        let load_params = make_named_params(&[(4, 8)]);
1045        let mut cursor = std::io::Cursor::new(&buf);
1046        // Same hash: should succeed
1047        let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
1048        assert_eq!(report.loaded.len(), 1);
1049    }
1050
1051    #[test]
1052    fn test_hash_mismatch_error() {
1053        let params = make_named_params(&[(4, 8)]);
1054        let hash_a = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
1055        let hash_b = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
1056
1057        let mut buf = Vec::new();
1058        save_checkpoint(&mut buf, &params, &[], Some(hash_a)).unwrap();
1059
1060        let load_params = make_named_params(&[(4, 8)]);
1061        let mut cursor = std::io::Cursor::new(&buf);
1062        let result = load_checkpoint(&mut cursor, &load_params, &[], Some(hash_b));
1063        assert!(result.is_err());
1064        let msg = format!("{}", result.unwrap_err());
1065        assert!(msg.contains("architecture mismatch"), "error: {}", msg);
1066    }
1067
1068    #[test]
1069    fn test_zero_hash_skips_validation() {
1070        let params = make_named_params(&[(4, 8)]);
1071
1072        // Save with no hash (zero bytes)
1073        let mut buf = Vec::new();
1074        save_checkpoint(&mut buf, &params, &[], None).unwrap();
1075
1076        // Load with a hash expectation — should still succeed (file has zeros)
1077        let hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
1078        let load_params = make_named_params(&[(4, 8)]);
1079        let mut cursor = std::io::Cursor::new(&buf);
1080        let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
1081        assert_eq!(report.loaded.len(), 1);
1082
1083        // Save with hash, load with None — should succeed (no expected hash)
1084        let mut buf2 = Vec::new();
1085        save_checkpoint(&mut buf2, &params, &[], Some(hash)).unwrap();
1086        let load_params2 = make_named_params(&[(4, 8)]);
1087        let mut cursor2 = std::io::Cursor::new(&buf2);
1088        let report2 = load_checkpoint(&mut cursor2, &load_params2, &[], None).unwrap();
1089        assert_eq!(report2.loaded.len(), 1);
1090    }
1091
1092    /// Write a checkpoint with an explicit version byte (for testing v1 migration).
1093    fn save_checkpoint_versioned<W: std::io::Write>(
1094        w: &mut W,
1095        version: u32,
1096        params: &[(String, Parameter)],
1097        buffers: &[(String, Buffer)],
1098    ) {
1099        w.write_all(&MAGIC).unwrap();
1100        w.write_all(&version.to_le_bytes()).unwrap();
1101        w.write_all(&[0u8; HASH_LEN]).unwrap();
1102        let total = (params.len() + buffers.len()) as u32;
1103        w.write_all(&total.to_le_bytes()).unwrap();
1104        for (name, p) in params {
1105            let name_bytes = name.as_bytes();
1106            w.write_all(&(name_bytes.len() as u32).to_le_bytes()).unwrap();
1107            w.write_all(name_bytes).unwrap();
1108            write_tensor_data(w, &p.variable.data()).unwrap();
1109        }
1110        for (name, b) in buffers {
1111            let name_bytes = name.as_bytes();
1112            w.write_all(&(name_bytes.len() as u32).to_le_bytes()).unwrap();
1113            w.write_all(name_bytes).unwrap();
1114            write_tensor_data(w, &b.get()).unwrap();
1115        }
1116    }
1117
1118    #[test]
1119    fn test_migrate_all_renamed() {
1120        // Simulate v1 checkpoint with old-style names
1121        let old_params = vec![
1122            ("linear_0/weight".to_string(), Parameter::new(
1123                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1124            ("linear_1/weight".to_string(), Parameter::new(
1125                Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1126        ];
1127        let mut ckpt = Vec::new();
1128        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1129
1130        // New model with renamed tags
1131        let new_params = vec![
1132            ("encoder/weight".to_string(), Parameter::new(
1133                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1134            ("decoder/weight".to_string(), Parameter::new(
1135                Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1136        ];
1137
1138        let mut out = Vec::new();
1139        let report = migrate_checkpoint(
1140            &mut std::io::Cursor::new(&ckpt), &mut out,
1141            &new_params, &[],
1142        ).unwrap();
1143
1144        assert!(report.unchanged.is_empty());
1145        assert_eq!(report.remapped.len(), 2);
1146        assert!(report.dropped.is_empty());
1147        assert!(report.missing.is_empty());
1148        assert!(report.is_complete());
1149
1150        // Verify the migrated checkpoint loads correctly
1151        let verify_params = vec![
1152            ("encoder/weight".to_string(), Parameter::new(
1153                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1154            ("decoder/weight".to_string(), Parameter::new(
1155                Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1156        ];
1157        let mut cursor = std::io::Cursor::new(&out);
1158        let load_report = load_checkpoint(&mut cursor, &verify_params, &[], None).unwrap();
1159        assert_eq!(load_report.loaded.len(), 2);
1160        assert!(load_report.missing.is_empty());
1161
1162        // Verify data preserved: old param data matches loaded data
1163        for (i, (_, vp)) in verify_params.iter().enumerate() {
1164            let expected = old_params[i].1.variable.data().to_f32_vec().unwrap();
1165            let got = vp.variable.data().to_f32_vec().unwrap();
1166            assert_eq!(expected, got, "data mismatch for param {}", i);
1167        }
1168    }
1169
1170    #[test]
1171    fn test_migrate_partial_rename() {
1172        // Some names match, some don't
1173        let old_params = vec![
1174            ("shared/weight".to_string(), Parameter::new(
1175                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1176            ("linear_0/weight".to_string(), Parameter::new(
1177                Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1178        ];
1179        let mut ckpt = Vec::new();
1180        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1181
1182        let new_params = vec![
1183            ("shared/weight".to_string(), Parameter::new(
1184                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1185            ("encoder/weight".to_string(), Parameter::new(
1186                Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1187        ];
1188
1189        let mut out = Vec::new();
1190        let report = migrate_checkpoint(
1191            &mut std::io::Cursor::new(&ckpt), &mut out,
1192            &new_params, &[],
1193        ).unwrap();
1194
1195        assert_eq!(report.unchanged, vec!["shared/weight"]);
1196        assert_eq!(report.remapped.len(), 1);
1197        assert_eq!(report.remapped[0], ("linear_0/weight".to_string(), "encoder/weight".to_string()));
1198        assert!(report.is_complete());
1199    }
1200
1201    #[test]
1202    fn test_migrate_with_buffers() {
1203        let old_params = vec![
1204            ("linear_0/weight".to_string(), Parameter::new(
1205                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1206        ];
1207        let old_buffers = vec![
1208            ("bn_0/running_mean".to_string(), Buffer::new(
1209                Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1210        ];
1211        let mut ckpt = Vec::new();
1212        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &old_buffers);
1213
1214        let new_params = vec![
1215            ("encoder/weight".to_string(), Parameter::new(
1216                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1217        ];
1218        let new_buffers = vec![
1219            ("norm/running_mean".to_string(), Buffer::new(
1220                Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1221        ];
1222
1223        let mut out = Vec::new();
1224        let report = migrate_checkpoint(
1225            &mut std::io::Cursor::new(&ckpt), &mut out,
1226            &new_params, &new_buffers,
1227        ).unwrap();
1228
1229        assert_eq!(report.remapped.len(), 2);
1230        assert!(report.is_complete());
1231
1232        // Verify migrated checkpoint loads with new names
1233        let vp = vec![
1234            ("encoder/weight".to_string(), Parameter::new(
1235                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1236        ];
1237        let vb = vec![
1238            ("norm/running_mean".to_string(), Buffer::new(
1239                Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1240        ];
1241        let mut cursor = std::io::Cursor::new(&out);
1242        let load_report = load_checkpoint(&mut cursor, &vp, &vb, None).unwrap();
1243        assert_eq!(load_report.loaded.len(), 2);
1244    }
1245
1246    #[test]
1247    fn test_migrate_dropped_and_missing() {
1248        let old_params = vec![
1249            ("old/weight".to_string(), Parameter::new(
1250                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1251            ("removed/weight".to_string(), Parameter::new(
1252                Tensor::randn(&[16, 16], crate::tensor::test_opts()).unwrap(), "weight")),
1253        ];
1254        let mut ckpt = Vec::new();
1255        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1256
1257        // New model: one matching shape, one entirely new
1258        let new_params = vec![
1259            ("new/weight".to_string(), Parameter::new(
1260                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1261            ("added/weight".to_string(), Parameter::new(
1262                Tensor::randn(&[32, 32], crate::tensor::test_opts()).unwrap(), "weight")),
1263        ];
1264
1265        let mut out = Vec::new();
1266        let report = migrate_checkpoint(
1267            &mut std::io::Cursor::new(&ckpt), &mut out,
1268            &new_params, &[],
1269        ).unwrap();
1270
1271        assert_eq!(report.remapped.len(), 1);
1272        assert_eq!(report.dropped, vec!["removed/weight"]);
1273        assert_eq!(report.missing, vec!["added/weight"]);
1274        assert!(!report.is_complete());
1275    }
1276
1277    #[test]
1278    fn test_migrate_positional_disambiguation() {
1279        // Two params with identical shape — must match by position
1280        let old_params = vec![
1281            ("linear_0/weight".to_string(), Parameter::new(
1282                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1283            ("linear_1/weight".to_string(), Parameter::new(
1284                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1285        ];
1286        let mut ckpt = Vec::new();
1287        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1288
1289        let new_params = vec![
1290            ("encoder/weight".to_string(), Parameter::new(
1291                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1292            ("decoder/weight".to_string(), Parameter::new(
1293                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1294        ];
1295
1296        let mut out = Vec::new();
1297        let report = migrate_checkpoint(
1298            &mut std::io::Cursor::new(&ckpt), &mut out,
1299            &new_params, &[],
1300        ).unwrap();
1301
1302        assert_eq!(report.remapped.len(), 2);
1303        // Positional: first old → first new, second old → second new
1304        assert_eq!(report.remapped[0].0, "linear_0/weight");
1305        assert_eq!(report.remapped[0].1, "encoder/weight");
1306        assert_eq!(report.remapped[1].0, "linear_1/weight");
1307        assert_eq!(report.remapped[1].1, "decoder/weight");
1308
1309        // Verify correct data assignment
1310        let vp = vec![
1311            ("encoder/weight".to_string(), Parameter::new(
1312                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1313            ("decoder/weight".to_string(), Parameter::new(
1314                Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1315        ];
1316        let mut cursor = std::io::Cursor::new(&out);
1317        load_checkpoint(&mut cursor, &vp, &[], None).unwrap();
1318
1319        // encoder/weight should have linear_0's data, decoder/weight should have linear_1's data
1320        let enc_data = vp[0].1.variable.data().to_f32_vec().unwrap();
1321        let dec_data = vp[1].1.variable.data().to_f32_vec().unwrap();
1322        let old_0 = old_params[0].1.variable.data().to_f32_vec().unwrap();
1323        let old_1 = old_params[1].1.variable.data().to_f32_vec().unwrap();
1324        assert_eq!(enc_data, old_0);
1325        assert_eq!(dec_data, old_1);
1326    }
1327
1328    #[test]
1329    fn test_migrate_v1_writes_v2() {
1330        let old_params = vec![
1331            ("x/weight".to_string(), Parameter::new(
1332                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1333        ];
1334        let mut ckpt = Vec::new();
1335        save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1336
1337        // Confirm source is v1
1338        let mut peek = std::io::Cursor::new(&ckpt);
1339        let mut magic = [0u8; 4];
1340        std::io::Read::read_exact(&mut peek, &mut magic).unwrap();
1341        let mut vbuf = [0u8; 4];
1342        std::io::Read::read_exact(&mut peek, &mut vbuf).unwrap();
1343        assert_eq!(u32::from_le_bytes(vbuf), 1);
1344
1345        let new_params = vec![
1346            ("y/weight".to_string(), Parameter::new(
1347                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1348        ];
1349
1350        let mut out = Vec::new();
1351        migrate_checkpoint(
1352            &mut std::io::Cursor::new(&ckpt), &mut out,
1353            &new_params, &[],
1354        ).unwrap();
1355
1356        // Confirm output is v2
1357        let mut peek2 = std::io::Cursor::new(&out);
1358        std::io::Read::read_exact(&mut peek2, &mut magic).unwrap();
1359        assert_eq!(&magic, b"FDLC");
1360        std::io::Read::read_exact(&mut peek2, &mut vbuf).unwrap();
1361        assert_eq!(u32::from_le_bytes(vbuf), VERSION); // should be 2
1362    }
1363
1364    #[test]
1365    fn test_migrate_file_roundtrip() {
1366        let old_params = vec![
1367            ("old/weight".to_string(), Parameter::new(
1368                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1369        ];
1370        let dir = std::env::temp_dir();
1371        let src = dir.join("test_migrate_src.fdl");
1372        let dst = dir.join("test_migrate_dst.fdl");
1373
1374        // Write v1 checkpoint to file
1375        {
1376            let f = std::fs::File::create(&src).unwrap();
1377            let mut w = std::io::BufWriter::new(f);
1378            save_checkpoint_versioned(&mut w, 1, &old_params, &[]);
1379        }
1380
1381        let new_params = vec![
1382            ("new/weight".to_string(), Parameter::new(
1383                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1384        ];
1385
1386        let report = migrate_checkpoint_file(
1387            src.to_str().unwrap(),
1388            dst.to_str().unwrap(),
1389            &new_params, &[],
1390        ).unwrap();
1391        assert_eq!(report.remapped.len(), 1);
1392        assert!(report.is_complete());
1393
1394        // Load migrated file
1395        let vp = vec![
1396            ("new/weight".to_string(), Parameter::new(
1397                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1398        ];
1399        let load_report = load_checkpoint_file(
1400            dst.to_str().unwrap(), &vp, &[], None,
1401        ).unwrap();
1402        assert_eq!(load_report.loaded.len(), 1);
1403
1404        // Verify data preserved
1405        let expected = old_params[0].1.variable.data().to_f32_vec().unwrap();
1406        let got = vp[0].1.variable.data().to_f32_vec().unwrap();
1407        assert_eq!(expected, got);
1408
1409        std::fs::remove_file(src).ok();
1410        std::fs::remove_file(dst).ok();
1411    }
1412
1413    #[test]
1414    fn test_migrate_display() {
1415        let report = MigrateReport {
1416            unchanged: vec!["shared/weight".to_string()],
1417            remapped: vec![("old/bias".to_string(), "new/bias".to_string())],
1418            dropped: vec!["removed/weight".to_string()],
1419            missing: vec!["added/weight".to_string()],
1420        };
1421        let text = format!("{}", report);
1422        assert!(text.contains("unchanged (1)"));
1423        assert!(text.contains("remapped (1)"));
1424        assert!(text.contains("old/bias -> new/bias"));
1425        assert!(text.contains("dropped (1)"));
1426        assert!(text.contains("missing (1)"));
1427    }
1428
1429    #[test]
1430    fn test_checkpoint_version_peek() {
1431        let params = make_named_params(&[(4, 8)]);
1432        let dir = std::env::temp_dir();
1433        let path = dir.join("test_version_peek.fdl");
1434        save_checkpoint_file(path.to_str().unwrap(), &params, &[], None).unwrap();
1435
1436        let v = checkpoint_version(path.to_str().unwrap()).unwrap();
1437        assert_eq!(v, VERSION);
1438
1439        std::fs::remove_file(path).ok();
1440    }
1441
1442    #[test]
1443    fn test_load_accepts_v1() {
1444        // v1 checkpoints must still load in v2 builds
1445        let params = vec![
1446            ("x/weight".to_string(), Parameter::new(
1447                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1448        ];
1449        let mut ckpt = Vec::new();
1450        save_checkpoint_versioned(&mut ckpt, 1, &params, &[]);
1451
1452        let load_params = vec![
1453            ("x/weight".to_string(), Parameter::new(
1454                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1455        ];
1456        let mut cursor = std::io::Cursor::new(&ckpt);
1457        let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
1458        assert_eq!(report.loaded.len(), 1);
1459
1460        let expected = params[0].1.variable.data().to_f32_vec().unwrap();
1461        let got = load_params[0].1.variable.data().to_f32_vec().unwrap();
1462        assert_eq!(expected, got);
1463    }
1464
1465    // --- Edge case / corruption tests ---
1466
1467    #[test]
1468    fn test_truncated_checkpoint_header_only() {
1469        // Write valid header but truncate before any entry data
1470        let mut buf = Vec::new();
1471        buf.extend_from_slice(&MAGIC);
1472        buf.extend_from_slice(&VERSION.to_le_bytes());
1473        buf.extend_from_slice(&[0u8; HASH_LEN]);
1474        // Claim 5 entries, but provide none
1475        buf.extend_from_slice(&5u32.to_le_bytes());
1476
1477        let params = make_named_params(&[(4, 8)]);
1478        let mut cursor = std::io::Cursor::new(&buf);
1479        let result = load_checkpoint(&mut cursor, &params, &[], None);
1480        assert!(result.is_err(), "truncated checkpoint should return Err, not panic");
1481        let msg = format!("{}", result.unwrap_err());
1482        assert!(msg.contains("io:"), "should be an IO error: {}", msg);
1483    }
1484
1485    #[test]
1486    fn test_truncated_checkpoint_mid_entry() {
1487        // Save a valid checkpoint, then truncate in the middle of the first entry
1488        let params = make_named_params(&[(4, 8)]);
1489        let mut full = Vec::new();
1490        save_checkpoint(&mut full, &params, &[], None).unwrap();
1491
1492        // Header = 4 (magic) + 4 (version) + 32 (hash) + 4 (count) = 44
1493        // Truncate partway through the first entry (e.g., keep only 50 bytes)
1494        let truncated = full[..50.min(full.len())].to_vec();
1495
1496        let load_params = make_named_params(&[(4, 8)]);
1497        let mut cursor = std::io::Cursor::new(&truncated);
1498        let result = load_checkpoint(&mut cursor, &load_params, &[], None);
1499        assert!(result.is_err(), "truncated mid-entry should return Err");
1500    }
1501
1502    #[test]
1503    fn test_empty_file() {
1504        // Zero bytes: read_exact for magic should fail
1505        let buf: Vec<u8> = Vec::new();
1506        let params = make_named_params(&[(4, 8)]);
1507        let mut cursor = std::io::Cursor::new(&buf);
1508        let result = load_checkpoint(&mut cursor, &params, &[], None);
1509        assert!(result.is_err(), "empty file should return Err");
1510    }
1511
1512    #[test]
1513    fn test_invalid_magic_bytes() {
1514        let mut buf = Vec::new();
1515        buf.extend_from_slice(b"JUNK"); // wrong magic
1516        buf.extend_from_slice(&VERSION.to_le_bytes());
1517        buf.extend_from_slice(&[0u8; HASH_LEN]);
1518        buf.extend_from_slice(&0u32.to_le_bytes());
1519
1520        let params = make_named_params(&[(4, 8)]);
1521        let mut cursor = std::io::Cursor::new(&buf);
1522        let result = load_checkpoint(&mut cursor, &params, &[], None);
1523        assert!(result.is_err());
1524        let msg = format!("{}", result.unwrap_err());
1525        assert!(msg.contains("bad magic"), "error should mention bad magic: {}", msg);
1526    }
1527
1528    #[test]
1529    fn test_invalid_magic_checkpoint_version() {
1530        // checkpoint_version() should also reject bad magic
1531        let dir = std::env::temp_dir();
1532        let path = dir.join("test_bad_magic_version.fdl");
1533        std::fs::write(&path, b"NOT_FDLC_data").unwrap();
1534
1535        let result = checkpoint_version(path.to_str().unwrap());
1536        assert!(result.is_err());
1537        let msg = format!("{}", result.unwrap_err());
1538        assert!(msg.contains("bad magic"), "error: {}", msg);
1539
1540        std::fs::remove_file(path).ok();
1541    }
1542
1543    #[test]
1544    fn test_unsupported_version_high() {
1545        let mut buf = Vec::new();
1546        buf.extend_from_slice(&MAGIC);
1547        buf.extend_from_slice(&99u32.to_le_bytes()); // version 99
1548        buf.extend_from_slice(&[0u8; HASH_LEN]);
1549        buf.extend_from_slice(&0u32.to_le_bytes());
1550
1551        let params = make_named_params(&[(4, 8)]);
1552        let mut cursor = std::io::Cursor::new(&buf);
1553        let result = load_checkpoint(&mut cursor, &params, &[], None);
1554        assert!(result.is_err());
1555        let msg = format!("{}", result.unwrap_err());
1556        assert!(msg.contains("unsupported checkpoint version"), "error: {}", msg);
1557        assert!(msg.contains("99"), "should mention version 99: {}", msg);
1558    }
1559
1560    #[test]
1561    fn test_unsupported_version_zero() {
1562        // Version 0 is also rejected (valid range is 1..=MAX_VERSION)
1563        let mut buf = Vec::new();
1564        buf.extend_from_slice(&MAGIC);
1565        buf.extend_from_slice(&0u32.to_le_bytes()); // version 0
1566        buf.extend_from_slice(&[0u8; HASH_LEN]);
1567        buf.extend_from_slice(&0u32.to_le_bytes());
1568
1569        let params = make_named_params(&[(4, 8)]);
1570        let mut cursor = std::io::Cursor::new(&buf);
1571        let result = load_checkpoint(&mut cursor, &params, &[], None);
1572        assert!(result.is_err());
1573        let msg = format!("{}", result.unwrap_err());
1574        assert!(msg.contains("unsupported checkpoint version"), "error: {}", msg);
1575    }
1576
1577    #[test]
1578    fn test_hash_mismatch_both_nonzero() {
1579        // Both file and expected have nonzero hashes that differ
1580        let params = make_named_params(&[(4, 8)]);
1581        let hash_a = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
1582        let hash_b = "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210";
1583
1584        let mut buf = Vec::new();
1585        save_checkpoint(&mut buf, &params, &[], Some(hash_a)).unwrap();
1586
1587        let load_params = make_named_params(&[(4, 8)]);
1588        let mut cursor = std::io::Cursor::new(&buf);
1589        let result = load_checkpoint(&mut cursor, &load_params, &[], Some(hash_b));
1590        assert!(result.is_err());
1591        let msg = format!("{}", result.unwrap_err());
1592        assert!(msg.contains("architecture mismatch"), "error: {}", msg);
1593        // Error message should include both hashes for diagnostics
1594        assert!(msg.contains(hash_b), "should show expected hash: {}", msg);
1595    }
1596
1597    #[test]
1598    fn test_zero_entries_empty_model() {
1599        // Save a checkpoint with no parameters and no buffers
1600        let mut buf = Vec::new();
1601        save_checkpoint(&mut buf, &[], &[], None).unwrap();
1602
1603        // Load into an empty model
1604        let mut cursor = std::io::Cursor::new(&buf);
1605        let report = load_checkpoint(&mut cursor, &[], &[], None).unwrap();
1606        assert!(report.loaded.is_empty());
1607        assert!(report.skipped.is_empty());
1608        assert!(report.missing.is_empty());
1609    }
1610
1611    #[test]
1612    fn test_zero_entries_nonempty_model() {
1613        // Save empty checkpoint, load into model that expects params
1614        let mut buf = Vec::new();
1615        save_checkpoint(&mut buf, &[], &[], None).unwrap();
1616
1617        let load_params = make_named_params(&[(4, 8)]);
1618        let mut cursor = std::io::Cursor::new(&buf);
1619        let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
1620        assert!(report.loaded.is_empty());
1621        assert!(report.skipped.is_empty());
1622        assert_eq!(report.missing.len(), 1, "model param should be reported as missing");
1623    }
1624
1625    #[test]
1626    fn test_shape_mismatch_transposed() {
1627        // Save [4, 8], try to load into [8, 4] (transposed, same numel)
1628        let params = vec![
1629            ("layer/weight".to_string(), Parameter::new(
1630                Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1631        ];
1632        let mut buf = Vec::new();
1633        save_checkpoint(&mut buf, &params, &[], None).unwrap();
1634
1635        let wrong_params = vec![
1636            ("layer/weight".to_string(), Parameter::new(
1637                Tensor::randn(&[8, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1638        ];
1639        let mut cursor = std::io::Cursor::new(&buf);
1640        let result = load_checkpoint(&mut cursor, &wrong_params, &[], None);
1641        assert!(result.is_err(), "transposed shape should be a mismatch error");
1642        let msg = format!("{}", result.unwrap_err());
1643        assert!(msg.contains("shape mismatch"), "error: {}", msg);
1644        assert!(msg.contains("[4, 8]"), "should show checkpoint shape: {}", msg);
1645        assert!(msg.contains("[8, 4]"), "should show model shape: {}", msg);
1646    }
1647
1648    #[test]
1649    fn test_dtype_mismatch_auto_cast() {
1650        // Save as f32, load into f64 parameter. The code does to_dtype() automatically.
1651        let f32_param = vec![
1652            ("layer/weight".to_string(), Parameter::new(
1653                Tensor::ones(&[2, 3], crate::tensor::test_opts()).unwrap(), "weight")),
1654        ];
1655        let mut buf = Vec::new();
1656        save_checkpoint(&mut buf, &f32_param, &[], None).unwrap();
1657
1658        // Create f64 parameter with same shape
1659        let f64_param = vec![
1660            ("layer/weight".to_string(), Parameter::new(
1661                Tensor::zeros(&[2, 3], TensorOptions {
1662                    dtype: DType::Float64,
1663                    device: crate::tensor::test_device(),
1664                }).unwrap(), "weight")),
1665        ];
1666        let mut cursor = std::io::Cursor::new(&buf);
1667        let report = load_checkpoint(&mut cursor, &f64_param, &[], None).unwrap();
1668        assert_eq!(report.loaded.len(), 1, "dtype auto-cast should succeed");
1669
1670        // Verify the loaded data is correct and in f64
1671        let loaded = f64_param[0].1.variable.data();
1672        assert_eq!(loaded.dtype(), DType::Float64);
1673        let vals = loaded.to_f64_vec().unwrap();
1674        for v in vals {
1675            assert!((v - 1.0).abs() < 1e-6, "expected ~1.0, got {}", v);
1676        }
1677    }
1678
1679    #[test]
1680    fn test_dtype_mismatch_buffer_auto_cast() {
1681        // Same auto-cast test for buffers
1682        let f32_buffers = vec![
1683            ("norm/running_mean".to_string(), Buffer::new(
1684                Tensor::ones(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1685        ];
1686        let mut buf = Vec::new();
1687        save_checkpoint(&mut buf, &[], &f32_buffers, None).unwrap();
1688
1689        let f64_buffers = vec![
1690            ("norm/running_mean".to_string(), Buffer::new(
1691                Tensor::zeros(&[8], TensorOptions {
1692                    dtype: DType::Float64,
1693                    device: crate::tensor::test_device(),
1694                }).unwrap(), "running_mean")),
1695        ];
1696        let mut cursor = std::io::Cursor::new(&buf);
1697        let report = load_checkpoint(&mut cursor, &[], &f64_buffers, None).unwrap();
1698        assert_eq!(report.loaded.len(), 1);
1699        assert_eq!(f64_buffers[0].1.get().dtype(), DType::Float64);
1700        let vals = f64_buffers[0].1.get().to_f64_vec().unwrap();
1701        for v in vals {
1702            assert!((v - 1.0).abs() < 1e-6);
1703        }
1704    }
1705
1706    #[test]
1707    fn test_compressed_roundtrip_with_hash() {
1708        // Test gz compression with structural hash validation
1709        let params = make_named_params(&[(8, 16)]);
1710        let hash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
1711
1712        let dir = std::env::temp_dir();
1713        let gz_path = dir.join("test_ckpt_hash_gz.fdl.gz");
1714        let path_str = gz_path.to_str().unwrap();
1715
1716        save_checkpoint_file(path_str, &params, &[], Some(hash)).unwrap();
1717
1718        // Load with matching hash
1719        let load_params = make_named_params(&[(8, 16)]);
1720        let report = load_checkpoint_file(path_str, &load_params, &[], Some(hash)).unwrap();
1721        assert_eq!(report.loaded.len(), 1);
1722
1723        // Load with wrong hash should fail
1724        let bad_hash = "1111111111111111111111111111111111111111111111111111111111111111";
1725        let load_params2 = make_named_params(&[(8, 16)]);
1726        let result = load_checkpoint_file(path_str, &load_params2, &[], Some(bad_hash));
1727        assert!(result.is_err());
1728
1729        std::fs::remove_file(gz_path).ok();
1730    }
1731
1732    #[test]
1733    fn test_corrupted_gz_file() {
1734        // Write valid gz header then garbage: should produce an error
1735        let dir = std::env::temp_dir();
1736        let path = dir.join("test_corrupt.fdl.gz");
1737        // Write some garbage that is not valid gzip
1738        std::fs::write(&path, b"\x1f\x8b\x08\x00GARBAGE_NOT_VALID_GZ").unwrap();
1739
1740        let params = make_named_params(&[(4, 8)]);
1741        let result = load_checkpoint_file(path.to_str().unwrap(), &params, &[], None);
1742        assert!(result.is_err(), "corrupted gz should return Err");
1743
1744        std::fs::remove_file(path).ok();
1745    }
1746
1747    #[test]
1748    fn test_unknown_dtype_tag() {
1749        // Manually craft a checkpoint with an invalid dtype tag byte
1750        let mut buf = Vec::new();
1751        buf.extend_from_slice(&MAGIC);
1752        buf.extend_from_slice(&VERSION.to_le_bytes());
1753        buf.extend_from_slice(&[0u8; HASH_LEN]);
1754        buf.extend_from_slice(&1u32.to_le_bytes()); // 1 entry
1755
1756        // Entry name
1757        let name = b"layer/weight";
1758        buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
1759        buf.extend_from_slice(name);
1760
1761        // ndim = 1, shape = [4]
1762        buf.extend_from_slice(&1u32.to_le_bytes());
1763        buf.extend_from_slice(&4i64.to_le_bytes());
1764
1765        // Invalid dtype tag (255)
1766        buf.push(255);
1767
1768        // byte_count = 16 (4 * f32), then dummy data
1769        buf.extend_from_slice(&16u64.to_le_bytes());
1770        buf.extend_from_slice(&[0u8; 16]);
1771
1772        let params = vec![
1773            ("layer/weight".to_string(), Parameter::new(
1774                Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(), "weight")),
1775        ];
1776        let mut cursor = std::io::Cursor::new(&buf);
1777        let result = load_checkpoint(&mut cursor, &params, &[], None);
1778        assert!(result.is_err());
1779        let msg = format!("{}", result.unwrap_err());
1780        assert!(msg.contains("unknown dtype tag"), "error: {}", msg);
1781    }
1782
1783    #[test]
1784    fn test_checkpoint_keys_peeks_names_without_loading_data() {
1785        let params = vec![
1786            (
1787                "encoder/layer/weight".to_string(),
1788                Parameter::new(
1789                    Tensor::ones(&[4, 8], crate::tensor::test_opts()).unwrap(),
1790                    "weight",
1791                ),
1792            ),
1793            (
1794                "pooler/dense/weight".to_string(),
1795                Parameter::new(
1796                    Tensor::ones(&[8, 8], crate::tensor::test_opts()).unwrap(),
1797                    "weight",
1798                ),
1799            ),
1800        ];
1801        let buffers = vec![(
1802            "encoder/layer/running_mean".to_string(),
1803            Buffer::new(
1804                Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(),
1805                "running_mean",
1806            ),
1807        )];
1808
1809        let dir = std::env::temp_dir();
1810        let path = dir.join("test_checkpoint_keys_peek.fdl");
1811        let path_str = path.to_str().unwrap();
1812
1813        save_checkpoint_file(path_str, &params, &buffers, None).unwrap();
1814        let keys = checkpoint_keys(path_str).unwrap();
1815        assert_eq!(
1816            keys,
1817            vec![
1818                "encoder/layer/weight".to_string(),
1819                "pooler/dense/weight".to_string(),
1820                "encoder/layer/running_mean".to_string(),
1821            ],
1822            "params first then buffers, in declaration order",
1823        );
1824
1825        std::fs::remove_file(path_str).ok();
1826    }
1827
1828    #[test]
1829    fn test_checkpoint_keys_handles_gzip() {
1830        let params = vec![(
1831            "x/w".to_string(),
1832            Parameter::new(
1833                Tensor::ones(&[2, 2], crate::tensor::test_opts()).unwrap(),
1834                "w",
1835            ),
1836        )];
1837        let dir = std::env::temp_dir();
1838        let path = dir.join("test_checkpoint_keys_gz.fdl.gz");
1839        let path_str = path.to_str().unwrap();
1840
1841        save_checkpoint_file(path_str, &params, &[], None).unwrap();
1842        let keys = checkpoint_keys(path_str).unwrap();
1843        assert_eq!(keys, vec!["x/w".to_string()]);
1844
1845        std::fs::remove_file(path_str).ok();
1846    }
1847
1848    #[test]
1849    fn test_checkpoint_keys_rejects_bad_magic() {
1850        let dir = std::env::temp_dir();
1851        let path = dir.join("test_checkpoint_keys_bad.fdl");
1852        std::fs::write(&path, b"NOPEnotacheckpoint").unwrap();
1853        let err = checkpoint_keys(path.to_str().unwrap()).unwrap_err();
1854        assert!(format!("{err}").contains("bad magic"), "got: {err}");
1855        std::fs::remove_file(path).ok();
1856    }
1857}