1use crate::error::{Result, SqlError};
5use crate::types::{DataType, Value};
6
7const TAG_NULL: u8 = 0x00;
11const TAG_BLOB: u8 = 0x01;
12const TAG_TEXT: u8 = 0x02;
13const TAG_BOOLEAN: u8 = 0x03;
14const TAG_INTEGER: u8 = 0x04;
15const TAG_REAL: u8 = 0x05;
16
17pub fn encode_key_value(value: &Value) -> Vec<u8> {
19 match value {
20 Value::Null => vec![TAG_NULL],
21 Value::Boolean(b) => vec![TAG_BOOLEAN, if *b { 0x01 } else { 0x00 }],
22 Value::Integer(i) => encode_integer(*i),
23 Value::Real(r) => encode_real(*r),
24 Value::Text(s) => encode_bytes(TAG_TEXT, s.as_bytes()),
25 Value::Blob(b) => encode_bytes(TAG_BLOB, b),
26 }
27}
28
29pub fn encode_composite_key(values: &[Value]) -> Vec<u8> {
31 let mut buf = Vec::new();
32 for v in values {
33 buf.extend_from_slice(&encode_key_value(v));
34 }
35 buf
36}
37
38pub fn decode_key_value(data: &[u8]) -> Result<(Value, usize)> {
40 if data.is_empty() {
41 return Err(SqlError::InvalidValue("empty key data".into()));
42 }
43 match data[0] {
44 TAG_NULL => Ok((Value::Null, 1)),
45 TAG_BOOLEAN => {
46 if data.len() < 2 {
47 return Err(SqlError::InvalidValue("truncated boolean".into()));
48 }
49 Ok((Value::Boolean(data[1] != 0), 2))
50 }
51 TAG_INTEGER => decode_integer(&data[1..]).map(|(v, n)| (v, n + 1)),
52 TAG_REAL => decode_real(&data[1..]).map(|(v, n)| (v, n + 1)),
53 TAG_TEXT => {
54 let (bytes, n) = decode_null_escaped(&data[1..])?;
55 let s = String::from_utf8(bytes)
56 .map_err(|_| SqlError::InvalidValue("invalid UTF-8 in key".into()))?;
57 Ok((Value::Text(s), n + 1))
58 }
59 TAG_BLOB => {
60 let (bytes, n) = decode_null_escaped(&data[1..])?;
61 Ok((Value::Blob(bytes), n + 1))
62 }
63 tag => Err(SqlError::InvalidValue(format!("unknown key tag: {tag:#x}"))),
64 }
65}
66
67pub fn decode_composite_key(data: &[u8], count: usize) -> Result<Vec<Value>> {
69 let mut values = Vec::with_capacity(count);
70 let mut pos = 0;
71 for _ in 0..count {
72 let (v, n) = decode_key_value(&data[pos..])?;
73 values.push(v);
74 pos += n;
75 }
76 Ok(values)
77}
78
79fn encode_integer(val: i64) -> Vec<u8> {
82 let mut buf = vec![TAG_INTEGER];
83 if val == 0 {
84 buf.push(0x80);
85 return buf;
86 }
87 if val > 0 {
88 let bytes = val.to_be_bytes();
89 let start = bytes.iter().position(|&b| b != 0).unwrap();
91 let byte_count = (8 - start) as u8;
92 buf.push(0x80 + byte_count);
93 buf.extend_from_slice(&bytes[start..]);
94 } else {
95 let abs_val = if val == i64::MIN {
97 u64::MAX / 2 + 1
99 } else {
100 (-val) as u64
101 };
102 let bytes = abs_val.to_be_bytes();
103 let start = bytes.iter().position(|&b| b != 0).unwrap();
104 let byte_count = (8 - start) as u8;
105 buf.push(0x80 - byte_count);
106 for &b in &bytes[start..] {
108 buf.push(!b);
109 }
110 }
111 buf
112}
113
114fn decode_integer(data: &[u8]) -> Result<(Value, usize)> {
115 if data.is_empty() {
116 return Err(SqlError::InvalidValue("truncated integer".into()));
117 }
118 let marker = data[0];
119 if marker == 0x80 {
120 return Ok((Value::Integer(0), 1));
121 }
122 if marker > 0x80 {
123 let byte_count = (marker - 0x80) as usize;
125 if data.len() < 1 + byte_count {
126 return Err(SqlError::InvalidValue("truncated positive integer".into()));
127 }
128 let mut bytes = [0u8; 8];
129 bytes[8 - byte_count..].copy_from_slice(&data[1..1 + byte_count]);
130 let val = i64::from_be_bytes(bytes);
131 Ok((Value::Integer(val), 1 + byte_count))
132 } else {
133 let byte_count = (0x80 - marker) as usize;
135 if data.len() < 1 + byte_count {
136 return Err(SqlError::InvalidValue("truncated negative integer".into()));
137 }
138 let mut bytes = [0u8; 8];
139 for i in 0..byte_count {
140 bytes[8 - byte_count + i] = !data[1 + i];
141 }
142 let abs_val = u64::from_be_bytes(bytes);
143 let val = (-(abs_val as i128)) as i64;
145 Ok((Value::Integer(val), 1 + byte_count))
146 }
147}
148
149fn encode_real(val: f64) -> Vec<u8> {
152 let mut buf = vec![TAG_REAL];
153 let bits = val.to_bits();
154 let encoded = if val.is_sign_negative() {
155 !bits
157 } else {
158 bits ^ (1u64 << 63)
160 };
161 buf.extend_from_slice(&encoded.to_be_bytes());
162 buf
163}
164
165fn decode_real(data: &[u8]) -> Result<(Value, usize)> {
166 if data.len() < 8 {
167 return Err(SqlError::InvalidValue("truncated real".into()));
168 }
169 let encoded = u64::from_be_bytes(data[..8].try_into().unwrap());
170 let bits = if encoded & (1u64 << 63) != 0 {
171 encoded ^ (1u64 << 63)
173 } else {
174 !encoded
176 };
177 let val = f64::from_bits(bits);
178 Ok((Value::Real(val), 8))
179}
180
181fn encode_bytes(tag: u8, data: &[u8]) -> Vec<u8> {
185 let mut buf = Vec::with_capacity(data.len() + 2);
186 buf.push(tag);
187 for &b in data {
188 if b == 0x00 {
189 buf.push(0x00);
190 buf.push(0xFF);
191 } else {
192 buf.push(b);
193 }
194 }
195 buf.push(0x00); buf
197}
198
199fn decode_null_escaped(data: &[u8]) -> Result<(Vec<u8>, usize)> {
201 let mut result = Vec::new();
202 let mut i = 0;
203 while i < data.len() {
204 if data[i] == 0x00 {
205 if i + 1 < data.len() && data[i + 1] == 0xFF {
206 result.push(0x00);
207 i += 2;
208 } else {
209 return Ok((result, i + 1)); }
211 } else {
212 result.push(data[i]);
213 i += 1;
214 }
215 }
216 Err(SqlError::InvalidValue("unterminated null-escaped string".into()))
217}
218
219pub fn encode_row(values: &[Value]) -> Vec<u8> {
224 let col_count = values.len();
225 let bitmap_bytes = (col_count + 7) / 8;
226 let mut buf = Vec::new();
227
228 buf.extend_from_slice(&(col_count as u16).to_le_bytes());
230
231 let mut bitmap = vec![0u8; bitmap_bytes];
233 for (i, v) in values.iter().enumerate() {
234 if v.is_null() {
235 bitmap[i / 8] |= 1 << (i % 8);
236 }
237 }
238 buf.extend_from_slice(&bitmap);
239
240 for v in values {
242 if v.is_null() {
243 continue;
244 }
245 match v {
246 Value::Integer(i) => {
247 buf.push(DataType::Integer.type_tag());
248 buf.extend_from_slice(&8u32.to_le_bytes());
249 buf.extend_from_slice(&i.to_le_bytes());
250 }
251 Value::Real(r) => {
252 buf.push(DataType::Real.type_tag());
253 buf.extend_from_slice(&8u32.to_le_bytes());
254 buf.extend_from_slice(&r.to_le_bytes());
255 }
256 Value::Boolean(b) => {
257 buf.push(DataType::Boolean.type_tag());
258 buf.extend_from_slice(&1u32.to_le_bytes());
259 buf.push(if *b { 1 } else { 0 });
260 }
261 Value::Text(s) => {
262 let bytes = s.as_bytes();
263 buf.push(DataType::Text.type_tag());
264 buf.extend_from_slice(&(bytes.len() as u32).to_le_bytes());
265 buf.extend_from_slice(bytes);
266 }
267 Value::Blob(data) => {
268 buf.push(DataType::Blob.type_tag());
269 buf.extend_from_slice(&(data.len() as u32).to_le_bytes());
270 buf.extend_from_slice(data);
271 }
272 Value::Null => unreachable!(),
273 }
274 }
275
276 buf
277}
278
279pub fn decode_row(data: &[u8]) -> Result<Vec<Value>> {
281 if data.len() < 2 {
282 return Err(SqlError::InvalidValue("row data too short".into()));
283 }
284 let col_count = u16::from_le_bytes([data[0], data[1]]) as usize;
285 let bitmap_bytes = (col_count + 7) / 8;
286 let mut pos = 2;
287
288 if data.len() < pos + bitmap_bytes {
289 return Err(SqlError::InvalidValue("truncated null bitmap".into()));
290 }
291 let bitmap = &data[pos..pos + bitmap_bytes];
292 pos += bitmap_bytes;
293
294 let mut values = Vec::with_capacity(col_count);
295 for i in 0..col_count {
296 let is_null = bitmap[i / 8] & (1 << (i % 8)) != 0;
297 if is_null {
298 values.push(Value::Null);
299 continue;
300 }
301
302 if pos + 5 > data.len() {
303 return Err(SqlError::InvalidValue("truncated column data".into()));
304 }
305 let type_tag = data[pos];
306 pos += 1;
307 let data_len = u32::from_le_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]) as usize;
308 pos += 4;
309
310 if pos + data_len > data.len() {
311 return Err(SqlError::InvalidValue("truncated column value".into()));
312 }
313
314 let value = match DataType::from_tag(type_tag) {
315 Some(DataType::Integer) => {
316 let i = i64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
317 Value::Integer(i)
318 }
319 Some(DataType::Real) => {
320 let r = f64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
321 Value::Real(r)
322 }
323 Some(DataType::Boolean) => Value::Boolean(data[pos] != 0),
324 Some(DataType::Text) => {
325 let s = String::from_utf8_lossy(&data[pos..pos + data_len]).into_owned();
326 Value::Text(s)
327 }
328 Some(DataType::Blob) => {
329 Value::Blob(data[pos..pos + data_len].to_vec())
330 }
331 _ => return Err(SqlError::InvalidValue(format!("unknown column type tag: {type_tag}"))),
332 };
333 pos += data_len;
334 values.push(value);
335 }
336
337 Ok(values)
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
347 fn key_null() {
348 let encoded = encode_key_value(&Value::Null);
349 let (decoded, n) = decode_key_value(&encoded).unwrap();
350 assert_eq!(n, 1);
351 assert_eq!(decoded, Value::Null);
352 }
353
354 #[test]
355 fn key_boolean() {
356 let f_enc = encode_key_value(&Value::Boolean(false));
357 let t_enc = encode_key_value(&Value::Boolean(true));
358 assert!(f_enc < t_enc);
359
360 let (f_dec, _) = decode_key_value(&f_enc).unwrap();
361 let (t_dec, _) = decode_key_value(&t_enc).unwrap();
362 assert_eq!(f_dec, Value::Boolean(false));
363 assert_eq!(t_dec, Value::Boolean(true));
364 }
365
366 #[test]
367 fn key_integer_roundtrip() {
368 let test_values = [
369 i64::MIN, -1_000_000, -256, -1, 0, 1, 127, 128, 255, 256,
370 65535, 1_000_000, i64::MAX,
371 ];
372 for &v in &test_values {
373 let encoded = encode_key_value(&Value::Integer(v));
374 let (decoded, _) = decode_key_value(&encoded).unwrap();
375 assert_eq!(decoded, Value::Integer(v), "roundtrip failed for {v}");
376 }
377 }
378
379 #[test]
380 fn key_integer_sort_order() {
381 let values: Vec<i64> = vec![
382 i64::MIN, -1_000_000, -1, 0, 1, 1_000_000, i64::MAX,
383 ];
384 let encoded: Vec<Vec<u8>> = values.iter()
385 .map(|&v| encode_key_value(&Value::Integer(v)))
386 .collect();
387
388 for i in 0..encoded.len() - 1 {
389 assert!(
390 encoded[i] < encoded[i + 1],
391 "sort order broken: {} vs {}",
392 values[i], values[i + 1]
393 );
394 }
395 }
396
397 #[test]
398 fn key_real_roundtrip() {
399 let test_values = [
400 f64::NEG_INFINITY, -1e100, -1.0, -f64::MIN_POSITIVE, -0.0,
401 0.0, f64::MIN_POSITIVE, 0.5, 1.0, 1e100, f64::INFINITY,
402 ];
403 for &v in &test_values {
404 let encoded = encode_key_value(&Value::Real(v));
405 let (decoded, _) = decode_key_value(&encoded).unwrap();
406 match decoded {
407 Value::Real(r) => {
408 assert!(
409 v.to_bits() == r.to_bits(),
410 "roundtrip failed for {v}: got {r}"
411 );
412 }
413 _ => panic!("expected Real"),
414 }
415 }
416 }
417
418 #[test]
419 fn key_real_sort_order() {
420 let values = vec![
421 f64::NEG_INFINITY, -100.0, -1.0, -0.0,
422 0.0, 1.0, 100.0, f64::INFINITY,
423 ];
424 let encoded: Vec<Vec<u8>> = values.iter()
425 .map(|&v| encode_key_value(&Value::Real(v)))
426 .collect();
427
428 for i in 0..encoded.len() - 1 {
429 assert!(
430 encoded[i] <= encoded[i + 1],
431 "sort order broken: {} vs {}",
432 values[i], values[i + 1]
433 );
434 }
435 }
436
437 #[test]
438 fn key_text_roundtrip() {
439 let test_values = ["", "hello", "world", "hello\0world", "\0\0\0"];
440 for &v in &test_values {
441 let encoded = encode_key_value(&Value::Text(v.into()));
442 let (decoded, _) = decode_key_value(&encoded).unwrap();
443 assert_eq!(decoded, Value::Text(v.into()), "roundtrip failed for {v:?}");
444 }
445 }
446
447 #[test]
448 fn key_text_sort_order() {
449 let values = vec!["", "a", "ab", "b", "ba", "z"];
450 let encoded: Vec<Vec<u8>> = values.iter()
451 .map(|&v| encode_key_value(&Value::Text(v.into())))
452 .collect();
453
454 for i in 0..encoded.len() - 1 {
455 assert!(
456 encoded[i] < encoded[i + 1],
457 "sort order broken: {:?} vs {:?}",
458 values[i], values[i + 1]
459 );
460 }
461 }
462
463 #[test]
464 fn key_blob_roundtrip() {
465 let test_values: Vec<Vec<u8>> = vec![
466 vec![], vec![0x00], vec![0x00, 0xFF], vec![0xFF, 0x00],
467 vec![0x00, 0x00, 0x00],
468 ];
469 for v in &test_values {
470 let encoded = encode_key_value(&Value::Blob(v.clone()));
471 let (decoded, _) = decode_key_value(&encoded).unwrap();
472 assert_eq!(decoded, Value::Blob(v.clone()));
473 }
474 }
475
476 #[test]
477 fn key_composite_roundtrip() {
478 let values = vec![
479 Value::Integer(42),
480 Value::Text("hello".into()),
481 Value::Boolean(true),
482 ];
483 let encoded = encode_composite_key(&values);
484 let decoded = decode_composite_key(&encoded, 3).unwrap();
485 assert_eq!(decoded[0], Value::Integer(42));
486 assert_eq!(decoded[1], Value::Text("hello".into()));
487 assert_eq!(decoded[2], Value::Boolean(true));
488 }
489
490 #[test]
491 fn key_composite_sort_order() {
492 let k1 = encode_composite_key(&[Value::Integer(1), Value::Text("b".into())]);
494 let k2 = encode_composite_key(&[Value::Integer(1), Value::Text("c".into())]);
495 let k3 = encode_composite_key(&[Value::Integer(2), Value::Text("a".into())]);
496 assert!(k1 < k2);
497 assert!(k2 < k3);
498 }
499
500 #[test]
501 fn key_cross_type_ordering() {
502 let null = encode_key_value(&Value::Null);
503 let bool_val = encode_key_value(&Value::Boolean(false));
504 let int = encode_key_value(&Value::Integer(0));
505 let text = encode_key_value(&Value::Text("".into()));
506 let blob = encode_key_value(&Value::Blob(vec![]));
507
508 assert!(null < blob);
509 assert!(blob < text);
510 assert!(text < bool_val);
511 assert!(bool_val < int);
512 }
513
514 #[test]
517 fn row_roundtrip_simple() {
518 let values = vec![
519 Value::Integer(42),
520 Value::Text("hello".into()),
521 Value::Boolean(true),
522 ];
523 let encoded = encode_row(&values);
524 let decoded = decode_row(&encoded).unwrap();
525 assert_eq!(decoded.len(), 3);
526 assert_eq!(decoded[0], Value::Integer(42));
527 assert_eq!(decoded[1], Value::Text("hello".into()));
528 assert_eq!(decoded[2], Value::Boolean(true));
529 }
530
531 #[test]
532 fn row_roundtrip_with_nulls() {
533 let values = vec![
534 Value::Integer(1),
535 Value::Null,
536 Value::Text("test".into()),
537 Value::Null,
538 ];
539 let encoded = encode_row(&values);
540 let decoded = decode_row(&encoded).unwrap();
541 assert_eq!(decoded.len(), 4);
542 assert_eq!(decoded[0], Value::Integer(1));
543 assert!(decoded[1].is_null());
544 assert_eq!(decoded[2], Value::Text("test".into()));
545 assert!(decoded[3].is_null());
546 }
547
548 #[test]
549 fn row_roundtrip_empty() {
550 let values: Vec<Value> = vec![];
551 let encoded = encode_row(&values);
552 let decoded = decode_row(&encoded).unwrap();
553 assert!(decoded.is_empty());
554 }
555
556 #[test]
557 fn row_roundtrip_all_types() {
558 let values = vec![
559 Value::Integer(-100),
560 Value::Real(3.14),
561 Value::Text("hello world".into()),
562 Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]),
563 Value::Boolean(false),
564 Value::Null,
565 ];
566 let encoded = encode_row(&values);
567 let decoded = decode_row(&encoded).unwrap();
568 assert_eq!(decoded.len(), 6);
569 assert_eq!(decoded[0], Value::Integer(-100));
570 assert_eq!(decoded[1], Value::Real(3.14));
571 assert_eq!(decoded[2], Value::Text("hello world".into()));
572 assert_eq!(decoded[3], Value::Blob(vec![0xDE, 0xAD, 0xBE, 0xEF]));
573 assert_eq!(decoded[4], Value::Boolean(false));
574 assert!(decoded[5].is_null());
575 }
576
577 #[test]
578 fn null_escaped_with_embedded_nulls() {
579 let text = "before\0after";
580 let encoded = encode_key_value(&Value::Text(text.into()));
581 let (decoded, _) = decode_key_value(&encoded).unwrap();
582 assert_eq!(decoded, Value::Text(text.into()));
583 }
584
585 #[test]
586 fn key_integer_edge_cases() {
587 for v in [i64::MIN, i64::MIN + 1, -1, 0, 1, i64::MAX - 1, i64::MAX] {
588 let encoded = encode_key_value(&Value::Integer(v));
589 let (decoded, n) = decode_key_value(&encoded).unwrap();
590 assert_eq!(n, encoded.len());
591 assert_eq!(decoded, Value::Integer(v), "edge case failed for {v}");
592 }
593 }
594}