use std::io::{self, Read, Write};
pub const MAGIC: [u8; 4] = *b"APRT";
pub const HEADER_SIZE: usize = 12;
pub const WHOLE_MODEL_LAYER: u32 = 0xFFFFFFFF;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct TensorHeader {
pub layer: u32,
pub dim_product: u32,
}
impl TensorHeader {
#[must_use]
pub fn is_whole_model(&self) -> bool {
self.layer == WHOLE_MODEL_LAYER
}
#[must_use]
pub fn total_file_size(&self) -> usize {
HEADER_SIZE + (self.dim_product as usize) * 4
}
}
#[derive(Debug, thiserror::Error)]
pub enum HeaderError {
#[error("save-tensor header truncated: got {got} bytes, need {HEADER_SIZE}")]
Truncated {
got: usize,
},
#[error("save-tensor magic mismatch: got {got:?}, expected {MAGIC:?}")]
BadMagic {
got: [u8; 4],
},
}
#[derive(Debug, thiserror::Error)]
pub enum ReadError {
#[error("save-tensor I/O error: {0}")]
Io(#[from] io::Error),
#[error("save-tensor header: {0}")]
Header(#[from] HeaderError),
#[error("save-tensor body length mismatch: header says {expected} bytes, got {got}")]
BodyLengthMismatch {
expected: usize,
got: usize,
},
}
#[must_use]
pub fn write_header(layer: u32, dim_product: u32) -> [u8; HEADER_SIZE] {
let mut buf = [0u8; HEADER_SIZE];
buf[0..4].copy_from_slice(&MAGIC);
buf[4..8].copy_from_slice(&layer.to_le_bytes());
buf[8..12].copy_from_slice(&dim_product.to_le_bytes());
buf
}
pub fn parse_header(bytes: &[u8]) -> Result<TensorHeader, HeaderError> {
if bytes.len() < HEADER_SIZE {
return Err(HeaderError::Truncated { got: bytes.len() });
}
let mut magic = [0u8; 4];
magic.copy_from_slice(&bytes[0..4]);
if magic != MAGIC {
return Err(HeaderError::BadMagic { got: magic });
}
let mut layer_bytes = [0u8; 4];
layer_bytes.copy_from_slice(&bytes[4..8]);
let layer = u32::from_le_bytes(layer_bytes);
let mut dp_bytes = [0u8; 4];
dp_bytes.copy_from_slice(&bytes[8..12]);
let dim_product = u32::from_le_bytes(dp_bytes);
Ok(TensorHeader { layer, dim_product })
}
pub fn write_tensor_file<W: Write>(w: &mut W, layer: u32, values: &[f32]) -> io::Result<()> {
let dim_product = u32::try_from(values.len())
.expect("save-tensor: values.len() exceeds u32::MAX (4 GiB elements)");
w.write_all(&write_header(layer, dim_product))?;
for &v in values {
w.write_all(&v.to_le_bytes())?;
}
Ok(())
}
pub fn read_tensor_file<R: Read>(r: &mut R) -> Result<(TensorHeader, Vec<f32>), ReadError> {
let mut header_bytes = [0u8; HEADER_SIZE];
r.read_exact(&mut header_bytes)?;
let header = parse_header(&header_bytes)?;
let n = header.dim_product as usize;
let mut body = vec![0u8; n * 4];
r.read_exact(&mut body)?;
let mut values = Vec::with_capacity(n);
for chunk in body.chunks_exact(4) {
let mut v = [0u8; 4];
v.copy_from_slice(chunk);
values.push(f32::from_le_bytes(v));
}
if values.len() != n {
return Err(ReadError::BodyLengthMismatch {
expected: n,
got: values.len(),
});
}
Ok((header, values))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn magic_is_aprt() {
assert_eq!(&MAGIC, b"APRT");
}
#[test]
fn header_size_is_twelve() {
assert_eq!(HEADER_SIZE, 12);
}
#[test]
fn falsify_apr_trace_save_004_header_format_layer_zero() {
let bytes = write_header(0, 7);
assert_eq!(&bytes[0..4], b"APRT");
assert_eq!(&bytes[4..8], &0u32.to_le_bytes());
assert_eq!(&bytes[8..12], &7u32.to_le_bytes());
}
#[test]
fn falsify_apr_trace_save_004_header_format_arbitrary_layer() {
let bytes = write_header(3, 3584);
assert_eq!(&bytes[0..4], b"APRT");
let mut layer_bytes = [0u8; 4];
layer_bytes.copy_from_slice(&bytes[4..8]);
assert_eq!(u32::from_le_bytes(layer_bytes), 3);
let mut dp = [0u8; 4];
dp.copy_from_slice(&bytes[8..12]);
assert_eq!(u32::from_le_bytes(dp), 3584);
}
#[test]
fn header_roundtrip() {
let original = TensorHeader {
layer: 42,
dim_product: 1024,
};
let bytes = write_header(original.layer, original.dim_product);
let parsed = parse_header(&bytes).expect("parse must succeed");
assert_eq!(parsed, original);
}
#[test]
fn header_roundtrip_whole_model() {
let original = TensorHeader {
layer: WHOLE_MODEL_LAYER,
dim_product: 151_936,
};
let bytes = write_header(original.layer, original.dim_product);
let parsed = parse_header(&bytes).expect("parse must succeed");
assert_eq!(parsed, original);
assert!(parsed.is_whole_model());
}
#[test]
fn parse_header_rejects_short_input() {
let bytes = vec![0u8; 11];
let result = parse_header(&bytes);
assert!(matches!(result, Err(HeaderError::Truncated { got: 11 })));
}
#[test]
fn parse_header_rejects_bad_magic() {
let mut bytes = vec![0u8; 12];
bytes[0..4].copy_from_slice(b"GGUF");
let result = parse_header(&bytes);
assert!(matches!(result, Err(HeaderError::BadMagic { got: g }) if &g == b"GGUF"));
}
#[test]
fn parse_header_ignores_trailing_body() {
let mut bytes = write_header(0, 2).to_vec();
bytes.extend_from_slice(&1.0_f32.to_le_bytes());
bytes.extend_from_slice(&2.0_f32.to_le_bytes());
let parsed = parse_header(&bytes).expect("parse must succeed");
assert_eq!(parsed.layer, 0);
assert_eq!(parsed.dim_product, 2);
}
#[test]
fn total_file_size_per_layer() {
let h = TensorHeader {
layer: 0,
dim_product: 100,
};
assert_eq!(h.total_file_size(), 12 + 400);
}
#[test]
fn total_file_size_empty() {
let h = TensorHeader {
layer: 0,
dim_product: 0,
};
assert_eq!(h.total_file_size(), 12);
}
#[test]
fn write_and_read_tensor_file_roundtrip() {
let values: Vec<f32> = vec![1.0, 2.0, 3.0, -4.0, 0.0, 5.5];
let mut buf = Vec::new();
write_tensor_file(&mut buf, 5, &values).expect("write must succeed");
assert_eq!(buf.len(), HEADER_SIZE + values.len() * 4);
let mut cursor = std::io::Cursor::new(&buf);
let (header, read_values) = read_tensor_file(&mut cursor).expect("read must succeed");
assert_eq!(header.layer, 5);
assert_eq!(header.dim_product as usize, values.len());
assert_eq!(read_values, values);
}
#[test]
fn write_tensor_file_empty() {
let mut buf = Vec::new();
write_tensor_file(&mut buf, 0, &[]).expect("write must succeed");
assert_eq!(buf.len(), HEADER_SIZE);
assert_eq!(&buf[0..4], b"APRT");
}
#[test]
fn write_preserves_nan_verbatim() {
let values: Vec<f32> = vec![f32::NAN, 1.0, f32::INFINITY, f32::NEG_INFINITY];
let mut buf = Vec::new();
write_tensor_file(&mut buf, 0, &values).expect("write must succeed");
let mut cursor = std::io::Cursor::new(&buf);
let (_, read_values) = read_tensor_file(&mut cursor).expect("read must succeed");
assert!(read_values[0].is_nan(), "NaN must be preserved");
assert_eq!(read_values[1], 1.0);
assert!(read_values[2].is_infinite() && read_values[2].is_sign_positive());
assert!(read_values[3].is_infinite() && read_values[3].is_sign_negative());
}
#[test]
fn read_tensor_file_truncated_body() {
let header = write_header(0, 3);
let mut buf = header.to_vec();
buf.extend_from_slice(&1.0_f32.to_le_bytes());
buf.extend_from_slice(&2.0_f32.to_le_bytes());
let mut cursor = std::io::Cursor::new(&buf);
let result = read_tensor_file(&mut cursor);
assert!(result.is_err(), "must error on truncated body");
}
#[test]
fn write_layer_index_max_u32_minus_one() {
let edge = u32::MAX - 1;
let bytes = write_header(edge, 1);
let parsed = parse_header(&bytes).expect("parse must succeed");
assert_eq!(parsed.layer, edge);
assert!(!parsed.is_whole_model());
}
}