use std::io::{Read, Write};
use crate::tensor::{Device, DType, Result, Tensor, TensorError};
use super::buffer::Buffer;
use super::parameter::Parameter;
pub(crate) const MAGIC: [u8; 4] = *b"FDLC";
pub(crate) const VERSION: u32 = 1;
pub(crate) const HASH_LEN: usize = 32;
#[derive(Debug, Clone)]
pub struct LoadReport {
pub loaded: Vec<String>,
pub skipped: Vec<String>,
pub missing: Vec<String>,
}
pub fn save_checkpoint<W: Write>(
w: &mut W,
params: &[(String, Parameter)],
buffers: &[(String, Buffer)],
structural_hash: Option<&str>,
) -> Result<()> {
w.write_all(&MAGIC).map_err(io_err)?;
w.write_all(&VERSION.to_le_bytes()).map_err(io_err)?;
let hash_bytes = match structural_hash {
Some(hex) => hex_to_bytes(hex)?,
None => [0u8; HASH_LEN],
};
w.write_all(&hash_bytes).map_err(io_err)?;
let total = (params.len() + buffers.len()) as u32;
w.write_all(&total.to_le_bytes()).map_err(io_err)?;
for (name, p) in params {
let name_bytes = name.as_bytes();
w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
w.write_all(name_bytes).map_err(io_err)?;
write_tensor_data(w, &p.variable.data())?;
}
for (name, b) in buffers {
let name_bytes = name.as_bytes();
w.write_all(&(name_bytes.len() as u32).to_le_bytes()).map_err(io_err)?;
w.write_all(name_bytes).map_err(io_err)?;
write_tensor_data(w, &b.get())?;
}
Ok(())
}
pub fn load_checkpoint<R: Read>(
r: &mut R,
params: &[(String, Parameter)],
buffers: &[(String, Buffer)],
structural_hash: Option<&str>,
) -> Result<LoadReport> {
let mut magic = [0u8; 4];
r.read_exact(&mut magic).map_err(io_err)?;
if magic != MAGIC {
return Err(TensorError::new(
"invalid checkpoint: bad magic (expected .fdl checkpoint)"
));
}
let version = read_u32(r)?;
if version != 1 {
return Err(TensorError::new(&format!(
"unsupported checkpoint version {} (want 1)", version
)));
}
let mut file_hash = [0u8; HASH_LEN];
r.read_exact(&mut file_hash).map_err(io_err)?;
let file_nonzero = file_hash.iter().any(|&b| b != 0);
if let Some(expected_hex) = structural_hash {
let expected = hex_to_bytes(expected_hex)?;
let expected_nonzero = expected.iter().any(|&b| b != 0);
if file_nonzero && expected_nonzero && file_hash != expected {
return Err(TensorError::new(&format!(
"checkpoint architecture mismatch: file={} model={}",
bytes_to_hex(&file_hash),
expected_hex,
)));
}
}
let count = read_u32(r)? as usize;
let mut ckpt: std::collections::HashMap<String, (Vec<i64>, DType, Vec<u8>)> =
std::collections::HashMap::with_capacity(count);
for _ in 0..count {
let name_len = read_u32(r)? as usize;
let mut name_bytes = vec![0u8; name_len];
r.read_exact(&mut name_bytes).map_err(io_err)?;
let name = String::from_utf8_lossy(&name_bytes).into_owned();
let ndim = read_u32(r)? as usize;
let mut shape = vec![0i64; ndim];
for s in &mut shape { *s = read_i64(r)?; }
let mut tag = [0u8; 1];
r.read_exact(&mut tag).map_err(io_err)?;
let dtype = dtype_from_tag(tag[0])?;
let byte_count = read_u64(r)? as usize;
let mut raw = vec![0u8; byte_count];
r.read_exact(&mut raw).map_err(io_err)?;
ckpt.insert(name, (shape, dtype, raw));
}
let mut loaded = Vec::new();
let mut missing = Vec::new();
for (name, p) in params {
if let Some((shape, dtype, raw)) = ckpt.remove(name) {
let model_shape = p.variable.shape();
if shape != model_shape {
return Err(TensorError::new(&format!(
"parameter {:?}: shape mismatch: checkpoint={:?} model={:?}",
name, shape, model_shape
)));
}
let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
let model_dtype = p.variable.data().dtype();
let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
let dev = p.variable.data().device();
if dev != Device::CPU {
p.variable.set_data(t.to_device(dev)?);
} else {
p.variable.set_data(t);
}
loaded.push(name.clone());
} else {
missing.push(name.clone());
}
}
for (name, b) in buffers {
if let Some((shape, dtype, raw)) = ckpt.remove(name) {
let model_shape = b.shape();
if shape != model_shape {
return Err(TensorError::new(&format!(
"buffer {:?}: shape mismatch: checkpoint={:?} model={:?}",
name, shape, model_shape
)));
}
let t = tensor_from_raw_bytes(&raw, &shape, dtype)?;
let model_dtype = b.get().dtype();
let t = if t.dtype() != model_dtype { t.to_dtype(model_dtype)? } else { t };
let dev = b.device();
if dev != Device::CPU {
b.set(t.to_device(dev)?);
} else {
b.set(t);
}
loaded.push(name.clone());
} else {
missing.push(name.clone());
}
}
let skipped: Vec<String> = ckpt.into_keys().collect();
Ok(LoadReport { loaded, skipped, missing })
}
pub fn save_checkpoint_file(
path: &str,
params: &[(String, Parameter)],
buffers: &[(String, Buffer)],
structural_hash: Option<&str>,
) -> Result<()> {
let f = std::fs::File::create(path).map_err(io_err)?;
if path.ends_with(".gz") {
let mut w = flate2::write::GzEncoder::new(f, flate2::Compression::default());
save_checkpoint(&mut w, params, buffers, structural_hash)?;
w.finish().map_err(io_err)?;
Ok(())
} else {
let mut w = std::io::BufWriter::new(f);
save_checkpoint(&mut w, params, buffers, structural_hash)
}
}
pub fn load_checkpoint_file(
path: &str,
params: &[(String, Parameter)],
buffers: &[(String, Buffer)],
structural_hash: Option<&str>,
) -> Result<LoadReport> {
let f = std::fs::File::open(path).map_err(io_err)?;
if path.ends_with(".gz") {
let mut r = flate2::read::GzDecoder::new(f);
load_checkpoint(&mut r, params, buffers, structural_hash)
} else {
let mut r = std::io::BufReader::new(f);
load_checkpoint(&mut r, params, buffers, structural_hash)
}
}
pub(crate) fn write_tensor_state<W: Write>(w: &mut W, t: Option<&Tensor>) -> Result<()> {
match t {
None => {
w.write_all(&[0u8]).map_err(io_err)?;
}
Some(t) => {
w.write_all(&[1u8]).map_err(io_err)?;
write_tensor_data(w, t)?;
}
}
Ok(())
}
pub(crate) fn read_tensor_state<R: Read>(r: &mut R, device: Device) -> Result<Option<Tensor>> {
let mut present = [0u8; 1];
r.read_exact(&mut present).map_err(io_err)?;
if present[0] == 0 {
return Ok(None);
}
let t = read_tensor_data(r)?;
if device != Device::CPU {
Ok(Some(t.to_device(device)?))
} else {
Ok(Some(t))
}
}
fn dtype_tag(dtype: DType) -> u8 {
match dtype {
DType::Float16 => 1,
DType::BFloat16 => 2,
DType::Float32 => 3,
DType::Float64 => 4,
DType::Int32 => 5,
DType::Int64 => 6,
}
}
fn dtype_from_tag(tag: u8) -> Result<DType> {
match tag {
1 => Ok(DType::Float16),
2 => Ok(DType::BFloat16),
3 => Ok(DType::Float32),
4 => Ok(DType::Float64),
5 => Ok(DType::Int32),
6 => Ok(DType::Int64),
_ => Err(TensorError::new(&format!("unknown dtype tag: {}", tag))),
}
}
pub(crate) fn write_tensor_data<W: Write>(w: &mut W, t: &Tensor) -> Result<()> {
let shape = t.shape();
w.write_all(&(shape.len() as u32).to_le_bytes()).map_err(io_err)?;
for &s in &shape {
w.write_all(&s.to_le_bytes()).map_err(io_err)?;
}
let dtype = t.dtype();
w.write_all(&[dtype_tag(dtype)]).map_err(io_err)?;
let numel = t.numel() as usize;
let elem_size = dtype.element_size();
let byte_count = numel * elem_size;
let raw = copy_raw_bytes(t, byte_count)?;
w.write_all(&(byte_count as u64).to_le_bytes()).map_err(io_err)?;
w.write_all(&raw).map_err(io_err)?;
Ok(())
}
fn read_tensor_data<R: Read>(r: &mut R) -> Result<Tensor> {
let ndim = read_u32(r)? as usize;
let mut shape = vec![0i64; ndim];
for s in &mut shape {
*s = read_i64(r)?;
}
let mut tag = [0u8; 1];
r.read_exact(&mut tag).map_err(io_err)?;
let dtype = dtype_from_tag(tag[0])?;
let byte_count = read_u64(r)? as usize;
let mut raw = vec![0u8; byte_count];
r.read_exact(&mut raw).map_err(io_err)?;
tensor_from_raw_bytes(&raw, &shape, dtype)
}
fn copy_raw_bytes(t: &Tensor, byte_count: usize) -> Result<Vec<u8>> {
let mut buf = vec![0u8; byte_count];
let err = unsafe {
flodl_sys::flodl_copy_data(
t.raw(),
buf.as_mut_ptr() as *mut std::ffi::c_void,
byte_count as i64,
)
};
check_err_raw(err)?;
Ok(buf)
}
fn tensor_from_raw_bytes(raw: &[u8], shape: &[i64], dtype: DType) -> Result<Tensor> {
match dtype {
DType::Float32 => {
let data: Vec<f32> = raw.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect();
Tensor::from_f32(&data, shape, Device::CPU)
}
DType::Float64 => {
let data: Vec<f64> = raw.chunks_exact(8)
.map(|c| f64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
.collect();
Tensor::from_f64(&data, shape, Device::CPU)
}
DType::Int64 => {
let data: Vec<i64> = raw.chunks_exact(8)
.map(|c| i64::from_le_bytes([c[0], c[1], c[2], c[3], c[4], c[5], c[6], c[7]]))
.collect();
Tensor::from_i64(&data, shape, Device::CPU)
}
DType::Float16 | DType::BFloat16 | DType::Int32 => {
let mut shape_v = shape.to_vec();
let mut handle: flodl_sys::FlodlTensor = std::ptr::null_mut();
let (dev_type, dev_idx) = crate::tensor::Device::CPU.to_ffi();
let err = unsafe {
flodl_sys::flodl_from_blob(
raw.as_ptr() as *mut std::ffi::c_void,
shape_v.as_mut_ptr(),
shape_v.len() as i32,
dtype as i32,
dev_type, dev_idx,
&mut handle,
)
};
check_err_raw(err)?;
debug_assert!(!handle.is_null());
Ok(unsafe { Tensor::from_raw_handle(handle) })
}
}
}
pub(crate) fn io_err(e: impl std::fmt::Display) -> TensorError {
TensorError::new(&format!("io: {}", e))
}
fn check_err_raw(err: *mut i8) -> Result<()> {
if err.is_null() {
Ok(())
} else {
let msg = unsafe { std::ffi::CStr::from_ptr(err) }
.to_string_lossy()
.into_owned();
unsafe { flodl_sys::flodl_free_string(err) };
Err(TensorError::new(&msg))
}
}
fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
let mut buf = [0u8; 4];
r.read_exact(&mut buf).map_err(io_err)?;
Ok(u32::from_le_bytes(buf))
}
fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
let mut buf = [0u8; 8];
r.read_exact(&mut buf).map_err(io_err)?;
Ok(u64::from_le_bytes(buf))
}
fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
let mut buf = [0u8; 8];
r.read_exact(&mut buf).map_err(io_err)?;
Ok(i64::from_le_bytes(buf))
}
pub(crate) fn read_f64_le<R: Read>(r: &mut R) -> Result<f64> {
let mut buf = [0u8; 8];
r.read_exact(&mut buf).map_err(io_err)?;
Ok(f64::from_le_bytes(buf))
}
pub(crate) fn write_f64_le<W: Write>(w: &mut W, v: f64) -> Result<()> {
w.write_all(&v.to_le_bytes()).map_err(io_err)?;
Ok(())
}
pub(crate) fn write_u32_le<W: Write>(w: &mut W, v: u32) -> Result<()> {
w.write_all(&v.to_le_bytes()).map_err(io_err)?;
Ok(())
}
pub(crate) fn write_i64_le<W: Write>(w: &mut W, v: i64) -> Result<()> {
w.write_all(&v.to_le_bytes()).map_err(io_err)?;
Ok(())
}
pub(crate) fn read_u32_le<R: Read>(r: &mut R) -> Result<u32> {
read_u32(r)
}
pub(crate) fn read_i64_le<R: Read>(r: &mut R) -> Result<i64> {
read_i64(r)
}
fn hex_to_bytes(hex: &str) -> Result<[u8; HASH_LEN]> {
if hex.len() != HASH_LEN * 2 {
return Err(TensorError::new(&format!(
"expected {} hex chars, got {}",
HASH_LEN * 2,
hex.len()
)));
}
let mut out = [0u8; HASH_LEN];
for (i, chunk) in hex.as_bytes().chunks(2).enumerate() {
let hi = hex_nibble(chunk[0])?;
let lo = hex_nibble(chunk[1])?;
out[i] = (hi << 4) | lo;
}
Ok(out)
}
fn hex_nibble(b: u8) -> Result<u8> {
match b {
b'0'..=b'9' => Ok(b - b'0'),
b'a'..=b'f' => Ok(b - b'a' + 10),
b'A'..=b'F' => Ok(b - b'A' + 10),
_ => Err(TensorError::new(&format!("invalid hex byte: {}", b))),
}
}
fn bytes_to_hex(bytes: &[u8]) -> String {
let mut s = String::with_capacity(bytes.len() * 2);
for &b in bytes {
use std::fmt::Write;
let _ = write!(s, "{:02x}", b);
}
s
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorOptions;
fn make_named_params(sizes: &[(i64, i64)]) -> Vec<(String, Parameter)> {
sizes.iter().enumerate().map(|(i, &(rows, cols))| {
let t = Tensor::randn(&[rows, cols], TensorOptions {
dtype: DType::Float32,
device: crate::tensor::test_device(),
}).unwrap();
let name = format!("layer_{}/weight", i);
(name.clone(), Parameter::new(t, "weight"))
}).collect()
}
fn make_named_buffers(sizes: &[i64]) -> Vec<(String, Buffer)> {
sizes.iter().enumerate().map(|(i, &features)| {
let t = Tensor::randn(&[features], TensorOptions {
dtype: DType::Float32,
device: crate::tensor::test_device(),
}).unwrap();
let name = format!("bn_{}/running_mean", i);
(name.clone(), Buffer::new(t, "running_mean"))
}).collect()
}
#[test]
fn test_named_roundtrip() {
let params = make_named_params(&[(4, 8), (8, 2)]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
let load_params = make_named_params(&[(4, 8), (8, 2)]);
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, &load_params, &[], None).unwrap();
assert_eq!(report.loaded.len(), 2);
assert!(report.skipped.is_empty());
assert!(report.missing.is_empty());
for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
let src_data = src.variable.data().to_f32_vec().unwrap();
let dst_data = dst.variable.data().to_f32_vec().unwrap();
assert_eq!(src_data, dst_data);
}
}
#[test]
fn test_buffer_roundtrip() {
let params = make_named_params(&[(4, 8)]);
let buffers = make_named_buffers(&[8]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &buffers, None).unwrap();
let load_params = make_named_params(&[(4, 8)]);
let load_buffers = make_named_buffers(&[8]);
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, &load_params, &load_buffers, None).unwrap();
assert_eq!(report.loaded.len(), 2); assert!(report.skipped.is_empty());
assert!(report.missing.is_empty());
let src_data = buffers[0].1.get().to_f32_vec().unwrap();
let dst_data = load_buffers[0].1.get().to_f32_vec().unwrap();
assert_eq!(src_data, dst_data);
}
#[test]
fn test_named_partial_load() {
let params_3 = make_named_params(&[(4, 8), (8, 4), (4, 2)]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms_3, &[], None).unwrap();
let mut params_4 = make_named_params(&[(4, 8), (8, 4), (4, 2), (2, 1)]);
params_4[3].0 = "extra/weight".to_string();
let before_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, ¶ms_4, &[], None).unwrap();
assert_eq!(report.loaded.len(), 3);
assert_eq!(report.missing.len(), 1);
assert_eq!(report.missing[0], "extra/weight");
assert!(report.skipped.is_empty());
let after_extra = params_4[3].1.variable.data().to_f32_vec().unwrap();
assert_eq!(before_extra, after_extra);
}
#[test]
fn test_named_skipped_checkpoint_params() {
let params = make_named_params(&[(4, 8), (8, 2)]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
let model = vec![params[0].clone()];
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, &model, &[], None).unwrap();
assert_eq!(report.loaded.len(), 1);
assert_eq!(report.skipped.len(), 1);
assert!(report.missing.is_empty());
}
#[test]
fn test_named_shape_mismatch_error() {
let params = make_named_params(&[(4, 8)]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
let wrong_shape = vec![(
"layer_0/weight".to_string(),
Parameter::new(
Tensor::randn(&[4, 4], TensorOptions {
dtype: DType::Float32,
device: crate::tensor::test_device(),
}).unwrap(),
"weight",
),
)];
let mut cursor = std::io::Cursor::new(&buf);
let result = load_checkpoint(&mut cursor, &wrong_shape, &[], None);
assert!(result.is_err(), "shape mismatch should be an error");
let err_msg = format!("{}", result.unwrap_err());
assert!(err_msg.contains("shape mismatch"), "error should mention shape: {}", err_msg);
}
#[test]
fn test_buffer_shape_mismatch_error() {
let buffers = make_named_buffers(&[8]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, &[], &buffers, None).unwrap();
let wrong_buffers = vec![(
"bn_0/running_mean".to_string(),
Buffer::new(
Tensor::zeros(&[4], crate::tensor::test_opts()).unwrap(),
"running_mean",
),
)];
let mut cursor = std::io::Cursor::new(&buf);
let result = load_checkpoint(&mut cursor, &[], &wrong_buffers, None);
assert!(result.is_err());
assert!(format!("{}", result.unwrap_err()).contains("shape mismatch"));
}
#[test]
fn test_compressed_roundtrip() {
let params = make_named_params(&[(16, 32), (32, 8)]);
let buffers = make_named_buffers(&[32]);
let dir = std::env::temp_dir();
let gz_path = dir.join("test_ckpt_v2.fdl.gz");
let plain_path = dir.join("test_ckpt_v2.fdl");
let gz = gz_path.to_str().unwrap();
let plain = plain_path.to_str().unwrap();
save_checkpoint_file(gz, ¶ms, &buffers, None).unwrap();
save_checkpoint_file(plain, ¶ms, &buffers, None).unwrap();
let gz_size = std::fs::metadata(gz).unwrap().len();
let plain_size = std::fs::metadata(plain).unwrap().len();
assert!(gz_size < plain_size, "gz={} should be < plain={}", gz_size, plain_size);
let load_params = make_named_params(&[(16, 32), (32, 8)]);
let load_buffers = make_named_buffers(&[32]);
let report = load_checkpoint_file(gz, &load_params, &load_buffers, None).unwrap();
assert_eq!(report.loaded.len(), 3);
for ((_, src), (_, dst)) in params.iter().zip(load_params.iter()) {
assert_eq!(src.variable.data().to_f32_vec().unwrap(),
dst.variable.data().to_f32_vec().unwrap());
}
let src_buf = buffers[0].1.get().to_f32_vec().unwrap();
let dst_buf = load_buffers[0].1.get().to_f32_vec().unwrap();
assert_eq!(src_buf, dst_buf);
std::fs::remove_file(gz).ok();
std::fs::remove_file(plain).ok();
}
#[test]
fn test_hash_roundtrip() {
let params = make_named_params(&[(4, 8)]);
let hash = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], Some(hash)).unwrap();
let load_params = make_named_params(&[(4, 8)]);
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
assert_eq!(report.loaded.len(), 1);
}
#[test]
fn test_hash_mismatch_error() {
let params = make_named_params(&[(4, 8)]);
let hash_a = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2a3b4c5d6a7b8c9d0e1f2";
let hash_b = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], Some(hash_a)).unwrap();
let load_params = make_named_params(&[(4, 8)]);
let mut cursor = std::io::Cursor::new(&buf);
let result = load_checkpoint(&mut cursor, &load_params, &[], Some(hash_b));
assert!(result.is_err());
let msg = format!("{}", result.unwrap_err());
assert!(msg.contains("architecture mismatch"), "error: {}", msg);
}
#[test]
fn test_zero_hash_skips_validation() {
let params = make_named_params(&[(4, 8)]);
let mut buf = Vec::new();
save_checkpoint(&mut buf, ¶ms, &[], None).unwrap();
let hash = "ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff";
let load_params = make_named_params(&[(4, 8)]);
let mut cursor = std::io::Cursor::new(&buf);
let report = load_checkpoint(&mut cursor, &load_params, &[], Some(hash)).unwrap();
assert_eq!(report.loaded.len(), 1);
let mut buf2 = Vec::new();
save_checkpoint(&mut buf2, ¶ms, &[], Some(hash)).unwrap();
let load_params2 = make_named_params(&[(4, 8)]);
let mut cursor2 = std::io::Cursor::new(&buf2);
let report2 = load_checkpoint(&mut cursor2, &load_params2, &[], None).unwrap();
assert_eq!(report2.loaded.len(), 1);
}
}