Skip to main content

cjc_runtime/
tensor_snap.rs

1//! Deterministic binary serialization for tensors and tensor lists.
2//!
3//! This module provides a small, self-contained, byte-stable format for
4//! saving and loading f64 tensors. It lives in `cjc-runtime` (not
5//! `cjc-snap`) because `cjc-snap` already depends on `cjc-runtime`, so
6//! adding the reverse dependency would create a cycle.
7//!
8//! # Wire format
9//!
10//! All integers are little-endian, all floats are IEEE-754 f64 little-endian.
11//!
12//! ```text
13//! offset  bytes  field
14//! ------  -----  ------------------------------------------------------
15//!      0      4  magic = b"CJCT"
16//!      4      1  format version = 1
17//!      5      3  reserved = 0
18//!      8      8  n_tensors : u64
19//!     16      -  tensor[0], tensor[1], ...
20//!
21//! tensor layout:
22//!      0      8  ndim : u64
23//!      8  ndim*8  shape[0..ndim] : u64[]
24//!      *  numel*8  data : f64[]      (row-major, contiguous)
25//!
26//! after the last tensor:
27//!      0      8  footer_hash : u64  (SplitMix64 fold of all preceding bytes)
28//! ```
29//!
30//! The footer hash lets a reader cheaply detect corruption and lets tests
31//! assert byte-identity across executors without parsing the payload.
32//!
33//! Determinism properties:
34//! - Tensors are materialized with `to_vec()` so stride/offset views are
35//!   flattened to a canonical row-major order before serialization.
36//! - NaN bit patterns are preserved as-written (we do not canonicalize);
37//!   CJC-Lang produces NaN only through explicit `0.0 / 0.0`-style paths,
38//!   and the determinism contract never introduces spurious NaNs.
39//! - The footer hash is a SplitMix64 fold, which is order-sensitive and
40//!   deterministic across platforms (pure integer arithmetic, no FP).
41
42use crate::tensor::Tensor;
43
44const MAGIC: &[u8; 4] = b"CJCT";
45const FORMAT_VERSION: u8 = 1;
46const HEADER_LEN: usize = 16; // magic(4) + ver(1) + reserved(3) + n_tensors(8)
47
48/// Errors returned by the tensor snap codec.
49#[derive(Debug, Clone, PartialEq, Eq)]
50pub enum TensorSnapError {
51    TooShort,
52    BadMagic,
53    BadVersion(u8),
54    TrailingGarbage,
55    BadShape,
56    BadHash { expected: u64, actual: u64 },
57}
58
59impl std::fmt::Display for TensorSnapError {
60    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61        match self {
62            Self::TooShort => write!(f, "tensor snap: input too short"),
63            Self::BadMagic => write!(f, "tensor snap: bad magic (expected CJCT)"),
64            Self::BadVersion(v) => write!(f, "tensor snap: unsupported version {v}"),
65            Self::TrailingGarbage => write!(f, "tensor snap: trailing garbage after footer"),
66            Self::BadShape => write!(f, "tensor snap: corrupt shape header"),
67            Self::BadHash { expected, actual } => {
68                write!(f, "tensor snap: hash mismatch (expected {expected:#x}, got {actual:#x})")
69            }
70        }
71    }
72}
73
74/// SplitMix64 folding hash. Pure integer arithmetic, deterministic across
75/// platforms. Not cryptographic — used only for integrity/parity checks.
76#[inline]
77fn splitmix64_fold(bytes: &[u8]) -> u64 {
78    let mut state: u64 = 0x9e37_79b9_7f4a_7c15;
79    // Mix length first so `[]` hashes differently from `[0x00 * 0]`.
80    state ^= bytes.len() as u64;
81    state = mix64(state);
82
83    // Fold 8 bytes at a time.
84    let mut i = 0;
85    while i + 8 <= bytes.len() {
86        let mut chunk = [0u8; 8];
87        chunk.copy_from_slice(&bytes[i..i + 8]);
88        state ^= u64::from_le_bytes(chunk);
89        state = mix64(state);
90        i += 8;
91    }
92    // Fold the trailing tail (0..7 bytes).
93    if i < bytes.len() {
94        let mut chunk = [0u8; 8];
95        for (j, b) in bytes[i..].iter().enumerate() {
96            chunk[j] = *b;
97        }
98        state ^= u64::from_le_bytes(chunk);
99        state = mix64(state);
100    }
101    state
102}
103
104#[inline]
105fn mix64(mut z: u64) -> u64 {
106    z = z.wrapping_add(0x9e37_79b9_7f4a_7c15);
107    z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
108    z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
109    z ^ (z >> 31)
110}
111
112/// Encode a list of tensors into the wire format. Infallible for any
113/// well-formed `Tensor` (materializes views via `to_vec`).
114pub fn encode_list(tensors: &[Tensor]) -> Vec<u8> {
115    // Estimate capacity to avoid reallocation.
116    let mut cap = HEADER_LEN + 8; // + footer
117    for t in tensors {
118        cap += 8 + 8 * t.ndim() + 8 * t.shape().iter().product::<usize>();
119    }
120    let mut buf = Vec::with_capacity(cap);
121
122    // Header
123    buf.extend_from_slice(MAGIC);
124    buf.push(FORMAT_VERSION);
125    buf.extend_from_slice(&[0u8; 3]); // reserved
126    buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
127
128    // Tensors
129    for t in tensors {
130        let shape = t.shape();
131        buf.extend_from_slice(&(shape.len() as u64).to_le_bytes());
132        for &d in shape {
133            buf.extend_from_slice(&(d as u64).to_le_bytes());
134        }
135        let data = t.to_vec();
136        for v in &data {
137            buf.extend_from_slice(&v.to_le_bytes());
138        }
139    }
140
141    // Footer hash over everything so far.
142    let hash = splitmix64_fold(&buf);
143    buf.extend_from_slice(&hash.to_le_bytes());
144    buf
145}
146
147/// Encode a single tensor (convenience wrapper).
148pub fn encode_one(tensor: &Tensor) -> Vec<u8> {
149    encode_list(std::slice::from_ref(tensor))
150}
151
152/// Decode a list of tensors from the wire format. Verifies magic, version,
153/// and footer hash.
154pub fn decode_list(bytes: &[u8]) -> Result<Vec<Tensor>, TensorSnapError> {
155    if bytes.len() < HEADER_LEN + 8 {
156        return Err(TensorSnapError::TooShort);
157    }
158    if &bytes[0..4] != MAGIC {
159        return Err(TensorSnapError::BadMagic);
160    }
161    let version = bytes[4];
162    if version != FORMAT_VERSION {
163        return Err(TensorSnapError::BadVersion(version));
164    }
165
166    // Verify footer hash over everything except the last 8 bytes.
167    let footer_start = bytes.len() - 8;
168    let expected_hash = u64::from_le_bytes([
169        bytes[footer_start],
170        bytes[footer_start + 1],
171        bytes[footer_start + 2],
172        bytes[footer_start + 3],
173        bytes[footer_start + 4],
174        bytes[footer_start + 5],
175        bytes[footer_start + 6],
176        bytes[footer_start + 7],
177    ]);
178    let actual_hash = splitmix64_fold(&bytes[..footer_start]);
179    if expected_hash != actual_hash {
180        return Err(TensorSnapError::BadHash {
181            expected: expected_hash,
182            actual: actual_hash,
183        });
184    }
185
186    // n_tensors
187    let n_tensors = read_u64(bytes, 8)? as usize;
188    let mut cursor = HEADER_LEN;
189    let mut out = Vec::with_capacity(n_tensors);
190
191    for _ in 0..n_tensors {
192        if cursor + 8 > footer_start {
193            return Err(TensorSnapError::TooShort);
194        }
195        let ndim = read_u64(bytes, cursor)? as usize;
196        cursor += 8;
197
198        // Cap ndim to a sane value to avoid pathological allocation.
199        if ndim > 16 {
200            return Err(TensorSnapError::BadShape);
201        }
202        if cursor + 8 * ndim > footer_start {
203            return Err(TensorSnapError::TooShort);
204        }
205
206        let mut shape = Vec::with_capacity(ndim);
207        for _ in 0..ndim {
208            let d = read_u64(bytes, cursor)? as usize;
209            shape.push(d);
210            cursor += 8;
211        }
212
213        // Guard against shape-overflow DoS.
214        let numel = shape.iter().try_fold(1usize, |acc, &d| acc.checked_mul(d))
215            .ok_or(TensorSnapError::BadShape)?;
216
217        if cursor + 8 * numel > footer_start {
218            return Err(TensorSnapError::TooShort);
219        }
220
221        let mut data = Vec::with_capacity(numel);
222        for _ in 0..numel {
223            let mut chunk = [0u8; 8];
224            chunk.copy_from_slice(&bytes[cursor..cursor + 8]);
225            data.push(f64::from_le_bytes(chunk));
226            cursor += 8;
227        }
228
229        let t = Tensor::from_vec(data, &shape).map_err(|_| TensorSnapError::BadShape)?;
230        out.push(t);
231    }
232
233    if cursor != footer_start {
234        return Err(TensorSnapError::TrailingGarbage);
235    }
236    Ok(out)
237}
238
239/// Decode a single tensor from the wire format. Errors if the payload
240/// contains zero or more than one tensor.
241pub fn decode_one(bytes: &[u8]) -> Result<Tensor, TensorSnapError> {
242    let list = decode_list(bytes)?;
243    if list.len() != 1 {
244        return Err(TensorSnapError::BadShape);
245    }
246    Ok(list.into_iter().next().unwrap())
247}
248
249/// Deterministic content hash of a tensor list. Separate from the wire
250/// format's footer hash — this one hashes *logical* content (shape + data)
251/// only, so it is invariant under re-encoding.
252pub fn hash_list(tensors: &[Tensor]) -> u64 {
253    let mut state: u64 = 0x243F_6A88_85A3_08D3; // arbitrary constant
254    state ^= tensors.len() as u64;
255    state = mix64(state);
256    for t in tensors {
257        let shape = t.shape();
258        state ^= shape.len() as u64;
259        state = mix64(state);
260        for &d in shape {
261            state ^= d as u64;
262            state = mix64(state);
263        }
264        let data = t.to_vec();
265        // Fold the f64 bits (NaN-bit-preserving — determinism relies on the
266        // caller not producing garbage NaN bit patterns).
267        for v in &data {
268            state ^= v.to_bits();
269            state = mix64(state);
270        }
271    }
272    state
273}
274
275fn read_u64(bytes: &[u8], offset: usize) -> Result<u64, TensorSnapError> {
276    if offset + 8 > bytes.len() {
277        return Err(TensorSnapError::TooShort);
278    }
279    let mut chunk = [0u8; 8];
280    chunk.copy_from_slice(&bytes[offset..offset + 8]);
281    Ok(u64::from_le_bytes(chunk))
282}
283
284// =============================================================================
285// Unit tests
286// =============================================================================
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291
292    fn t(data: Vec<f64>, shape: &[usize]) -> Tensor {
293        Tensor::from_vec(data, shape).unwrap()
294    }
295
296    #[test]
297    fn empty_list_roundtrips() {
298        let bytes = encode_list(&[]);
299        let out = decode_list(&bytes).unwrap();
300        assert_eq!(out.len(), 0);
301    }
302
303    #[test]
304    fn scalar_tensor_roundtrips() {
305        let a = t(vec![42.0], &[1]);
306        let bytes = encode_one(&a);
307        let b = decode_one(&bytes).unwrap();
308        assert_eq!(b.shape(), &[1]);
309        assert_eq!(b.to_vec(), vec![42.0]);
310    }
311
312    #[test]
313    fn matrix_roundtrips() {
314        let a = t(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
315        let bytes = encode_one(&a);
316        let b = decode_one(&bytes).unwrap();
317        assert_eq!(b.shape(), &[2, 3]);
318        assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
319    }
320
321    #[test]
322    fn multiple_tensors_roundtrip() {
323        let a = t(vec![1.0, 2.0], &[2]);
324        let b = t(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]);
325        let c = t(vec![7.0], &[1, 1]);
326        let bytes = encode_list(&[a.clone(), b.clone(), c.clone()]);
327        let out = decode_list(&bytes).unwrap();
328        assert_eq!(out.len(), 3);
329        assert_eq!(out[0].to_vec(), a.to_vec());
330        assert_eq!(out[1].to_vec(), b.to_vec());
331        assert_eq!(out[2].to_vec(), c.to_vec());
332    }
333
334    #[test]
335    fn encoding_is_deterministic() {
336        let a = t(vec![1.5, -2.5, 3.25], &[3]);
337        let e1 = encode_one(&a);
338        let e2 = encode_one(&a);
339        assert_eq!(e1, e2, "encoding must be byte-identical for the same input");
340    }
341
342    #[test]
343    fn different_tensors_produce_different_encodings() {
344        let a = t(vec![1.0, 2.0], &[2]);
345        let b = t(vec![1.0, 2.0, 3.0], &[3]);
346        assert_ne!(encode_one(&a), encode_one(&b));
347    }
348
349    #[test]
350    fn different_shapes_same_data_produce_different_encodings() {
351        let a = t(vec![1.0, 2.0, 3.0, 4.0], &[4]);
352        let b = t(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
353        assert_ne!(encode_one(&a), encode_one(&b));
354    }
355
356    #[test]
357    fn bad_magic_is_rejected() {
358        let a = t(vec![1.0], &[1]);
359        let mut bytes = encode_one(&a);
360        bytes[0] = b'X';
361        assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadMagic)));
362    }
363
364    #[test]
365    fn bad_version_is_rejected() {
366        let a = t(vec![1.0], &[1]);
367        let mut bytes = encode_one(&a);
368        bytes[4] = 99;
369        assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadVersion(99))));
370    }
371
372    #[test]
373    fn hash_mismatch_is_rejected() {
374        let a = t(vec![1.0, 2.0, 3.0], &[3]);
375        let mut bytes = encode_one(&a);
376        // Flip one data byte
377        let idx = HEADER_LEN + 8 + 8; // past ndim + shape[0]
378        bytes[idx] ^= 0xFF;
379        assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadHash { .. })));
380    }
381
382    #[test]
383    fn too_short_is_rejected() {
384        assert!(matches!(decode_list(&[]), Err(TensorSnapError::TooShort)));
385        assert!(matches!(decode_list(&[0u8; 10]), Err(TensorSnapError::TooShort)));
386    }
387
388    #[test]
389    fn hash_list_is_deterministic() {
390        let a = t(vec![1.0, 2.0, 3.0], &[3]);
391        let b = t(vec![4.0, 5.0], &[2]);
392        let h1 = hash_list(&[a.clone(), b.clone()]);
393        let h2 = hash_list(&[a.clone(), b.clone()]);
394        assert_eq!(h1, h2);
395    }
396
397    #[test]
398    fn hash_list_is_order_sensitive() {
399        let a = t(vec![1.0, 2.0, 3.0], &[3]);
400        let b = t(vec![4.0, 5.0], &[2]);
401        let h1 = hash_list(&[a.clone(), b.clone()]);
402        let h2 = hash_list(&[b, a]);
403        assert_ne!(h1, h2, "hash must change when order changes");
404    }
405
406    #[test]
407    fn hash_list_distinguishes_shapes_and_data() {
408        let a = t(vec![1.0, 2.0, 3.0, 4.0], &[4]);
409        let b = t(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
410        let c = t(vec![1.0, 2.0, 3.0, 5.0], &[4]);
411        assert_ne!(hash_list(&[a.clone()]), hash_list(&[b]));
412        assert_ne!(hash_list(&[a]), hash_list(&[c]));
413    }
414
415    #[test]
416    fn pathological_ndim_is_rejected() {
417        // Construct a blob with ndim=1000 in the header — should error out
418        // before attempting to allocate.
419        let mut bytes = Vec::new();
420        bytes.extend_from_slice(MAGIC);
421        bytes.push(FORMAT_VERSION);
422        bytes.extend_from_slice(&[0u8; 3]);
423        bytes.extend_from_slice(&1u64.to_le_bytes()); // n_tensors
424        bytes.extend_from_slice(&1000u64.to_le_bytes()); // ndim = 1000
425        let hash = splitmix64_fold(&bytes);
426        bytes.extend_from_slice(&hash.to_le_bytes());
427        assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadShape)));
428    }
429
430    #[test]
431    fn shape_overflow_is_rejected() {
432        // ndim=2, shape=[usize::MAX, 2] — multiplying should overflow.
433        let mut bytes = Vec::new();
434        bytes.extend_from_slice(MAGIC);
435        bytes.push(FORMAT_VERSION);
436        bytes.extend_from_slice(&[0u8; 3]);
437        bytes.extend_from_slice(&1u64.to_le_bytes());
438        bytes.extend_from_slice(&2u64.to_le_bytes()); // ndim
439        bytes.extend_from_slice(&u64::MAX.to_le_bytes());
440        bytes.extend_from_slice(&2u64.to_le_bytes());
441        let hash = splitmix64_fold(&bytes);
442        bytes.extend_from_slice(&hash.to_le_bytes());
443        assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadShape)));
444    }
445}