1use crate::tensor::Tensor;
43
44const MAGIC: &[u8; 4] = b"CJCT";
45const FORMAT_VERSION: u8 = 1;
46const HEADER_LEN: usize = 16; #[derive(Debug, Clone, PartialEq, Eq)]
50pub enum TensorSnapError {
51 TooShort,
52 BadMagic,
53 BadVersion(u8),
54 TrailingGarbage,
55 BadShape,
56 BadHash { expected: u64, actual: u64 },
57}
58
59impl std::fmt::Display for TensorSnapError {
60 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
61 match self {
62 Self::TooShort => write!(f, "tensor snap: input too short"),
63 Self::BadMagic => write!(f, "tensor snap: bad magic (expected CJCT)"),
64 Self::BadVersion(v) => write!(f, "tensor snap: unsupported version {v}"),
65 Self::TrailingGarbage => write!(f, "tensor snap: trailing garbage after footer"),
66 Self::BadShape => write!(f, "tensor snap: corrupt shape header"),
67 Self::BadHash { expected, actual } => {
68 write!(f, "tensor snap: hash mismatch (expected {expected:#x}, got {actual:#x})")
69 }
70 }
71 }
72}
73
74#[inline]
77fn splitmix64_fold(bytes: &[u8]) -> u64 {
78 let mut state: u64 = 0x9e37_79b9_7f4a_7c15;
79 state ^= bytes.len() as u64;
81 state = mix64(state);
82
83 let mut i = 0;
85 while i + 8 <= bytes.len() {
86 let mut chunk = [0u8; 8];
87 chunk.copy_from_slice(&bytes[i..i + 8]);
88 state ^= u64::from_le_bytes(chunk);
89 state = mix64(state);
90 i += 8;
91 }
92 if i < bytes.len() {
94 let mut chunk = [0u8; 8];
95 for (j, b) in bytes[i..].iter().enumerate() {
96 chunk[j] = *b;
97 }
98 state ^= u64::from_le_bytes(chunk);
99 state = mix64(state);
100 }
101 state
102}
103
104#[inline]
105fn mix64(mut z: u64) -> u64 {
106 z = z.wrapping_add(0x9e37_79b9_7f4a_7c15);
107 z = (z ^ (z >> 30)).wrapping_mul(0xbf58_476d_1ce4_e5b9);
108 z = (z ^ (z >> 27)).wrapping_mul(0x94d0_49bb_1331_11eb);
109 z ^ (z >> 31)
110}
111
112pub fn encode_list(tensors: &[Tensor]) -> Vec<u8> {
115 let mut cap = HEADER_LEN + 8; for t in tensors {
118 cap += 8 + 8 * t.ndim() + 8 * t.shape().iter().product::<usize>();
119 }
120 let mut buf = Vec::with_capacity(cap);
121
122 buf.extend_from_slice(MAGIC);
124 buf.push(FORMAT_VERSION);
125 buf.extend_from_slice(&[0u8; 3]); buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
127
128 for t in tensors {
130 let shape = t.shape();
131 buf.extend_from_slice(&(shape.len() as u64).to_le_bytes());
132 for &d in shape {
133 buf.extend_from_slice(&(d as u64).to_le_bytes());
134 }
135 let data = t.to_vec();
136 for v in &data {
137 buf.extend_from_slice(&v.to_le_bytes());
138 }
139 }
140
141 let hash = splitmix64_fold(&buf);
143 buf.extend_from_slice(&hash.to_le_bytes());
144 buf
145}
146
147pub fn encode_one(tensor: &Tensor) -> Vec<u8> {
149 encode_list(std::slice::from_ref(tensor))
150}
151
152pub fn decode_list(bytes: &[u8]) -> Result<Vec<Tensor>, TensorSnapError> {
155 if bytes.len() < HEADER_LEN + 8 {
156 return Err(TensorSnapError::TooShort);
157 }
158 if &bytes[0..4] != MAGIC {
159 return Err(TensorSnapError::BadMagic);
160 }
161 let version = bytes[4];
162 if version != FORMAT_VERSION {
163 return Err(TensorSnapError::BadVersion(version));
164 }
165
166 let footer_start = bytes.len() - 8;
168 let expected_hash = u64::from_le_bytes([
169 bytes[footer_start],
170 bytes[footer_start + 1],
171 bytes[footer_start + 2],
172 bytes[footer_start + 3],
173 bytes[footer_start + 4],
174 bytes[footer_start + 5],
175 bytes[footer_start + 6],
176 bytes[footer_start + 7],
177 ]);
178 let actual_hash = splitmix64_fold(&bytes[..footer_start]);
179 if expected_hash != actual_hash {
180 return Err(TensorSnapError::BadHash {
181 expected: expected_hash,
182 actual: actual_hash,
183 });
184 }
185
186 let n_tensors = read_u64(bytes, 8)? as usize;
188 let mut cursor = HEADER_LEN;
189 let mut out = Vec::with_capacity(n_tensors);
190
191 for _ in 0..n_tensors {
192 if cursor + 8 > footer_start {
193 return Err(TensorSnapError::TooShort);
194 }
195 let ndim = read_u64(bytes, cursor)? as usize;
196 cursor += 8;
197
198 if ndim > 16 {
200 return Err(TensorSnapError::BadShape);
201 }
202 if cursor + 8 * ndim > footer_start {
203 return Err(TensorSnapError::TooShort);
204 }
205
206 let mut shape = Vec::with_capacity(ndim);
207 for _ in 0..ndim {
208 let d = read_u64(bytes, cursor)? as usize;
209 shape.push(d);
210 cursor += 8;
211 }
212
213 let numel = shape.iter().try_fold(1usize, |acc, &d| acc.checked_mul(d))
215 .ok_or(TensorSnapError::BadShape)?;
216
217 if cursor + 8 * numel > footer_start {
218 return Err(TensorSnapError::TooShort);
219 }
220
221 let mut data = Vec::with_capacity(numel);
222 for _ in 0..numel {
223 let mut chunk = [0u8; 8];
224 chunk.copy_from_slice(&bytes[cursor..cursor + 8]);
225 data.push(f64::from_le_bytes(chunk));
226 cursor += 8;
227 }
228
229 let t = Tensor::from_vec(data, &shape).map_err(|_| TensorSnapError::BadShape)?;
230 out.push(t);
231 }
232
233 if cursor != footer_start {
234 return Err(TensorSnapError::TrailingGarbage);
235 }
236 Ok(out)
237}
238
239pub fn decode_one(bytes: &[u8]) -> Result<Tensor, TensorSnapError> {
242 let list = decode_list(bytes)?;
243 if list.len() != 1 {
244 return Err(TensorSnapError::BadShape);
245 }
246 Ok(list.into_iter().next().unwrap())
247}
248
249pub fn hash_list(tensors: &[Tensor]) -> u64 {
253 let mut state: u64 = 0x243F_6A88_85A3_08D3; state ^= tensors.len() as u64;
255 state = mix64(state);
256 for t in tensors {
257 let shape = t.shape();
258 state ^= shape.len() as u64;
259 state = mix64(state);
260 for &d in shape {
261 state ^= d as u64;
262 state = mix64(state);
263 }
264 let data = t.to_vec();
265 for v in &data {
268 state ^= v.to_bits();
269 state = mix64(state);
270 }
271 }
272 state
273}
274
275fn read_u64(bytes: &[u8], offset: usize) -> Result<u64, TensorSnapError> {
276 if offset + 8 > bytes.len() {
277 return Err(TensorSnapError::TooShort);
278 }
279 let mut chunk = [0u8; 8];
280 chunk.copy_from_slice(&bytes[offset..offset + 8]);
281 Ok(u64::from_le_bytes(chunk))
282}
283
284#[cfg(test)]
289mod tests {
290 use super::*;
291
292 fn t(data: Vec<f64>, shape: &[usize]) -> Tensor {
293 Tensor::from_vec(data, shape).unwrap()
294 }
295
296 #[test]
297 fn empty_list_roundtrips() {
298 let bytes = encode_list(&[]);
299 let out = decode_list(&bytes).unwrap();
300 assert_eq!(out.len(), 0);
301 }
302
303 #[test]
304 fn scalar_tensor_roundtrips() {
305 let a = t(vec![42.0], &[1]);
306 let bytes = encode_one(&a);
307 let b = decode_one(&bytes).unwrap();
308 assert_eq!(b.shape(), &[1]);
309 assert_eq!(b.to_vec(), vec![42.0]);
310 }
311
312 #[test]
313 fn matrix_roundtrips() {
314 let a = t(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
315 let bytes = encode_one(&a);
316 let b = decode_one(&bytes).unwrap();
317 assert_eq!(b.shape(), &[2, 3]);
318 assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
319 }
320
321 #[test]
322 fn multiple_tensors_roundtrip() {
323 let a = t(vec![1.0, 2.0], &[2]);
324 let b = t(vec![3.0, 4.0, 5.0, 6.0], &[2, 2]);
325 let c = t(vec![7.0], &[1, 1]);
326 let bytes = encode_list(&[a.clone(), b.clone(), c.clone()]);
327 let out = decode_list(&bytes).unwrap();
328 assert_eq!(out.len(), 3);
329 assert_eq!(out[0].to_vec(), a.to_vec());
330 assert_eq!(out[1].to_vec(), b.to_vec());
331 assert_eq!(out[2].to_vec(), c.to_vec());
332 }
333
334 #[test]
335 fn encoding_is_deterministic() {
336 let a = t(vec![1.5, -2.5, 3.25], &[3]);
337 let e1 = encode_one(&a);
338 let e2 = encode_one(&a);
339 assert_eq!(e1, e2, "encoding must be byte-identical for the same input");
340 }
341
342 #[test]
343 fn different_tensors_produce_different_encodings() {
344 let a = t(vec![1.0, 2.0], &[2]);
345 let b = t(vec![1.0, 2.0, 3.0], &[3]);
346 assert_ne!(encode_one(&a), encode_one(&b));
347 }
348
349 #[test]
350 fn different_shapes_same_data_produce_different_encodings() {
351 let a = t(vec![1.0, 2.0, 3.0, 4.0], &[4]);
352 let b = t(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
353 assert_ne!(encode_one(&a), encode_one(&b));
354 }
355
356 #[test]
357 fn bad_magic_is_rejected() {
358 let a = t(vec![1.0], &[1]);
359 let mut bytes = encode_one(&a);
360 bytes[0] = b'X';
361 assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadMagic)));
362 }
363
364 #[test]
365 fn bad_version_is_rejected() {
366 let a = t(vec![1.0], &[1]);
367 let mut bytes = encode_one(&a);
368 bytes[4] = 99;
369 assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadVersion(99))));
370 }
371
372 #[test]
373 fn hash_mismatch_is_rejected() {
374 let a = t(vec![1.0, 2.0, 3.0], &[3]);
375 let mut bytes = encode_one(&a);
376 let idx = HEADER_LEN + 8 + 8; bytes[idx] ^= 0xFF;
379 assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadHash { .. })));
380 }
381
382 #[test]
383 fn too_short_is_rejected() {
384 assert!(matches!(decode_list(&[]), Err(TensorSnapError::TooShort)));
385 assert!(matches!(decode_list(&[0u8; 10]), Err(TensorSnapError::TooShort)));
386 }
387
388 #[test]
389 fn hash_list_is_deterministic() {
390 let a = t(vec![1.0, 2.0, 3.0], &[3]);
391 let b = t(vec![4.0, 5.0], &[2]);
392 let h1 = hash_list(&[a.clone(), b.clone()]);
393 let h2 = hash_list(&[a.clone(), b.clone()]);
394 assert_eq!(h1, h2);
395 }
396
397 #[test]
398 fn hash_list_is_order_sensitive() {
399 let a = t(vec![1.0, 2.0, 3.0], &[3]);
400 let b = t(vec![4.0, 5.0], &[2]);
401 let h1 = hash_list(&[a.clone(), b.clone()]);
402 let h2 = hash_list(&[b, a]);
403 assert_ne!(h1, h2, "hash must change when order changes");
404 }
405
406 #[test]
407 fn hash_list_distinguishes_shapes_and_data() {
408 let a = t(vec![1.0, 2.0, 3.0, 4.0], &[4]);
409 let b = t(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]);
410 let c = t(vec![1.0, 2.0, 3.0, 5.0], &[4]);
411 assert_ne!(hash_list(&[a.clone()]), hash_list(&[b]));
412 assert_ne!(hash_list(&[a]), hash_list(&[c]));
413 }
414
415 #[test]
416 fn pathological_ndim_is_rejected() {
417 let mut bytes = Vec::new();
420 bytes.extend_from_slice(MAGIC);
421 bytes.push(FORMAT_VERSION);
422 bytes.extend_from_slice(&[0u8; 3]);
423 bytes.extend_from_slice(&1u64.to_le_bytes()); bytes.extend_from_slice(&1000u64.to_le_bytes()); let hash = splitmix64_fold(&bytes);
426 bytes.extend_from_slice(&hash.to_le_bytes());
427 assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadShape)));
428 }
429
430 #[test]
431 fn shape_overflow_is_rejected() {
432 let mut bytes = Vec::new();
434 bytes.extend_from_slice(MAGIC);
435 bytes.push(FORMAT_VERSION);
436 bytes.extend_from_slice(&[0u8; 3]);
437 bytes.extend_from_slice(&1u64.to_le_bytes());
438 bytes.extend_from_slice(&2u64.to_le_bytes()); bytes.extend_from_slice(&u64::MAX.to_le_bytes());
440 bytes.extend_from_slice(&2u64.to_le_bytes());
441 let hash = splitmix64_fold(&bytes);
442 bytes.extend_from_slice(&hash.to_le_bytes());
443 assert!(matches!(decode_list(&bytes), Err(TensorSnapError::BadShape)));
444 }
445}