Skip to main content

nodedb_strict/
encode.rs

1//! Binary Tuple encoder: schema + values → compact byte representation.
2//!
3//! Layout:
4//! ```text
5//! [schema_version: u16 LE]
6//! [null_bitmap: ceil(N/8) bytes, bit=1 means NULL]
7//! [fixed_fields: concatenated, zeroed when null]
8//! [offset_table: (N_var + 1) × u32 LE]
9//! [variable_data: concatenated variable-length bytes]
10//! ```
11
12use nodedb_types::columnar::{ColumnType, StrictSchema};
13use nodedb_types::value::Value;
14
15use crate::error::StrictError;
16
17/// Encodes rows into Binary Tuples according to a fixed schema.
18///
19/// Reusable: create once per schema, encode many rows. Internal buffers
20/// are reused across calls to minimize allocation.
21pub struct TupleEncoder {
22    schema: StrictSchema,
23    /// Precomputed: byte offset of each fixed-size column within the fixed section.
24    /// Variable-length columns get `None`.
25    fixed_offsets: Vec<Option<usize>>,
26    /// Total size of the fixed-fields section.
27    fixed_section_size: usize,
28    /// Indices of variable-length columns in schema order.
29    var_indices: Vec<usize>,
30    /// Size of the tuple header: 2 (version) + null_bitmap_size.
31    header_size: usize,
32}
33
34impl TupleEncoder {
35    /// Create an encoder for the given schema.
36    pub fn new(schema: &StrictSchema) -> Self {
37        let mut fixed_offsets = Vec::with_capacity(schema.columns.len());
38        let mut var_indices = Vec::new();
39        let mut fixed_offset = 0usize;
40
41        for (i, col) in schema.columns.iter().enumerate() {
42            if let Some(size) = col.column_type.fixed_size() {
43                fixed_offsets.push(Some(fixed_offset));
44                fixed_offset += size;
45            } else {
46                fixed_offsets.push(None);
47                var_indices.push(i);
48            }
49        }
50
51        let header_size = 2 + schema.null_bitmap_size();
52
53        Self {
54            schema: schema.clone(),
55            fixed_offsets,
56            fixed_section_size: fixed_offset,
57            var_indices,
58            header_size,
59        }
60    }
61
62    /// Encode a row of values into a Binary Tuple.
63    ///
64    /// `values` must have exactly `schema.len()` entries. A `Value::Null` is
65    /// allowed only if the corresponding column is nullable.
66    pub fn encode(&self, values: &[Value]) -> Result<Vec<u8>, StrictError> {
67        let n_cols = self.schema.columns.len();
68        if values.len() != n_cols {
69            return Err(StrictError::ValueCountMismatch {
70                expected: n_cols,
71                got: values.len(),
72            });
73        }
74
75        // Pre-size: header + fixed + offset_table. Variable data appended later.
76        let offset_table_size = (self.var_indices.len() + 1) * 4;
77        let base_size = self.header_size + self.fixed_section_size + offset_table_size;
78        let mut buf = vec![0u8; base_size];
79
80        // 1. Schema version.
81        buf[0..2].copy_from_slice(&self.schema.version.to_le_bytes());
82
83        // 2. Null bitmap + fixed fields + type validation.
84        let bitmap_start = 2;
85        let fixed_start = self.header_size;
86
87        for (i, (col, val)) in self.schema.columns.iter().zip(values.iter()).enumerate() {
88            let is_null = matches!(val, Value::Null);
89
90            if is_null {
91                if !col.nullable {
92                    return Err(StrictError::NullViolation(col.name.clone()));
93                }
94                // Set null bit: byte = i / 8, bit = i % 8.
95                buf[bitmap_start + i / 8] |= 1 << (i % 8);
96                // Fixed fields remain zeroed; no variable data emitted.
97                continue;
98            }
99
100            // Type check (with coercion).
101            if !col.column_type.accepts(val) {
102                return Err(StrictError::TypeMismatch {
103                    column: col.name.clone(),
104                    expected: col.column_type.clone(),
105                });
106            }
107
108            // Write fixed-size value.
109            if let Some(offset) = self.fixed_offsets[i] {
110                let dst = fixed_start + offset;
111                encode_fixed(&mut buf[dst..], &col.column_type, val);
112            }
113            // Variable-length values are handled in the offset table pass below.
114        }
115
116        // 3. Variable-length fields: build offset table + variable data.
117        let offset_table_start = self.header_size + self.fixed_section_size;
118        let mut var_data: Vec<u8> = Vec::new();
119
120        for (var_idx, &col_idx) in self.var_indices.iter().enumerate() {
121            // Write current offset.
122            let offset = var_data.len() as u32;
123            let table_pos = offset_table_start + var_idx * 4;
124            buf[table_pos..table_pos + 4].copy_from_slice(&offset.to_le_bytes());
125
126            let val = &values[col_idx];
127            if !matches!(val, Value::Null) {
128                encode_variable(
129                    &mut var_data,
130                    &self.schema.columns[col_idx].column_type,
131                    val,
132                );
133            }
134            // If null: offset stays the same as next entry → zero length.
135        }
136
137        // Final sentinel offset (marks end of last variable field).
138        let sentinel = var_data.len() as u32;
139        let sentinel_pos = offset_table_start + self.var_indices.len() * 4;
140        buf[sentinel_pos..sentinel_pos + 4].copy_from_slice(&sentinel.to_le_bytes());
141
142        // 4. Append variable data.
143        buf.extend_from_slice(&var_data);
144
145        Ok(buf)
146    }
147
148    /// Access the schema this encoder was built for.
149    pub fn schema(&self) -> &StrictSchema {
150        &self.schema
151    }
152}
153
154/// Encode a fixed-size value into the buffer at the given position.
155///
156/// Handles both native Value types and SQL coercion sources.
157fn encode_fixed(dst: &mut [u8], col_type: &ColumnType, value: &Value) {
158    match (col_type, value) {
159        // Int64: native.
160        (ColumnType::Int64, Value::Integer(v)) => {
161            dst[..8].copy_from_slice(&v.to_le_bytes());
162        }
163        // Float64: native + Int64→Float64 coercion.
164        (ColumnType::Float64, Value::Float(v)) => {
165            dst[..8].copy_from_slice(&v.to_le_bytes());
166        }
167        (ColumnType::Float64, Value::Integer(v)) => {
168            dst[..8].copy_from_slice(&(*v as f64).to_le_bytes());
169        }
170        // Bool: native.
171        (ColumnType::Bool, Value::Bool(v)) => {
172            dst[0] = *v as u8;
173        }
174        // Timestamp: native DateTime + Integer (micros) + String (ISO 8601 parse).
175        (ColumnType::Timestamp, Value::DateTime(dt)) => {
176            dst[..8].copy_from_slice(&dt.micros.to_le_bytes());
177        }
178        (ColumnType::Timestamp, Value::Integer(micros)) => {
179            dst[..8].copy_from_slice(&micros.to_le_bytes());
180        }
181        (ColumnType::Timestamp, Value::String(s)) => {
182            let micros = nodedb_types::NdbDateTime::parse(s)
183                .map(|dt| dt.micros)
184                .unwrap_or(0);
185            dst[..8].copy_from_slice(&micros.to_le_bytes());
186        }
187        // Decimal: native Decimal + String/Float/Integer coercion.
188        (ColumnType::Decimal, Value::Decimal(d)) => {
189            dst[..16].copy_from_slice(&d.serialize());
190        }
191        (ColumnType::Decimal, Value::String(s)) => {
192            let d: rust_decimal::Decimal = s.parse().unwrap_or_default();
193            dst[..16].copy_from_slice(&d.serialize());
194        }
195        (ColumnType::Decimal, Value::Float(f)) => {
196            let d = rust_decimal::Decimal::try_from(*f).unwrap_or_default();
197            dst[..16].copy_from_slice(&d.serialize());
198        }
199        (ColumnType::Decimal, Value::Integer(i)) => {
200            let d = rust_decimal::Decimal::from(*i);
201            dst[..16].copy_from_slice(&d.serialize());
202        }
203        // Uuid: native Uuid string + String coercion.
204        (ColumnType::Uuid, Value::Uuid(s) | Value::String(s)) => {
205            if let Ok(parsed) = uuid::Uuid::parse_str(s) {
206                dst[..16].copy_from_slice(parsed.as_bytes());
207            }
208        }
209        // Vector: Array of floats + Bytes (packed f32).
210        (ColumnType::Vector(dim), Value::Array(arr)) => {
211            let d = *dim as usize;
212            for (i, v) in arr.iter().take(d).enumerate() {
213                let f = match v {
214                    Value::Float(f) => *f as f32,
215                    Value::Integer(n) => *n as f32,
216                    _ => 0.0,
217                };
218                dst[i * 4..(i + 1) * 4].copy_from_slice(&f.to_le_bytes());
219            }
220        }
221        (ColumnType::Vector(dim), Value::Bytes(b)) => {
222            let byte_len = (*dim as usize) * 4;
223            let copy_len = b.len().min(byte_len);
224            dst[..copy_len].copy_from_slice(&b[..copy_len]);
225        }
226        _ => {} // Type mismatch caught earlier by accepts().
227    }
228}
229
230/// Encode a variable-length value, appending to the data buffer.
231///
232/// Handles both native Value types and SQL coercion sources.
233fn encode_variable(var_data: &mut Vec<u8>, col_type: &ColumnType, value: &Value) {
234    match (col_type, value) {
235        (ColumnType::String, Value::String(s)) => {
236            var_data.extend_from_slice(s.as_bytes());
237        }
238        (ColumnType::Bytes, Value::Bytes(b)) => {
239            var_data.extend_from_slice(b);
240        }
241        // Geometry: native Geometry (JSON-serialized) + String (WKT/GeoJSON passthrough).
242        (ColumnType::Geometry, Value::Geometry(g)) => {
243            if let Ok(json) = sonic_rs::to_vec(g) {
244                var_data.extend_from_slice(&json);
245            }
246        }
247        (ColumnType::Geometry, Value::String(s)) => {
248            var_data.extend_from_slice(s.as_bytes());
249        }
250        (ColumnType::Json, Value::String(s)) => {
251            // String input for JSON column: parse as JSON, then serialize as MessagePack.
252            // This handles VALUES ('{"key":"val"}') where the SQL planner passes a string literal.
253            let parsed = sonic_rs::from_str::<serde_json::Value>(s)
254                .ok()
255                .map(nodedb_types::Value::from);
256            let to_encode = parsed.as_ref().unwrap_or(value);
257            if let Ok(bytes) = nodedb_types::value_to_msgpack(to_encode) {
258                var_data.extend_from_slice(&bytes);
259            }
260        }
261        (ColumnType::Json, value) => {
262            // Non-string input (Object, Array, etc.): serialize directly as MessagePack.
263            if let Ok(bytes) = nodedb_types::value_to_msgpack(value) {
264                var_data.extend_from_slice(&bytes);
265            }
266        }
267        _ => {}
268    }
269}
270
271#[cfg(test)]
272mod tests {
273    use nodedb_types::columnar::ColumnDef;
274    use nodedb_types::datetime::NdbDateTime;
275
276    use super::*;
277
278    fn crm_schema() -> StrictSchema {
279        StrictSchema::new(vec![
280            ColumnDef::required("id", ColumnType::Int64).with_primary_key(),
281            ColumnDef::required("name", ColumnType::String),
282            ColumnDef::nullable("email", ColumnType::String),
283            ColumnDef::required("balance", ColumnType::Decimal),
284            ColumnDef::nullable("active", ColumnType::Bool),
285        ])
286        .unwrap()
287    }
288
289    #[test]
290    fn encode_basic_row() {
291        let schema = crm_schema();
292        let encoder = TupleEncoder::new(&schema);
293
294        let values = vec![
295            Value::Integer(42),
296            Value::String("Alice".into()),
297            Value::String("alice@example.com".into()),
298            Value::Decimal(rust_decimal::Decimal::new(5000, 2)),
299            Value::Bool(true),
300        ];
301
302        let tuple = encoder.encode(&values).unwrap();
303
304        // Header: 2 (version) + 1 (null bitmap for 5 cols) = 3 bytes
305        assert_eq!(tuple[0], 1); // schema version 1
306        assert_eq!(tuple[1], 0); // version high byte
307        assert_eq!(tuple[2], 0); // null bitmap: no nulls
308
309        // Fixed section: Int64(8) + Decimal(16) + Bool(1) = 25 bytes
310        // Starting at offset 3
311        let id_bytes = &tuple[3..11];
312        assert_eq!(i64::from_le_bytes(id_bytes.try_into().unwrap()), 42);
313    }
314
315    #[test]
316    fn encode_with_nulls() {
317        let schema = crm_schema();
318        let encoder = TupleEncoder::new(&schema);
319
320        let values = vec![
321            Value::Integer(1),
322            Value::String("Bob".into()),
323            Value::Null, // email is nullable
324            Value::Decimal(rust_decimal::Decimal::ZERO),
325            Value::Null, // active is nullable
326        ];
327
328        let tuple = encoder.encode(&values).unwrap();
329
330        // Null bitmap: bit 2 (email) and bit 4 (active) set.
331        // Bit 2 = 0b00000100 = 4, bit 4 = 0b00010000 = 16. Combined = 20.
332        assert_eq!(tuple[2], 0b00010100);
333    }
334
335    #[test]
336    fn encode_null_violation() {
337        let schema = crm_schema();
338        let encoder = TupleEncoder::new(&schema);
339
340        let values = vec![
341            Value::Null, // id is NOT NULL
342            Value::String("x".into()),
343            Value::Null,
344            Value::Decimal(rust_decimal::Decimal::ZERO),
345            Value::Null,
346        ];
347
348        let err = encoder.encode(&values).unwrap_err();
349        assert!(matches!(err, StrictError::NullViolation(ref s) if s == "id"));
350    }
351
352    #[test]
353    fn encode_type_mismatch() {
354        let schema = crm_schema();
355        let encoder = TupleEncoder::new(&schema);
356
357        let values = vec![
358            Value::String("not_an_int".into()), // id expects Int64
359            Value::String("x".into()),
360            Value::Null,
361            Value::Decimal(rust_decimal::Decimal::ZERO),
362            Value::Null,
363        ];
364
365        let err = encoder.encode(&values).unwrap_err();
366        assert!(matches!(err, StrictError::TypeMismatch { .. }));
367    }
368
369    #[test]
370    fn encode_value_count_mismatch() {
371        let schema = crm_schema();
372        let encoder = TupleEncoder::new(&schema);
373
374        let err = encoder.encode(&[Value::Integer(1)]).unwrap_err();
375        assert!(matches!(err, StrictError::ValueCountMismatch { .. }));
376    }
377
378    #[test]
379    fn encode_int_to_float_coercion() {
380        let schema =
381            StrictSchema::new(vec![ColumnDef::required("val", ColumnType::Float64)]).unwrap();
382        let encoder = TupleEncoder::new(&schema);
383
384        // Int64 → Float64 coercion should work.
385        let tuple = encoder.encode(&[Value::Integer(42)]).unwrap();
386        // Header: 2 (version) + 1 (bitmap) = 3. Fixed: 8 bytes Float64.
387        let f = f64::from_le_bytes(tuple[3..11].try_into().unwrap());
388        assert_eq!(f, 42.0);
389    }
390
391    #[test]
392    fn encode_timestamp() {
393        let schema =
394            StrictSchema::new(vec![ColumnDef::required("ts", ColumnType::Timestamp)]).unwrap();
395        let encoder = TupleEncoder::new(&schema);
396
397        let dt = NdbDateTime::from_micros(1_700_000_000_000_000);
398        let tuple = encoder.encode(&[Value::DateTime(dt)]).unwrap();
399        let micros = i64::from_le_bytes(tuple[3..11].try_into().unwrap());
400        assert_eq!(micros, 1_700_000_000_000_000);
401    }
402
403    #[test]
404    fn encode_decode_json_column() {
405        let schema = StrictSchema::new(vec![
406            ColumnDef::required("id", ColumnType::Int64).with_primary_key(),
407            ColumnDef::nullable("metadata", ColumnType::Json),
408        ])
409        .unwrap();
410        let encoder = TupleEncoder::new(&schema);
411
412        let metadata = Value::Object(std::collections::HashMap::from([
413            ("source".to_string(), Value::String("web".to_string())),
414            ("priority".to_string(), Value::Integer(3)),
415        ]));
416        let values = vec![Value::Integer(1), metadata.clone()];
417        let tuple = encoder.encode(&values).unwrap();
418
419        // Tuple must be longer than just the header + fixed section.
420        // Header: 2 (version) + 1 (bitmap) = 3. Fixed: 8 (Int64). Offset table: 8 (2 entries × u32).
421        // Variable data must be non-empty (MessagePack of the object).
422        let min_size = 3 + 8 + 8;
423        assert!(tuple.len() > min_size, "tuple should contain variable data");
424
425        // Decode and verify the value roundtrips correctly.
426        let decoder = crate::decode::TupleDecoder::new(&schema);
427        let decoded = decoder.extract_all(&tuple).unwrap();
428        assert_eq!(decoded[0], Value::Integer(1));
429        assert_eq!(decoded[1], metadata);
430    }
431
432    #[test]
433    fn encode_json_null() {
434        let schema = StrictSchema::new(vec![
435            ColumnDef::required("id", ColumnType::Int64).with_primary_key(),
436            ColumnDef::nullable("data", ColumnType::Json),
437        ])
438        .unwrap();
439        let encoder = TupleEncoder::new(&schema);
440        let tuple = encoder.encode(&[Value::Integer(1), Value::Null]).unwrap();
441        // Null bitmap byte (index 2): bit 1 (column 1) should be set → 0b00000010 = 2.
442        assert_eq!(tuple[2] & 0b10, 0b10);
443    }
444
445    #[test]
446    fn encode_vector() {
447        let schema =
448            StrictSchema::new(vec![ColumnDef::required("emb", ColumnType::Vector(3))]).unwrap();
449        let encoder = TupleEncoder::new(&schema);
450
451        let vals = vec![Value::Array(vec![
452            Value::Float(1.0),
453            Value::Float(2.0),
454            Value::Float(3.0),
455        ])];
456        let tuple = encoder.encode(&vals).unwrap();
457        // Header: 3 bytes. Fixed: 12 bytes (3 × f32).
458        let f0 = f32::from_le_bytes(tuple[3..7].try_into().unwrap());
459        let f1 = f32::from_le_bytes(tuple[7..11].try_into().unwrap());
460        let f2 = f32::from_le_bytes(tuple[11..15].try_into().unwrap());
461        assert_eq!((f0, f1, f2), (1.0, 2.0, 3.0));
462    }
463}