Skip to main content

grafeo_core/execution/spill/
serializer.rs

1//! Binary serialization for Values without serde overhead.
2//!
3//! This module provides efficient serialization for spilling operator state
4//! to disk. The format is designed for:
5//! - Minimal overhead (no schema, direct binary encoding)
6//! - Fast serialization/deserialization
7//! - Compact representation
8
9use arcstr::ArcStr;
10use grafeo_common::types::Value;
11use std::collections::BTreeMap;
12use std::io::{Read, Write};
13use std::sync::Arc;
14
15// Type tags for Value variants
16const TAG_NULL: u8 = 0;
17const TAG_BOOL: u8 = 1;
18const TAG_INT64: u8 = 2;
19const TAG_FLOAT64: u8 = 3;
20const TAG_STRING: u8 = 4;
21const TAG_BYTES: u8 = 5;
22const TAG_TIMESTAMP: u8 = 6;
23const TAG_LIST: u8 = 7;
24const TAG_MAP: u8 = 8;
25const TAG_VECTOR: u8 = 9;
26const TAG_DATE: u8 = 10;
27const TAG_TIME: u8 = 11;
28const TAG_DURATION: u8 = 12;
29const TAG_PATH: u8 = 13;
30const TAG_ZONED_DATETIME: u8 = 14;
31
32/// Serializes a Value to bytes.
33///
34/// Returns the number of bytes written.
35///
36/// # Errors
37///
38/// Returns an error if writing fails.
39pub fn serialize_value<W: Write + ?Sized>(value: &Value, w: &mut W) -> std::io::Result<usize> {
40    match value {
41        Value::Null => {
42            w.write_all(&[TAG_NULL])?;
43            Ok(1)
44        }
45        Value::Bool(b) => {
46            w.write_all(&[TAG_BOOL, u8::from(*b)])?;
47            Ok(2)
48        }
49        Value::Int64(i) => {
50            w.write_all(&[TAG_INT64])?;
51            w.write_all(&i.to_le_bytes())?;
52            Ok(9)
53        }
54        Value::Float64(f) => {
55            w.write_all(&[TAG_FLOAT64])?;
56            w.write_all(&f.to_le_bytes())?;
57            Ok(9)
58        }
59        Value::String(s) => {
60            w.write_all(&[TAG_STRING])?;
61            let bytes = s.as_bytes();
62            w.write_all(&(bytes.len() as u64).to_le_bytes())?;
63            w.write_all(bytes)?;
64            Ok(1 + 8 + bytes.len())
65        }
66        Value::Bytes(b) => {
67            w.write_all(&[TAG_BYTES])?;
68            w.write_all(&(b.len() as u64).to_le_bytes())?;
69            w.write_all(b)?;
70            Ok(1 + 8 + b.len())
71        }
72        Value::Timestamp(t) => {
73            w.write_all(&[TAG_TIMESTAMP])?;
74            // Timestamp is internally an i64 (microseconds since epoch)
75            let micros = t.as_micros();
76            w.write_all(&micros.to_le_bytes())?;
77            Ok(9)
78        }
79        Value::List(items) => {
80            w.write_all(&[TAG_LIST])?;
81            w.write_all(&(items.len() as u64).to_le_bytes())?;
82            let mut total = 1 + 8;
83            for item in items.iter() {
84                total += serialize_value(item, w)?;
85            }
86            Ok(total)
87        }
88        Value::Map(map) => {
89            w.write_all(&[TAG_MAP])?;
90            w.write_all(&(map.len() as u64).to_le_bytes())?;
91            let mut total = 1 + 8;
92            for (key, val) in map.iter() {
93                // Serialize key as string
94                let key_bytes = key.as_str().as_bytes();
95                w.write_all(&(key_bytes.len() as u64).to_le_bytes())?;
96                w.write_all(key_bytes)?;
97                total += 8 + key_bytes.len();
98                // Serialize value
99                total += serialize_value(val, w)?;
100            }
101            Ok(total)
102        }
103        Value::Vector(v) => {
104            w.write_all(&[TAG_VECTOR])?;
105            w.write_all(&(v.len() as u64).to_le_bytes())?;
106            for &f in v.iter() {
107                w.write_all(&f.to_le_bytes())?;
108            }
109            Ok(1 + 8 + v.len() * 4)
110        }
111        Value::Date(d) => {
112            w.write_all(&[TAG_DATE])?;
113            w.write_all(&d.as_days().to_le_bytes())?;
114            Ok(5)
115        }
116        Value::Time(t) => {
117            w.write_all(&[TAG_TIME])?;
118            w.write_all(&t.as_nanos().to_le_bytes())?;
119            let offset = t.offset_seconds().unwrap_or(i32::MIN);
120            w.write_all(&offset.to_le_bytes())?;
121            Ok(13)
122        }
123        Value::Duration(d) => {
124            w.write_all(&[TAG_DURATION])?;
125            w.write_all(&d.months().to_le_bytes())?;
126            w.write_all(&d.days().to_le_bytes())?;
127            w.write_all(&d.nanos().to_le_bytes())?;
128            Ok(25)
129        }
130        Value::ZonedDatetime(zdt) => {
131            w.write_all(&[TAG_ZONED_DATETIME])?;
132            w.write_all(&zdt.as_timestamp().as_micros().to_le_bytes())?;
133            w.write_all(&zdt.offset_seconds().to_le_bytes())?;
134            Ok(13)
135        }
136        Value::Path { nodes, edges } => {
137            w.write_all(&[TAG_PATH])?;
138            w.write_all(&(nodes.len() as u64).to_le_bytes())?;
139            let mut total = 1 + 8;
140            for node in nodes.iter() {
141                total += serialize_value(node, w)?;
142            }
143            w.write_all(&(edges.len() as u64).to_le_bytes())?;
144            total += 8;
145            for edge in edges.iter() {
146                total += serialize_value(edge, w)?;
147            }
148            Ok(total)
149        }
150    }
151}
152
153/// Deserializes a Value from bytes.
154///
155/// # Errors
156///
157/// Returns an error if reading fails or the format is invalid.
158pub fn deserialize_value<R: Read + ?Sized>(r: &mut R) -> std::io::Result<Value> {
159    let mut tag = [0u8; 1];
160    r.read_exact(&mut tag)?;
161
162    match tag[0] {
163        TAG_NULL => Ok(Value::Null),
164        TAG_BOOL => {
165            let mut buf = [0u8; 1];
166            r.read_exact(&mut buf)?;
167            Ok(Value::Bool(buf[0] != 0))
168        }
169        TAG_INT64 => {
170            let mut buf = [0u8; 8];
171            r.read_exact(&mut buf)?;
172            Ok(Value::Int64(i64::from_le_bytes(buf)))
173        }
174        TAG_FLOAT64 => {
175            let mut buf = [0u8; 8];
176            r.read_exact(&mut buf)?;
177            Ok(Value::Float64(f64::from_le_bytes(buf)))
178        }
179        TAG_STRING => {
180            let mut len_buf = [0u8; 8];
181            r.read_exact(&mut len_buf)?;
182            let len = u64::from_le_bytes(len_buf) as usize;
183            let mut str_buf = vec![0u8; len];
184            r.read_exact(&mut str_buf)?;
185            let s = String::from_utf8(str_buf)
186                .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string()))?;
187            Ok(Value::String(ArcStr::from(s)))
188        }
189        TAG_BYTES => {
190            let mut len_buf = [0u8; 8];
191            r.read_exact(&mut len_buf)?;
192            let len = u64::from_le_bytes(len_buf) as usize;
193            let mut bytes_buf = vec![0u8; len];
194            r.read_exact(&mut bytes_buf)?;
195            Ok(Value::Bytes(Arc::from(bytes_buf)))
196        }
197        TAG_TIMESTAMP => {
198            let mut buf = [0u8; 8];
199            r.read_exact(&mut buf)?;
200            let micros = i64::from_le_bytes(buf);
201            Ok(Value::Timestamp(
202                grafeo_common::types::Timestamp::from_micros(micros),
203            ))
204        }
205        TAG_LIST => {
206            let mut len_buf = [0u8; 8];
207            r.read_exact(&mut len_buf)?;
208            let len = u64::from_le_bytes(len_buf) as usize;
209            let mut items = Vec::with_capacity(len);
210            for _ in 0..len {
211                items.push(deserialize_value(r)?);
212            }
213            Ok(Value::List(Arc::from(items)))
214        }
215        TAG_MAP => {
216            let mut len_buf = [0u8; 8];
217            r.read_exact(&mut len_buf)?;
218            let len = u64::from_le_bytes(len_buf) as usize;
219            let mut map = BTreeMap::new();
220            for _ in 0..len {
221                // Read key
222                let mut key_len_buf = [0u8; 8];
223                r.read_exact(&mut key_len_buf)?;
224                let key_len = u64::from_le_bytes(key_len_buf) as usize;
225                let mut key_buf = vec![0u8; key_len];
226                r.read_exact(&mut key_buf)?;
227                let key_str = String::from_utf8(key_buf).map_err(|e| {
228                    std::io::Error::new(std::io::ErrorKind::InvalidData, e.to_string())
229                })?;
230                // Read value
231                let val = deserialize_value(r)?;
232                map.insert(grafeo_common::types::PropertyKey::new(key_str), val);
233            }
234            Ok(Value::Map(Arc::new(map)))
235        }
236        TAG_VECTOR => {
237            let mut len_buf = [0u8; 8];
238            r.read_exact(&mut len_buf)?;
239            let len = u64::from_le_bytes(len_buf) as usize;
240            let mut floats = Vec::with_capacity(len);
241            let mut buf = [0u8; 4];
242            for _ in 0..len {
243                r.read_exact(&mut buf)?;
244                floats.push(f32::from_le_bytes(buf));
245            }
246            Ok(Value::Vector(Arc::from(floats)))
247        }
248        TAG_DATE => {
249            let mut buf = [0u8; 4];
250            r.read_exact(&mut buf)?;
251            Ok(Value::Date(grafeo_common::types::Date::from_days(
252                i32::from_le_bytes(buf),
253            )))
254        }
255        TAG_TIME => {
256            let mut nanos_buf = [0u8; 8];
257            r.read_exact(&mut nanos_buf)?;
258            let nanos = u64::from_le_bytes(nanos_buf);
259            let mut offset_buf = [0u8; 4];
260            r.read_exact(&mut offset_buf)?;
261            let offset = i32::from_le_bytes(offset_buf);
262            let time = grafeo_common::types::Time::from_nanos(nanos).ok_or_else(|| {
263                std::io::Error::new(std::io::ErrorKind::InvalidData, "invalid time nanos")
264            })?;
265            if offset == i32::MIN {
266                Ok(Value::Time(time))
267            } else {
268                Ok(Value::Time(time.with_offset(offset)))
269            }
270        }
271        TAG_DURATION => {
272            let mut buf = [0u8; 8];
273            r.read_exact(&mut buf)?;
274            let months = i64::from_le_bytes(buf);
275            r.read_exact(&mut buf)?;
276            let days = i64::from_le_bytes(buf);
277            r.read_exact(&mut buf)?;
278            let nanos = i64::from_le_bytes(buf);
279            Ok(Value::Duration(grafeo_common::types::Duration::new(
280                months, days, nanos,
281            )))
282        }
283        TAG_ZONED_DATETIME => {
284            let mut micros_buf = [0u8; 8];
285            r.read_exact(&mut micros_buf)?;
286            let micros = i64::from_le_bytes(micros_buf);
287            let mut offset_buf = [0u8; 4];
288            r.read_exact(&mut offset_buf)?;
289            let offset = i32::from_le_bytes(offset_buf);
290            Ok(Value::ZonedDatetime(
291                grafeo_common::types::ZonedDatetime::from_timestamp_offset(
292                    grafeo_common::types::Timestamp::from_micros(micros),
293                    offset,
294                ),
295            ))
296        }
297        TAG_PATH => {
298            let mut len_buf = [0u8; 8];
299            r.read_exact(&mut len_buf)?;
300            let nodes_len = u64::from_le_bytes(len_buf) as usize;
301            let mut nodes = Vec::with_capacity(nodes_len);
302            for _ in 0..nodes_len {
303                nodes.push(deserialize_value(r)?);
304            }
305            r.read_exact(&mut len_buf)?;
306            let edges_len = u64::from_le_bytes(len_buf) as usize;
307            let mut edges = Vec::with_capacity(edges_len);
308            for _ in 0..edges_len {
309                edges.push(deserialize_value(r)?);
310            }
311            Ok(Value::Path {
312                nodes: Arc::from(nodes),
313                edges: Arc::from(edges),
314            })
315        }
316        _ => Err(std::io::Error::new(
317            std::io::ErrorKind::InvalidData,
318            format!("Unknown value tag: {}", tag[0]),
319        )),
320    }
321}
322
323/// Serializes a row (slice of Values) to bytes.
324///
325/// Format: `[num_columns: u64][value1][value2]...`
326///
327/// Returns the number of bytes written.
328///
329/// # Errors
330///
331/// Returns an error if writing fails.
332pub fn serialize_row<W: Write + ?Sized>(row: &[Value], w: &mut W) -> std::io::Result<usize> {
333    w.write_all(&(row.len() as u64).to_le_bytes())?;
334    let mut total = 8;
335    for value in row {
336        total += serialize_value(value, w)?;
337    }
338    Ok(total)
339}
340
341/// Deserializes a row from bytes.
342///
343/// # Arguments
344///
345/// * `r` - Reader to read from
346/// * `expected_columns` - Expected number of columns (for validation, 0 to skip)
347///
348/// # Errors
349///
350/// Returns an error if reading fails or column count mismatches.
351pub fn deserialize_row<R: Read + ?Sized>(
352    r: &mut R,
353    expected_columns: usize,
354) -> std::io::Result<Vec<Value>> {
355    let mut len_buf = [0u8; 8];
356    r.read_exact(&mut len_buf)?;
357    let num_columns = u64::from_le_bytes(len_buf) as usize;
358
359    if expected_columns > 0 && num_columns != expected_columns {
360        return Err(std::io::Error::new(
361            std::io::ErrorKind::InvalidData,
362            format!(
363                "Column count mismatch: expected {}, got {}",
364                expected_columns, num_columns
365            ),
366        ));
367    }
368
369    let mut row = Vec::with_capacity(num_columns);
370    for _ in 0..num_columns {
371        row.push(deserialize_value(r)?);
372    }
373    Ok(row)
374}
375
376#[cfg(test)]
377mod tests {
378    use super::*;
379    use arcstr::ArcStr;
380    use std::io::Cursor;
381
382    fn roundtrip_value(value: Value) -> Value {
383        let mut buf = Vec::new();
384        serialize_value(&value, &mut buf).unwrap();
385        let mut cursor = Cursor::new(buf);
386        deserialize_value(&mut cursor).unwrap()
387    }
388
389    #[test]
390    fn test_serialize_null() {
391        let result = roundtrip_value(Value::Null);
392        assert_eq!(result, Value::Null);
393    }
394
395    #[test]
396    fn test_serialize_bool() {
397        assert_eq!(roundtrip_value(Value::Bool(true)), Value::Bool(true));
398        assert_eq!(roundtrip_value(Value::Bool(false)), Value::Bool(false));
399    }
400
401    #[test]
402    fn test_serialize_int64() {
403        assert_eq!(roundtrip_value(Value::Int64(0)), Value::Int64(0));
404        assert_eq!(
405            roundtrip_value(Value::Int64(i64::MAX)),
406            Value::Int64(i64::MAX)
407        );
408        assert_eq!(
409            roundtrip_value(Value::Int64(i64::MIN)),
410            Value::Int64(i64::MIN)
411        );
412        assert_eq!(roundtrip_value(Value::Int64(-42)), Value::Int64(-42));
413    }
414
415    #[test]
416    fn test_serialize_float64() {
417        assert_eq!(roundtrip_value(Value::Float64(0.0)), Value::Float64(0.0));
418        assert_eq!(
419            roundtrip_value(Value::Float64(std::f64::consts::PI)),
420            Value::Float64(std::f64::consts::PI)
421        );
422        // Note: NaN != NaN, so we test differently
423        let nan_result = roundtrip_value(Value::Float64(f64::NAN));
424        assert!(matches!(nan_result, Value::Float64(f) if f.is_nan()));
425    }
426
427    #[test]
428    fn test_serialize_string() {
429        let result = roundtrip_value(Value::String(ArcStr::from("hello world")));
430        assert_eq!(result.as_str(), Some("hello world"));
431
432        // Empty string
433        let result = roundtrip_value(Value::String(ArcStr::from("")));
434        assert_eq!(result.as_str(), Some(""));
435
436        // Unicode
437        let result = roundtrip_value(Value::String(ArcStr::from("héllo 世界 🌍")));
438        assert_eq!(result.as_str(), Some("héllo 世界 🌍"));
439    }
440
441    #[test]
442    fn test_serialize_bytes() {
443        let data = vec![0u8, 1, 2, 255, 128];
444        let result = roundtrip_value(Value::Bytes(Arc::from(data.clone())));
445        assert_eq!(result.as_bytes(), Some(&data[..]));
446
447        // Empty bytes
448        let result = roundtrip_value(Value::Bytes(Arc::from(vec![])));
449        assert_eq!(result.as_bytes(), Some(&[][..]));
450    }
451
452    #[test]
453    fn test_serialize_timestamp() {
454        let ts = grafeo_common::types::Timestamp::from_micros(1234567890);
455        let result = roundtrip_value(Value::Timestamp(ts));
456        assert_eq!(result.as_timestamp(), Some(ts));
457    }
458
459    #[test]
460    fn test_serialize_list() {
461        let list = Value::List(Arc::from(vec![
462            Value::Int64(1),
463            Value::String(ArcStr::from("two")),
464            Value::Bool(true),
465        ]));
466        let result = roundtrip_value(list.clone());
467        assert_eq!(result, list);
468
469        // Nested list
470        let nested = Value::List(Arc::from(vec![
471            Value::List(Arc::from(vec![Value::Int64(1), Value::Int64(2)])),
472            Value::List(Arc::from(vec![Value::Int64(3)])),
473        ]));
474        let result = roundtrip_value(nested.clone());
475        assert_eq!(result, nested);
476
477        // Empty list
478        let empty = Value::List(Arc::from(vec![]));
479        let result = roundtrip_value(empty.clone());
480        assert_eq!(result, empty);
481    }
482
483    #[test]
484    fn test_serialize_map() {
485        let mut map = BTreeMap::new();
486        map.insert(
487            grafeo_common::types::PropertyKey::new("name"),
488            Value::String(ArcStr::from("Alix")),
489        );
490        map.insert(
491            grafeo_common::types::PropertyKey::new("age"),
492            Value::Int64(30),
493        );
494
495        let value = Value::Map(Arc::new(map));
496        let result = roundtrip_value(value.clone());
497        assert_eq!(result, value);
498    }
499
500    #[test]
501    fn test_serialize_row() {
502        let row = vec![
503            Value::Int64(1),
504            Value::String(ArcStr::from("test")),
505            Value::Bool(true),
506            Value::Null,
507        ];
508
509        let mut buf = Vec::new();
510        serialize_row(&row, &mut buf).unwrap();
511
512        let mut cursor = Cursor::new(buf);
513        let result = deserialize_row(&mut cursor, 4).unwrap();
514        assert_eq!(result, row);
515    }
516
517    #[test]
518    fn test_serialize_row_column_count_check() {
519        let row = vec![Value::Int64(1), Value::Int64(2)];
520
521        let mut buf = Vec::new();
522        serialize_row(&row, &mut buf).unwrap();
523
524        // Wrong expected column count
525        let mut cursor = Cursor::new(buf.clone());
526        let result = deserialize_row(&mut cursor, 3);
527        assert!(result.is_err());
528
529        // Skip check with 0
530        let mut cursor = Cursor::new(buf);
531        let result = deserialize_row(&mut cursor, 0).unwrap();
532        assert_eq!(result.len(), 2);
533    }
534
535    #[test]
536    fn test_serialize_multiple_rows() {
537        let rows = vec![
538            vec![Value::Int64(1), Value::String(ArcStr::from("a"))],
539            vec![Value::Int64(2), Value::String(ArcStr::from("b"))],
540            vec![Value::Int64(3), Value::String(ArcStr::from("c"))],
541        ];
542
543        let mut buf = Vec::new();
544        for row in &rows {
545            serialize_row(row, &mut buf).unwrap();
546        }
547
548        let mut cursor = Cursor::new(buf);
549        for expected in &rows {
550            let result = deserialize_row(&mut cursor, 2).unwrap();
551            assert_eq!(&result, expected);
552        }
553    }
554
555    #[test]
556    fn test_serialization_size() {
557        // Verify expected sizes
558        let mut buf = Vec::new();
559
560        // Null: 1 byte (tag only)
561        serialize_value(&Value::Null, &mut buf).unwrap();
562        assert_eq!(buf.len(), 1);
563        buf.clear();
564
565        // Bool: 2 bytes (tag + value)
566        serialize_value(&Value::Bool(true), &mut buf).unwrap();
567        assert_eq!(buf.len(), 2);
568        buf.clear();
569
570        // Int64: 9 bytes (tag + 8)
571        serialize_value(&Value::Int64(42), &mut buf).unwrap();
572        assert_eq!(buf.len(), 9);
573        buf.clear();
574
575        // String "hi": 11 bytes (tag + 8 length + 2)
576        serialize_value(&Value::String(ArcStr::from("hi")), &mut buf).unwrap();
577        assert_eq!(buf.len(), 11);
578    }
579}