1use std::io::{Read, Write};
2
3use crate::tensor::{Device, DType, Result, Tensor, TensorError};
4
5use super::buffer::Buffer;
6use super::parameter::Parameter;
7
8pub(crate) const MAGIC: [u8; 4] = *b"FDLC";
10pub(crate) const VERSION: u32 = 2;
13const MAX_VERSION: u32 = 2;
15pub(crate) const HASH_LEN: usize = 32;
17
18#[derive(Debug, Clone)]
20pub struct LoadReport {
21 pub loaded: Vec<String>,
23 pub skipped: Vec<String>,
25 pub missing: Vec<String>,
27}
28
29pub fn save_checkpoint<W: Write>(
37 w: &mut W,
38 params: &[(String, Parameter)],
39 buffers: &[(String, Buffer)],
40 structural_hash: Option<&str>,
41) -> Result<()> {
42 w.write_all(&MAGIC).map_err(io_err)?;
43 w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
44
45 let hash_bytes = match structural_hash {
47 Some(hex) => hex_to_bytes(hex)?,
48 None => [0u8; HASH_LEN],
49 };
50 w.write_all(&hash_bytes).map_err(io_err)?;
51
52 let total = (params.len() + buffers.len()) as u32;
53 w.write_all(&total.to_le_bytes()).map_err(io_err)?;
54
55 for (name, p) in params {
56 let name_bytes = name.as_bytes();
57 w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
58 w.write_all(name_bytes).map_err(io_err)?;
59 write_tensor_data(w, &p.variable.data())?;
60 }
61
62 for (name, b) in buffers {
63 let name_bytes = name.as_bytes();
64 w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
65 w.write_all(name_bytes).map_err(io_err)?;
66 write_tensor_data(w, &b.get())?;
67 }
68
69 Ok(())
70}
71
72pub fn load_checkpoint<R: Read>(
82 r: &mut R,
83 params: &[(String, Parameter)],
84 buffers: &[(String, Buffer)],
85 structural_hash: Option<&str>,
86) -> Result<LoadReport> {
87 let mut magic = [0u8; 4];
88 r.read_exact(&mut magic).map_err(io_err)?;
89 if magic != MAGIC {
90 return Err(TensorError::new(
91 "invalid checkpoint: bad magic (expected .fdl checkpoint)"
92 ));
93 }
94
95 let version = read_u32(r)?;
96 if version == 0 || version > MAX_VERSION {
97 return Err(TensorError::new(&format!(
98 "unsupported checkpoint version {} (this build supports 1..={})",
99 version, MAX_VERSION,
100 )));
101 }
102
103 let mut file_hash = [0u8; HASH_LEN];
105 r.read_exact(&mut file_hash).map_err(io_err)?;
106
107 let file_nonzero = file_hash.iter().any(|&b| b != 0);
108 if let Some(expected_hex) = structural_hash {
109 let expected = hex_to_bytes(expected_hex)?;
110 let expected_nonzero = expected.iter().any(|&b| b != 0);
111 if file_nonzero && expected_nonzero && file_hash != expected {
112 return Err(TensorError::new(&format!(
113 "checkpoint architecture mismatch: file={} model={}",
114 bytes_to_hex(&file_hash),
115 expected_hex,
116 )));
117 }
118 }
119
120 let count = read_u32(r)? as usize;
121
122 let mut ckpt: std::collections::HashMap<String, (Vec<i64>, DType, Vec<u8>)> =
124 std::collections::HashMap::with_capacity(count);
125
126 for _ in 0..count {
127 let name_len = read_u32(r)? as usize;
128 let mut name_bytes = vec![0u8; name_len];
129 r.read_exact(&mut name_bytes).map_err(io_err)?;
130 let name = String::from_utf8_lossy(&name_bytes).into_owned();
131
132 let ndim = read_u32(r)? as usize;
133 let mut shape = vec![0i64; ndim];
134 for s in &mut shape { *s = read_i64(r)?; }
135 let mut tag = [0u8; 1];
136 r.read_exact(&mut tag).map_err(io_err)?;
137 let dtype = dtype_from_tag(tag[0])?;
138 let byte_count = read_u64(r)? as usize;
139 let mut raw = vec![0u8; byte_count];
140 r.read_exact(&mut raw).map_err(io_err)?;
141 ckpt.insert(name, (shape, dtype, raw));
142 }
143
144 let mut loaded = Vec::new();
145 let mut missing = Vec::new();
146
147 for (name, p) in params {
149 if let Some((shape, dtype, raw)) = ckpt.remove(name) {
150 let model_shape = p.variable.shape();
151 if shape != model_shape {
152 return Err(TensorError::new(&format!(
153 "parameter {:?}: shape mismatch: checkpoint={:?} model={:?}",
154 name, shape, model_shape
155 )));
156 }
157 let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
158 let model_dtype = p.variable.data().dtype();
159 let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
160 let dev = p.variable.data().device();
161 if dev != Device::CPU {
162 p.variable.set_data(t.to_device(dev)?);
163 } else {
164 p.variable.set_data(t);
165 }
166 loaded.push(name.clone());
167 } else {
168 missing.push(name.clone());
169 }
170 }
171
172 for (name, b) in buffers {
174 if let Some((shape, dtype, raw)) = ckpt.remove(name) {
175 let model_shape = b.shape();
176 if shape != model_shape {
177 return Err(TensorError::new(&format!(
178 "buffer {:?}: shape mismatch: checkpoint={:?} model={:?}",
179 name, shape, model_shape
180 )));
181 }
182 let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
183 let model_dtype = b.get().dtype();
184 let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
185 let dev = b.device();
186 if dev != Device::CPU {
187 b.set(t.to_device(dev)?);
188 } else {
189 b.set(t);
190 }
191 loaded.push(name.clone());
192 } else {
193 missing.push(name.clone());
194 }
195 }
196
197 let skipped: Vec<String> = ckpt.into_keys().collect();
198
199 Ok(LoadReport { loaded, skipped, missing })
200}
201
202pub fn save_checkpoint_file(
204 path: &str,
205 params: &[(String, Parameter)],
206 buffers: &[(String, Buffer)],
207 structural_hash: Option<&str>,
208) -> Result<()> {
209 let f = std::fs::File::create(path).map_err(io_err)?;
210 if path.ends_with(".gz") {
211 let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
212 save_checkpoint(&mut w, params, buffers, structural_hash)?;
213 w.finish().map_err(io_err)?;
214 Ok(())
215 } else {
216 let mut w = std::io::BufWriter::new(f);
217 save_checkpoint(&mut w, params, buffers, structural_hash)
218 }
219}
220
221pub fn load_checkpoint_file(
223 path: &str,
224 params: &[(String, Parameter)],
225 buffers: &[(String, Buffer)],
226 structural_hash: Option<&str>,
227) -> Result<LoadReport> {
228 let f = std::fs::File::open(path).map_err(io_err)?;
229 if path.ends_with(".gz") {
230 let mut r = flate2::read::GzDecoder::new(f);
231 load_checkpoint(&mut r, params, buffers, structural_hash)
232 } else {
233 let mut r = std::io::BufReader::new(f);
234 load_checkpoint(&mut r, params, buffers, structural_hash)
235 }
236}
237
238pub fn checkpoint_keys(path: &str) -> Result<Vec<String>> {
253 let f = std::fs::File::open(path).map_err(io_err)?;
254 let mut r: Box<dyn Read> = if path.ends_with(".gz") {
255 Box::new(flate2::read::GzDecoder::new(f))
256 } else {
257 Box::new(std::io::BufReader::new(f))
258 };
259
260 let mut magic = [0u8; 4];
261 r.read_exact(&mut magic).map_err(io_err)?;
262 if magic != MAGIC {
263 return Err(TensorError::new(
264 "invalid checkpoint: bad magic (expected .fdl checkpoint)",
265 ));
266 }
267 let version = read_u32(&mut r)?;
268 if version == 0 || version > MAX_VERSION {
269 return Err(TensorError::new(&format!(
270 "unsupported checkpoint version {} (this build supports 1..={})",
271 version, MAX_VERSION,
272 )));
273 }
274 let mut _hash = [0u8; HASH_LEN];
276 r.read_exact(&mut _hash).map_err(io_err)?;
277
278 let count = read_u32(&mut r)? as usize;
279 let mut keys = Vec::with_capacity(count);
280 for _ in 0..count {
281 let name_len = read_u32(&mut r)? as usize;
282 let mut name_bytes = vec![0u8; name_len];
283 r.read_exact(&mut name_bytes).map_err(io_err)?;
284 keys.push(String::from_utf8_lossy(&name_bytes).into_owned());
285 let ndim = read_u32(&mut r)? as usize;
287 for _ in 0..ndim {
288 let _ = read_i64(&mut r)?;
289 }
290 let mut tag = [0u8; 1];
291 r.read_exact(&mut tag).map_err(io_err)?;
292 let byte_count = read_u64(&mut r)? as usize;
293 std::io::copy(&mut r.by_ref().take(byte_count as u64), &mut std::io::sink())
295 .map_err(io_err)?;
296 }
297 Ok(keys)
298}
299
300pub fn checkpoint_version(path: &str) -> Result<u32> {
303 let f = std::fs::File::open(path).map_err(io_err)?;
304 let mut r: Box<dyn Read> = if path.ends_with(".gz") {
305 Box::new(flate2::read::GzDecoder::new(f))
306 } else {
307 Box::new(std::io::BufReader::new(f))
308 };
309 let mut magic = [0u8; 4];
310 r.read_exact(&mut magic).map_err(io_err)?;
311 if magic != MAGIC {
312 return Err(TensorError::new(
313 "invalid checkpoint: bad magic (expected .fdl checkpoint)"
314 ));
315 }
316 read_u32(&mut r)
317}
318
319pub(crate) fn write_tensor_state<W: Write>(w: &mut W, t: Option<&Tensor>) -> Result<()> {
324 match t {
325 None => {
326 w.write_all(&[0u8]).map_err(io_err)?;
327 }
328 Some(t) => {
329 w.write_all(&[1u8]).map_err(io_err)?;
330 write_tensor_data(w, t)?;
331 }
332 }
333 Ok(())
334}
335
336pub(crate) fn read_tensor_state<R: Read>(r: &mut R, device: Device) -> Result<Option<Tensor>> {
338 let mut present = [0u8; 1];
339 r.read_exact(&mut present).map_err(io_err)?;
340 if present[0] == 0 {
341 return Ok(None);
342 }
343
344 let t = read_tensor_data(r)?;
345 if device != Device::CPU {
346 Ok(Some(t.to_device(device)?))
347 } else {
348 Ok(Some(t))
349 }
350}
351
352fn dtype_tag(dtype: DType) -> u8 {
356 match dtype {
357 DType::Float16 => 1,
358 DType::BFloat16 => 2,
359 DType::Float32 => 3,
360 DType::Float64 => 4,
361 DType::Int32 => 5,
362 DType::Int64 => 6,
363 }
364}
365
366fn dtype_from_tag(tag: u8) -> Result<DType> {
367 match tag {
368 1 => Ok(DType::Float16),
369 2 => Ok(DType::BFloat16),
370 3 => Ok(DType::Float32),
371 4 => Ok(DType::Float64),
372 5 => Ok(DType::Int32),
373 6 => Ok(DType::Int64),
374 _ => Err(TensorError::new(&format!("unknown dtype tag: {}", tag))),
375 }
376}
377
378pub(crate) fn write_tensor_data<W: Write>(w: &mut W, t: &Tensor) -> Result<()> {
380 let shape = t.shape();
381 w.write_all(&(shape.len() as u32).to_le_bytes()).map_err(io_err)?;
382 for &s in &shape {
383 w.write_all(&s.to_le_bytes()).map_err(io_err)?;
384 }
385
386 let dtype = t.dtype();
387 w.write_all(&[dtype_tag(dtype)]).map_err(io_err)?;
388
389 let numel = t.numel() as usize;
390 let elem_size = dtype.element_size();
391 let byte_count = numel * elem_size;
392
393 let raw = copy_raw_bytes(t, byte_count)?;
395 w.write_all(&(byte_count as u64).to_le_bytes()).map_err(io_err)?;
396 w.write_all(&raw).map_err(io_err)?;
397
398 Ok(())
399}
400
401fn read_tensor_data<R: Read>(r: &mut R) -> Result<Tensor> {
403 let ndim = read_u32(r)? as usize;
404 let mut shape = vec![0i64; ndim];
405 for s in &mut shape {
406 *s = read_i64(r)?;
407 }
408
409 let mut tag = [0u8; 1];
410 r.read_exact(&mut tag).map_err(io_err)?;
411 let dtype = dtype_from_tag(tag[0])?;
412
413 let byte_count = read_u64(r)? as usize;
414 let mut raw = vec![0u8; byte_count];
415 r.read_exact(&mut raw).map_err(io_err)?;
416
417 tensor_from_raw_bytes(&raw, &shape, dtype)
418}
419
420fn copy_raw_bytes(t: &Tensor, byte_count: usize) -> Result<Vec<u8>> {
422 let mut buf = vec![0u8; byte_count];
423 let err = unsafe {
424 flodl_sys::flodl_copy_data(
425 t.raw(),
426 buf.as_mut_ptr() as *mut std::ffi::c_void,
427 byte_count as i64,
428 )
429 };
430 check_err_raw(err)?;
431 Ok(buf)
432}
433
434fn tensor_from_raw_bytes(raw: &[u8], shape: &[i64], dtype: DType) -> Result<Tensor> {
436 match dtype {
438 DType::Float32 => {
439 let data: Vec<f32> = raw.chunks_exact(4)
440 .map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
441 .collect();
442 Tensor::from_f32(&data, shape, Device::CPU)
443 }
444 DType::Float64 => {
445 let data: Vec<f64> = raw.chunks_exact(8)
446 .map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
447 .collect();
448 Tensor::from_f64(&data, shape, Device::CPU)
449 }
450 DType::Int64 => {
451 let data: Vec<i64> = raw.chunks_exact(8)
452 .map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
453 .collect();
454 Tensor::from_i64(&data, shape, Device::CPU)
455 }
456 DType::Float16 | DType::BFloat16 | DType::Int32 => {
457 let mut shape_v = shape.to_vec();
459 let mut handle: flodl_sys::FlodlTensor = std::ptr::null_mut();
460 let (dev_type, dev_idx) = crate::tensor::Device::CPU.to_ffi();
461 let err = unsafe {
462 flodl_sys::flodl_from_blob(
463 raw.as_ptr() as *mut std::ffi::c_void,
464 shape_v.as_mut_ptr(),
465 shape_v.len() as i32,
466 dtype as i32,
467 dev_type, dev_idx,
468 &mut handle,
469 )
470 };
471 check_err_raw(err)?;
472 debug_assert!(!handle.is_null());
473 Ok(unsafe { Tensor::from_raw_handle(handle) })
475 }
476 }
477}
478
479#[derive(Debug, Clone)]
483pub struct MigrateReport {
484 pub unchanged: Vec<String>,
486 pub remapped: Vec<(String, String)>,
488 pub dropped: Vec<String>,
490 pub missing: Vec<String>,
492}
493
494impl MigrateReport {
495 pub fn is_complete(&self) -> bool {
497 self.dropped.is_empty() && self.missing.is_empty()
498 }
499}
500
501impl std::fmt::Display for MigrateReport {
502 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
503 if !self.unchanged.is_empty() {
504 writeln!(f, "unchanged ({}):", self.unchanged.len())?;
505 for name in &self.unchanged { writeln!(f, " {}", name)?; }
506 }
507 if !self.remapped.is_empty() {
508 writeln!(f, "remapped ({}):", self.remapped.len())?;
509 for (old, new) in &self.remapped { writeln!(f, " {} -> {}", old, new)?; }
510 }
511 if !self.dropped.is_empty() {
512 writeln!(f, "dropped ({}):", self.dropped.len())?;
513 for name in &self.dropped { writeln!(f, " {}", name)?; }
514 }
515 if !self.missing.is_empty() {
516 writeln!(f, "missing ({}):", self.missing.len())?;
517 for name in &self.missing { writeln!(f, " {}", name)?; }
518 }
519 Ok(())
520 }
521}
522
523struct RawEntry {
525 name: String,
526 shape: Vec<i64>,
527 dtype: DType,
528 raw: Vec<u8>,
529}
530
531fn read_raw_checkpoint<R: Read>(r: &mut R) -> Result<Vec<RawEntry>> {
533 let mut magic = [0u8; 4];
534 r.read_exact(&mut magic).map_err(io_err)?;
535 if magic != MAGIC {
536 return Err(TensorError::new(
537 "invalid checkpoint: bad magic (expected .fdl checkpoint)"
538 ));
539 }
540 let version = read_u32(r)?;
541 if version == 0 || version > MAX_VERSION {
542 return Err(TensorError::new(&format!(
543 "unsupported checkpoint version {} (this build supports 1..={})",
544 version, MAX_VERSION,
545 )));
546 }
547 let mut _hash = [0u8; HASH_LEN];
549 r.read_exact(&mut _hash).map_err(io_err)?;
550
551 let count = read_u32(r)? as usize;
552 let mut entries = Vec::with_capacity(count);
553
554 for _ in 0..count {
555 let name_len = read_u32(r)? as usize;
556 let mut name_bytes = vec![0u8; name_len];
557 r.read_exact(&mut name_bytes).map_err(io_err)?;
558 let name = String::from_utf8_lossy(&name_bytes).into_owned();
559
560 let ndim = read_u32(r)? as usize;
561 let mut shape = vec![0i64; ndim];
562 for s in &mut shape { *s = read_i64(r)?; }
563 let mut tag = [0u8; 1];
564 r.read_exact(&mut tag).map_err(io_err)?;
565 let dtype = dtype_from_tag(tag[0])?;
566 let byte_count = read_u64(r)? as usize;
567 let mut raw = vec![0u8; byte_count];
568 r.read_exact(&mut raw).map_err(io_err)?;
569
570 entries.push(RawEntry { name, shape, dtype, raw });
571 }
572
573 Ok(entries)
574}
575
576fn write_raw_entry<W: Write>(w: &mut W, name: &str, e: &RawEntry) -> Result<()> {
578 let name_bytes = name.as_bytes();
579 w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
580 w.write_all(name_bytes).map_err(io_err)?;
581 w.write_all(&(e.shape.len() as u32).to_le_bytes()).map_err(io_err)?;
582 for &s in &e.shape {
583 w.write_all(&s.to_le_bytes()).map_err(io_err)?;
584 }
585 w.write_all(&[dtype_tag(e.dtype)]).map_err(io_err)?;
586 w.write_all(&(e.raw.len() as u64).to_le_bytes()).map_err(io_err)?;
587 w.write_all(&e.raw).map_err(io_err)?;
588 Ok(())
589}
590
591pub fn migrate_checkpoint<R: Read, W: Write>(
621 r: &mut R,
622 w: &mut W,
623 params: &[(String, Parameter)],
624 buffers: &[(String, Buffer)],
625) -> Result<MigrateReport> {
626 let entries = read_raw_checkpoint(r)?;
627
628 let mut targets: Vec<(String, Vec<i64>, DType)> = Vec::with_capacity(
630 params.len() + buffers.len()
631 );
632 for (name, p) in params {
633 targets.push((name.clone(), p.variable.shape(), p.variable.data().dtype()));
634 }
635 for (name, b) in buffers {
636 targets.push((name.clone(), b.shape(), b.get().dtype()));
637 }
638
639 let mut unchanged = Vec::new();
640 let mut remapped = Vec::new();
641 let mut missing = Vec::new();
642 let mut used = vec![false; entries.len()];
643
644 let mut output: Vec<(String, usize)> = Vec::new();
646
647 let name_index: std::collections::HashMap<&str, usize> =
649 entries.iter().enumerate().map(|(i, e)| (e.name.as_str(), i)).collect();
650
651 let mut unmatched: Vec<usize> = Vec::new();
653
654 for (mi, (name, shape, _)) in targets.iter().enumerate() {
656 if let Some(&ci) = name_index.get(name.as_str()) {
657 if !used[ci] && entries[ci].shape == *shape {
658 unchanged.push(name.clone());
659 used[ci] = true;
660 output.push((name.clone(), ci));
661 continue;
662 }
663 }
664 unmatched.push(mi);
665 }
666
667 for &mi in &unmatched {
669 let (name, shape, dtype) = &targets[mi];
670
671 let found = entries.iter().enumerate()
672 .find(|(ci, e)| !used[*ci] && e.shape == *shape && e.dtype == *dtype)
673 .map(|(ci, _)| ci);
674
675 if let Some(ci) = found {
676 remapped.push((entries[ci].name.clone(), name.clone()));
677 used[ci] = true;
678 output.push((name.clone(), ci));
679 } else {
680 missing.push(name.clone());
681 }
682 }
683
684 let dropped: Vec<String> = entries.iter().enumerate()
685 .filter(|(i, _)| !used[*i])
686 .map(|(_, e)| e.name.clone())
687 .collect();
688
689 w.write_all(&MAGIC).map_err(io_err)?;
691 w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
692 w.write_all(&[0u8; HASH_LEN]).map_err(io_err)?;
693 w.write_all(&(output.len() as u32).to_le_bytes()).map_err(io_err)?;
694
695 for (name, ci) in &output {
696 write_raw_entry(w, name, &entries[*ci])?;
697 }
698
699 Ok(MigrateReport { unchanged, remapped, dropped, missing })
700}
701
702pub fn migrate_checkpoint_file(
706 src: &str,
707 dst: &str,
708 params: &[(String, Parameter)],
709 buffers: &[(String, Buffer)],
710) -> Result<MigrateReport> {
711 let sf = std::fs::File::open(src).map_err(io_err)?;
712 let df = std::fs::File::create(dst).map_err(io_err)?;
713
714 match (src.ends_with(".gz"), dst.ends_with(".gz")) {
715 (true, true) => {
716 let mut r = flate2::read::GzDecoder::new(sf);
717 let mut w = flate2::write::GzEncoder::new(df, flate2::Compression::default());
718 let report = migrate_checkpoint(&mut r, &mut w, params, buffers)?;
719 w.finish().map_err(io_err)?;
720 Ok(report)
721 }
722 (true, false) => {
723 let mut r = flate2::read::GzDecoder::new(sf);
724 let mut w = std::io::BufWriter::new(df);
725 migrate_checkpoint(&mut r, &mut w, params, buffers)
726 }
727 (false, true) => {
728 let mut r = std::io::BufReader::new(sf);
729 let mut w = flate2::write::GzEncoder::new(df, flate2::Compression::default());
730 let report = migrate_checkpoint(&mut r, &mut w, params, buffers)?;
731 w.finish().map_err(io_err)?;
732 Ok(report)
733 }
734 (false, false) => {
735 let mut r = std::io::BufReader::new(sf);
736 let mut w = std::io::BufWriter::new(df);
737 migrate_checkpoint(&mut r, &mut w, params, buffers)
738 }
739 }
740}
741
742pub(crate) fn io_err(e: impl std::fmt::Display) -> TensorError {
745 TensorError::new(&format!("io: {}", e))
746}
747
748fn check_err_raw(err: *mut i8) -> Result<()> {
749 if err.is_null() {
750 Ok(())
751 } else {
752 let msg = unsafe { std::ffi::CStr::from_ptr(err) }
753 .to_string_lossy()
754 .into_owned();
755 unsafe { flodl_sys::flodl_free_string(err) };
756 Err(TensorError::new(&msg))
757 }
758}
759
760fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
761 let mut buf = [0u8; 4];
762 r.read_exact(&mut buf).map_err(io_err)?;
763 Ok(u32::from_le_bytes(buf))
764}
765
766fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
767 let mut buf = [0u8; 8];
768 r.read_exact(&mut buf).map_err(io_err)?;
769 Ok(u64::from_le_bytes(buf))
770}
771
772fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
773 let mut buf = [0u8; 8];
774 r.read_exact(&mut buf).map_err(io_err)?;
775 Ok(i64::from_le_bytes(buf))
776}
777
778pub(crate) fn read_f64_le<R: Read>(r: &mut R) -> Result<f64> {
780 let mut buf = [0u8; 8];
781 r.read_exact(&mut buf).map_err(io_err)?;
782 Ok(f64::from_le_bytes(buf))
783}
784pub(crate) fn write_f64_le<W: Write>(w: &mut W, v: f64) -> Result<()> {
785 w.write_all(&v.to_le_bytes()).map_err(io_err)?;
786 Ok(())
787}
788pub(crate) fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<()> {
789 w.write_all(&v.to_le_bytes()).map_err(io_err)?;
790 Ok(())
791}
792pub(crate) fn write_i64_le<W: Write>(w: &mut W, v: i64) -> Result<()> {
793 w.write_all(&v.to_le_bytes()).map_err(io_err)?;
794 Ok(())
795}
796pub(crate) fn read_u32_le<R: Read>(r: &mut R) -> Result<u32> {
797 read_u32(r)
798}
799pub(crate) fn read_i64_le<R: Read>(r: &mut R) -> Result<i64> {
800 read_i64(r)
801}
802
803fn hex_to_bytes(hex: &str) -> Result<[u8; HASH_LEN]> {
805 if hex.len() != HASH_LEN * 2 {
806 return Err(TensorError::new(&format!(
807 "expected {} hex chars, got {}",
808 HASH_LEN * 2,
809 hex.len()
810 )));
811 }
812 let mut out = [0u8; HASH_LEN];
813 for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
814 let hi = hex_nibble(chunk[0])?;
815 let lo = hex_nibble(chunk[1])?;
816 out[i] = (hi << 4) | lo;
817 }
818 Ok(out)
819}
820
821fn hex_nibble(b: u8) -> Result<u8> {
822 match b {
823 b'0'..=b'9' => Ok(b - b'0'),
824 b'a'..=b'f' => Ok(b - b'a' + 10),
825 b'A'..=b'F' => Ok(b - b'A' + 10),
826 _ => Err(TensorError::new(&format!("invalid hex byte: {}", b))),
827 }
828}
829
830fn bytes_to_hex(bytes: &[u8]) -> String {
832 let mut s = String::with_capacity(bytes.len() * 2);
833 for &b in bytes {
834 use std::fmt::Write;
835 let _ = write!(s, "{:02x}", b);
836 }
837 s
838}
839
840#[cfg(test)]
841mod tests {
842 use super::*;
843 use crate::tensor::TensorOptions;
844
845 fn make_named_params(sizes: &[(i64, i64)]) -> Vec<(String, Parameter)> {
846 sizes.iter().enumerate().map(|(i, &(rows, cols))| {
847 let t = Tensor::randn(&[rows, cols], TensorOptions {
848 dtype: DType::Float32,
849 device: crate::tensor::test_device(),
850 }).unwrap();
851 let name = format!("layer_{}/weight", i);
852 (name.clone(), Parameter::new(t, "weight"))
853 }).collect()
854 }
855
856 fn make_named_buffers(sizes: &[i64]) -> Vec<(String, Buffer)> {
857 sizes.iter().enumerate().map(|(i, &features)| {
858 let t = Tensor::randn(&[features], TensorOptions {
859 dtype: DType::Float32,
860 device: crate::tensor::test_device(),
861 }).unwrap();
862 let name = format!("bn_{}/running_mean", i);
863 (name.clone(), Buffer::new(t, "running_mean"))
864 }).collect()
865 }
866
867 #[test]
868 fn test_named_roundtrip() {
869 let params = make_named_params(&[(4, 8), (8, 2)]);
870
871 let mut buf = Vec::new();
872 save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
873
874 let load_params = make_named_params(&[(4, 8), (8, 2)]);
875 let mut cursor = std::io::Cursor::new(&buf);
876 let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
877
878 assert_eq!(report.loaded.len(), 2);
879 assert!(report.skipped.is_empty());
880 assert!(report.missing.is_empty());
881
882 for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
883 let src_data = src.variable.data().to_f32_vec().unwrap();
884 let dst_data = dst.variable.data().to_f32_vec().unwrap();
885 assert_eq!(src_data, dst_data);
886 }
887 }
888
889 #[test]
890 fn test_buffer_roundtrip() {
891 let params = make_named_params(&[(4, 8)]);
892 let buffers = make_named_buffers(&[8]);
893
894 let mut buf = Vec::new();
895 save_checkpoint(&mut buf, ¶ms, &buffers, None).unwrap();
896
897 let load_params = make_named_params(&[(4, 8)]);
899 let load_buffers = make_named_buffers(&[8]);
900 let mut cursor = std::io::Cursor::new(&buf);
901 let report = load_checkpoint(&mut cursor, &load_params, &load_buffers, None).unwrap();
902
903 assert_eq!(report.loaded.len(), 2); assert!(report.skipped.is_empty());
905 assert!(report.missing.is_empty());
906
907 let src_data = buffers[0].1.get().to_f32_vec().unwrap();
909 let dst_data = load_buffers[0].1.get().to_f32_vec().unwrap();
910 assert_eq!(src_data, dst_data);
911 }
912
913 #[test]
914 fn test_named_partial_load() {
915 let params_3 = make_named_params(&[(4, 8), (8, 4), (4, 2)]);
916
917 let mut buf = Vec::new();
918 save_checkpoint(&mut buf, ¶ms_3, &[], None).unwrap();
919
920 let mut params_4 = make_named_params(&[(4, 8), (8, 4), (4, 2), (2, 1)]);
921 params_4[3].0 = "extra/weight".to_string();
922
923 let before_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
924
925 let mut cursor = std::io::Cursor::new(&buf);
926 let report = load_checkpoint(&mut cursor, ¶ms_4, &[], None).unwrap();
927
928 assert_eq!(report.loaded.len(), 3);
929 assert_eq!(report.missing.len(), 1);
930 assert_eq!(report.missing[0], "extra/weight");
931 assert!(report.skipped.is_empty());
932
933 let after_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
934 assert_eq!(before_extra, after_extra);
935 }
936
937 #[test]
938 fn test_named_skipped_checkpoint_params() {
939 let params = make_named_params(&[(4, 8), (8, 2)]);
940
941 let mut buf = Vec::new();
942 save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
943
944 let model = vec![params[0].clone()];
945 let mut cursor = std::io::Cursor::new(&buf);
946 let report = load_checkpoint(&mut cursor, &model, &[], None).unwrap();
947
948 assert_eq!(report.loaded.len(), 1);
949 assert_eq!(report.skipped.len(), 1);
950 assert!(report.missing.is_empty());
951 }
952
953 #[test]
954 fn test_named_shape_mismatch_error() {
955 let params = make_named_params(&[(4, 8)]);
956
957 let mut buf = Vec::new();
958 save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
959
960 let wrong_shape = vec![(
961 "layer_0/weight".to_string(),
962 Parameter::new(
963 Tensor::randn(&[4, 4], TensorOptions {
964 dtype: DType::Float32,
965 device: crate::tensor::test_device(),
966 }).unwrap(),
967 "weight",
968 ),
969 )];
970 let mut cursor = std::io::Cursor::new(&buf);
971 let result = load_checkpoint(&mut cursor, &wrong_shape, &[], None);
972 assert!(result.is_err(), "shape mismatch should be an error");
973 let err_msg = format!("{}", result.unwrap_err());
974 assert!(err_msg.contains("shape mismatch"), "error should mention shape: {}", err_msg);
975 }
976
977 #[test]
978 fn test_buffer_shape_mismatch_error() {
979 let buffers = make_named_buffers(&[8]);
980
981 let mut buf = Vec::new();
982 save_checkpoint(&mut buf, &[], &buffers, None).unwrap();
983
984 let wrong_buffers = vec![(
985 "bn_0/running_mean".to_string(),
986 Buffer::new(
987 Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(),
988 "running_mean",
989 ),
990 )];
991 let mut cursor = std::io::Cursor::new(&buf);
992 let result = load_checkpoint(&mut cursor, &[], &wrong_buffers, None);
993 assert!(result.is_err());
994 assert!(format!("{}", result.unwrap_err()).contains("shape mismatch"));
995 }
996
997 #[test]
998 fn test_compressed_roundtrip() {
999 let params = make_named_params(&[(16, 32), (32, 8)]);
1000 let buffers = make_named_buffers(&[32]);
1001
1002 let dir = std::env::temp_dir();
1003 let gz_path = dir.join("test_ckpt_v2.fdl.gz");
1004 let plain_path = dir.join("test_ckpt_v2.fdl");
1005 let gz = gz_path.to_str().unwrap();
1006 let plain = plain_path.to_str().unwrap();
1007
1008 save_checkpoint_file(gz, ¶ms, &buffers, None).unwrap();
1009 save_checkpoint_file(plain, ¶ms, &buffers, None).unwrap();
1010
1011 let gz_size = std::fs::metadata(gz).unwrap().len();
1013 let plain_size = std::fs::metadata(plain).unwrap().len();
1014 assert!(gz_size < plain_size, "gz={} should be < plain={}", gz_size, plain_size);
1015
1016 let load_params = make_named_params(&[(16, 32), (32, 8)]);
1018 let load_buffers = make_named_buffers(&[32]);
1019 let report = load_checkpoint_file(gz, &load_params, &load_buffers, None).unwrap();
1020 assert_eq!(report.loaded.len(), 3); for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
1023 assert_eq!(src.variable.data().to_f32_vec().unwrap(),
1024 dst.variable.data().to_f32_vec().unwrap());
1025 }
1026
1027 let src_buf = buffers[0].1.get().to_f32_vec().unwrap();
1028 let dst_buf = load_buffers[0].1.get().to_f32_vec().unwrap();
1029 assert_eq!(src_buf, dst_buf);
1030
1031 std::fs::remove_file(gz).ok();
1032 std::fs::remove_file(plain).ok();
1033 }
1034
1035 #[test]
1036 fn test_hash_roundtrip() {
1037 let params = make_named_params(&[(4, 8)]);
1038 let hash = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
1040
1041 let mut buf = Vec::new();
1042 save_checkpoint(&mut buf, ¶ms, &[], Some(hash)).unwrap();
1043
1044 let load_params = make_named_params(&[(4, 8)]);
1045 let mut cursor = std::io::Cursor::new(&buf);
1046 let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
1048 assert_eq!(report.loaded.len(), 1);
1049 }
1050
1051 #[test]
1052 fn test_hash_mismatch_error() {
1053 let params = make_named_params(&[(4, 8)]);
1054 let hash_a = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
1055 let hash_b = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
1056
1057 let mut buf = Vec::new();
1058 save_checkpoint(&mut buf, ¶ms, &[], Some(hash_a)).unwrap();
1059
1060 let load_params = make_named_params(&[(4, 8)]);
1061 let mut cursor = std::io::Cursor::new(&buf);
1062 let result = load_checkpoint(&mut cursor, &load_params, &[], Some(hash_b));
1063 assert!(result.is_err());
1064 let msg = format!("{}", result.unwrap_err());
1065 assert!(msg.contains("architecture mismatch"), "error: {}", msg);
1066 }
1067
1068 #[test]
1069 fn test_zero_hash_skips_validation() {
1070 let params = make_named_params(&[(4, 8)]);
1071
1072 let mut buf = Vec::new();
1074 save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
1075
1076 let hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
1078 let load_params = make_named_params(&[(4, 8)]);
1079 let mut cursor = std::io::Cursor::new(&buf);
1080 let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
1081 assert_eq!(report.loaded.len(), 1);
1082
1083 let mut buf2 = Vec::new();
1085 save_checkpoint(&mut buf2, ¶ms, &[], Some(hash)).unwrap();
1086 let load_params2 = make_named_params(&[(4, 8)]);
1087 let mut cursor2 = std::io::Cursor::new(&buf2);
1088 let report2 = load_checkpoint(&mut cursor2, &load_params2, &[], None).unwrap();
1089 assert_eq!(report2.loaded.len(), 1);
1090 }
1091
1092 fn save_checkpoint_versioned<W: std::io::Write>(
1094 w: &mut W,
1095 version: u32,
1096 params: &[(String, Parameter)],
1097 buffers: &[(String, Buffer)],
1098 ) {
1099 w.write_all(&MAGIC).unwrap();
1100 w.write_all(&version.to_le_bytes()).unwrap();
1101 w.write_all(&[0u8; HASH_LEN]).unwrap();
1102 let total = (params.len() + buffers.len()) as u32;
1103 w.write_all(&total.to_le_bytes()).unwrap();
1104 for (name, p) in params {
1105 let name_bytes = name.as_bytes();
1106 w.write_all(&(name_bytes.len() as u32).to_le_bytes()).unwrap();
1107 w.write_all(name_bytes).unwrap();
1108 write_tensor_data(w, &p.variable.data()).unwrap();
1109 }
1110 for (name, b) in buffers {
1111 let name_bytes = name.as_bytes();
1112 w.write_all(&(name_bytes.len() as u32).to_le_bytes()).unwrap();
1113 w.write_all(name_bytes).unwrap();
1114 write_tensor_data(w, &b.get()).unwrap();
1115 }
1116 }
1117
1118 #[test]
1119 fn test_migrate_all_renamed() {
1120 let old_params = vec![
1122 ("linear_0/weight".to_string(), Parameter::new(
1123 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1124 ("linear_1/weight".to_string(), Parameter::new(
1125 Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1126 ];
1127 let mut ckpt = Vec::new();
1128 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1129
1130 let new_params = vec![
1132 ("encoder/weight".to_string(), Parameter::new(
1133 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1134 ("decoder/weight".to_string(), Parameter::new(
1135 Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1136 ];
1137
1138 let mut out = Vec::new();
1139 let report = migrate_checkpoint(
1140 &mut std::io::Cursor::new(&ckpt), &mut out,
1141 &new_params, &[],
1142 ).unwrap();
1143
1144 assert!(report.unchanged.is_empty());
1145 assert_eq!(report.remapped.len(), 2);
1146 assert!(report.dropped.is_empty());
1147 assert!(report.missing.is_empty());
1148 assert!(report.is_complete());
1149
1150 let verify_params = vec![
1152 ("encoder/weight".to_string(), Parameter::new(
1153 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1154 ("decoder/weight".to_string(), Parameter::new(
1155 Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1156 ];
1157 let mut cursor = std::io::Cursor::new(&out);
1158 let load_report = load_checkpoint(&mut cursor, &verify_params, &[], None).unwrap();
1159 assert_eq!(load_report.loaded.len(), 2);
1160 assert!(load_report.missing.is_empty());
1161
1162 for (i, (_, vp)) in verify_params.iter().enumerate() {
1164 let expected = old_params[i].1.variable.data().to_f32_vec().unwrap();
1165 let got = vp.variable.data().to_f32_vec().unwrap();
1166 assert_eq!(expected, got, "data mismatch for param {}", i);
1167 }
1168 }
1169
1170 #[test]
1171 fn test_migrate_partial_rename() {
1172 let old_params = vec![
1174 ("shared/weight".to_string(), Parameter::new(
1175 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1176 ("linear_0/weight".to_string(), Parameter::new(
1177 Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1178 ];
1179 let mut ckpt = Vec::new();
1180 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1181
1182 let new_params = vec![
1183 ("shared/weight".to_string(), Parameter::new(
1184 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1185 ("encoder/weight".to_string(), Parameter::new(
1186 Tensor::randn(&[8, 2], crate::tensor::test_opts()).unwrap(), "weight")),
1187 ];
1188
1189 let mut out = Vec::new();
1190 let report = migrate_checkpoint(
1191 &mut std::io::Cursor::new(&ckpt), &mut out,
1192 &new_params, &[],
1193 ).unwrap();
1194
1195 assert_eq!(report.unchanged, vec!["shared/weight"]);
1196 assert_eq!(report.remapped.len(), 1);
1197 assert_eq!(report.remapped[0], ("linear_0/weight".to_string(), "encoder/weight".to_string()));
1198 assert!(report.is_complete());
1199 }
1200
1201 #[test]
1202 fn test_migrate_with_buffers() {
1203 let old_params = vec![
1204 ("linear_0/weight".to_string(), Parameter::new(
1205 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1206 ];
1207 let old_buffers = vec![
1208 ("bn_0/running_mean".to_string(), Buffer::new(
1209 Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1210 ];
1211 let mut ckpt = Vec::new();
1212 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &old_buffers);
1213
1214 let new_params = vec![
1215 ("encoder/weight".to_string(), Parameter::new(
1216 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1217 ];
1218 let new_buffers = vec![
1219 ("norm/running_mean".to_string(), Buffer::new(
1220 Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1221 ];
1222
1223 let mut out = Vec::new();
1224 let report = migrate_checkpoint(
1225 &mut std::io::Cursor::new(&ckpt), &mut out,
1226 &new_params, &new_buffers,
1227 ).unwrap();
1228
1229 assert_eq!(report.remapped.len(), 2);
1230 assert!(report.is_complete());
1231
1232 let vp = vec![
1234 ("encoder/weight".to_string(), Parameter::new(
1235 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1236 ];
1237 let vb = vec![
1238 ("norm/running_mean".to_string(), Buffer::new(
1239 Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1240 ];
1241 let mut cursor = std::io::Cursor::new(&out);
1242 let load_report = load_checkpoint(&mut cursor, &vp, &vb, None).unwrap();
1243 assert_eq!(load_report.loaded.len(), 2);
1244 }
1245
1246 #[test]
1247 fn test_migrate_dropped_and_missing() {
1248 let old_params = vec![
1249 ("old/weight".to_string(), Parameter::new(
1250 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1251 ("removed/weight".to_string(), Parameter::new(
1252 Tensor::randn(&[16, 16], crate::tensor::test_opts()).unwrap(), "weight")),
1253 ];
1254 let mut ckpt = Vec::new();
1255 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1256
1257 let new_params = vec![
1259 ("new/weight".to_string(), Parameter::new(
1260 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1261 ("added/weight".to_string(), Parameter::new(
1262 Tensor::randn(&[32, 32], crate::tensor::test_opts()).unwrap(), "weight")),
1263 ];
1264
1265 let mut out = Vec::new();
1266 let report = migrate_checkpoint(
1267 &mut std::io::Cursor::new(&ckpt), &mut out,
1268 &new_params, &[],
1269 ).unwrap();
1270
1271 assert_eq!(report.remapped.len(), 1);
1272 assert_eq!(report.dropped, vec!["removed/weight"]);
1273 assert_eq!(report.missing, vec!["added/weight"]);
1274 assert!(!report.is_complete());
1275 }
1276
1277 #[test]
1278 fn test_migrate_positional_disambiguation() {
1279 let old_params = vec![
1281 ("linear_0/weight".to_string(), Parameter::new(
1282 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1283 ("linear_1/weight".to_string(), Parameter::new(
1284 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1285 ];
1286 let mut ckpt = Vec::new();
1287 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1288
1289 let new_params = vec![
1290 ("encoder/weight".to_string(), Parameter::new(
1291 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1292 ("decoder/weight".to_string(), Parameter::new(
1293 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1294 ];
1295
1296 let mut out = Vec::new();
1297 let report = migrate_checkpoint(
1298 &mut std::io::Cursor::new(&ckpt), &mut out,
1299 &new_params, &[],
1300 ).unwrap();
1301
1302 assert_eq!(report.remapped.len(), 2);
1303 assert_eq!(report.remapped[0].0, "linear_0/weight");
1305 assert_eq!(report.remapped[0].1, "encoder/weight");
1306 assert_eq!(report.remapped[1].0, "linear_1/weight");
1307 assert_eq!(report.remapped[1].1, "decoder/weight");
1308
1309 let vp = vec![
1311 ("encoder/weight".to_string(), Parameter::new(
1312 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1313 ("decoder/weight".to_string(), Parameter::new(
1314 Tensor::randn(&[4, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1315 ];
1316 let mut cursor = std::io::Cursor::new(&out);
1317 load_checkpoint(&mut cursor, &vp, &[], None).unwrap();
1318
1319 let enc_data = vp[0].1.variable.data().to_f32_vec().unwrap();
1321 let dec_data = vp[1].1.variable.data().to_f32_vec().unwrap();
1322 let old_0 = old_params[0].1.variable.data().to_f32_vec().unwrap();
1323 let old_1 = old_params[1].1.variable.data().to_f32_vec().unwrap();
1324 assert_eq!(enc_data, old_0);
1325 assert_eq!(dec_data, old_1);
1326 }
1327
1328 #[test]
1329 fn test_migrate_v1_writes_v2() {
1330 let old_params = vec![
1331 ("x/weight".to_string(), Parameter::new(
1332 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1333 ];
1334 let mut ckpt = Vec::new();
1335 save_checkpoint_versioned(&mut ckpt, 1, &old_params, &[]);
1336
1337 let mut peek = std::io::Cursor::new(&ckpt);
1339 let mut magic = [0u8; 4];
1340 std::io::Read::read_exact(&mut peek, &mut magic).unwrap();
1341 let mut vbuf = [0u8; 4];
1342 std::io::Read::read_exact(&mut peek, &mut vbuf).unwrap();
1343 assert_eq!(u32::from_le_bytes(vbuf), 1);
1344
1345 let new_params = vec![
1346 ("y/weight".to_string(), Parameter::new(
1347 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1348 ];
1349
1350 let mut out = Vec::new();
1351 migrate_checkpoint(
1352 &mut std::io::Cursor::new(&ckpt), &mut out,
1353 &new_params, &[],
1354 ).unwrap();
1355
1356 let mut peek2 = std::io::Cursor::new(&out);
1358 std::io::Read::read_exact(&mut peek2, &mut magic).unwrap();
1359 assert_eq!(&magic, b"FDLC");
1360 std::io::Read::read_exact(&mut peek2, &mut vbuf).unwrap();
1361 assert_eq!(u32::from_le_bytes(vbuf), VERSION); }
1363
1364 #[test]
1365 fn test_migrate_file_roundtrip() {
1366 let old_params = vec![
1367 ("old/weight".to_string(), Parameter::new(
1368 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1369 ];
1370 let dir = std::env::temp_dir();
1371 let src = dir.join("test_migrate_src.fdl");
1372 let dst = dir.join("test_migrate_dst.fdl");
1373
1374 {
1376 let f = std::fs::File::create(&src).unwrap();
1377 let mut w = std::io::BufWriter::new(f);
1378 save_checkpoint_versioned(&mut w, 1, &old_params, &[]);
1379 }
1380
1381 let new_params = vec![
1382 ("new/weight".to_string(), Parameter::new(
1383 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1384 ];
1385
1386 let report = migrate_checkpoint_file(
1387 src.to_str().unwrap(),
1388 dst.to_str().unwrap(),
1389 &new_params, &[],
1390 ).unwrap();
1391 assert_eq!(report.remapped.len(), 1);
1392 assert!(report.is_complete());
1393
1394 let vp = vec![
1396 ("new/weight".to_string(), Parameter::new(
1397 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1398 ];
1399 let load_report = load_checkpoint_file(
1400 dst.to_str().unwrap(), &vp, &[], None,
1401 ).unwrap();
1402 assert_eq!(load_report.loaded.len(), 1);
1403
1404 let expected = old_params[0].1.variable.data().to_f32_vec().unwrap();
1406 let got = vp[0].1.variable.data().to_f32_vec().unwrap();
1407 assert_eq!(expected, got);
1408
1409 std::fs::remove_file(src).ok();
1410 std::fs::remove_file(dst).ok();
1411 }
1412
1413 #[test]
1414 fn test_migrate_display() {
1415 let report = MigrateReport {
1416 unchanged: vec!["shared/weight".to_string()],
1417 remapped: vec![("old/bias".to_string(), "new/bias".to_string())],
1418 dropped: vec!["removed/weight".to_string()],
1419 missing: vec!["added/weight".to_string()],
1420 };
1421 let text = format!("{}", report);
1422 assert!(text.contains("unchanged (1)"));
1423 assert!(text.contains("remapped (1)"));
1424 assert!(text.contains("old/bias -> new/bias"));
1425 assert!(text.contains("dropped (1)"));
1426 assert!(text.contains("missing (1)"));
1427 }
1428
1429 #[test]
1430 fn test_checkpoint_version_peek() {
1431 let params = make_named_params(&[(4, 8)]);
1432 let dir = std::env::temp_dir();
1433 let path = dir.join("test_version_peek.fdl");
1434 save_checkpoint_file(path.to_str().unwrap(), ¶ms, &[], None).unwrap();
1435
1436 let v = checkpoint_version(path.to_str().unwrap()).unwrap();
1437 assert_eq!(v, VERSION);
1438
1439 std::fs::remove_file(path).ok();
1440 }
1441
1442 #[test]
1443 fn test_load_accepts_v1() {
1444 let params = vec![
1446 ("x/weight".to_string(), Parameter::new(
1447 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1448 ];
1449 let mut ckpt = Vec::new();
1450 save_checkpoint_versioned(&mut ckpt, 1, ¶ms, &[]);
1451
1452 let load_params = vec![
1453 ("x/weight".to_string(), Parameter::new(
1454 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1455 ];
1456 let mut cursor = std::io::Cursor::new(&ckpt);
1457 let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
1458 assert_eq!(report.loaded.len(), 1);
1459
1460 let expected = params[0].1.variable.data().to_f32_vec().unwrap();
1461 let got = load_params[0].1.variable.data().to_f32_vec().unwrap();
1462 assert_eq!(expected, got);
1463 }
1464
1465 #[test]
1468 fn test_truncated_checkpoint_header_only() {
1469 let mut buf = Vec::new();
1471 buf.extend_from_slice(&MAGIC);
1472 buf.extend_from_slice(&VERSION.to_le_bytes());
1473 buf.extend_from_slice(&[0u8; HASH_LEN]);
1474 buf.extend_from_slice(&5u32.to_le_bytes());
1476
1477 let params = make_named_params(&[(4, 8)]);
1478 let mut cursor = std::io::Cursor::new(&buf);
1479 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1480 assert!(result.is_err(), "truncated checkpoint should return Err, not panic");
1481 let msg = format!("{}", result.unwrap_err());
1482 assert!(msg.contains("io:"), "should be an IO error: {}", msg);
1483 }
1484
1485 #[test]
1486 fn test_truncated_checkpoint_mid_entry() {
1487 let params = make_named_params(&[(4, 8)]);
1489 let mut full = Vec::new();
1490 save_checkpoint(&mut full, ¶ms, &[], None).unwrap();
1491
1492 let truncated = full[..50.min(full.len())].to_vec();
1495
1496 let load_params = make_named_params(&[(4, 8)]);
1497 let mut cursor = std::io::Cursor::new(&truncated);
1498 let result = load_checkpoint(&mut cursor, &load_params, &[], None);
1499 assert!(result.is_err(), "truncated mid-entry should return Err");
1500 }
1501
1502 #[test]
1503 fn test_empty_file() {
1504 let buf: Vec<u8> = Vec::new();
1506 let params = make_named_params(&[(4, 8)]);
1507 let mut cursor = std::io::Cursor::new(&buf);
1508 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1509 assert!(result.is_err(), "empty file should return Err");
1510 }
1511
1512 #[test]
1513 fn test_invalid_magic_bytes() {
1514 let mut buf = Vec::new();
1515 buf.extend_from_slice(b"JUNK"); buf.extend_from_slice(&VERSION.to_le_bytes());
1517 buf.extend_from_slice(&[0u8; HASH_LEN]);
1518 buf.extend_from_slice(&0u32.to_le_bytes());
1519
1520 let params = make_named_params(&[(4, 8)]);
1521 let mut cursor = std::io::Cursor::new(&buf);
1522 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1523 assert!(result.is_err());
1524 let msg = format!("{}", result.unwrap_err());
1525 assert!(msg.contains("bad magic"), "error should mention bad magic: {}", msg);
1526 }
1527
1528 #[test]
1529 fn test_invalid_magic_checkpoint_version() {
1530 let dir = std::env::temp_dir();
1532 let path = dir.join("test_bad_magic_version.fdl");
1533 std::fs::write(&path, b"NOT_FDLC_data").unwrap();
1534
1535 let result = checkpoint_version(path.to_str().unwrap());
1536 assert!(result.is_err());
1537 let msg = format!("{}", result.unwrap_err());
1538 assert!(msg.contains("bad magic"), "error: {}", msg);
1539
1540 std::fs::remove_file(path).ok();
1541 }
1542
1543 #[test]
1544 fn test_unsupported_version_high() {
1545 let mut buf = Vec::new();
1546 buf.extend_from_slice(&MAGIC);
1547 buf.extend_from_slice(&99u32.to_le_bytes()); buf.extend_from_slice(&[0u8; HASH_LEN]);
1549 buf.extend_from_slice(&0u32.to_le_bytes());
1550
1551 let params = make_named_params(&[(4, 8)]);
1552 let mut cursor = std::io::Cursor::new(&buf);
1553 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1554 assert!(result.is_err());
1555 let msg = format!("{}", result.unwrap_err());
1556 assert!(msg.contains("unsupported checkpoint version"), "error: {}", msg);
1557 assert!(msg.contains("99"), "should mention version 99: {}", msg);
1558 }
1559
1560 #[test]
1561 fn test_unsupported_version_zero() {
1562 let mut buf = Vec::new();
1564 buf.extend_from_slice(&MAGIC);
1565 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&[0u8; HASH_LEN]);
1567 buf.extend_from_slice(&0u32.to_le_bytes());
1568
1569 let params = make_named_params(&[(4, 8)]);
1570 let mut cursor = std::io::Cursor::new(&buf);
1571 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1572 assert!(result.is_err());
1573 let msg = format!("{}", result.unwrap_err());
1574 assert!(msg.contains("unsupported checkpoint version"), "error: {}", msg);
1575 }
1576
1577 #[test]
1578 fn test_hash_mismatch_both_nonzero() {
1579 let params = make_named_params(&[(4, 8)]);
1581 let hash_a = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef";
1582 let hash_b = "fedcba9876543210fedcba9876543210fedcba9876543210fedcba9876543210";
1583
1584 let mut buf = Vec::new();
1585 save_checkpoint(&mut buf, ¶ms, &[], Some(hash_a)).unwrap();
1586
1587 let load_params = make_named_params(&[(4, 8)]);
1588 let mut cursor = std::io::Cursor::new(&buf);
1589 let result = load_checkpoint(&mut cursor, &load_params, &[], Some(hash_b));
1590 assert!(result.is_err());
1591 let msg = format!("{}", result.unwrap_err());
1592 assert!(msg.contains("architecture mismatch"), "error: {}", msg);
1593 assert!(msg.contains(hash_b), "should show expected hash: {}", msg);
1595 }
1596
1597 #[test]
1598 fn test_zero_entries_empty_model() {
1599 let mut buf = Vec::new();
1601 save_checkpoint(&mut buf, &[], &[], None).unwrap();
1602
1603 let mut cursor = std::io::Cursor::new(&buf);
1605 let report = load_checkpoint(&mut cursor, &[], &[], None).unwrap();
1606 assert!(report.loaded.is_empty());
1607 assert!(report.skipped.is_empty());
1608 assert!(report.missing.is_empty());
1609 }
1610
1611 #[test]
1612 fn test_zero_entries_nonempty_model() {
1613 let mut buf = Vec::new();
1615 save_checkpoint(&mut buf, &[], &[], None).unwrap();
1616
1617 let load_params = make_named_params(&[(4, 8)]);
1618 let mut cursor = std::io::Cursor::new(&buf);
1619 let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
1620 assert!(report.loaded.is_empty());
1621 assert!(report.skipped.is_empty());
1622 assert_eq!(report.missing.len(), 1, "model param should be reported as missing");
1623 }
1624
1625 #[test]
1626 fn test_shape_mismatch_transposed() {
1627 let params = vec![
1629 ("layer/weight".to_string(), Parameter::new(
1630 Tensor::randn(&[4, 8], crate::tensor::test_opts()).unwrap(), "weight")),
1631 ];
1632 let mut buf = Vec::new();
1633 save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
1634
1635 let wrong_params = vec![
1636 ("layer/weight".to_string(), Parameter::new(
1637 Tensor::randn(&[8, 4], crate::tensor::test_opts()).unwrap(), "weight")),
1638 ];
1639 let mut cursor = std::io::Cursor::new(&buf);
1640 let result = load_checkpoint(&mut cursor, &wrong_params, &[], None);
1641 assert!(result.is_err(), "transposed shape should be a mismatch error");
1642 let msg = format!("{}", result.unwrap_err());
1643 assert!(msg.contains("shape mismatch"), "error: {}", msg);
1644 assert!(msg.contains("[4, 8]"), "should show checkpoint shape: {}", msg);
1645 assert!(msg.contains("[8, 4]"), "should show model shape: {}", msg);
1646 }
1647
1648 #[test]
1649 fn test_dtype_mismatch_auto_cast() {
1650 let f32_param = vec![
1652 ("layer/weight".to_string(), Parameter::new(
1653 Tensor::ones(&[2, 3], crate::tensor::test_opts()).unwrap(), "weight")),
1654 ];
1655 let mut buf = Vec::new();
1656 save_checkpoint(&mut buf, &f32_param, &[], None).unwrap();
1657
1658 let f64_param = vec![
1660 ("layer/weight".to_string(), Parameter::new(
1661 Tensor::zeros(&[2, 3], TensorOptions {
1662 dtype: DType::Float64,
1663 device: crate::tensor::test_device(),
1664 }).unwrap(), "weight")),
1665 ];
1666 let mut cursor = std::io::Cursor::new(&buf);
1667 let report = load_checkpoint(&mut cursor, &f64_param, &[], None).unwrap();
1668 assert_eq!(report.loaded.len(), 1, "dtype auto-cast should succeed");
1669
1670 let loaded = f64_param[0].1.variable.data();
1672 assert_eq!(loaded.dtype(), DType::Float64);
1673 let vals = loaded.to_f64_vec().unwrap();
1674 for v in vals {
1675 assert!((v - 1.0).abs() < 1e-6, "expected ~1.0, got {}", v);
1676 }
1677 }
1678
1679 #[test]
1680 fn test_dtype_mismatch_buffer_auto_cast() {
1681 let f32_buffers = vec![
1683 ("norm/running_mean".to_string(), Buffer::new(
1684 Tensor::ones(&[8], crate::tensor::test_opts()).unwrap(), "running_mean")),
1685 ];
1686 let mut buf = Vec::new();
1687 save_checkpoint(&mut buf, &[], &f32_buffers, None).unwrap();
1688
1689 let f64_buffers = vec![
1690 ("norm/running_mean".to_string(), Buffer::new(
1691 Tensor::zeros(&[8], TensorOptions {
1692 dtype: DType::Float64,
1693 device: crate::tensor::test_device(),
1694 }).unwrap(), "running_mean")),
1695 ];
1696 let mut cursor = std::io::Cursor::new(&buf);
1697 let report = load_checkpoint(&mut cursor, &[], &f64_buffers, None).unwrap();
1698 assert_eq!(report.loaded.len(), 1);
1699 assert_eq!(f64_buffers[0].1.get().dtype(), DType::Float64);
1700 let vals = f64_buffers[0].1.get().to_f64_vec().unwrap();
1701 for v in vals {
1702 assert!((v - 1.0).abs() < 1e-6);
1703 }
1704 }
1705
1706 #[test]
1707 fn test_compressed_roundtrip_with_hash() {
1708 let params = make_named_params(&[(8, 16)]);
1710 let hash = "abcdef0123456789abcdef0123456789abcdef0123456789abcdef0123456789";
1711
1712 let dir = std::env::temp_dir();
1713 let gz_path = dir.join("test_ckpt_hash_gz.fdl.gz");
1714 let path_str = gz_path.to_str().unwrap();
1715
1716 save_checkpoint_file(path_str, ¶ms, &[], Some(hash)).unwrap();
1717
1718 let load_params = make_named_params(&[(8, 16)]);
1720 let report = load_checkpoint_file(path_str, &load_params, &[], Some(hash)).unwrap();
1721 assert_eq!(report.loaded.len(), 1);
1722
1723 let bad_hash = "1111111111111111111111111111111111111111111111111111111111111111";
1725 let load_params2 = make_named_params(&[(8, 16)]);
1726 let result = load_checkpoint_file(path_str, &load_params2, &[], Some(bad_hash));
1727 assert!(result.is_err());
1728
1729 std::fs::remove_file(gz_path).ok();
1730 }
1731
1732 #[test]
1733 fn test_corrupted_gz_file() {
1734 let dir = std::env::temp_dir();
1736 let path = dir.join("test_corrupt.fdl.gz");
1737 std::fs::write(&path, b"\x1f\x8b\x08\x00GARBAGE_NOT_VALID_GZ").unwrap();
1739
1740 let params = make_named_params(&[(4, 8)]);
1741 let result = load_checkpoint_file(path.to_str().unwrap(), ¶ms, &[], None);
1742 assert!(result.is_err(), "corrupted gz should return Err");
1743
1744 std::fs::remove_file(path).ok();
1745 }
1746
1747 #[test]
1748 fn test_unknown_dtype_tag() {
1749 let mut buf = Vec::new();
1751 buf.extend_from_slice(&MAGIC);
1752 buf.extend_from_slice(&VERSION.to_le_bytes());
1753 buf.extend_from_slice(&[0u8; HASH_LEN]);
1754 buf.extend_from_slice(&1u32.to_le_bytes()); let name = b"layer/weight";
1758 buf.extend_from_slice(&(name.len() as u32).to_le_bytes());
1759 buf.extend_from_slice(name);
1760
1761 buf.extend_from_slice(&1u32.to_le_bytes());
1763 buf.extend_from_slice(&4i64.to_le_bytes());
1764
1765 buf.push(255);
1767
1768 buf.extend_from_slice(&16u64.to_le_bytes());
1770 buf.extend_from_slice(&[0u8; 16]);
1771
1772 let params = vec![
1773 ("layer/weight".to_string(), Parameter::new(
1774 Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(), "weight")),
1775 ];
1776 let mut cursor = std::io::Cursor::new(&buf);
1777 let result = load_checkpoint(&mut cursor, ¶ms, &[], None);
1778 assert!(result.is_err());
1779 let msg = format!("{}", result.unwrap_err());
1780 assert!(msg.contains("unknown dtype tag"), "error: {}", msg);
1781 }
1782
1783 #[test]
1784 fn test_checkpoint_keys_peeks_names_without_loading_data() {
1785 let params = vec![
1786 (
1787 "encoder/layer/weight".to_string(),
1788 Parameter::new(
1789 Tensor::ones(&[4, 8], crate::tensor::test_opts()).unwrap(),
1790 "weight",
1791 ),
1792 ),
1793 (
1794 "pooler/dense/weight".to_string(),
1795 Parameter::new(
1796 Tensor::ones(&[8, 8], crate::tensor::test_opts()).unwrap(),
1797 "weight",
1798 ),
1799 ),
1800 ];
1801 let buffers = vec![(
1802 "encoder/layer/running_mean".to_string(),
1803 Buffer::new(
1804 Tensor::zeros(&[8], crate::tensor::test_opts()).unwrap(),
1805 "running_mean",
1806 ),
1807 )];
1808
1809 let dir = std::env::temp_dir();
1810 let path = dir.join("test_checkpoint_keys_peek.fdl");
1811 let path_str = path.to_str().unwrap();
1812
1813 save_checkpoint_file(path_str, ¶ms, &buffers, None).unwrap();
1814 let keys = checkpoint_keys(path_str).unwrap();
1815 assert_eq!(
1816 keys,
1817 vec![
1818 "encoder/layer/weight".to_string(),
1819 "pooler/dense/weight".to_string(),
1820 "encoder/layer/running_mean".to_string(),
1821 ],
1822 "params first then buffers, in declaration order",
1823 );
1824
1825 std::fs::remove_file(path_str).ok();
1826 }
1827
1828 #[test]
1829 fn test_checkpoint_keys_handles_gzip() {
1830 let params = vec![(
1831 "x/w".to_string(),
1832 Parameter::new(
1833 Tensor::ones(&[2, 2], crate::tensor::test_opts()).unwrap(),
1834 "w",
1835 ),
1836 )];
1837 let dir = std::env::temp_dir();
1838 let path = dir.join("test_checkpoint_keys_gz.fdl.gz");
1839 let path_str = path.to_str().unwrap();
1840
1841 save_checkpoint_file(path_str, ¶ms, &[], None).unwrap();
1842 let keys = checkpoint_keys(path_str).unwrap();
1843 assert_eq!(keys, vec!["x/w".to_string()]);
1844
1845 std::fs::remove_file(path_str).ok();
1846 }
1847
1848 #[test]
1849 fn test_checkpoint_keys_rejects_bad_magic() {
1850 let dir = std::env::temp_dir();
1851 let path = dir.join("test_checkpoint_keys_bad.fdl");
1852 std::fs::write(&path, b"NOPEnotacheckpoint").unwrap();
1853 let err = checkpoint_keys(path.to_str().unwrap()).unwrap_err();
1854 assert!(format!("{err}").contains("bad magic"), "got: {err}");
1855 std::fs::remove_file(path).ok();
1856 }
1857}