Skip to main content

citadeldb_sql/
encoding.rs

1//! Order-preserving key encoding (tuple layer)
2//! and row encoding for non-PK column storage.
3
4use crate::error::{Result, SqlError};
5use crate::types::{DataType, Value};
6
7// ── Key encoding (order-preserving) ─────────────────────────────────
8
9/// Type tag bytes for key encoding. Ordering: NULL < BLOB < TEXT < BOOLEAN < INTEGER < REAL
10const TAG_NULL: u8 = 0x00;
11const TAG_BLOB: u8 = 0x01;
12const TAG_TEXT: u8 = 0x02;
13const TAG_BOOLEAN: u8 = 0x03;
14const TAG_INTEGER: u8 = 0x04;
15const TAG_REAL: u8 = 0x05;
16
17/// Encode a single value into an order-preserving byte sequence.
18pub fn encode_key_value(value: &Value) -> Vec<u8> {
19    match value {
20        Value::Null => vec![TAG_NULL],
21        Value::Boolean(b) => vec![TAG_BOOLEAN, if *b { 0x01 } else { 0x00 }],
22        Value::Integer(i) => encode_integer(*i),
23        Value::Real(r) => encode_real(*r),
24        Value::Text(s) => encode_bytes(TAG_TEXT, s.as_bytes()),
25        Value::Blob(b) => encode_bytes(TAG_BLOB, b),
26    }
27}
28
29/// Encode a composite key (multiple values concatenated).
30pub fn encode_composite_key(values: &[Value]) -> Vec<u8> {
31    let mut buf = Vec::new();
32    for v in values {
33        buf.extend_from_slice(&encode_key_value(v));
34    }
35    buf
36}
37
38/// Decode a single key value, returning the value and the number of bytes consumed.
39pub fn decode_key_value(data: &[u8]) -> Result<(Value, usize)> {
40    if data.is_empty() {
41        return Err(SqlError::InvalidValue("empty key data".into()));
42    }
43    match data[0] {
44        TAG_NULL => Ok((Value::Null, 1)),
45        TAG_BOOLEAN => {
46            if data.len() < 2 {
47                return Err(SqlError::InvalidValue("truncated boolean".into()));
48            }
49            Ok((Value::Boolean(data[1] != 0), 2))
50        }
51        TAG_INTEGER => decode_integer(&data[1..]).map(|(v, n)| (v, n + 1)),
52        TAG_REAL => decode_real(&data[1..]).map(|(v, n)| (v, n + 1)),
53        TAG_TEXT => {
54            let (bytes, n) = decode_null_escaped(&data[1..])?;
55            let s = String::from_utf8(bytes)
56                .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in key".into()))?;
57            Ok((Value::Text(s), n + 1))
58        }
59        TAG_BLOB => {
60            let (bytes, n) = decode_null_escaped(&data[1..])?;
61            Ok((Value::Blob(bytes), n + 1))
62        }
63        tag => Err(SqlError::InvalidValue(format!("unknown key tag: {tag:#x}"))),
64    }
65}
66
67/// Decode a composite key into multiple values.
68pub fn decode_composite_key(data: &[u8], count: usize) -> Result<Vec<Value>> {
69    let mut values = Vec::with_capacity(count);
70    let mut pos = 0;
71    for _ in 0..count {
72        let (v, n) = decode_key_value(&data[pos..])?;
73        values.push(v);
74        pos += n;
75    }
76    Ok(values)
77}
78
79// ── Integer encoding (variable-width) ───────────────────────────────
80
81fn encode_integer(val: i64) -> Vec<u8> {
82    let mut buf = vec![TAG_INTEGER];
83    if val == 0 {
84        buf.push(0x80);
85        return buf;
86    }
87    if val > 0 {
88        let bytes = val.to_be_bytes();
89        // Find first non-zero byte
90        let start = bytes.iter().position(|&b| b != 0).unwrap();
91        let byte_count = (8 - start) as u8;
92        buf.push(0x80 + byte_count);
93        buf.extend_from_slice(&bytes[start..]);
94    } else {
95        // Negative: one's complement of absolute value
96        let abs_val = if val == i64::MIN {
97            // Special case: |i64::MIN| doesn't fit in i64
98            u64::MAX / 2 + 1
99        } else {
100            (-val) as u64
101        };
102        let bytes = abs_val.to_be_bytes();
103        let start = bytes.iter().position(|&b| b != 0).unwrap();
104        let byte_count = (8 - start) as u8;
105        buf.push(0x80 - byte_count);
106        // One's complement: invert all bits
107        for &b in &bytes[start..] {
108            buf.push(!b);
109        }
110    }
111    buf
112}
113
114fn decode_integer(data: &[u8]) -> Result<(Value, usize)> {
115    if data.is_empty() {
116        return Err(SqlError::InvalidValue("truncated integer".into()));
117    }
118    let marker = data[0];
119    if marker == 0x80 {
120        return Ok((Value::Integer(0), 1));
121    }
122    if marker > 0x80 {
123        // Positive
124        let byte_count = (marker - 0x80) as usize;
125        if data.len() < 1 + byte_count {
126            return Err(SqlError::InvalidValue("truncated positive integer".into()));
127        }
128        let mut bytes = [0u8; 8];
129        bytes[8 - byte_count..].copy_from_slice(&data[1..1 + byte_count]);
130        let val = i64::from_be_bytes(bytes);
131        Ok((Value::Integer(val), 1 + byte_count))
132    } else {
133        // Negative
134        let byte_count = (0x80 - marker) as usize;
135        if data.len() < 1 + byte_count {
136            return Err(SqlError::InvalidValue("truncated negative integer".into()));
137        }
138        let mut bytes = [0u8; 8];
139        for i in 0..byte_count {
140            bytes[8 - byte_count + i] = !data[1 + i];
141        }
142        let abs_val = u64::from_be_bytes(bytes);
143        // Use wrapping negation to handle i64::MIN correctly
144        let val = (-(abs_val as i128)) as i64;
145        Ok((Value::Integer(val), 1 + byte_count))
146    }
147}
148
149// ── Real encoding (IEEE 754 sign-bit manipulation) ──────────────────
150
151fn encode_real(val: f64) -> Vec<u8> {
152    let mut buf = vec![TAG_REAL];
153    let bits = val.to_bits();
154    let encoded = if val.is_sign_negative() {
155        // Negative (including -0.0): flip ALL bits
156        !bits
157    } else {
158        // Positive (including +0.0): flip sign bit only
159        bits ^ (1u64 << 63)
160    };
161    buf.extend_from_slice(&encoded.to_be_bytes());
162    buf
163}
164
165fn decode_real(data: &[u8]) -> Result<(Value, usize)> {
166    if data.len() < 8 {
167        return Err(SqlError::InvalidValue("truncated real".into()));
168    }
169    let encoded = u64::from_be_bytes(data[..8].try_into().unwrap());
170    let bits = if encoded & (1u64 << 63) != 0 {
171        // Was positive: undo sign bit flip
172        encoded ^ (1u64 << 63)
173    } else {
174        // Was negative: undo full inversion
175        !encoded
176    };
177    let val = f64::from_bits(bits);
178    Ok((Value::Real(val), 8))
179}
180
181// ── Null-escaped byte encoding ──────────────────────────────────────
182
183/// Encode bytes with null-escape: 0x00 → 0x00 0xFF, terminated by bare 0x00.
184fn encode_bytes(tag: u8, data: &[u8]) -> Vec<u8> {
185    let mut buf = Vec::with_capacity(data.len() + 2);
186    buf.push(tag);
187    for &b in data {
188        if b == 0x00 {
189            buf.push(0x00);
190            buf.push(0xFF);
191        } else {
192            buf.push(b);
193        }
194    }
195    buf.push(0x00); // terminator
196    buf
197}
198
199/// Decode null-escaped bytes. Returns (decoded bytes, bytes consumed including terminator).
200fn decode_null_escaped(data: &[u8]) -> Result<(Vec<u8>, usize)> {
201    let mut result = Vec::new();
202    let mut i = 0;
203    while i < data.len() {
204        if data[i] == 0x00 {
205            if i + 1 < data.len() && data[i + 1] == 0xFF {
206                result.push(0x00);
207                i += 2;
208            } else {
209                return Ok((result, i + 1)); // terminator consumed
210            }
211        } else {
212            result.push(data[i]);
213            i += 1;
214        }
215    }
216    Err(SqlError::InvalidValue("unterminated null-escaped string".into()))
217}
218
219// ── Row encoding (for B+ tree values — non-PK columns) ─────────────
220
221/// Encode non-PK column values into a row.
222/// Format: [col_count: u16][null_bitmap][per-column: data_type(u8) + data_len(u32) + data]
223pub fn encode_row(values: &[Value]) -> Vec<u8> {
224    let col_count = values.len();
225    let bitmap_bytes = (col_count + 7) / 8;
226    let mut buf = Vec::new();
227
228    // Column count
229    buf.extend_from_slice(&(col_count as u16).to_le_bytes());
230
231    // Null bitmap
232    let mut bitmap = vec![0u8; bitmap_bytes];
233    for (i, v) in values.iter().enumerate() {
234        if v.is_null() {
235            bitmap[i / 8] |= 1 << (i % 8);
236        }
237    }
238    buf.extend_from_slice(&bitmap);
239
240    // Column data
241    for v in values {
242        if v.is_null() {
243            continue;
244        }
245        match v {
246            Value::Integer(i) => {
247                buf.push(DataType::Integer.type_tag());
248                buf.extend_from_slice(&8u32.to_le_bytes());
249                buf.extend_from_slice(&i.to_le_bytes());
250            }
251            Value::Real(r) => {
252                buf.push(DataType::Real.type_tag());
253                buf.extend_from_slice(&8u32.to_le_bytes());
254                buf.extend_from_slice(&r.to_le_bytes());
255            }
256            Value::Boolean(b) => {
257                buf.push(DataType::Boolean.type_tag());
258                buf.extend_from_slice(&1u32.to_le_bytes());
259                buf.push(if *b { 1 } else { 0 });
260            }
261            Value::Text(s) => {
262                let bytes = s.as_bytes();
263                buf.push(DataType::Text.type_tag());
264                buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
265                buf.extend_from_slice(bytes);
266            }
267            Value::Blob(data) => {
268                buf.push(DataType::Blob.type_tag());
269                buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
270                buf.extend_from_slice(data);
271            }
272            Value::Null => unreachable!(),
273        }
274    }
275
276    buf
277}
278
279/// Decode a row into non-PK column values.
280pub fn decode_row(data: &[u8]) -> Result<Vec<Value>> {
281    if data.len() < 2 {
282        return Err(SqlError::InvalidValue("row data too short".into()));
283    }
284    let col_count = u16::from_le_bytes([data[0], data[1]]) as usize;
285    let bitmap_bytes = (col_count + 7) / 8;
286    let mut pos = 2;
287
288    if data.len() < pos + bitmap_bytes {
289        return Err(SqlError::InvalidValue("truncated null bitmap".into()));
290    }
291    let bitmap = &data[pos..pos + bitmap_bytes];
292    pos += bitmap_bytes;
293
294    let mut values = Vec::with_capacity(col_count);
295    for i in 0..col_count {
296        let is_null = bitmap[i / 8] & (1 << (i % 8)) != 0;
297        if is_null {
298            values.push(Value::Null);
299            continue;
300        }
301
302        if pos + 5 > data.len() {
303            return Err(SqlError::InvalidValue("truncated column data".into()));
304        }
305        let type_tag = data[pos];
306        pos += 1;
307        let data_len = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
308        pos += 4;
309
310        if pos + data_len > data.len() {
311            return Err(SqlError::InvalidValue("truncated column value".into()));
312        }
313
314        let value = match DataType::from_tag(type_tag) {
315            Some(DataType::Integer) => {
316                let i = i64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
317                Value::Integer(i)
318            }
319            Some(DataType::Real) => {
320                let r = f64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
321                Value::Real(r)
322            }
323            Some(DataType::Boolean) => Value::Boolean(data[pos] != 0),
324            Some(DataType::Text) => {
325                let s = String::from_utf8_lossy(&data[pos..pos + data_len]).into_owned();
326                Value::Text(s)
327            }
328            Some(DataType::Blob) => {
329                Value::Blob(data[pos..pos + data_len].to_vec())
330            }
331            _ => return Err(SqlError::InvalidValue(format!("unknown column type tag: {type_tag}"))),
332        };
333        pos += data_len;
334        values.push(value);
335    }
336
337    Ok(values)
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    // ── Key encoding tests ──────────────────────────────────────────
345
346    #[test]
347    fn key_null() {
348        let encoded = encode_key_value(&Value::Null);
349        let (decoded, n) = decode_key_value(&encoded).unwrap();
350        assert_eq!(n, 1);
351        assert_eq!(decoded, Value::Null);
352    }
353
354    #[test]
355    fn key_boolean() {
356        let f_enc = encode_key_value(&Value::Boolean(false));
357        let t_enc = encode_key_value(&Value::Boolean(true));
358        assert!(f_enc < t_enc);
359
360        let (f_dec, _) = decode_key_value(&f_enc).unwrap();
361        let (t_dec, _) = decode_key_value(&t_enc).unwrap();
362        assert_eq!(f_dec, Value::Boolean(false));
363        assert_eq!(t_dec, Value::Boolean(true));
364    }
365
366    #[test]
367    fn key_integer_roundtrip() {
368        let test_values = [
369            i64::MIN, -1_000_000, -256, -1, 0, 1, 127, 128, 255, 256,
370            65535, 1_000_000, i64::MAX,
371        ];
372        for &v in &test_values {
373            let encoded = encode_key_value(&Value::Integer(v));
374            let (decoded, _) = decode_key_value(&encoded).unwrap();
375            assert_eq!(decoded, Value::Integer(v), "roundtrip failed for {v}");
376        }
377    }
378
379    #[test]
380    fn key_integer_sort_order() {
381        let values: Vec<i64> = vec![
382            i64::MIN, -1_000_000, -1, 0, 1, 1_000_000, i64::MAX,
383        ];
384        let encoded: Vec<Vec<u8>> = values.iter()
385            .map(|&v| encode_key_value(&Value::Integer(v)))
386            .collect();
387
388        for i in 0..encoded.len() - 1 {
389            assert!(
390                encoded[i] < encoded[i + 1],
391                "sort order broken: {} vs {}",
392                values[i], values[i + 1]
393            );
394        }
395    }
396
397    #[test]
398    fn key_real_roundtrip() {
399        let test_values = [
400            f64::NEG_INFINITY, -1e100, -1.0, -f64::MIN_POSITIVE, -0.0,
401            0.0, f64::MIN_POSITIVE, 0.5, 1.0, 1e100, f64::INFINITY,
402        ];
403        for &v in &test_values {
404            let encoded = encode_key_value(&Value::Real(v));
405            let (decoded, _) = decode_key_value(&encoded).unwrap();
406            match decoded {
407                Value::Real(r) => {
408                    assert!(
409                        v.to_bits() == r.to_bits(),
410                        "roundtrip failed for {v}: got {r}"
411                    );
412                }
413                _ => panic!("expected Real"),
414            }
415        }
416    }
417
418    #[test]
419    fn key_real_sort_order() {
420        let values = vec![
421            f64::NEG_INFINITY, -100.0, -1.0, -0.0,
422            0.0, 1.0, 100.0, f64::INFINITY,
423        ];
424        let encoded: Vec<Vec<u8>> = values.iter()
425            .map(|&v| encode_key_value(&Value::Real(v)))
426            .collect();
427
428        for i in 0..encoded.len() - 1 {
429            assert!(
430                encoded[i] <= encoded[i + 1],
431                "sort order broken: {} vs {}",
432                values[i], values[i + 1]
433            );
434        }
435    }
436
437    #[test]
438    fn key_text_roundtrip() {
439        let test_values = ["", "hello", "world", "hello\0world", "\0\0\0"];
440        for &v in &test_values {
441            let encoded = encode_key_value(&Value::Text(v.into()));
442            let (decoded, _) = decode_key_value(&encoded).unwrap();
443            assert_eq!(decoded, Value::Text(v.into()), "roundtrip failed for {v:?}");
444        }
445    }
446
447    #[test]
448    fn key_text_sort_order() {
449        let values = vec!["", "a", "ab", "b", "ba", "z"];
450        let encoded: Vec<Vec<u8>> = values.iter()
451            .map(|&v| encode_key_value(&Value::Text(v.into())))
452            .collect();
453
454        for i in 0..encoded.len() - 1 {
455            assert!(
456                encoded[i] < encoded[i + 1],
457                "sort order broken: {:?} vs {:?}",
458                values[i], values[i + 1]
459            );
460        }
461    }
462
463    #[test]
464    fn key_blob_roundtrip() {
465        let test_values: Vec<Vec<u8>> = vec![
466            vec![], vec![0x00], vec![0x00, 0xFF], vec![0xFF, 0x00],
467            vec![0x00, 0x00, 0x00],
468        ];
469        for v in &test_values {
470            let encoded = encode_key_value(&Value::Blob(v.clone()));
471            let (decoded, _) = decode_key_value(&encoded).unwrap();
472            assert_eq!(decoded, Value::Blob(v.clone()));
473        }
474    }
475
476    #[test]
477    fn key_composite_roundtrip() {
478        let values = vec![
479            Value::Integer(42),
480            Value::Text("hello".into()),
481            Value::Boolean(true),
482        ];
483        let encoded = encode_composite_key(&values);
484        let decoded = decode_composite_key(&encoded, 3).unwrap();
485        assert_eq!(decoded[0], Value::Integer(42));
486        assert_eq!(decoded[1], Value::Text("hello".into()));
487        assert_eq!(decoded[2], Value::Boolean(true));
488    }
489
490    #[test]
491    fn key_composite_sort_order() {
492        // Composite keys: (1, "b") < (1, "c") < (2, "a")
493        let k1 = encode_composite_key(&[Value::Integer(1), Value::Text("b".into())]);
494        let k2 = encode_composite_key(&[Value::Integer(1), Value::Text("c".into())]);
495        let k3 = encode_composite_key(&[Value::Integer(2), Value::Text("a".into())]);
496        assert!(k1 < k2);
497        assert!(k2 < k3);
498    }
499
500    #[test]
501    fn key_cross_type_ordering() {
502        let null = encode_key_value(&Value::Null);
503        let bool_val = encode_key_value(&Value::Boolean(false));
504        let int = encode_key_value(&Value::Integer(0));
505        let text = encode_key_value(&Value::Text("".into()));
506        let blob = encode_key_value(&Value::Blob(vec![]));
507
508        assert!(null < blob);
509        assert!(blob < text);
510        assert!(text < bool_val);
511        assert!(bool_val < int);
512    }
513
514    // ── Row encoding tests ──────────────────────────────────────────
515
516    #[test]
517    fn row_roundtrip_simple() {
518        let values = vec![
519            Value::Integer(42),
520            Value::Text("hello".into()),
521            Value::Boolean(true),
522        ];
523        let encoded = encode_row(&values);
524        let decoded = decode_row(&encoded).unwrap();
525        assert_eq!(decoded.len(), 3);
526        assert_eq!(decoded[0], Value::Integer(42));
527        assert_eq!(decoded[1], Value::Text("hello".into()));
528        assert_eq!(decoded[2], Value::Boolean(true));
529    }
530
531    #[test]
532    fn row_roundtrip_with_nulls() {
533        let values = vec![
534            Value::Integer(1),
535            Value::Null,
536            Value::Text("test".into()),
537            Value::Null,
538        ];
539        let encoded = encode_row(&values);
540        let decoded = decode_row(&encoded).unwrap();
541        assert_eq!(decoded.len(), 4);
542        assert_eq!(decoded[0], Value::Integer(1));
543        assert!(decoded[1].is_null());
544        assert_eq!(decoded[2], Value::Text("test".into()));
545        assert!(decoded[3].is_null());
546    }
547
548    #[test]
549    fn row_roundtrip_empty() {
550        let values: Vec<Value> = vec![];
551        let encoded = encode_row(&values);
552        let decoded = decode_row(&encoded).unwrap();
553        assert!(decoded.is_empty());
554    }
555
556    #[test]
557    fn row_roundtrip_all_types() {
558        let values = vec![
559            Value::Integer(-100),
560            Value::Real(3.14),
561            Value::Text("hello world".into()),
562            Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]),
563            Value::Boolean(false),
564            Value::Null,
565        ];
566        let encoded = encode_row(&values);
567        let decoded = decode_row(&encoded).unwrap();
568        assert_eq!(decoded.len(), 6);
569        assert_eq!(decoded[0], Value::Integer(-100));
570        assert_eq!(decoded[1], Value::Real(3.14));
571        assert_eq!(decoded[2], Value::Text("hello world".into()));
572        assert_eq!(decoded[3], Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]));
573        assert_eq!(decoded[4], Value::Boolean(false));
574        assert!(decoded[5].is_null());
575    }
576
577    #[test]
578    fn null_escaped_with_embedded_nulls() {
579        let text = "before\0after";
580        let encoded = encode_key_value(&Value::Text(text.into()));
581        let (decoded, _) = decode_key_value(&encoded).unwrap();
582        assert_eq!(decoded, Value::Text(text.into()));
583    }
584
585    #[test]
586    fn key_integer_edge_cases() {
587        for v in [i64::MIN, i64::MIN + 1, -1, 0, 1, i64::MAX - 1, i64::MAX] {
588            let encoded = encode_key_value(&Value::Integer(v));
589            let (decoded, n) = decode_key_value(&encoded).unwrap();
590            assert_eq!(n, encoded.len());
591            assert_eq!(decoded, Value::Integer(v), "edge case failed for {v}");
592        }
593    }
594}