Skip to main content

cjc_snap/
lib.rs

1//! Content-addressable serialization for CJC values.
2//!
3//! `cjc-snap` provides deterministic binary encoding of `Value` types with
4//! SHA-256 content hashing. Zero external dependencies -- the SHA-256
5//! implementation is hand-rolled following FIPS 180-4.
6//!
7//! # Overview
8//!
9//! ```text
10//!   Value ──snap_encode──> bytes ──sha256──> hash
11//!          ◄──snap_decode──
12//! ```
13//!
14//! The high-level API is `snap()` / `restore()`:
15//!
16//! ```ignore
17//! let blob = cjc_snap::snap(&value);
18//! let restored = cjc_snap::restore(&blob).unwrap();
19//! ```
20
21pub mod hash;
22pub mod encode;
23pub mod decode;
24pub mod json;
25pub mod persist;
26
27pub use hash::sha256;
28pub use encode::snap_encode;
29pub use encode::snap_encode_v2;
30pub use encode::{encode_typed_tensor, encode_chunked_tensor, encode_sparse_csr, encode_categorical, encode_schema, encode_dataframe};
31pub use encode::{DataFrameColumnData, DEFAULT_CHUNK_SIZE};
32pub use encode::{
33    TAG_TYPED_TENSOR, TAG_CHUNKED_TENSOR, TAG_SPARSE_CSR, TAG_CATEGORICAL,
34    TAG_SCHEMA, TAG_DATAFRAME, SNAP_MAGIC, SNAP_VERSION,
35    COL_TYPE_INT, COL_TYPE_FLOAT, COL_TYPE_STR, COL_TYPE_BOOL,
36    COL_TYPE_CATEGORICAL, COL_TYPE_DATETIME,
37};
38pub use decode::snap_decode;
39pub use decode::snap_decode_v2;
40pub use json::snap_to_json;
41pub use persist::{snap_save, snap_save_v2, snap_load};
42
43use cjc_runtime::Value;
44use std::fmt;
45
46// ---------------------------------------------------------------------------
47// SnapError
48// ---------------------------------------------------------------------------
49
50/// Errors that can occur during snap decoding or integrity verification.
51///
52/// Returned by [`snap_decode`], [`snap_decode_v2`], [`restore`], and
53/// [`restore_v2`] when the binary payload is malformed or tampered with.
54#[derive(Debug)]
55pub enum SnapError {
56    /// The tag byte does not correspond to any known [`Value`] variant.
57    ///
58    /// Contains the unrecognized tag byte.
59    InvalidTag(u8),
60    /// The byte stream ended before the value was fully decoded.
61    UnexpectedEof,
62    /// A string field contained invalid UTF-8.
63    Utf8Error,
64    /// The SHA-256 hash of the data does not match the stored content hash.
65    ///
66    /// Indicates data corruption or tampering. Contains both the expected
67    /// and actual 32-byte digests.
68    HashMismatch {
69        /// The hash stored in the [`SnapBlob`] header or chunk header.
70        expected: [u8; 32],
71        /// The hash recomputed from the actual data bytes.
72        actual: [u8; 32],
73    },
74}
75
76impl fmt::Display for SnapError {
77    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78        match self {
79            SnapError::InvalidTag(tag) => write!(f, "invalid tag byte: 0x{:02x}", tag),
80            SnapError::UnexpectedEof => write!(f, "unexpected end of data"),
81            SnapError::Utf8Error => write!(f, "invalid UTF-8 in string field"),
82            SnapError::HashMismatch { expected, actual } => {
83                write!(
84                    f,
85                    "hash mismatch: expected {}, got {}",
86                    hash::hex_string(expected),
87                    hash::hex_string(actual)
88                )
89            }
90        }
91    }
92}
93
94// ---------------------------------------------------------------------------
95// SnapBlob
96// ---------------------------------------------------------------------------
97
98/// A content-addressable blob: the canonical binary encoding of a [`Value`]
99/// together with its SHA-256 digest.
100///
101/// Two logically equal values always produce identical `SnapBlob`s, making
102/// `content_hash` usable as a cache key or deduplication fingerprint.
103///
104/// Construct via [`snap`] (v1) or [`snap_v2`] (v2 format), and decode back
105/// via [`restore`] or [`restore_v2`].
106#[derive(Debug, Clone)]
107pub struct SnapBlob {
108    /// SHA-256 hash of [`data`](Self::data), used for integrity verification
109    /// and content-addressing.
110    pub content_hash: [u8; 32],
111    /// Canonical binary encoding of the value (v1 or v2 format).
112    pub data: Vec<u8>,
113}
114
115// ---------------------------------------------------------------------------
116// High-level API
117// ---------------------------------------------------------------------------
118
119/// Encode a [`Value`] into a content-addressable [`SnapBlob`] (v1 format).
120///
121/// The blob contains the canonical binary encoding produced by
122/// [`snap_encode`] and its SHA-256 hash computed by [`sha256`].
123/// Two values that are logically equal will always produce blobs with the
124/// same `content_hash`, regardless of `HashMap` iteration order or `Rc`
125/// identity.
126///
127/// # Arguments
128///
129/// * `value` - The [`Value`] to encode. Must be snap-encodable (see
130///   [`is_snappable`]).
131///
132/// # Panics
133///
134/// Panics if `value` contains a runtime-only variant that cannot be
135/// serialized (e.g., `Fn`, `Closure`).
136pub fn snap(value: &Value) -> SnapBlob {
137    let data = snap_encode(value);
138    let content_hash = sha256(&data);
139    SnapBlob { content_hash, data }
140}
141
142/// Encode a [`Value`] into a content-addressable [`SnapBlob`] (v2 format with header).
143///
144/// Uses the v2 format: `[MAGIC][version][flags][payload...]`.
145/// Supports all new tags (typed tensors, sparse CSR, categorical, chunked
146/// tensors, DataFrames, etc.).
147///
148/// # Arguments
149///
150/// * `value` - The [`Value`] to encode. Must be snap-encodable (see
151///   [`is_snappable`]).
152///
153/// # Panics
154///
155/// Panics if `value` contains a runtime-only variant that cannot be
156/// serialized.
157pub fn snap_v2(value: &Value) -> SnapBlob {
158    let data = snap_encode_v2(value);
159    let content_hash = sha256(&data);
160    SnapBlob { content_hash, data }
161}
162
163/// Restore a [`Value`] from a [`SnapBlob`], verifying data integrity.
164///
165/// Recomputes the SHA-256 of `blob.data` and compares it against the
166/// stored `blob.content_hash` before decoding. Uses the v1 decoder.
167///
168/// # Arguments
169///
170/// * `blob` - The [`SnapBlob`] to decode.
171///
172/// # Errors
173///
174/// Returns [`SnapError::HashMismatch`] if the SHA-256 of `blob.data` does
175/// not match `blob.content_hash`, or a decoding error if the payload is
176/// malformed.
177pub fn restore(blob: &SnapBlob) -> Result<Value, SnapError> {
178    // Verify integrity
179    let actual_hash = sha256(&blob.data);
180    if actual_hash != blob.content_hash {
181        return Err(SnapError::HashMismatch {
182            expected: blob.content_hash,
183            actual: actual_hash,
184        });
185    }
186    snap_decode(&blob.data)
187}
188
189/// Restore a [`Value`] from a [`SnapBlob`], auto-detecting v1 or v2 format.
190///
191/// Verifies SHA-256 integrity first, then inspects the payload for the v2
192/// magic header (`CJS\x01`). Falls back to v1 decoding when the magic is
193/// absent.
194///
195/// # Errors
196///
197/// Returns [`SnapError::HashMismatch`] if the SHA-256 of `blob.data` does
198/// not match `blob.content_hash`, or a decoding error if the payload is
199/// malformed.
200pub fn restore_v2(blob: &SnapBlob) -> Result<Value, SnapError> {
201    let actual_hash = sha256(&blob.data);
202    if actual_hash != blob.content_hash {
203        return Err(SnapError::HashMismatch {
204            expected: blob.content_hash,
205            actual: actual_hash,
206        });
207    }
208    snap_decode_v2(&blob.data)
209}
210
211/// Check whether a [`Value`] can be snap-encoded without panicking.
212///
213/// Returns `true` for data-bearing variants (scalars, tensors, arrays,
214/// structs, enums, maps, bytes, and sparse tensors) and recursively
215/// checks container contents. Returns `false` for runtime-only variants
216/// (`Fn`, `Closure`, `ClassRef`, `GradGraph`, `OptimizerState`, etc.)
217/// that cannot be meaningfully serialized.
218///
219/// # Arguments
220///
221/// * `value` - The [`Value`] to test for serializability.
222///
223/// # Returns
224///
225/// `true` when [`snap`] or [`snap_v2`] can encode the value without
226/// panicking; `false` otherwise.
227pub fn is_snappable(value: &Value) -> bool {
228    match value {
229        Value::Void | Value::Na | Value::Int(_) | Value::Float(_) | Value::Bool(_)
230        | Value::String(_) | Value::U8(_) | Value::Bytes(_)
231        | Value::ByteSlice(_) | Value::StrView(_) | Value::Bf16(_)
232        | Value::F16(_) | Value::Complex(_) | Value::Tensor(_)
233        | Value::SparseTensor(_) => true,
234        Value::Array(arr) => arr.iter().all(is_snappable),
235        Value::Tuple(elems) => elems.iter().all(is_snappable),
236        Value::Struct { fields, .. } => fields.values().all(is_snappable),
237        Value::Enum { fields, .. } => fields.iter().all(is_snappable),
238        Value::Map(m) => {
239            let map = m.borrow();
240            let result = map.iter().all(|(k, v)| is_snappable(k) && is_snappable(v));
241            result
242        }
243        // Runtime-only: not serializable
244        _ => false,
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251    use cjc_runtime::{Tensor, SparseCsr};
252    use std::collections::BTreeMap;
253    use std::rc::Rc;
254
255    // -- Existing v1 tests --
256
257    #[test]
258    fn test_snap_restore_int() {
259        let original = Value::Int(42);
260        let blob = snap(&original);
261        let restored = restore(&blob).unwrap();
262        match restored {
263            Value::Int(v) => assert_eq!(v, 42),
264            _ => panic!("expected Int"),
265        }
266    }
267
268    #[test]
269    fn test_snap_restore_string() {
270        let original = Value::String(Rc::new("hello".to_string()));
271        let blob = snap(&original);
272        let restored = restore(&blob).unwrap();
273        match restored {
274            Value::String(s) => assert_eq!(s.as_str(), "hello"),
275            _ => panic!("expected String"),
276        }
277    }
278
279    #[test]
280    fn test_snap_restore_tensor() {
281        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
282        let original = Value::Tensor(t);
283        let blob = snap(&original);
284        let restored = restore(&blob).unwrap();
285        match restored {
286            Value::Tensor(t) => {
287                assert_eq!(t.shape(), &[2, 2]);
288                assert_eq!(t.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
289            }
290            _ => panic!("expected Tensor"),
291        }
292    }
293
294    #[test]
295    fn test_snap_restore_nested() {
296        let mut fields = BTreeMap::new();
297        fields.insert("x".to_string(), Value::Int(1));
298        fields.insert("data".to_string(), Value::Array(Rc::new(vec![
299            Value::Float(1.0),
300            Value::Float(2.0),
301        ])));
302        let original = Value::Struct {
303            name: "Complex".to_string(),
304            fields,
305        };
306        let blob = snap(&original);
307        let restored = restore(&blob).unwrap();
308        match restored {
309            Value::Struct { name, fields } => {
310                assert_eq!(name, "Complex");
311                assert_eq!(fields.len(), 2);
312            }
313            _ => panic!("expected Struct"),
314        }
315    }
316
317    #[test]
318    fn test_content_addressable_same_value() {
319        let v1 = Value::Int(42);
320        let v2 = Value::Int(42);
321        let blob1 = snap(&v1);
322        let blob2 = snap(&v2);
323        assert_eq!(blob1.content_hash, blob2.content_hash);
324        assert_eq!(blob1.data, blob2.data);
325    }
326
327    #[test]
328    fn test_content_addressable_different_values() {
329        let v1 = Value::Int(42);
330        let v2 = Value::Int(43);
331        let blob1 = snap(&v1);
332        let blob2 = snap(&v2);
333        assert_ne!(blob1.content_hash, blob2.content_hash);
334    }
335
336    #[test]
337    fn test_struct_determinism() {
338        let mut f1 = BTreeMap::new();
339        f1.insert("b".to_string(), Value::Int(2));
340        f1.insert("a".to_string(), Value::Int(1));
341        f1.insert("c".to_string(), Value::Int(3));
342
343        let mut f2 = BTreeMap::new();
344        f2.insert("c".to_string(), Value::Int(3));
345        f2.insert("a".to_string(), Value::Int(1));
346        f2.insert("b".to_string(), Value::Int(2));
347
348        let blob1 = snap(&Value::Struct { name: "S".to_string(), fields: f1 });
349        let blob2 = snap(&Value::Struct { name: "S".to_string(), fields: f2 });
350        assert_eq!(blob1.content_hash, blob2.content_hash);
351    }
352
353    #[test]
354    fn test_hash_mismatch_detection() {
355        let blob = snap(&Value::Int(42));
356        let tampered = SnapBlob {
357            content_hash: blob.content_hash,
358            data: snap_encode(&Value::Int(999)),
359        };
360        let result = restore(&tampered);
361        assert!(matches!(result, Err(SnapError::HashMismatch { .. })));
362    }
363
364    #[test]
365    fn test_snap_void() {
366        let blob = snap(&Value::Void);
367        assert_eq!(blob.data, vec![0x00]);
368        let restored = restore(&blob).unwrap();
369        assert!(matches!(restored, Value::Void));
370    }
371
372    #[test]
373    fn test_hex_display() {
374        let blob = snap(&Value::Int(0));
375        let hex = hash::hex_string(&blob.content_hash);
376        assert_eq!(hex.len(), 64, "SHA-256 hex should be 64 chars");
377    }
378
379    // -- v2 high-level API tests --
380
381    #[test]
382    fn test_snap_v2_roundtrip_int() {
383        let blob = snap_v2(&Value::Int(123));
384        // v2 data starts with magic header
385        assert_eq!(&blob.data[0..4], SNAP_MAGIC);
386        assert_eq!(blob.data[4], SNAP_VERSION);
387        let restored = restore_v2(&blob).unwrap();
388        assert!(matches!(restored, Value::Int(123)));
389    }
390
391    #[test]
392    fn test_snap_v2_roundtrip_tensor() {
393        let t = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]).unwrap();
394        let blob = snap_v2(&Value::Tensor(t));
395        let restored = restore_v2(&blob).unwrap();
396        match restored {
397            Value::Tensor(t) => {
398                assert_eq!(t.shape(), &[2, 3]);
399                assert_eq!(t.to_vec(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
400            }
401            _ => panic!("expected Tensor"),
402        }
403    }
404
405    #[test]
406    fn test_snap_v2_restore_falls_back_to_v1() {
407        // v1 blob should be decodable by restore_v2
408        let blob = snap(&Value::Int(42));
409        let restored = restore_v2(&blob).unwrap();
410        assert!(matches!(restored, Value::Int(42)));
411    }
412
413    // -- SparseTensor roundtrip tests --
414
415    #[test]
416    fn test_snap_sparse_tensor_roundtrip() {
417        let sparse = SparseCsr {
418            nrows: 3,
419            ncols: 3,
420            row_offsets: vec![0, 1, 2, 3],
421            col_indices: vec![0, 1, 2],
422            values: vec![1.0, 2.0, 3.0],
423        };
424        let blob = snap(&Value::SparseTensor(sparse));
425        let restored = restore(&blob).unwrap();
426        match restored {
427            Value::SparseTensor(s) => {
428                assert_eq!(s.nrows, 3);
429                assert_eq!(s.ncols, 3);
430                assert_eq!(s.row_offsets, vec![0, 1, 2, 3]);
431                assert_eq!(s.col_indices, vec![0, 1, 2]);
432                assert_eq!(s.values, vec![1.0, 2.0, 3.0]);
433            }
434            _ => panic!("expected SparseTensor"),
435        }
436    }
437
438    #[test]
439    fn test_snap_sparse_empty() {
440        let sparse = SparseCsr {
441            nrows: 2,
442            ncols: 2,
443            row_offsets: vec![0, 0, 0],
444            col_indices: vec![],
445            values: vec![],
446        };
447        let blob = snap(&Value::SparseTensor(sparse));
448        let restored = restore(&blob).unwrap();
449        match restored {
450            Value::SparseTensor(s) => {
451                assert_eq!(s.nrows, 2);
452                assert_eq!(s.ncols, 2);
453                assert_eq!(s.values.len(), 0);
454            }
455            _ => panic!("expected SparseTensor"),
456        }
457    }
458
459    #[test]
460    fn test_is_snappable_sparse() {
461        let sparse = SparseCsr {
462            nrows: 1,
463            ncols: 1,
464            row_offsets: vec![0, 1],
465            col_indices: vec![0],
466            values: vec![1.0],
467        };
468        assert!(is_snappable(&Value::SparseTensor(sparse)));
469    }
470
471    // -- Chunked tensor tests --
472
473    #[test]
474    fn test_chunked_tensor_roundtrip_small() {
475        // Small tensor that fits in one chunk
476        let raw_bytes: Vec<u8> = (0..24u8).collect(); // 3 f64s worth
477        let mut buf = Vec::new();
478        encode_chunked_tensor(0, &[3], &raw_bytes, 1024, &mut buf);
479        let decoded = snap_decode(&buf).unwrap();
480        match decoded {
481            Value::Tensor(t) => {
482                assert_eq!(t.shape(), &[3]);
483                assert_eq!(t.to_vec().len(), 3);
484            }
485            _ => panic!("expected Tensor"),
486        }
487    }
488
489    #[test]
490    fn test_chunked_tensor_roundtrip_multi_chunk() {
491        // Force multiple chunks with small chunk size
492        let n = 100;
493        let mut raw_bytes = Vec::with_capacity(n * 8);
494        for i in 0..n {
495            raw_bytes.extend_from_slice(&(i as f64).to_bits().to_le_bytes());
496        }
497        let mut buf = Vec::new();
498        // 64 bytes per chunk = 8 f64s per chunk, so 13 chunks for 100 elements
499        encode_chunked_tensor(0, &[n], &raw_bytes, 64, &mut buf);
500        let decoded = snap_decode(&buf).unwrap();
501        match decoded {
502            Value::Tensor(t) => {
503                assert_eq!(t.shape(), &[n]);
504                for i in 0..n {
505                    assert_eq!(t.to_vec()[i], i as f64);
506                }
507            }
508            _ => panic!("expected Tensor"),
509        }
510    }
511
512    #[test]
513    fn test_chunked_tensor_empty() {
514        let mut buf = Vec::new();
515        encode_chunked_tensor(0, &[0], &[], 1024, &mut buf);
516        let decoded = snap_decode(&buf).unwrap();
517        match decoded {
518            Value::Tensor(t) => {
519                assert_eq!(t.shape(), &[0]);
520                assert_eq!(t.to_vec().len(), 0);
521            }
522            _ => panic!("expected Tensor"),
523        }
524    }
525
526    #[test]
527    fn test_chunked_tensor_hash_integrity() {
528        // Tamper with a chunk and verify hash mismatch is detected
529        let raw_bytes: Vec<u8> = vec![0u8; 16]; // 2 f64s
530        let mut buf = Vec::new();
531        encode_chunked_tensor(0, &[2], &raw_bytes, 8, &mut buf);
532
533        // Find and tamper with a data byte (after the chunk hash)
534        // TAG(1) + dtype(1) + ndim(8) + shape(8) + chunk_size(8) + n_chunks(8) = 34
535        // chunk 0: len(8) + hash(32) + data(8) = 48
536        // Tamper with the data of chunk 0
537        let data_start = 34 + 8 + 32; // skip tag+meta + first chunk header
538        if buf.len() > data_start {
539            buf[data_start] = 0xFF; // tamper
540        }
541        let result = snap_decode(&buf);
542        assert!(matches!(result, Err(SnapError::HashMismatch { .. })));
543    }
544
545    #[test]
546    fn test_chunked_tensor_deterministic() {
547        let raw_bytes: Vec<u8> = (0..80u8).collect();
548        let mut buf1 = Vec::new();
549        let mut buf2 = Vec::new();
550        encode_chunked_tensor(0, &[10], &raw_bytes, 32, &mut buf1);
551        encode_chunked_tensor(0, &[10], &raw_bytes, 32, &mut buf2);
552        assert_eq!(buf1, buf2, "chunked encoding must be deterministic");
553    }
554
555    // -- DataFrame encoding tests --
556
557    #[test]
558    fn test_dataframe_roundtrip_basic() {
559        let int_data = vec![1i64, 2, 3];
560        let float_data = vec![1.5f64, 2.5, 3.5];
561        let str_data = vec!["a".to_string(), "b".to_string(), "c".to_string()];
562
563        let mut buf = Vec::new();
564        encode_dataframe(
565            &["id", "value", "name"],
566            &[COL_TYPE_INT, COL_TYPE_FLOAT, COL_TYPE_STR],
567            &[
568                DataFrameColumnData::Int(&int_data),
569                DataFrameColumnData::Float(&float_data),
570                DataFrameColumnData::Str(&str_data),
571            ],
572            3,
573            &mut buf,
574        );
575
576        let decoded = snap_decode(&buf).unwrap();
577        match decoded {
578            Value::Struct { name, fields } => {
579                assert_eq!(name, "DataFrame");
580                assert!(fields.contains_key("__nrows"));
581                assert!(fields.contains_key("__columns"));
582                match fields.get("__nrows") {
583                    Some(Value::Int(n)) => assert_eq!(*n, 3),
584                    _ => panic!("expected __nrows = 3"),
585                }
586                // Check id column
587                match fields.get("id") {
588                    Some(Value::Array(arr)) => {
589                        assert_eq!(arr.len(), 3);
590                        assert!(matches!(arr[0], Value::Int(1)));
591                        assert!(matches!(arr[1], Value::Int(2)));
592                    }
593                    _ => panic!("expected id array"),
594                }
595                // Check value column
596                match fields.get("value") {
597                    Some(Value::Array(arr)) => {
598                        assert_eq!(arr.len(), 3);
599                        match &arr[0] {
600                            Value::Float(f) => assert_eq!(*f, 1.5),
601                            _ => panic!("expected Float"),
602                        }
603                    }
604                    _ => panic!("expected value array"),
605                }
606                // Check name column
607                match fields.get("name") {
608                    Some(Value::Array(arr)) => {
609                        assert_eq!(arr.len(), 3);
610                        match &arr[0] {
611                            Value::String(s) => assert_eq!(s.as_str(), "a"),
612                            _ => panic!("expected String"),
613                        }
614                    }
615                    _ => panic!("expected name array"),
616                }
617            }
618            _ => panic!("expected Struct"),
619        }
620    }
621
622    #[test]
623    fn test_dataframe_bool_column() {
624        let bool_data = vec![true, false, true];
625        let mut buf = Vec::new();
626        encode_dataframe(
627            &["flag"],
628            &[COL_TYPE_BOOL],
629            &[DataFrameColumnData::Bool(&bool_data)],
630            3,
631            &mut buf,
632        );
633        let decoded = snap_decode(&buf).unwrap();
634        match decoded {
635            Value::Struct { fields, .. } => {
636                match fields.get("flag") {
637                    Some(Value::Array(arr)) => {
638                        assert_eq!(arr.len(), 3);
639                        assert!(matches!(arr[0], Value::Bool(true)));
640                        assert!(matches!(arr[1], Value::Bool(false)));
641                        assert!(matches!(arr[2], Value::Bool(true)));
642                    }
643                    _ => panic!("expected flag array"),
644                }
645            }
646            _ => panic!("expected Struct"),
647        }
648    }
649
650    #[test]
651    fn test_dataframe_categorical_column() {
652        let levels = vec!["cat".to_string(), "dog".to_string(), "fish".to_string()];
653        let codes = vec![0u32, 1, 2, 0, 1];
654        let mut buf = Vec::new();
655        encode_dataframe(
656            &["animal"],
657            &[COL_TYPE_CATEGORICAL],
658            &[DataFrameColumnData::Categorical { levels: &levels, codes: &codes }],
659            5,
660            &mut buf,
661        );
662        let decoded = snap_decode(&buf).unwrap();
663        match decoded {
664            Value::Struct { fields, .. } => {
665                match fields.get("animal") {
666                    Some(Value::Array(arr)) => {
667                        assert_eq!(arr.len(), 5);
668                        match &arr[0] {
669                            Value::String(s) => assert_eq!(s.as_str(), "cat"),
670                            _ => panic!("expected String"),
671                        }
672                        match &arr[2] {
673                            Value::String(s) => assert_eq!(s.as_str(), "fish"),
674                            _ => panic!("expected String"),
675                        }
676                    }
677                    _ => panic!("expected animal array"),
678                }
679            }
680            _ => panic!("expected Struct"),
681        }
682    }
683
684    #[test]
685    fn test_dataframe_datetime_column() {
686        let dt_data = vec![1000i64, 2000, 3000];
687        let mut buf = Vec::new();
688        encode_dataframe(
689            &["timestamp"],
690            &[COL_TYPE_DATETIME],
691            &[DataFrameColumnData::DateTime(&dt_data)],
692            3,
693            &mut buf,
694        );
695        let decoded = snap_decode(&buf).unwrap();
696        match decoded {
697            Value::Struct { fields, .. } => {
698                match fields.get("timestamp") {
699                    Some(Value::Array(arr)) => {
700                        assert_eq!(arr.len(), 3);
701                        assert!(matches!(arr[0], Value::Int(1000)));
702                    }
703                    _ => panic!("expected timestamp array"),
704                }
705            }
706            _ => panic!("expected Struct"),
707        }
708    }
709
710    #[test]
711    fn test_dataframe_empty() {
712        let mut buf = Vec::new();
713        encode_dataframe(&[], &[], &[], 0, &mut buf);
714        let decoded = snap_decode(&buf).unwrap();
715        match decoded {
716            Value::Struct { name, fields } => {
717                assert_eq!(name, "DataFrame");
718                match fields.get("__nrows") {
719                    Some(Value::Int(0)) => {}
720                    _ => panic!("expected __nrows = 0"),
721                }
722            }
723            _ => panic!("expected Struct"),
724        }
725    }
726
727    #[test]
728    fn test_dataframe_deterministic() {
729        let int_data = vec![1i64, 2, 3];
730        let mut buf1 = Vec::new();
731        let mut buf2 = Vec::new();
732        encode_dataframe(
733            &["x"],
734            &[COL_TYPE_INT],
735            &[DataFrameColumnData::Int(&int_data)],
736            3,
737            &mut buf1,
738        );
739        encode_dataframe(
740            &["x"],
741            &[COL_TYPE_INT],
742            &[DataFrameColumnData::Int(&int_data)],
743            3,
744            &mut buf2,
745        );
746        assert_eq!(buf1, buf2, "dataframe encoding must be deterministic");
747    }
748
749    // -- Typed tensor standalone tests --
750
751    #[test]
752    fn test_typed_tensor_f64_roundtrip() {
753        let raw: Vec<u8> = vec![1.0f64, 2.0, 3.0]
754            .iter()
755            .flat_map(|v| v.to_bits().to_le_bytes())
756            .collect();
757        let mut buf = Vec::new();
758        encode_typed_tensor(0, &[3], &raw, &mut buf);
759        let decoded = snap_decode(&buf).unwrap();
760        match decoded {
761            Value::Tensor(t) => {
762                assert_eq!(t.to_vec(), vec![1.0, 2.0, 3.0]);
763            }
764            _ => panic!("expected Tensor"),
765        }
766    }
767
768    #[test]
769    fn test_typed_tensor_i32_roundtrip() {
770        let raw: Vec<u8> = vec![10i32, 20, 30]
771            .iter()
772            .flat_map(|v| v.to_le_bytes())
773            .collect();
774        let mut buf = Vec::new();
775        encode_typed_tensor(3, &[3], &raw, &mut buf); // dtype 3 = I32
776        let decoded = snap_decode(&buf).unwrap();
777        match decoded {
778            Value::Tensor(t) => {
779                assert_eq!(t.to_vec(), vec![10.0, 20.0, 30.0]);
780            }
781            _ => panic!("expected Tensor"),
782        }
783    }
784
785    // -- Schema roundtrip test --
786
787    #[test]
788    fn test_schema_roundtrip() {
789        let fields = vec![
790            ("id".to_string(), 0x01u8),
791            ("name".to_string(), 0x04u8),
792            ("value".to_string(), 0x02u8),
793        ];
794        let mut buf = Vec::new();
795        encode_schema(&fields, &mut buf);
796        let decoded = snap_decode(&buf).unwrap();
797        match decoded {
798            Value::Struct { name, fields } => {
799                assert_eq!(name, "Schema");
800                assert_eq!(fields.len(), 3);
801                assert!(matches!(fields.get("id"), Some(Value::Int(1))));
802                assert!(matches!(fields.get("name"), Some(Value::Int(4))));
803                assert!(matches!(fields.get("value"), Some(Value::Int(2))));
804            }
805            _ => panic!("expected Schema struct"),
806        }
807    }
808
809    // -- Categorical standalone roundtrip --
810
811    #[test]
812    fn test_categorical_roundtrip() {
813        let levels = vec!["a".to_string(), "b".to_string(), "c".to_string()];
814        let codes = vec![0u32, 1, 2, 0];
815        let mut buf = Vec::new();
816        encode_categorical(&levels, &codes, &mut buf);
817        let decoded = snap_decode(&buf).unwrap();
818        match decoded {
819            Value::Array(arr) => {
820                assert_eq!(arr.len(), 4);
821                match &arr[0] { Value::String(s) => assert_eq!(s.as_str(), "a"), _ => panic!() }
822                match &arr[1] { Value::String(s) => assert_eq!(s.as_str(), "b"), _ => panic!() }
823                match &arr[2] { Value::String(s) => assert_eq!(s.as_str(), "c"), _ => panic!() }
824                match &arr[3] { Value::String(s) => assert_eq!(s.as_str(), "a"), _ => panic!() }
825            }
826            _ => panic!("expected Array"),
827        }
828    }
829
830    // -- SparseCsr standalone roundtrip --
831
832    #[test]
833    fn test_sparse_csr_standalone_roundtrip() {
834        let mut buf = Vec::new();
835        encode_sparse_csr(
836            2, 3,
837            &[0, 2, 3],
838            &[0, 2, 1],
839            &[1.0, 2.0, 3.0],
840            &mut buf,
841        );
842        let decoded = snap_decode(&buf).unwrap();
843        match decoded {
844            Value::SparseTensor(s) => {
845                assert_eq!(s.nrows, 2);
846                assert_eq!(s.ncols, 3);
847                assert_eq!(s.values, vec![1.0, 2.0, 3.0]);
848                assert_eq!(s.col_indices, vec![0, 2, 1]);
849                assert_eq!(s.row_offsets, vec![0, 2, 3]);
850            }
851            _ => panic!("expected SparseTensor"),
852        }
853    }
854}