Skip to main content

citadel_sql/
encoding.rs

1//! Order-preserving key encoding and row encoding for non-PK column storage.
2
3use crate::error::{Result, SqlError};
4use crate::types::{CompactString, DataType, Value};
5
6/// Type tags for order-preserving key encoding.
7const TAG_NULL: u8 = 0x00;
8const TAG_BLOB: u8 = 0x01;
9const TAG_TEXT: u8 = 0x02;
10const TAG_BOOLEAN: u8 = 0x03;
11const TAG_INTEGER: u8 = 0x04;
12const TAG_REAL: u8 = 0x05;
13const TAG_TIME: u8 = 0x06;
14const TAG_DATE: u8 = 0x07;
15const TAG_TIMESTAMP: u8 = 0x08;
16const TAG_INTERVAL: u8 = 0x09;
17
18/// Encode a single value into an order-preserving byte sequence.
19pub fn encode_key_value(value: &Value) -> Vec<u8> {
20    let mut buf = Vec::with_capacity(16);
21    encode_key_value_into(value, &mut buf);
22    buf
23}
24
25/// Encode a composite key (multiple values concatenated).
26pub fn encode_composite_key(values: &[Value]) -> Vec<u8> {
27    let mut buf = Vec::new();
28    for v in values {
29        buf.extend_from_slice(&encode_key_value(v));
30    }
31    buf
32}
33
34pub fn encode_composite_key_into(values: &[Value], buf: &mut Vec<u8>) {
35    buf.clear();
36    for v in values {
37        encode_key_value_into(v, buf);
38    }
39}
40
41pub fn encode_composite_key_from_indices(indices: &[u16], row: &[Value], buf: &mut Vec<u8>) {
42    buf.clear();
43    for &i in indices {
44        encode_key_value_into(&row[i as usize], buf);
45    }
46}
47
48#[inline]
49pub fn encode_int_key_into(val: i64, buf: &mut Vec<u8>) {
50    buf.clear();
51    encode_signed_varint(TAG_INTEGER, val, buf);
52}
53
54pub(crate) fn encode_key_value_into(value: &Value, buf: &mut Vec<u8>) {
55    match value {
56        Value::Null => buf.push(TAG_NULL),
57        Value::Boolean(b) => {
58            buf.push(TAG_BOOLEAN);
59            buf.push(if *b { 0x01 } else { 0x00 });
60        }
61        Value::Integer(i) => encode_integer_into(*i, buf),
62        Value::Real(r) => encode_real_into(*r, buf),
63        Value::Text(s) => encode_bytes_into(TAG_TEXT, s.as_bytes(), buf),
64        Value::Blob(b) => encode_bytes_into(TAG_BLOB, b, buf),
65        Value::Time(t) => encode_signed_varint(TAG_TIME, *t, buf),
66        Value::Date(d) => encode_signed_varint(TAG_DATE, i64::from(*d), buf),
67        Value::Timestamp(t) => encode_signed_varint(TAG_TIMESTAMP, *t, buf),
68        Value::Interval {
69            months,
70            days,
71            micros,
72        } => {
73            // 17 bytes: tag + (i32,i32,i64) BE with sign-flipped high byte per field.
74            buf.push(TAG_INTERVAL);
75            let mut mb = months.to_be_bytes();
76            mb[0] ^= 0x80;
77            buf.extend_from_slice(&mb);
78            let mut db = days.to_be_bytes();
79            db[0] ^= 0x80;
80            buf.extend_from_slice(&db);
81            let mut ub = micros.to_be_bytes();
82            ub[0] ^= 0x80;
83            buf.extend_from_slice(&ub);
84        }
85    }
86}
87
88fn encode_integer_into(val: i64, buf: &mut Vec<u8>) {
89    encode_signed_varint(TAG_INTEGER, val, buf);
90}
91
92/// Order-preserving variable-width codec for signed i64 with a caller-supplied tag byte.
93/// Layout: [tag] [marker] [data bytes].
94/// marker = 0x80 for zero; 0x80+n for positive (n bytes follow);
95/// 0x80-n for negative (n one's-complemented bytes follow).
96/// Byte-wise lex compare matches signed integer order.
97pub(crate) fn encode_signed_varint(tag: u8, val: i64, buf: &mut Vec<u8>) {
98    buf.push(tag);
99    if val == 0 {
100        buf.push(0x80);
101        return;
102    }
103    if val > 0 {
104        let bytes = val.to_be_bytes();
105        let start = bytes.iter().position(|&b| b != 0).unwrap();
106        let byte_count = (8 - start) as u8;
107        buf.push(0x80 + byte_count);
108        buf.extend_from_slice(&bytes[start..]);
109    } else {
110        let abs_val = if val == i64::MIN {
111            u64::MAX / 2 + 1
112        } else {
113            (-val) as u64
114        };
115        let bytes = abs_val.to_be_bytes();
116        let start = bytes.iter().position(|&b| b != 0).unwrap();
117        let byte_count = (8 - start) as u8;
118        buf.push(0x80 - byte_count);
119        for &b in &bytes[start..] {
120            buf.push(!b);
121        }
122    }
123}
124
125fn encode_real_into(val: f64, buf: &mut Vec<u8>) {
126    buf.push(TAG_REAL);
127    let bits = val.to_bits();
128    let encoded = if val.is_sign_negative() {
129        !bits
130    } else {
131        bits ^ (1u64 << 63)
132    };
133    buf.extend_from_slice(&encoded.to_be_bytes());
134}
135
136fn encode_bytes_into(tag: u8, data: &[u8], buf: &mut Vec<u8>) {
137    buf.push(tag);
138    for &b in data {
139        if b == 0x00 {
140            buf.push(0x00);
141            buf.push(0xFF);
142        } else {
143            buf.push(b);
144        }
145    }
146    buf.push(0x00);
147}
148
149/// Decode a single key value, returning the value and the number of bytes consumed.
150pub fn decode_key_value(data: &[u8]) -> Result<(Value, usize)> {
151    if data.is_empty() {
152        return Err(SqlError::InvalidValue("empty key data".into()));
153    }
154    match data[0] {
155        TAG_NULL => Ok((Value::Null, 1)),
156        TAG_BOOLEAN => {
157            if data.len() < 2 {
158                return Err(SqlError::InvalidValue("truncated boolean".into()));
159            }
160            Ok((Value::Boolean(data[1] != 0), 2))
161        }
162        TAG_INTEGER => decode_integer(&data[1..]).map(|(v, n)| (v, n + 1)),
163        TAG_REAL => decode_real(&data[1..]).map(|(v, n)| (v, n + 1)),
164        TAG_TIME => decode_signed_varint(&data[1..]).map(|(v, n)| (Value::Time(v), n + 1)),
165        TAG_DATE => decode_signed_varint(&data[1..]).map(|(v, n)| {
166            let d = v.clamp(i32::MIN as i64, i32::MAX as i64) as i32;
167            (Value::Date(d), n + 1)
168        }),
169        TAG_TIMESTAMP => {
170            decode_signed_varint(&data[1..]).map(|(v, n)| (Value::Timestamp(v), n + 1))
171        }
172        TAG_INTERVAL => {
173            if data.len() < 1 + 16 {
174                return Err(SqlError::InvalidValue("truncated interval".into()));
175            }
176            let mut mb: [u8; 4] = data[1..5].try_into().unwrap();
177            mb[0] ^= 0x80;
178            let mut db: [u8; 4] = data[5..9].try_into().unwrap();
179            db[0] ^= 0x80;
180            let mut ub: [u8; 8] = data[9..17].try_into().unwrap();
181            ub[0] ^= 0x80;
182            Ok((
183                Value::Interval {
184                    months: i32::from_be_bytes(mb),
185                    days: i32::from_be_bytes(db),
186                    micros: i64::from_be_bytes(ub),
187                },
188                17,
189            ))
190        }
191        TAG_TEXT => {
192            let (bytes, n) = decode_null_escaped(&data[1..])?;
193            let s = String::from_utf8(bytes)
194                .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in key".into()))?;
195            Ok((Value::Text(CompactString::from(s)), n + 1))
196        }
197        TAG_BLOB => {
198            let (bytes, n) = decode_null_escaped(&data[1..])?;
199            Ok((Value::Blob(bytes), n + 1))
200        }
201        tag => Err(SqlError::InvalidValue(format!("unknown key tag: {tag:#x}"))),
202    }
203}
204
205/// Decode a composite key into multiple values.
206pub fn decode_composite_key(data: &[u8], count: usize) -> Result<Vec<Value>> {
207    let mut values = Vec::with_capacity(count);
208    let mut pos = 0;
209    for _ in 0..count {
210        let (v, n) = decode_key_value(&data[pos..])?;
211        values.push(v);
212        pos += n;
213    }
214    Ok(values)
215}
216
217fn decode_integer(data: &[u8]) -> Result<(Value, usize)> {
218    let (v, n) = decode_signed_varint(data)?;
219    Ok((Value::Integer(v), n))
220}
221
222/// Decode the variable-width codec emitted by `encode_signed_varint` (tag byte already consumed).
223pub(crate) fn decode_signed_varint(data: &[u8]) -> Result<(i64, usize)> {
224    if data.is_empty() {
225        return Err(SqlError::InvalidValue("truncated integer".into()));
226    }
227    let marker = data[0];
228    if marker == 0x80 {
229        return Ok((0, 1));
230    }
231    if marker > 0x80 {
232        let byte_count = (marker - 0x80) as usize;
233        if data.len() < 1 + byte_count {
234            return Err(SqlError::InvalidValue("truncated positive integer".into()));
235        }
236        let mut bytes = [0u8; 8];
237        bytes[8 - byte_count..].copy_from_slice(&data[1..1 + byte_count]);
238        let val = i64::from_be_bytes(bytes);
239        Ok((val, 1 + byte_count))
240    } else {
241        let byte_count = (0x80 - marker) as usize;
242        if data.len() < 1 + byte_count {
243            return Err(SqlError::InvalidValue("truncated negative integer".into()));
244        }
245        let mut bytes = [0u8; 8];
246        for i in 0..byte_count {
247            bytes[8 - byte_count + i] = !data[1 + i];
248        }
249        let abs_val = u64::from_be_bytes(bytes);
250        let val = (-(abs_val as i128)) as i64;
251        Ok((val, 1 + byte_count))
252    }
253}
254
255fn decode_real(data: &[u8]) -> Result<(Value, usize)> {
256    if data.len() < 8 {
257        return Err(SqlError::InvalidValue("truncated real".into()));
258    }
259    let encoded = u64::from_be_bytes(data[..8].try_into().unwrap());
260    let bits = if encoded & (1u64 << 63) != 0 {
261        // Was positive: undo sign bit flip
262        encoded ^ (1u64 << 63)
263    } else {
264        // Was negative: undo full inversion
265        !encoded
266    };
267    let val = f64::from_bits(bits);
268    Ok((Value::Real(val), 8))
269}
270
271/// Decode null-escaped bytes. Returns (decoded bytes, bytes consumed including terminator).
272fn decode_null_escaped(data: &[u8]) -> Result<(Vec<u8>, usize)> {
273    let mut result = Vec::new();
274    let mut i = 0;
275    while i < data.len() {
276        if data[i] == 0x00 {
277            if i + 1 < data.len() && data[i + 1] == 0xFF {
278                result.push(0x00);
279                i += 2;
280            } else {
281                return Ok((result, i + 1)); // terminator consumed
282            }
283        } else {
284            result.push(data[i]);
285            i += 1;
286        }
287    }
288    Err(SqlError::InvalidValue(
289        "unterminated null-escaped string".into(),
290    ))
291}
292
293fn encode_cell_v2(v: &Value, buf: &mut Vec<u8>) {
294    match v {
295        Value::Integer(val) => {
296            buf.push(DataType::Integer.type_tag());
297            buf.extend_from_slice(&val.to_le_bytes());
298        }
299        Value::Real(r) => {
300            buf.push(DataType::Real.type_tag());
301            buf.extend_from_slice(&r.to_le_bytes());
302        }
303        Value::Boolean(b) => {
304            buf.push(DataType::Boolean.type_tag());
305            buf.push(if *b { 1 } else { 0 });
306        }
307        Value::Text(s) => {
308            let bytes = s.as_bytes();
309            buf.push(DataType::Text.type_tag());
310            buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
311            buf.extend_from_slice(bytes);
312        }
313        Value::Blob(data) => {
314            buf.push(DataType::Blob.type_tag());
315            buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
316            buf.extend_from_slice(data);
317        }
318        Value::Time(t) => {
319            buf.push(DataType::Time.type_tag());
320            buf.extend_from_slice(&t.to_le_bytes());
321        }
322        Value::Date(d) => {
323            buf.push(DataType::Date.type_tag());
324            buf.extend_from_slice(&d.to_le_bytes());
325        }
326        Value::Timestamp(t) => {
327            buf.push(DataType::Timestamp.type_tag());
328            buf.extend_from_slice(&t.to_le_bytes());
329        }
330        Value::Interval {
331            months,
332            days,
333            micros,
334        } => {
335            buf.push(DataType::Interval.type_tag());
336            buf.extend_from_slice(&months.to_le_bytes());
337            buf.extend_from_slice(&days.to_le_bytes());
338            buf.extend_from_slice(&micros.to_le_bytes());
339        }
340        Value::Null => unreachable!(),
341    }
342}
343
344pub fn encode_row(values: &[Value]) -> Vec<u8> {
345    let mut buf = Vec::new();
346    encode_row_into(values, &mut buf);
347    buf
348}
349
350pub fn encode_row_into(values: &[Value], buf: &mut Vec<u8>) {
351    buf.clear();
352    let col_count = values.len();
353    let bitmap_bytes = col_count.div_ceil(8);
354
355    let header = (col_count as u16) | V2_FLAG;
356    buf.extend_from_slice(&header.to_le_bytes());
357
358    let bitmap_start = buf.len();
359    buf.resize(buf.len() + bitmap_bytes, 0);
360
361    for (i, v) in values.iter().enumerate() {
362        if v.is_null() {
363            buf[bitmap_start + i / 8] |= 1 << (i % 8);
364            continue;
365        }
366        encode_cell_v2(v, buf);
367    }
368}
369
370pub struct IntRowTemplate {
371    pub template: Vec<u8>,
372    pub slot_offsets: Vec<(usize, usize)>,
373}
374
375pub fn build_int_row_template(phys_count: usize, null_slots: &[usize]) -> IntRowTemplate {
376    let bitmap_bytes = phys_count.div_ceil(8);
377    let mut template = Vec::with_capacity(2 + bitmap_bytes + phys_count * 9);
378    let header = (phys_count as u16) | V2_FLAG;
379    template.extend_from_slice(&header.to_le_bytes());
380    let bitmap_start = template.len();
381    template.resize(bitmap_start + bitmap_bytes, 0);
382    for &i in null_slots {
383        template[bitmap_start + i / 8] |= 1 << (i % 8);
384    }
385    let mut slot_offsets = Vec::with_capacity(phys_count.saturating_sub(null_slots.len()));
386    for slot in 0..phys_count {
387        if null_slots.contains(&slot) {
388            continue;
389        }
390        template.push(DataType::Integer.type_tag());
391        let value_offset = template.len();
392        template.extend_from_slice(&[0u8; 8]);
393        slot_offsets.push((slot, value_offset));
394    }
395    IntRowTemplate {
396        template,
397        slot_offsets,
398    }
399}
400
401/// Caller must guarantee every non-NULL `values[slot]` is `Value::Integer`.
402#[inline]
403pub fn encode_int_row_with_template(
404    tmpl: &IntRowTemplate,
405    values: &[Value],
406    buf: &mut Vec<u8>,
407) -> Result<()> {
408    buf.clear();
409    buf.extend_from_slice(&tmpl.template);
410    for &(slot, off) in &tmpl.slot_offsets {
411        match &values[slot] {
412            Value::Integer(v) => buf[off..off + 8].copy_from_slice(&v.to_le_bytes()),
413            other => {
414                return Err(SqlError::TypeMismatch {
415                    expected: "Integer".into(),
416                    got: other.data_type().to_string(),
417                });
418            }
419        }
420    }
421    Ok(())
422}
423
424fn decode_value(type_tag: u8, data: &[u8]) -> Result<Value> {
425    match DataType::from_tag(type_tag) {
426        Some(DataType::Integer) => Ok(Value::Integer(i64::from_le_bytes(
427            data[..8].try_into().unwrap(),
428        ))),
429        Some(DataType::Real) => Ok(Value::Real(f64::from_le_bytes(
430            data[..8].try_into().unwrap(),
431        ))),
432        Some(DataType::Boolean) => Ok(Value::Boolean(data[0] != 0)),
433        Some(DataType::Text) => {
434            let s = std::str::from_utf8(data)
435                .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in column".into()))?;
436            Ok(Value::Text(CompactString::from(s)))
437        }
438        Some(DataType::Blob) => Ok(Value::Blob(data.to_vec())),
439        Some(DataType::Time) => Ok(Value::Time(i64::from_le_bytes(
440            data[..8].try_into().unwrap(),
441        ))),
442        Some(DataType::Date) => Ok(Value::Date(i32::from_le_bytes(
443            data[..4].try_into().unwrap(),
444        ))),
445        Some(DataType::Timestamp) => Ok(Value::Timestamp(i64::from_le_bytes(
446            data[..8].try_into().unwrap(),
447        ))),
448        Some(DataType::Interval) => {
449            if data.len() < 16 {
450                return Err(SqlError::InvalidValue("truncated interval".into()));
451            }
452            let months = i32::from_le_bytes(data[0..4].try_into().unwrap());
453            let days = i32::from_le_bytes(data[4..8].try_into().unwrap());
454            let micros = i64::from_le_bytes(data[8..16].try_into().unwrap());
455            Ok(Value::Interval {
456                months,
457                days,
458                micros,
459            })
460        }
461        _ => Err(SqlError::InvalidValue(format!(
462            "unknown column type tag: {type_tag}"
463        ))),
464    }
465}
466
467/// V1 cells: `[tag:u8][len:u32][data]`. V2 cells drop `len` for fixed-width types.
468/// High bit of `col_count:u16` flags V2.
469#[derive(Clone, Copy, PartialEq, Eq, Debug)]
470pub(crate) enum RowVersion {
471    V1,
472    V2,
473}
474
475pub(crate) const V2_FLAG: u16 = 0x8000;
476pub(crate) const COL_COUNT_MASK: u16 = 0x7FFF;
477
478#[inline]
479pub(crate) fn fixed_width_size(type_tag: u8) -> Option<usize> {
480    match DataType::from_tag(type_tag)? {
481        DataType::Integer | DataType::Real | DataType::Time | DataType::Timestamp => Some(8),
482        DataType::Date => Some(4),
483        DataType::Boolean => Some(1),
484        DataType::Interval => Some(16),
485        DataType::Text | DataType::Blob | DataType::Null => None,
486    }
487}
488
489#[inline]
490fn read_cell(data: &[u8], pos: usize, version: RowVersion) -> Result<(u8, &[u8], usize)> {
491    if pos >= data.len() {
492        return Err(SqlError::InvalidValue("truncated column data".into()));
493    }
494    let type_tag = data[pos];
495    let after_tag = pos + 1;
496    let (data_len, body_pos) = match version {
497        RowVersion::V2 => match fixed_width_size(type_tag) {
498            Some(n) => (n, after_tag),
499            None => {
500                if after_tag + 4 > data.len() {
501                    return Err(SqlError::InvalidValue("truncated column data".into()));
502                }
503                let len = u32::from_le_bytes([
504                    data[after_tag],
505                    data[after_tag + 1],
506                    data[after_tag + 2],
507                    data[after_tag + 3],
508                ]) as usize;
509                (len, after_tag + 4)
510            }
511        },
512        RowVersion::V1 => {
513            if after_tag + 4 > data.len() {
514                return Err(SqlError::InvalidValue("truncated column data".into()));
515            }
516            let len = u32::from_le_bytes([
517                data[after_tag],
518                data[after_tag + 1],
519                data[after_tag + 2],
520                data[after_tag + 3],
521            ]) as usize;
522            (len, after_tag + 4)
523        }
524    };
525    if body_pos + data_len > data.len() {
526        return Err(SqlError::InvalidValue("truncated column value".into()));
527    }
528    Ok((
529        type_tag,
530        &data[body_pos..body_pos + data_len],
531        body_pos + data_len,
532    ))
533}
534
535#[inline]
536fn skip_cell(data: &[u8], pos: usize, version: RowVersion) -> Result<usize> {
537    let (_, _, next) = read_cell(data, pos, version)?;
538    Ok(next)
539}
540
541fn copy_cell_to_v2(
542    data: &[u8],
543    pos: usize,
544    version: RowVersion,
545    out: &mut Vec<u8>,
546) -> Result<usize> {
547    let (tag, body, next) = read_cell(data, pos, version)?;
548    out.push(tag);
549    if fixed_width_size(tag).is_none() {
550        out.extend_from_slice(&(body.len() as u32).to_le_bytes());
551    }
552    out.extend_from_slice(body);
553    Ok(next)
554}
555
556fn parse_row_header(data: &[u8]) -> Result<(RowVersion, usize, &[u8], usize)> {
557    if data.len() < 2 {
558        return Err(SqlError::InvalidValue("row data too short".into()));
559    }
560    let raw = u16::from_le_bytes([data[0], data[1]]);
561    let version = if raw & V2_FLAG != 0 {
562        RowVersion::V2
563    } else {
564        RowVersion::V1
565    };
566    let col_count = (raw & COL_COUNT_MASK) as usize;
567    let bitmap_bytes = col_count.div_ceil(8);
568    let pos = 2;
569    if data.len() < pos + bitmap_bytes {
570        return Err(SqlError::InvalidValue("truncated null bitmap".into()));
571    }
572    Ok((
573        version,
574        col_count,
575        &data[pos..pos + bitmap_bytes],
576        pos + bitmap_bytes,
577    ))
578}
579
580pub fn decode_row(data: &[u8]) -> Result<Vec<Value>> {
581    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
582
583    let mut values = Vec::with_capacity(col_count);
584    for i in 0..col_count {
585        if bitmap[i / 8] & (1 << (i % 8)) != 0 {
586            values.push(Value::Null);
587            continue;
588        }
589        let (type_tag, body, next) = read_cell(data, pos, version)?;
590        values.push(decode_value(type_tag, body)?);
591        pos = next;
592    }
593
594    Ok(values)
595}
596
597/// Returns the number of non-PK columns stored in a row value blob.
598#[inline]
599pub fn row_non_pk_count(data: &[u8]) -> usize {
600    (u16::from_le_bytes([data[0], data[1]]) & COL_COUNT_MASK) as usize
601}
602
603pub fn decode_row_into(data: &[u8], out: &mut [Value], col_mapping: &[usize]) -> Result<()> {
604    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
605
606    for i in 0..col_count {
607        if bitmap[i / 8] & (1 << (i % 8)) != 0 {
608            continue;
609        }
610        let (type_tag, body, next) = read_cell(data, pos, version)?;
611        if i < col_mapping.len() && col_mapping[i] != usize::MAX {
612            out[col_mapping[i]] = decode_value(type_tag, body)?;
613        }
614        pos = next;
615    }
616
617    Ok(())
618}
619
620pub fn decode_pk_into(
621    key: &[u8],
622    count: usize,
623    out: &mut [Value],
624    pk_mapping: &[usize],
625) -> Result<()> {
626    let mut pos = 0;
627    for i in 0..count {
628        let (v, n) = decode_key_value(&key[pos..])?;
629        if i < pk_mapping.len() {
630            out[pk_mapping[i]] = v;
631        }
632        pos += n;
633    }
634    Ok(())
635}
636
637pub fn decode_columns(data: &[u8], targets: &[usize]) -> Result<Vec<Value>> {
638    if targets.is_empty() {
639        return Ok(Vec::new());
640    }
641    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
642
643    let mut results = Vec::with_capacity(targets.len());
644    let mut ti = 0;
645
646    for col in 0..col_count {
647        if ti >= targets.len() {
648            break;
649        }
650        let is_null = bitmap[col / 8] & (1 << (col % 8)) != 0;
651
652        if col == targets[ti] {
653            if is_null {
654                results.push(Value::Null);
655            } else {
656                let (type_tag, body, next) = read_cell(data, pos, version)?;
657                results.push(decode_value(type_tag, body)?);
658                pos = next;
659            }
660            ti += 1;
661        } else if !is_null {
662            pos = skip_cell(data, pos, version)?;
663        }
664    }
665
666    while ti < targets.len() {
667        results.push(Value::Null);
668        ti += 1;
669    }
670
671    Ok(results)
672}
673
674pub fn decode_columns_into(
675    data: &[u8],
676    targets: &[usize],
677    schema_cols: &[usize],
678    row: &mut [Value],
679) -> Result<()> {
680    if targets.is_empty() {
681        return Ok(());
682    }
683    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
684
685    let mut ti = 0;
686    for col in 0..col_count {
687        if ti >= targets.len() {
688            break;
689        }
690        let is_null = bitmap[col / 8] & (1 << (col % 8)) != 0;
691
692        if col == targets[ti] {
693            if is_null {
694                row[schema_cols[ti]] = Value::Null;
695            } else {
696                let (type_tag, body, next) = read_cell(data, pos, version)?;
697                row[schema_cols[ti]] = decode_value(type_tag, body)?;
698                pos = next;
699            }
700            ti += 1;
701        } else if !is_null {
702            pos = skip_cell(data, pos, version)?;
703        }
704    }
705
706    Ok(())
707}
708
709#[derive(Debug, Clone, Copy)]
710pub enum RawColumn<'a> {
711    Null,
712    Integer(i64),
713    Real(f64),
714    Boolean(bool),
715    Text(&'a str),
716    Blob(&'a [u8]),
717    Time(i64),
718    Date(i32),
719    Timestamp(i64),
720    Interval { months: i32, days: i32, micros: i64 },
721}
722
723impl<'a> RawColumn<'a> {
724    pub fn to_value(self) -> Value {
725        match self {
726            RawColumn::Null => Value::Null,
727            RawColumn::Integer(i) => Value::Integer(i),
728            RawColumn::Real(r) => Value::Real(r),
729            RawColumn::Boolean(b) => Value::Boolean(b),
730            RawColumn::Text(s) => Value::Text(CompactString::from(s)),
731            RawColumn::Blob(b) => Value::Blob(b.to_vec()),
732            RawColumn::Time(t) => Value::Time(t),
733            RawColumn::Date(d) => Value::Date(d),
734            RawColumn::Timestamp(t) => Value::Timestamp(t),
735            RawColumn::Interval {
736                months,
737                days,
738                micros,
739            } => Value::Interval {
740                months,
741                days,
742                micros,
743            },
744        }
745    }
746
747    pub fn cmp_value(&self, other: &Value) -> Option<std::cmp::Ordering> {
748        use std::cmp::Ordering;
749        match (self, other) {
750            (RawColumn::Null, Value::Null) => Some(Ordering::Equal),
751            (RawColumn::Null, _) | (_, Value::Null) => None,
752            (RawColumn::Integer(a), Value::Integer(b)) => Some(a.cmp(b)),
753            (RawColumn::Integer(a), Value::Real(b)) => (*a as f64).partial_cmp(b),
754            (RawColumn::Real(a), Value::Real(b)) => a.partial_cmp(b),
755            (RawColumn::Real(a), Value::Integer(b)) => a.partial_cmp(&(*b as f64)),
756            (RawColumn::Text(a), Value::Text(b)) => Some((*a).cmp(b.as_str())),
757            (RawColumn::Blob(a), Value::Blob(b)) => Some((*a).cmp(b.as_slice())),
758            (RawColumn::Boolean(a), Value::Boolean(b)) => Some(a.cmp(b)),
759            (RawColumn::Time(a), Value::Time(b)) => Some(a.cmp(b)),
760            (RawColumn::Date(a), Value::Date(b)) => Some(a.cmp(b)),
761            (RawColumn::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
762            (
763                RawColumn::Interval {
764                    months: am,
765                    days: ad,
766                    micros: au,
767                },
768                Value::Interval {
769                    months: bm,
770                    days: bd,
771                    micros: bu,
772                },
773            ) => Some(am.cmp(bm).then(ad.cmp(bd)).then(au.cmp(bu))),
774            _ => None,
775        }
776    }
777
778    pub fn eq_value(&self, other: &Value) -> bool {
779        match (self, other) {
780            (RawColumn::Null, Value::Null) => true,
781            (RawColumn::Integer(a), Value::Integer(b)) => a == b,
782            (RawColumn::Integer(a), Value::Real(b)) => (*a as f64) == *b,
783            (RawColumn::Real(a), Value::Real(b)) => a == b,
784            (RawColumn::Real(a), Value::Integer(b)) => *a == (*b as f64),
785            (RawColumn::Text(a), Value::Text(b)) => *a == b.as_str(),
786            (RawColumn::Blob(a), Value::Blob(b)) => *a == b.as_slice(),
787            (RawColumn::Boolean(a), Value::Boolean(b)) => a == b,
788            (RawColumn::Time(a), Value::Time(b)) => a == b,
789            (RawColumn::Date(a), Value::Date(b)) => a == b,
790            (RawColumn::Timestamp(a), Value::Timestamp(b)) => a == b,
791            (
792                RawColumn::Interval {
793                    months: am,
794                    days: ad,
795                    micros: au,
796                },
797                Value::Interval {
798                    months: bm,
799                    days: bd,
800                    micros: bu,
801                },
802            ) => am == bm && ad == bd && au == bu,
803            _ => false,
804        }
805    }
806
807    pub fn as_f64(&self) -> Option<f64> {
808        match self {
809            RawColumn::Integer(i) => Some(*i as f64),
810            RawColumn::Real(r) => Some(*r),
811            _ => None,
812        }
813    }
814
815    pub fn as_i64(&self) -> Option<i64> {
816        match self {
817            RawColumn::Integer(i) => Some(*i),
818            RawColumn::Time(t) => Some(*t),
819            RawColumn::Date(d) => Some(*d as i64),
820            RawColumn::Timestamp(t) => Some(*t),
821            _ => None,
822        }
823    }
824}
825
826fn decode_value_raw(type_tag: u8, data: &[u8]) -> Result<RawColumn<'_>> {
827    match DataType::from_tag(type_tag) {
828        Some(DataType::Integer) => Ok(RawColumn::Integer(i64::from_le_bytes(
829            data[..8].try_into().unwrap(),
830        ))),
831        Some(DataType::Real) => Ok(RawColumn::Real(f64::from_le_bytes(
832            data[..8].try_into().unwrap(),
833        ))),
834        Some(DataType::Boolean) => Ok(RawColumn::Boolean(data[0] != 0)),
835        Some(DataType::Text) => {
836            let s = std::str::from_utf8(data)
837                .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in column".into()))?;
838            Ok(RawColumn::Text(s))
839        }
840        Some(DataType::Blob) => Ok(RawColumn::Blob(data)),
841        Some(DataType::Time) => Ok(RawColumn::Time(i64::from_le_bytes(
842            data[..8].try_into().unwrap(),
843        ))),
844        Some(DataType::Date) => Ok(RawColumn::Date(i32::from_le_bytes(
845            data[..4].try_into().unwrap(),
846        ))),
847        Some(DataType::Timestamp) => Ok(RawColumn::Timestamp(i64::from_le_bytes(
848            data[..8].try_into().unwrap(),
849        ))),
850        Some(DataType::Interval) => {
851            if data.len() < 16 {
852                return Err(SqlError::InvalidValue("truncated interval".into()));
853            }
854            let months = i32::from_le_bytes(data[0..4].try_into().unwrap());
855            let days = i32::from_le_bytes(data[4..8].try_into().unwrap());
856            let micros = i64::from_le_bytes(data[8..16].try_into().unwrap());
857            Ok(RawColumn::Interval {
858                months,
859                days,
860                micros,
861            })
862        }
863        _ => Err(SqlError::InvalidValue(format!(
864            "unknown column type tag: {type_tag}"
865        ))),
866    }
867}
868
869/// Patch column in-place if value size unchanged. Ok(false) = size mismatch, use `patch_row_column`.
870pub fn patch_column_in_place(data: &mut [u8], target: usize, new_val: &Value) -> Result<bool> {
871    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
872    if target >= col_count || new_val.is_null() {
873        return Ok(false);
874    }
875    let was_null = bitmap[target / 8] & (1 << (target % 8)) != 0;
876    if was_null {
877        return Ok(false);
878    }
879    for col in 0..target {
880        let is_null = bitmap[col / 8] & (1 << (col % 8)) != 0;
881        if !is_null {
882            pos = skip_cell(data, pos, version)?;
883        }
884    }
885    let type_tag = data[pos];
886    let (old_data_len, val_start) = match version {
887        RowVersion::V2 => match fixed_width_size(type_tag) {
888            Some(n) => (n, pos + 1),
889            None => {
890                if pos + 5 > data.len() {
891                    return Err(SqlError::InvalidValue("truncated column data".into()));
892                }
893                let len = u32::from_le_bytes(data[pos + 1..pos + 5].try_into().unwrap()) as usize;
894                (len, pos + 5)
895            }
896        },
897        RowVersion::V1 => {
898            if pos + 5 > data.len() {
899                return Err(SqlError::InvalidValue("truncated column data".into()));
900            }
901            let len = u32::from_le_bytes(data[pos + 1..pos + 5].try_into().unwrap()) as usize;
902            (len, pos + 5)
903        }
904    };
905    let new_data_len = match new_val {
906        Value::Integer(_) | Value::Real(_) | Value::Time(_) | Value::Timestamp(_) => 8,
907        Value::Date(_) => 4,
908        Value::Interval { .. } => 16,
909        Value::Boolean(_) => 1,
910        Value::Text(s) => s.len(),
911        Value::Blob(b) => b.len(),
912        Value::Null => return Ok(false),
913    };
914    if new_data_len != old_data_len {
915        return Ok(false);
916    }
917    data[pos] = new_val.data_type().type_tag();
918    match new_val {
919        Value::Integer(v) => data[val_start..val_start + 8].copy_from_slice(&v.to_le_bytes()),
920        Value::Real(r) => data[val_start..val_start + 8].copy_from_slice(&r.to_le_bytes()),
921        Value::Boolean(b) => data[val_start] = if *b { 1 } else { 0 },
922        Value::Text(s) => data[val_start..val_start + s.len()].copy_from_slice(s.as_bytes()),
923        Value::Blob(d) => data[val_start..val_start + d.len()].copy_from_slice(d),
924        Value::Time(t) => data[val_start..val_start + 8].copy_from_slice(&t.to_le_bytes()),
925        Value::Date(d) => data[val_start..val_start + 4].copy_from_slice(&d.to_le_bytes()),
926        Value::Timestamp(t) => data[val_start..val_start + 8].copy_from_slice(&t.to_le_bytes()),
927        Value::Interval {
928            months,
929            days,
930            micros,
931        } => {
932            data[val_start..val_start + 4].copy_from_slice(&months.to_le_bytes());
933            data[val_start + 4..val_start + 8].copy_from_slice(&days.to_le_bytes());
934            data[val_start + 8..val_start + 16].copy_from_slice(&micros.to_le_bytes());
935        }
936        Value::Null => unreachable!(),
937    }
938    Ok(true)
939}
940
941/// Patch a single column in encoded row, writing result into `out`. Copies others unchanged.
942pub fn patch_row_column(
943    data: &[u8],
944    target: usize,
945    new_val: &Value,
946    out: &mut Vec<u8>,
947) -> Result<()> {
948    let (version, col_count, bitmap, header_end) = parse_row_header(data)?;
949
950    let new_col_count = if target >= col_count {
951        target + 1
952    } else {
953        col_count
954    };
955    let new_bitmap_bytes = new_col_count.div_ceil(8);
956    let bitmap_bytes = col_count.div_ceil(8);
957    out.clear();
958
959    let header = (new_col_count as u16) | V2_FLAG;
960    out.extend_from_slice(&header.to_le_bytes());
961    let bitmap_start = out.len();
962    out.extend_from_slice(&data[2..2 + bitmap_bytes]);
963    for _ in bitmap_bytes..new_bitmap_bytes {
964        out.push(0xFF);
965    }
966    if new_val.is_null() {
967        out[bitmap_start + target / 8] |= 1 << (target % 8);
968    } else {
969        out[bitmap_start + target / 8] &= !(1 << (target % 8));
970    }
971
972    let mut pos = header_end;
973    for col in 0..new_col_count {
974        let was_null = if col < col_count {
975            bitmap[col / 8] & (1 << (col % 8)) != 0
976        } else {
977            true
978        };
979
980        if col == target {
981            if !was_null {
982                pos = skip_cell(data, pos, version)?;
983            }
984            if !new_val.is_null() {
985                encode_cell_v2(new_val, out);
986            }
987        } else if !was_null {
988            pos = copy_cell_to_v2(data, pos, version, out)?;
989        }
990    }
991    Ok(())
992}
993
994pub fn decode_column_raw(data: &[u8], target: usize) -> Result<RawColumn<'_>> {
995    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
996    if target >= col_count {
997        return Ok(RawColumn::Null);
998    }
999
1000    for col in 0..=target {
1001        let is_null = bitmap[col / 8] & (1 << (col % 8)) != 0;
1002
1003        if col == target {
1004            if is_null {
1005                return Ok(RawColumn::Null);
1006            }
1007            let (type_tag, body, _) = read_cell(data, pos, version)?;
1008            return decode_value_raw(type_tag, body);
1009        } else if !is_null {
1010            pos = skip_cell(data, pos, version)?;
1011        }
1012    }
1013
1014    unreachable!()
1015}
1016
1017/// Like `decode_column_raw` but also returns the byte offset (usize::MAX if NULL).
1018pub fn decode_column_with_offset(data: &[u8], target: usize) -> Result<(RawColumn<'_>, usize)> {
1019    let (version, col_count, bitmap, mut pos) = parse_row_header(data)?;
1020    if target >= col_count {
1021        return Ok((RawColumn::Null, usize::MAX));
1022    }
1023
1024    for col in 0..=target {
1025        let is_null = bitmap[col / 8] & (1 << (col % 8)) != 0;
1026
1027        if col == target {
1028            if is_null {
1029                return Ok((RawColumn::Null, usize::MAX));
1030            }
1031            let tag_offset = pos;
1032            let (type_tag, body, _) = read_cell(data, pos, version)?;
1033            let raw = decode_value_raw(type_tag, body)?;
1034            return Ok((raw, tag_offset));
1035        } else if !is_null {
1036            pos = skip_cell(data, pos, version)?;
1037        }
1038    }
1039
1040    unreachable!()
1041}
1042
1043/// Patch at a known byte offset. Ok(false) if size mismatch or NULL offset.
1044pub fn patch_at_offset(data: &mut [u8], offset: usize, new_val: &Value) -> Result<bool> {
1045    if offset == usize::MAX || new_val.is_null() {
1046        return Ok(false);
1047    }
1048    if data.len() < 2 || offset >= data.len() {
1049        return Err(SqlError::InvalidValue("truncated column data".into()));
1050    }
1051    let version = if u16::from_le_bytes([data[0], data[1]]) & V2_FLAG != 0 {
1052        RowVersion::V2
1053    } else {
1054        RowVersion::V1
1055    };
1056    let type_tag = data[offset];
1057    let (old_data_len, val_start) = match version {
1058        RowVersion::V2 => match fixed_width_size(type_tag) {
1059            Some(n) => (n, offset + 1),
1060            None => {
1061                if offset + 5 > data.len() {
1062                    return Err(SqlError::InvalidValue("truncated column data".into()));
1063                }
1064                let len =
1065                    u32::from_le_bytes(data[offset + 1..offset + 5].try_into().unwrap()) as usize;
1066                (len, offset + 5)
1067            }
1068        },
1069        RowVersion::V1 => {
1070            if offset + 5 > data.len() {
1071                return Err(SqlError::InvalidValue("truncated column data".into()));
1072            }
1073            let len = u32::from_le_bytes(data[offset + 1..offset + 5].try_into().unwrap()) as usize;
1074            (len, offset + 5)
1075        }
1076    };
1077    let new_data_len = match new_val {
1078        Value::Integer(_) | Value::Real(_) | Value::Time(_) | Value::Timestamp(_) => 8,
1079        Value::Date(_) => 4,
1080        Value::Interval { .. } => 16,
1081        Value::Boolean(_) => 1,
1082        Value::Text(s) => s.len(),
1083        Value::Blob(b) => b.len(),
1084        Value::Null => return Ok(false),
1085    };
1086    if new_data_len != old_data_len {
1087        return Ok(false);
1088    }
1089    data[offset] = new_val.data_type().type_tag();
1090    match new_val {
1091        Value::Integer(v) => data[val_start..val_start + 8].copy_from_slice(&v.to_le_bytes()),
1092        Value::Real(r) => data[val_start..val_start + 8].copy_from_slice(&r.to_le_bytes()),
1093        Value::Boolean(b) => data[val_start] = if *b { 1 } else { 0 },
1094        Value::Text(s) => data[val_start..val_start + s.len()].copy_from_slice(s.as_bytes()),
1095        Value::Blob(d) => data[val_start..val_start + d.len()].copy_from_slice(d),
1096        Value::Time(t) => data[val_start..val_start + 8].copy_from_slice(&t.to_le_bytes()),
1097        Value::Date(d) => data[val_start..val_start + 4].copy_from_slice(&d.to_le_bytes()),
1098        Value::Timestamp(t) => data[val_start..val_start + 8].copy_from_slice(&t.to_le_bytes()),
1099        Value::Interval {
1100            months,
1101            days,
1102            micros,
1103        } => {
1104            data[val_start..val_start + 4].copy_from_slice(&months.to_le_bytes());
1105            data[val_start + 4..val_start + 8].copy_from_slice(&days.to_le_bytes());
1106            data[val_start + 8..val_start + 16].copy_from_slice(&micros.to_le_bytes());
1107        }
1108        Value::Null => unreachable!(),
1109    }
1110    Ok(true)
1111}
1112
1113pub fn decode_pk_integer(key: &[u8]) -> Result<i64> {
1114    if key.is_empty() || key[0] != TAG_INTEGER {
1115        return Err(SqlError::InvalidValue("not an integer key".into()));
1116    }
1117    let (val, _) = decode_integer(&key[1..])?;
1118    match val {
1119        Value::Integer(i) => Ok(i),
1120        _ => unreachable!(),
1121    }
1122}
1123
1124#[cfg(test)]
1125#[path = "encoding_tests.rs"]
1126mod tests;