Skip to main content

citadel_sql/
encoding.rs

1//! Order-preserving key encoding (tuple layer)
2//! and row encoding for non-PK column storage.
3
4use crate::error::{Result, SqlError};
5use crate::types::{DataType, Value};
6
7// ── Key encoding (order-preserving) ─────────────────────────────────
8
9/// Type tag bytes for key encoding. Ordering: NULL < BLOB < TEXT < BOOLEAN < INTEGER < REAL
10const TAG_NULL: u8 = 0x00;
11const TAG_BLOB: u8 = 0x01;
12const TAG_TEXT: u8 = 0x02;
13const TAG_BOOLEAN: u8 = 0x03;
14const TAG_INTEGER: u8 = 0x04;
15const TAG_REAL: u8 = 0x05;
16
17/// Encode a single value into an order-preserving byte sequence.
18pub fn encode_key_value(value: &Value) -> Vec<u8> {
19    match value {
20        Value::Null => vec![TAG_NULL],
21        Value::Boolean(b) => vec![TAG_BOOLEAN, if *b { 0x01 } else { 0x00 }],
22        Value::Integer(i) => encode_integer(*i),
23        Value::Real(r) => encode_real(*r),
24        Value::Text(s) => encode_bytes(TAG_TEXT, s.as_bytes()),
25        Value::Blob(b) => encode_bytes(TAG_BLOB, b),
26    }
27}
28
29/// Encode a composite key (multiple values concatenated).
30pub fn encode_composite_key(values: &[Value]) -> Vec<u8> {
31    let mut buf = Vec::new();
32    for v in values {
33        buf.extend_from_slice(&encode_key_value(v));
34    }
35    buf
36}
37
38/// Decode a single key value, returning the value and the number of bytes consumed.
39pub fn decode_key_value(data: &[u8]) -> Result<(Value, usize)> {
40    if data.is_empty() {
41        return Err(SqlError::InvalidValue("empty key data".into()));
42    }
43    match data[0] {
44        TAG_NULL => Ok((Value::Null, 1)),
45        TAG_BOOLEAN => {
46            if data.len() < 2 {
47                return Err(SqlError::InvalidValue("truncated boolean".into()));
48            }
49            Ok((Value::Boolean(data[1] != 0), 2))
50        }
51        TAG_INTEGER => decode_integer(&data[1..]).map(|(v, n)| (v, n + 1)),
52        TAG_REAL => decode_real(&data[1..]).map(|(v, n)| (v, n + 1)),
53        TAG_TEXT => {
54            let (bytes, n) = decode_null_escaped(&data[1..])?;
55            let s = String::from_utf8(bytes)
56                .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in key".into()))?;
57            Ok((Value::Text(s), n + 1))
58        }
59        TAG_BLOB => {
60            let (bytes, n) = decode_null_escaped(&data[1..])?;
61            Ok((Value::Blob(bytes), n + 1))
62        }
63        tag => Err(SqlError::InvalidValue(format!("unknown key tag: {tag:#x}"))),
64    }
65}
66
67/// Decode a composite key into multiple values.
68pub fn decode_composite_key(data: &[u8], count: usize) -> Result<Vec<Value>> {
69    let mut values = Vec::with_capacity(count);
70    let mut pos = 0;
71    for _ in 0..count {
72        let (v, n) = decode_key_value(&data[pos..])?;
73        values.push(v);
74        pos += n;
75    }
76    Ok(values)
77}
78
79// ── Integer encoding (variable-width) ───────────────────────────────
80
81fn encode_integer(val: i64) -> Vec<u8> {
82    let mut buf = vec![TAG_INTEGER];
83    if val == 0 {
84        buf.push(0x80);
85        return buf;
86    }
87    if val > 0 {
88        let bytes = val.to_be_bytes();
89        // Find first non-zero byte
90        let start = bytes.iter().position(|&b| b != 0).unwrap();
91        let byte_count = (8 - start) as u8;
92        buf.push(0x80 + byte_count);
93        buf.extend_from_slice(&bytes[start..]);
94    } else {
95        // Negative: one's complement of absolute value
96        let abs_val = if val == i64::MIN {
97            // Special case: |i64::MIN| doesn't fit in i64
98            u64::MAX / 2 + 1
99        } else {
100            (-val) as u64
101        };
102        let bytes = abs_val.to_be_bytes();
103        let start = bytes.iter().position(|&b| b != 0).unwrap();
104        let byte_count = (8 - start) as u8;
105        buf.push(0x80 - byte_count);
106        // One's complement: invert all bits
107        for &b in &bytes[start..] {
108            buf.push(!b);
109        }
110    }
111    buf
112}
113
114fn decode_integer(data: &[u8]) -> Result<(Value, usize)> {
115    if data.is_empty() {
116        return Err(SqlError::InvalidValue("truncated integer".into()));
117    }
118    let marker = data[0];
119    if marker == 0x80 {
120        return Ok((Value::Integer(0), 1));
121    }
122    if marker > 0x80 {
123        // Positive
124        let byte_count = (marker - 0x80) as usize;
125        if data.len() < 1 + byte_count {
126            return Err(SqlError::InvalidValue("truncated positive integer".into()));
127        }
128        let mut bytes = [0u8; 8];
129        bytes[8 - byte_count..].copy_from_slice(&data[1..1 + byte_count]);
130        let val = i64::from_be_bytes(bytes);
131        Ok((Value::Integer(val), 1 + byte_count))
132    } else {
133        // Negative
134        let byte_count = (0x80 - marker) as usize;
135        if data.len() < 1 + byte_count {
136            return Err(SqlError::InvalidValue("truncated negative integer".into()));
137        }
138        let mut bytes = [0u8; 8];
139        for i in 0..byte_count {
140            bytes[8 - byte_count + i] = !data[1 + i];
141        }
142        let abs_val = u64::from_be_bytes(bytes);
143        // Use wrapping negation to handle i64::MIN correctly
144        let val = (-(abs_val as i128)) as i64;
145        Ok((Value::Integer(val), 1 + byte_count))
146    }
147}
148
149// ── Real encoding (IEEE 754 sign-bit manipulation) ──────────────────
150
151fn encode_real(val: f64) -> Vec<u8> {
152    let mut buf = vec![TAG_REAL];
153    let bits = val.to_bits();
154    let encoded = if val.is_sign_negative() {
155        // Negative (including -0.0): flip ALL bits
156        !bits
157    } else {
158        // Positive (including +0.0): flip sign bit only
159        bits ^ (1u64 << 63)
160    };
161    buf.extend_from_slice(&encoded.to_be_bytes());
162    buf
163}
164
165fn decode_real(data: &[u8]) -> Result<(Value, usize)> {
166    if data.len() < 8 {
167        return Err(SqlError::InvalidValue("truncated real".into()));
168    }
169    let encoded = u64::from_be_bytes(data[..8].try_into().unwrap());
170    let bits = if encoded & (1u64 << 63) != 0 {
171        // Was positive: undo sign bit flip
172        encoded ^ (1u64 << 63)
173    } else {
174        // Was negative: undo full inversion
175        !encoded
176    };
177    let val = f64::from_bits(bits);
178    Ok((Value::Real(val), 8))
179}
180
181// ── Null-escaped byte encoding ──────────────────────────────────────
182
183/// Encode bytes with null-escape: 0x00 → 0x00 0xFF, terminated by bare 0x00.
184fn encode_bytes(tag: u8, data: &[u8]) -> Vec<u8> {
185    let mut buf = Vec::with_capacity(data.len() + 2);
186    buf.push(tag);
187    for &b in data {
188        if b == 0x00 {
189            buf.push(0x00);
190            buf.push(0xFF);
191        } else {
192            buf.push(b);
193        }
194    }
195    buf.push(0x00); // terminator
196    buf
197}
198
199/// Decode null-escaped bytes. Returns (decoded bytes, bytes consumed including terminator).
200fn decode_null_escaped(data: &[u8]) -> Result<(Vec<u8>, usize)> {
201    let mut result = Vec::new();
202    let mut i = 0;
203    while i < data.len() {
204        if data[i] == 0x00 {
205            if i + 1 < data.len() && data[i + 1] == 0xFF {
206                result.push(0x00);
207                i += 2;
208            } else {
209                return Ok((result, i + 1)); // terminator consumed
210            }
211        } else {
212            result.push(data[i]);
213            i += 1;
214        }
215    }
216    Err(SqlError::InvalidValue(
217        "unterminated null-escaped string".into(),
218    ))
219}
220
221// ── Row encoding (for B+ tree values — non-PK columns) ─────────────
222
223/// Encode non-PK column values into a row.
224/// Format: [col_count: u16][null_bitmap][per-column: data_type(u8) + data_len(u32) + data]
225pub fn encode_row(values: &[Value]) -> Vec<u8> {
226    let col_count = values.len();
227    let bitmap_bytes = col_count.div_ceil(8);
228    let mut buf = Vec::new();
229
230    // Column count
231    buf.extend_from_slice(&(col_count as u16).to_le_bytes());
232
233    // Null bitmap
234    let mut bitmap = vec![0u8; bitmap_bytes];
235    for (i, v) in values.iter().enumerate() {
236        if v.is_null() {
237            bitmap[i / 8] |= 1 << (i % 8);
238        }
239    }
240    buf.extend_from_slice(&bitmap);
241
242    // Column data
243    for v in values {
244        if v.is_null() {
245            continue;
246        }
247        match v {
248            Value::Integer(i) => {
249                buf.push(DataType::Integer.type_tag());
250                buf.extend_from_slice(&8u32.to_le_bytes());
251                buf.extend_from_slice(&i.to_le_bytes());
252            }
253            Value::Real(r) => {
254                buf.push(DataType::Real.type_tag());
255                buf.extend_from_slice(&8u32.to_le_bytes());
256                buf.extend_from_slice(&r.to_le_bytes());
257            }
258            Value::Boolean(b) => {
259                buf.push(DataType::Boolean.type_tag());
260                buf.extend_from_slice(&1u32.to_le_bytes());
261                buf.push(if *b { 1 } else { 0 });
262            }
263            Value::Text(s) => {
264                let bytes = s.as_bytes();
265                buf.push(DataType::Text.type_tag());
266                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
267                buf.extend_from_slice(bytes);
268            }
269            Value::Blob(data) => {
270                buf.push(DataType::Blob.type_tag());
271                buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
272                buf.extend_from_slice(data);
273            }
274            Value::Null => unreachable!(),
275        }
276    }
277
278    buf
279}
280
281/// Decode a row into non-PK column values.
282pub fn decode_row(data: &[u8]) -> Result<Vec<Value>> {
283    if data.len() < 2 {
284        return Err(SqlError::InvalidValue("row data too short".into()));
285    }
286    let col_count = u16::from_le_bytes([data[0], data[1]]) as usize;
287    let bitmap_bytes = col_count.div_ceil(8);
288    let mut pos = 2;
289
290    if data.len() < pos + bitmap_bytes {
291        return Err(SqlError::InvalidValue("truncated null bitmap".into()));
292    }
293    let bitmap = &data[pos..pos + bitmap_bytes];
294    pos += bitmap_bytes;
295
296    let mut values = Vec::with_capacity(col_count);
297    for i in 0..col_count {
298        let is_null = bitmap[i / 8] & (1 << (i % 8)) != 0;
299        if is_null {
300            values.push(Value::Null);
301            continue;
302        }
303
304        if pos + 5 > data.len() {
305            return Err(SqlError::InvalidValue("truncated column data".into()));
306        }
307        let type_tag = data[pos];
308        pos += 1;
309        let data_len =
310            u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
311        pos += 4;
312
313        if pos + data_len > data.len() {
314            return Err(SqlError::InvalidValue("truncated column value".into()));
315        }
316
317        let value = match DataType::from_tag(type_tag) {
318            Some(DataType::Integer) => {
319                let i = i64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
320                Value::Integer(i)
321            }
322            Some(DataType::Real) => {
323                let r = f64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
324                Value::Real(r)
325            }
326            Some(DataType::Boolean) => Value::Boolean(data[pos] != 0),
327            Some(DataType::Text) => {
328                let s = String::from_utf8_lossy(&data[pos..pos + data_len]).into_owned();
329                Value::Text(s)
330            }
331            Some(DataType::Blob) => Value::Blob(data[pos..pos + data_len].to_vec()),
332            _ => {
333                return Err(SqlError::InvalidValue(format!(
334                    "unknown column type tag: {type_tag}"
335                )))
336            }
337        };
338        pos += data_len;
339        values.push(value);
340    }
341
342    Ok(values)
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    // ── Key encoding tests ──────────────────────────────────────────
350
351    #[test]
352    fn key_null() {
353        let encoded = encode_key_value(&Value::Null);
354        let (decoded, n) = decode_key_value(&encoded).unwrap();
355        assert_eq!(n, 1);
356        assert_eq!(decoded, Value::Null);
357    }
358
359    #[test]
360    fn key_boolean() {
361        let f_enc = encode_key_value(&Value::Boolean(false));
362        let t_enc = encode_key_value(&Value::Boolean(true));
363        assert!(f_enc < t_enc);
364
365        let (f_dec, _) = decode_key_value(&f_enc).unwrap();
366        let (t_dec, _) = decode_key_value(&t_enc).unwrap();
367        assert_eq!(f_dec, Value::Boolean(false));
368        assert_eq!(t_dec, Value::Boolean(true));
369    }
370
371    #[test]
372    fn key_integer_roundtrip() {
373        let test_values = [
374            i64::MIN,
375            -1_000_000,
376            -256,
377            -1,
378            0,
379            1,
380            127,
381            128,
382            255,
383            256,
384            65535,
385            1_000_000,
386            i64::MAX,
387        ];
388        for &v in &test_values {
389            let encoded = encode_key_value(&Value::Integer(v));
390            let (decoded, _) = decode_key_value(&encoded).unwrap();
391            assert_eq!(decoded, Value::Integer(v), "roundtrip failed for {v}");
392        }
393    }
394
395    #[test]
396    fn key_integer_sort_order() {
397        let values: Vec<i64> = vec![i64::MIN, -1_000_000, -1, 0, 1, 1_000_000, i64::MAX];
398        let encoded: Vec<Vec<u8>> = values
399            .iter()
400            .map(|&v| encode_key_value(&Value::Integer(v)))
401            .collect();
402
403        for i in 0..encoded.len() - 1 {
404            assert!(
405                encoded[i] < encoded[i + 1],
406                "sort order broken: {} vs {}",
407                values[i],
408                values[i + 1]
409            );
410        }
411    }
412
413    #[test]
414    fn key_real_roundtrip() {
415        let test_values = [
416            f64::NEG_INFINITY,
417            -1e100,
418            -1.0,
419            -f64::MIN_POSITIVE,
420            -0.0,
421            0.0,
422            f64::MIN_POSITIVE,
423            0.5,
424            1.0,
425            1e100,
426            f64::INFINITY,
427        ];
428        for &v in &test_values {
429            let encoded = encode_key_value(&Value::Real(v));
430            let (decoded, _) = decode_key_value(&encoded).unwrap();
431            match decoded {
432                Value::Real(r) => {
433                    assert!(
434                        v.to_bits() == r.to_bits(),
435                        "roundtrip failed for {v}: got {r}"
436                    );
437                }
438                _ => panic!("expected Real"),
439            }
440        }
441    }
442
443    #[test]
444    fn key_real_sort_order() {
445        let values = [
446            f64::NEG_INFINITY,
447            -100.0,
448            -1.0,
449            -0.0,
450            0.0,
451            1.0,
452            100.0,
453            f64::INFINITY,
454        ];
455        let encoded: Vec<Vec<u8>> = values
456            .iter()
457            .map(|&v| encode_key_value(&Value::Real(v)))
458            .collect();
459
460        for i in 0..encoded.len() - 1 {
461            assert!(
462                encoded[i] <= encoded[i + 1],
463                "sort order broken: {} vs {}",
464                values[i],
465                values[i + 1]
466            );
467        }
468    }
469
470    #[test]
471    fn key_text_roundtrip() {
472        let test_values = ["", "hello", "world", "hello\0world", "\0\0\0"];
473        for &v in &test_values {
474            let encoded = encode_key_value(&Value::Text(v.into()));
475            let (decoded, _) = decode_key_value(&encoded).unwrap();
476            assert_eq!(decoded, Value::Text(v.into()), "roundtrip failed for {v:?}");
477        }
478    }
479
480    #[test]
481    fn key_text_sort_order() {
482        let values = ["", "a", "ab", "b", "ba", "z"];
483        let encoded: Vec<Vec<u8>> = values
484            .iter()
485            .map(|&v| encode_key_value(&Value::Text(v.into())))
486            .collect();
487
488        for i in 0..encoded.len() - 1 {
489            assert!(
490                encoded[i] < encoded[i + 1],
491                "sort order broken: {:?} vs {:?}",
492                values[i],
493                values[i + 1]
494            );
495        }
496    }
497
498    #[test]
499    fn key_blob_roundtrip() {
500        let test_values: Vec<Vec<u8>> = vec![
501            vec![],
502            vec![0x00],
503            vec![0x00, 0xFF],
504            vec![0xFF, 0x00],
505            vec![0x00, 0x00, 0x00],
506        ];
507        for v in &test_values {
508            let encoded = encode_key_value(&Value::Blob(v.clone()));
509            let (decoded, _) = decode_key_value(&encoded).unwrap();
510            assert_eq!(decoded, Value::Blob(v.clone()));
511        }
512    }
513
514    #[test]
515    fn key_composite_roundtrip() {
516        let values = vec![
517            Value::Integer(42),
518            Value::Text("hello".into()),
519            Value::Boolean(true),
520        ];
521        let encoded = encode_composite_key(&values);
522        let decoded = decode_composite_key(&encoded, 3).unwrap();
523        assert_eq!(decoded[0], Value::Integer(42));
524        assert_eq!(decoded[1], Value::Text("hello".into()));
525        assert_eq!(decoded[2], Value::Boolean(true));
526    }
527
528    #[test]
529    fn key_composite_sort_order() {
530        // Composite keys: (1, "b") < (1, "c") < (2, "a")
531        let k1 = encode_composite_key(&[Value::Integer(1), Value::Text("b".into())]);
532        let k2 = encode_composite_key(&[Value::Integer(1), Value::Text("c".into())]);
533        let k3 = encode_composite_key(&[Value::Integer(2), Value::Text("a".into())]);
534        assert!(k1 < k2);
535        assert!(k2 < k3);
536    }
537
538    #[test]
539    fn key_cross_type_ordering() {
540        let null = encode_key_value(&Value::Null);
541        let bool_val = encode_key_value(&Value::Boolean(false));
542        let int = encode_key_value(&Value::Integer(0));
543        let text = encode_key_value(&Value::Text("".into()));
544        let blob = encode_key_value(&Value::Blob(vec![]));
545
546        assert!(null < blob);
547        assert!(blob < text);
548        assert!(text < bool_val);
549        assert!(bool_val < int);
550    }
551
552    // ── Row encoding tests ──────────────────────────────────────────
553
554    #[test]
555    fn row_roundtrip_simple() {
556        let values = vec![
557            Value::Integer(42),
558            Value::Text("hello".into()),
559            Value::Boolean(true),
560        ];
561        let encoded = encode_row(&values);
562        let decoded = decode_row(&encoded).unwrap();
563        assert_eq!(decoded.len(), 3);
564        assert_eq!(decoded[0], Value::Integer(42));
565        assert_eq!(decoded[1], Value::Text("hello".into()));
566        assert_eq!(decoded[2], Value::Boolean(true));
567    }
568
569    #[test]
570    fn row_roundtrip_with_nulls() {
571        let values = vec![
572            Value::Integer(1),
573            Value::Null,
574            Value::Text("test".into()),
575            Value::Null,
576        ];
577        let encoded = encode_row(&values);
578        let decoded = decode_row(&encoded).unwrap();
579        assert_eq!(decoded.len(), 4);
580        assert_eq!(decoded[0], Value::Integer(1));
581        assert!(decoded[1].is_null());
582        assert_eq!(decoded[2], Value::Text("test".into()));
583        assert!(decoded[3].is_null());
584    }
585
586    #[test]
587    fn row_roundtrip_empty() {
588        let values: Vec<Value> = vec![];
589        let encoded = encode_row(&values);
590        let decoded = decode_row(&encoded).unwrap();
591        assert!(decoded.is_empty());
592    }
593
594    #[test]
595    fn row_roundtrip_all_types() {
596        let values = vec![
597            Value::Integer(-100),
598            Value::Real(3.15),
599            Value::Text("hello world".into()),
600            Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]),
601            Value::Boolean(false),
602            Value::Null,
603        ];
604        let encoded = encode_row(&values);
605        let decoded = decode_row(&encoded).unwrap();
606        assert_eq!(decoded.len(), 6);
607        assert_eq!(decoded[0], Value::Integer(-100));
608        assert_eq!(decoded[1], Value::Real(3.15));
609        assert_eq!(decoded[2], Value::Text("hello world".into()));
610        assert_eq!(decoded[3], Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]));
611        assert_eq!(decoded[4], Value::Boolean(false));
612        assert!(decoded[5].is_null());
613    }
614
615    #[test]
616    fn null_escaped_with_embedded_nulls() {
617        let text = "before\0after";
618        let encoded = encode_key_value(&Value::Text(text.into()));
619        let (decoded, _) = decode_key_value(&encoded).unwrap();
620        assert_eq!(decoded, Value::Text(text.into()));
621    }
622
623    #[test]
624    fn key_integer_edge_cases() {
625        for v in [i64::MIN, i64::MIN + 1, -1, 0, 1, i64::MAX - 1, i64::MAX] {
626            let encoded = encode_key_value(&Value::Integer(v));
627            let (decoded, n) = decode_key_value(&encoded).unwrap();
628            assert_eq!(n, encoded.len());
629            assert_eq!(decoded, Value::Integer(v), "edge case failed for {v}");
630        }
631    }
632}