alopex_sql/storage/
codec.rs

1use std::convert::TryFrom;
2
3use crate::planner::ResolvedType;
4
5use super::error::{Result, StorageError};
6use super::value::SqlValue;
7
8const MAX_INLINE_BYTES: usize = 16 * 1024 * 1024; // 16 MiB guard for Text/Blob payloads
9const MAX_VECTOR_LEN: usize = 4 * 1024 * 1024; // 4 million elements (~16 MiB of f32)
10
11/// RowCodec converts between `Vec<SqlValue>` and a binary TLV format with a null bitmap.
12///
13/// Format:
14/// ```text
15/// [column_count: u16 LE]
16/// [null_bitmap: ceil(count/8) bytes] // bit=1 means NULL
17/// for each non-null column:
18///     [type_tag: u8]
19///     [value_bytes: variable length]
20/// ```
21pub struct RowCodec;
22
23impl RowCodec {
24    /// Encode a row into binary form.
25    pub fn encode(row: &[SqlValue]) -> Vec<u8> {
26        let column_count =
27            u16::try_from(row.len()).expect("row column count exceeds u16::MAX (design limit)");
28        let null_bytes = (column_count as usize).div_ceil(8);
29
30        // Pre-allocate roughly: header + bitmap + average 8 bytes per column.
31        let mut buf = Vec::with_capacity(2 + null_bytes + row.len() * 8);
32        buf.extend_from_slice(&column_count.to_le_bytes());
33
34        let mut null_bitmap = vec![0u8; null_bytes];
35        for (idx, val) in row.iter().enumerate() {
36            if val.is_null() {
37                null_bitmap[idx / 8] |= 1 << (idx % 8);
38            }
39        }
40        buf.extend_from_slice(&null_bitmap);
41
42        for value in row {
43            if value.is_null() {
44                continue;
45            }
46            buf.push(value.type_tag());
47            encode_value(value, &mut buf);
48        }
49
50        buf
51    }
52
53    /// Decode a row from binary form.
54    pub fn decode(bytes: &[u8]) -> Result<Vec<SqlValue>> {
55        let mut cursor = 0;
56        if bytes.len() < 2 {
57            return Err(StorageError::CorruptedData {
58                reason: "missing column count".into(),
59            });
60        }
61
62        let column_count =
63            u16::from_le_bytes(bytes[cursor..cursor + 2].try_into().unwrap()) as usize;
64        cursor += 2;
65
66        let null_bytes = column_count.div_ceil(8);
67        if bytes.len() < cursor + null_bytes {
68            return Err(StorageError::CorruptedData {
69                reason: "missing null bitmap".into(),
70            });
71        }
72        let null_bitmap = &bytes[cursor..cursor + null_bytes];
73        cursor += null_bytes;
74
75        let mut values = Vec::with_capacity(column_count);
76        for idx in 0..column_count {
77            let is_null = (null_bitmap[idx / 8] & (1 << (idx % 8))) != 0;
78            if is_null {
79                values.push(SqlValue::Null);
80                continue;
81            }
82
83            if cursor >= bytes.len() {
84                return Err(StorageError::CorruptedData {
85                    reason: "missing type tag".into(),
86                });
87            }
88            let tag = bytes[cursor];
89            cursor += 1;
90
91            let value = decode_value(tag, bytes, &mut cursor)?;
92            values.push(value);
93        }
94
95        if cursor != bytes.len() {
96            return Err(StorageError::CorruptedData {
97                reason: "trailing bytes after decoding row".into(),
98            });
99        }
100
101        Ok(values)
102    }
103
104    /// Decode with schema validation.
105    pub fn decode_with_schema(bytes: &[u8], schema: &[ResolvedType]) -> Result<Vec<SqlValue>> {
106        let values = Self::decode(bytes)?;
107
108        if values.len() != schema.len() {
109            return Err(StorageError::CorruptedData {
110                reason: format!(
111                    "column count mismatch: encoded={}, expected={}",
112                    values.len(),
113                    schema.len()
114                ),
115            });
116        }
117
118        values
119            .into_iter()
120            .zip(schema.iter())
121            .map(|(value, ty)| ensure_type(value, ty))
122            .collect()
123    }
124}
125
126fn encode_value(value: &SqlValue, buf: &mut Vec<u8>) {
127    match value {
128        SqlValue::Null => {}
129        SqlValue::Integer(v) => buf.extend_from_slice(&v.to_le_bytes()),
130        SqlValue::BigInt(v) => buf.extend_from_slice(&v.to_le_bytes()),
131        SqlValue::Float(v) => buf.extend_from_slice(&v.to_bits().to_le_bytes()),
132        SqlValue::Double(v) => buf.extend_from_slice(&v.to_bits().to_le_bytes()),
133        SqlValue::Text(s) => {
134            let len = u32::try_from(s.len())
135                .expect("text length exceeds u32::MAX (design limit for row encoding)");
136            buf.extend_from_slice(&len.to_le_bytes());
137            buf.extend_from_slice(s.as_bytes());
138        }
139        SqlValue::Blob(bytes) => {
140            let len = u32::try_from(bytes.len())
141                .expect("blob length exceeds u32::MAX (design limit for row encoding)");
142            buf.extend_from_slice(&len.to_le_bytes());
143            buf.extend_from_slice(bytes);
144        }
145        SqlValue::Boolean(b) => buf.push(u8::from(*b)),
146        SqlValue::Timestamp(v) => buf.extend_from_slice(&v.to_le_bytes()),
147        SqlValue::Vector(values) => {
148            let len = u32::try_from(values.len())
149                .expect("vector length exceeds u32::MAX (design limit for row encoding)");
150            buf.extend_from_slice(&len.to_le_bytes());
151            for f in values {
152                buf.extend_from_slice(&f.to_bits().to_le_bytes());
153            }
154        }
155    }
156}
157
158fn decode_value(tag: u8, bytes: &[u8], cursor: &mut usize) -> Result<SqlValue> {
159    let mut take = |len: usize, reason: &'static str| -> Result<&[u8]> {
160        let end = cursor
161            .checked_add(len)
162            .ok_or_else(|| StorageError::CorruptedData {
163                reason: reason.to_string(),
164            })?;
165        if end > bytes.len() {
166            return Err(StorageError::CorruptedData {
167                reason: reason.to_string(),
168            });
169        }
170        let slice = &bytes[*cursor..end];
171        *cursor = end;
172        Ok(slice)
173    };
174
175    match tag {
176        0x00 => Ok(SqlValue::Null),
177        0x01 => {
178            let raw = take(4, "truncated Integer value")?;
179            Ok(SqlValue::Integer(i32::from_le_bytes(
180                raw.try_into().unwrap(),
181            )))
182        }
183        0x02 => {
184            let raw = take(8, "truncated BigInt value")?;
185            Ok(SqlValue::BigInt(i64::from_le_bytes(
186                raw.try_into().unwrap(),
187            )))
188        }
189        0x03 => {
190            let raw = take(4, "truncated Float value")?;
191            Ok(SqlValue::Float(f32::from_bits(u32::from_le_bytes(
192                raw.try_into().unwrap(),
193            ))))
194        }
195        0x04 => {
196            let raw = take(8, "truncated Double value")?;
197            Ok(SqlValue::Double(f64::from_bits(u64::from_le_bytes(
198                raw.try_into().unwrap(),
199            ))))
200        }
201        0x05 => {
202            let len_bytes = take(4, "truncated Text length")?;
203            let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
204            if len > MAX_INLINE_BYTES {
205                return Err(StorageError::CorruptedData {
206                    reason: format!("text length exceeds limit: {len}"),
207                });
208            }
209            let raw = take(len, "truncated Text payload")?;
210            let s = String::from_utf8(raw.to_vec()).map_err(|_| StorageError::CorruptedData {
211                reason: "invalid UTF-8 in Text".into(),
212            })?;
213            Ok(SqlValue::Text(s))
214        }
215        0x06 => {
216            let len_bytes = take(4, "truncated Blob length")?;
217            let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
218            if len > MAX_INLINE_BYTES {
219                return Err(StorageError::CorruptedData {
220                    reason: format!("blob length exceeds limit: {len}"),
221                });
222            }
223            let raw = take(len, "truncated Blob payload")?;
224            Ok(SqlValue::Blob(raw.to_vec()))
225        }
226        0x07 => {
227            let raw = take(1, "truncated Boolean")?[0];
228            match raw {
229                0 => Ok(SqlValue::Boolean(false)),
230                1 => Ok(SqlValue::Boolean(true)),
231                other => Err(StorageError::CorruptedData {
232                    reason: format!("invalid boolean value: {}", other),
233                }),
234            }
235        }
236        0x08 => {
237            let raw = take(8, "truncated Timestamp value")?;
238            Ok(SqlValue::Timestamp(i64::from_le_bytes(
239                raw.try_into().unwrap(),
240            )))
241        }
242        0x09 => {
243            let len_bytes = take(4, "truncated Vector length")?;
244            let len = u32::from_le_bytes(len_bytes.try_into().unwrap()) as usize;
245            if len > MAX_VECTOR_LEN {
246                return Err(StorageError::CorruptedData {
247                    reason: format!("vector length exceeds limit: {len}"),
248                });
249            }
250            let total = len
251                .checked_mul(4)
252                .ok_or_else(|| StorageError::CorruptedData {
253                    reason: "vector length overflow".into(),
254                })?;
255            let raw = take(total, "truncated Vector payload")?;
256
257            let mut values = Vec::with_capacity(len);
258            for chunk in raw.chunks_exact(4) {
259                values.push(f32::from_bits(u32::from_le_bytes(
260                    chunk.try_into().unwrap(),
261                )));
262            }
263            Ok(SqlValue::Vector(values))
264        }
265        other => Err(StorageError::CorruptedData {
266            reason: format!("unknown type tag: 0x{other:02x}"),
267        }),
268    }
269}
270
271fn ensure_type(value: SqlValue, expected: &ResolvedType) -> Result<SqlValue> {
272    use ResolvedType::*;
273    match (expected, value) {
274        (_, SqlValue::Null) => Ok(SqlValue::Null),
275        (Integer, SqlValue::Integer(v)) => Ok(SqlValue::Integer(v)),
276        (BigInt, SqlValue::BigInt(v)) => Ok(SqlValue::BigInt(v)),
277        (Float, SqlValue::Float(v)) => Ok(SqlValue::Float(v)),
278        (Double, SqlValue::Double(v)) => Ok(SqlValue::Double(v)),
279        (Text, SqlValue::Text(s)) => Ok(SqlValue::Text(s)),
280        (Blob, SqlValue::Blob(b)) => Ok(SqlValue::Blob(b)),
281        (Boolean, SqlValue::Boolean(v)) => Ok(SqlValue::Boolean(v)),
282        (Timestamp, SqlValue::Timestamp(v)) => Ok(SqlValue::Timestamp(v)),
283        (Vector { dimension, .. }, SqlValue::Vector(values)) => {
284            if values.len() as u32 == *dimension {
285                Ok(SqlValue::Vector(values))
286            } else {
287                Err(StorageError::TypeMismatch {
288                    expected: format!("Vector(dim={})", dimension),
289                    actual: format!("Vector(dim={})", values.len()),
290                })
291            }
292        }
293        (expected_ty, actual) => Err(StorageError::TypeMismatch {
294            expected: expected_ty.type_name().to_string(),
295            actual: actual.type_name().to_string(),
296        }),
297    }
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303    use proptest::prelude::*;
304
305    fn values_equal(a: &SqlValue, b: &SqlValue) -> bool {
306        match (a, b) {
307            (SqlValue::Float(x), SqlValue::Float(y)) => x.to_bits() == y.to_bits(),
308            (SqlValue::Double(x), SqlValue::Double(y)) => x.to_bits() == y.to_bits(),
309            (SqlValue::Vector(xs), SqlValue::Vector(ys)) => {
310                xs.len() == ys.len()
311                    && xs
312                        .iter()
313                        .zip(ys.iter())
314                        .all(|(x, y)| x.to_bits() == y.to_bits())
315            }
316            _ => a == b,
317        }
318    }
319
320    fn row_equal(a: &[SqlValue], b: &[SqlValue]) -> bool {
321        a.len() == b.len()
322            && a.iter()
323                .zip(b.iter())
324                .all(|(lhs, rhs)| values_equal(lhs, rhs))
325    }
326
327    fn sql_value_strategy() -> impl Strategy<Value = SqlValue> {
328        let finite_f32 = any::<f32>();
329        let finite_f64 = any::<f64>();
330        prop_oneof![
331            Just(SqlValue::Null),
332            any::<i32>().prop_map(SqlValue::Integer),
333            any::<i64>().prop_map(SqlValue::BigInt),
334            finite_f32.prop_map(SqlValue::Float),
335            finite_f64.prop_map(SqlValue::Double),
336            ".*".prop_map(SqlValue::Text),
337            proptest::collection::vec(any::<u8>(), 0..32).prop_map(SqlValue::Blob),
338            any::<bool>().prop_map(SqlValue::Boolean),
339            any::<i64>().prop_map(SqlValue::Timestamp),
340            proptest::collection::vec(any::<f32>(), 0..8).prop_map(SqlValue::Vector),
341        ]
342    }
343
344    #[test]
345    fn roundtrip_preserves_all_types() {
346        let row = vec![
347            SqlValue::Null,
348            SqlValue::Integer(42),
349            SqlValue::BigInt(-42),
350            SqlValue::Float(1.5),
351            SqlValue::Double(-2.5),
352            SqlValue::Text("hello".into()),
353            SqlValue::Blob(vec![0x01, 0x02]),
354            SqlValue::Boolean(true),
355            SqlValue::Timestamp(1_700_000_000),
356            SqlValue::Vector(vec![0.1, 0.2, 0.3]),
357        ];
358
359        let encoded = RowCodec::encode(&row);
360        let decoded = RowCodec::decode(&encoded).unwrap();
361
362        assert!(row_equal(&row, &decoded));
363    }
364
365    #[test]
366    fn null_bitmap_is_respected() {
367        let row = vec![SqlValue::Integer(1), SqlValue::Null, SqlValue::Integer(2)];
368        let encoded = RowCodec::encode(&row);
369        let decoded = RowCodec::decode(&encoded).unwrap();
370        assert!(matches!(decoded[1], SqlValue::Null));
371    }
372
373    #[test]
374    fn corruption_is_detected_for_truncated_payload() {
375        let row = vec![SqlValue::Text("abc".into())];
376        let mut encoded = RowCodec::encode(&row);
377        encoded.pop(); // truncate
378        let err = RowCodec::decode(&encoded).unwrap_err();
379        assert!(matches!(err, StorageError::CorruptedData { .. }));
380    }
381
382    #[test]
383    fn corruption_is_detected_for_unknown_tag() {
384        // column_count=1, null_bitmap=0, tag=0xFF (invalid)
385        let bytes = vec![1, 0, 0, 0xFF];
386        let err = RowCodec::decode(&bytes).unwrap_err();
387        assert!(matches!(err, StorageError::CorruptedData { .. }));
388    }
389
390    #[test]
391    fn oversized_lengths_are_rejected() {
392        // Text length = MAX_INLINE_BYTES + 1 with no payload.
393        let mut bytes = Vec::new();
394        bytes.extend_from_slice(&(1u16).to_le_bytes()); // column count
395        bytes.push(0); // null bitmap
396        bytes.push(0x05); // text tag
397        let too_large = (super::MAX_INLINE_BYTES as u32) + 1;
398        bytes.extend_from_slice(&too_large.to_le_bytes());
399        let err = RowCodec::decode(&bytes).unwrap_err();
400        assert!(matches!(err, StorageError::CorruptedData { .. }));
401    }
402
403    #[test]
404    fn oversized_vector_is_rejected() {
405        let mut bytes = Vec::new();
406        bytes.extend_from_slice(&(1u16).to_le_bytes()); // column count
407        bytes.push(0); // null bitmap
408        bytes.push(0x09); // vector tag
409        let too_large = (super::MAX_VECTOR_LEN as u32) + 1;
410        bytes.extend_from_slice(&too_large.to_le_bytes());
411        let err = RowCodec::decode(&bytes).unwrap_err();
412        assert!(matches!(err, StorageError::CorruptedData { .. }));
413    }
414
415    #[test]
416    fn decode_with_schema_validates_types() {
417        let row = vec![SqlValue::Vector(vec![1.0, 2.0])];
418        let encoded = RowCodec::encode(&row);
419        let schema = vec![ResolvedType::Vector {
420            dimension: 3,
421            metric: crate::ast::ddl::VectorMetric::Cosine,
422        }];
423        let err = RowCodec::decode_with_schema(&encoded, &schema).unwrap_err();
424        assert!(matches!(err, StorageError::TypeMismatch { .. }));
425    }
426
427    proptest! {
428        #[test]
429        fn proptest_roundtrip(row in proptest::collection::vec(sql_value_strategy(), 0..16)) {
430            let encoded = RowCodec::encode(&row);
431            let decoded = RowCodec::decode(&encoded).unwrap();
432            prop_assert!(row_equal(&row, &decoded));
433        }
434
435        #[test]
436        fn decode_with_schema_matches_lengths(row in proptest::collection::vec(sql_value_strategy(), 1..5)) {
437            let schema: Vec<ResolvedType> = row.iter().map(|v| v.resolved_type()).collect();
438            let encoded = RowCodec::encode(&row);
439            let decoded = RowCodec::decode_with_schema(&encoded, &schema).unwrap();
440            prop_assert!(row_equal(&row, &decoded));
441        }
442    }
443}