Skip to main content

citadel_sql/
types.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4
5pub use compact_str::CompactString;
6
7/// SQL data types.
8#[derive(Debug, Clone, Copy, PartialEq, Eq)]
9pub enum DataType {
10    Null,
11    Integer,
12    Real,
13    Text,
14    Blob,
15    Boolean,
16}
17
18impl DataType {
19    pub fn type_tag(self) -> u8 {
20        match self {
21            DataType::Null => 0,
22            DataType::Blob => 1,
23            DataType::Text => 2,
24            DataType::Boolean => 3,
25            DataType::Integer => 4,
26            DataType::Real => 5,
27        }
28    }
29
30    pub fn from_tag(tag: u8) -> Option<Self> {
31        match tag {
32            0 => Some(DataType::Null),
33            1 => Some(DataType::Blob),
34            2 => Some(DataType::Text),
35            3 => Some(DataType::Boolean),
36            4 => Some(DataType::Integer),
37            5 => Some(DataType::Real),
38            _ => None,
39        }
40    }
41}
42
43impl fmt::Display for DataType {
44    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45        match self {
46            DataType::Null => write!(f, "NULL"),
47            DataType::Integer => write!(f, "INTEGER"),
48            DataType::Real => write!(f, "REAL"),
49            DataType::Text => write!(f, "TEXT"),
50            DataType::Blob => write!(f, "BLOB"),
51            DataType::Boolean => write!(f, "BOOLEAN"),
52        }
53    }
54}
55
56/// SQL value.
57#[derive(Debug, Clone, Default)]
58pub enum Value {
59    #[default]
60    Null,
61    Integer(i64),
62    Real(f64),
63    Text(CompactString),
64    Blob(Vec<u8>),
65    Boolean(bool),
66}
67
68impl Value {
69    pub fn data_type(&self) -> DataType {
70        match self {
71            Value::Null => DataType::Null,
72            Value::Integer(_) => DataType::Integer,
73            Value::Real(_) => DataType::Real,
74            Value::Text(_) => DataType::Text,
75            Value::Blob(_) => DataType::Blob,
76            Value::Boolean(_) => DataType::Boolean,
77        }
78    }
79
80    pub fn is_null(&self) -> bool {
81        matches!(self, Value::Null)
82    }
83
84    /// Attempt to coerce this value to the target type.
85    pub fn coerce_to(&self, target: DataType) -> Option<Value> {
86        match (self, target) {
87            (_, DataType::Null) => Some(Value::Null),
88            (Value::Null, _) => Some(Value::Null),
89            (Value::Integer(i), DataType::Integer) => Some(Value::Integer(*i)),
90            (Value::Integer(i), DataType::Real) => Some(Value::Real(*i as f64)),
91            (Value::Real(r), DataType::Real) => Some(Value::Real(*r)),
92            (Value::Real(r), DataType::Integer) => Some(Value::Integer(*r as i64)),
93            (Value::Text(s), DataType::Text) => Some(Value::Text(s.clone())),
94            (Value::Blob(b), DataType::Blob) => Some(Value::Blob(b.clone())),
95            (Value::Boolean(b), DataType::Boolean) => Some(Value::Boolean(*b)),
96            (Value::Boolean(b), DataType::Integer) => Some(Value::Integer(if *b { 1 } else { 0 })),
97            (Value::Integer(i), DataType::Boolean) => Some(Value::Boolean(*i != 0)),
98            _ => None,
99        }
100    }
101
102    pub fn coerce_into(self, target: DataType) -> Option<Value> {
103        if self.is_null() || target == DataType::Null {
104            return Some(Value::Null);
105        }
106        if self.data_type() == target {
107            return Some(self);
108        }
109        match (self, target) {
110            (Value::Integer(i), DataType::Real) => Some(Value::Real(i as f64)),
111            (Value::Real(r), DataType::Integer) => Some(Value::Integer(r as i64)),
112            (Value::Boolean(b), DataType::Integer) => Some(Value::Integer(if b { 1 } else { 0 })),
113            (Value::Integer(i), DataType::Boolean) => Some(Value::Boolean(i != 0)),
114            _ => None,
115        }
116    }
117
118    /// Numeric ordering for Integer and Real values (promotes to f64 for mixed).
119    fn numeric_cmp(&self, other: &Value) -> Option<Ordering> {
120        match (self, other) {
121            (Value::Integer(a), Value::Integer(b)) => Some(a.cmp(b)),
122            (Value::Real(a), Value::Real(b)) => a.partial_cmp(b),
123            (Value::Integer(a), Value::Real(b)) => (*a as f64).partial_cmp(b),
124            (Value::Real(a), Value::Integer(b)) => a.partial_cmp(&(*b as f64)),
125            _ => None,
126        }
127    }
128}
129
130impl PartialEq for Value {
131    fn eq(&self, other: &Self) -> bool {
132        match (self, other) {
133            (Value::Null, Value::Null) => true,
134            (Value::Integer(a), Value::Integer(b)) => a == b,
135            (Value::Real(a), Value::Real(b)) => a == b,
136            (Value::Integer(a), Value::Real(b)) => (*a as f64) == *b,
137            (Value::Real(a), Value::Integer(b)) => *a == (*b as f64),
138            (Value::Text(a), Value::Text(b)) => a == b,
139            (Value::Blob(a), Value::Blob(b)) => a == b,
140            (Value::Boolean(a), Value::Boolean(b)) => a == b,
141            _ => false,
142        }
143    }
144}
145
146impl Eq for Value {}
147
148impl Hash for Value {
149    fn hash<H: Hasher>(&self, state: &mut H) {
150        match self {
151            Value::Null => 0u8.hash(state),
152            Value::Integer(i) => {
153                // Hash via f64 bits so Integer(n) and Real(n.0) produce the same hash,
154                // matching the cross-type PartialEq contract.
155                1u8.hash(state);
156                (*i as f64).to_bits().hash(state);
157            }
158            Value::Real(r) => {
159                1u8.hash(state);
160                r.to_bits().hash(state);
161            }
162            Value::Text(s) => {
163                2u8.hash(state);
164                s.hash(state);
165            }
166            Value::Blob(b) => {
167                3u8.hash(state);
168                b.hash(state);
169            }
170            Value::Boolean(b) => {
171                4u8.hash(state);
172                b.hash(state);
173            }
174        }
175    }
176}
177
178impl PartialOrd for Value {
179    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
180        Some(self.cmp(other))
181    }
182}
183
184impl Ord for Value {
185    fn cmp(&self, other: &Self) -> Ordering {
186        // NULL < BOOLEAN < INTEGER/REAL (numeric) < TEXT < BLOB
187        match (self, other) {
188            (Value::Null, Value::Null) => Ordering::Equal,
189            (Value::Null, _) => Ordering::Less,
190            (_, Value::Null) => Ordering::Greater,
191
192            (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
193            (Value::Boolean(_), _) => Ordering::Less,
194            (_, Value::Boolean(_)) => Ordering::Greater,
195
196            // Numeric: Integer and Real are comparable
197            (Value::Integer(_) | Value::Real(_), Value::Integer(_) | Value::Real(_)) => {
198                self.numeric_cmp(other).unwrap_or(Ordering::Equal)
199            }
200            (Value::Integer(_) | Value::Real(_), _) => Ordering::Less,
201            (_, Value::Integer(_) | Value::Real(_)) => Ordering::Greater,
202
203            (Value::Text(a), Value::Text(b)) => a.cmp(b),
204            (Value::Text(_), _) => Ordering::Less,
205            (_, Value::Text(_)) => Ordering::Greater,
206
207            (Value::Blob(a), Value::Blob(b)) => a.cmp(b),
208        }
209    }
210}
211
212impl fmt::Display for Value {
213    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
214        match self {
215            Value::Null => write!(f, "NULL"),
216            Value::Integer(i) => write!(f, "{i}"),
217            Value::Real(r) => {
218                if r.fract() == 0.0 && r.is_finite() {
219                    write!(f, "{r:.1}")
220                } else {
221                    write!(f, "{r}")
222                }
223            }
224            Value::Text(s) => write!(f, "{s}"),
225            Value::Blob(b) => write!(f, "X'{}'", hex_encode(b)),
226            Value::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
227        }
228    }
229}
230
231fn hex_encode(data: &[u8]) -> String {
232    let mut s = String::with_capacity(data.len() * 2);
233    for byte in data {
234        s.push_str(&format!("{byte:02X}"));
235    }
236    s
237}
238
239/// Column definition.
240#[derive(Debug, Clone)]
241pub struct ColumnDef {
242    pub name: String,
243    pub data_type: DataType,
244    pub nullable: bool,
245    pub position: u16,
246}
247
248/// Index definition stored as part of the table schema.
249#[derive(Debug, Clone)]
250pub struct IndexDef {
251    pub name: String,
252    pub columns: Vec<u16>,
253    pub unique: bool,
254}
255
256/// Table schema stored in the _schema table.
257#[derive(Debug, Clone)]
258pub struct TableSchema {
259    pub name: String,
260    pub columns: Vec<ColumnDef>,
261    pub primary_key_columns: Vec<u16>,
262    pub indices: Vec<IndexDef>,
263    pk_idx_cache: Vec<usize>,
264    non_pk_idx_cache: Vec<usize>,
265}
266
267impl TableSchema {
268    pub fn new(
269        name: String,
270        columns: Vec<ColumnDef>,
271        primary_key_columns: Vec<u16>,
272        indices: Vec<IndexDef>,
273    ) -> Self {
274        let pk_idx_cache: Vec<usize> = primary_key_columns.iter().map(|&i| i as usize).collect();
275        let non_pk_idx_cache: Vec<usize> = (0..columns.len())
276            .filter(|i| !primary_key_columns.contains(&(*i as u16)))
277            .collect();
278        Self {
279            name,
280            columns,
281            primary_key_columns,
282            indices,
283            pk_idx_cache,
284            non_pk_idx_cache,
285        }
286    }
287}
288
289const SCHEMA_VERSION: u8 = 2;
290
291impl TableSchema {
292    pub fn serialize(&self) -> Vec<u8> {
293        let mut buf = Vec::new();
294        buf.push(SCHEMA_VERSION);
295
296        // Table name
297        let name_bytes = self.name.as_bytes();
298        buf.extend_from_slice(&(name_bytes.len() as u16).to_le_bytes());
299        buf.extend_from_slice(name_bytes);
300
301        // Column count
302        buf.extend_from_slice(&(self.columns.len() as u16).to_le_bytes());
303
304        // Columns
305        for col in &self.columns {
306            let col_name = col.name.as_bytes();
307            buf.extend_from_slice(&(col_name.len() as u16).to_le_bytes());
308            buf.extend_from_slice(col_name);
309            buf.push(col.data_type.type_tag());
310            buf.push(if col.nullable { 1 } else { 0 });
311            buf.extend_from_slice(&col.position.to_le_bytes());
312        }
313
314        // Primary key columns
315        buf.extend_from_slice(&(self.primary_key_columns.len() as u16).to_le_bytes());
316        for &pk_idx in &self.primary_key_columns {
317            buf.extend_from_slice(&pk_idx.to_le_bytes());
318        }
319
320        // Indices (v2+)
321        buf.extend_from_slice(&(self.indices.len() as u16).to_le_bytes());
322        for idx in &self.indices {
323            let idx_name = idx.name.as_bytes();
324            buf.extend_from_slice(&(idx_name.len() as u16).to_le_bytes());
325            buf.extend_from_slice(idx_name);
326            buf.extend_from_slice(&(idx.columns.len() as u16).to_le_bytes());
327            for &col_idx in &idx.columns {
328                buf.extend_from_slice(&col_idx.to_le_bytes());
329            }
330            buf.push(if idx.unique { 1 } else { 0 });
331        }
332
333        buf
334    }
335
336    pub fn deserialize(data: &[u8]) -> crate::error::Result<Self> {
337        let mut pos = 0;
338
339        if data.is_empty() || (data[0] != SCHEMA_VERSION && data[0] != 1) {
340            return Err(crate::error::SqlError::InvalidValue(
341                "invalid schema version".into(),
342            ));
343        }
344        let version = data[0];
345        pos += 1;
346
347        // Table name
348        let name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
349        pos += 2;
350        let name = String::from_utf8_lossy(&data[pos..pos + name_len]).into_owned();
351        pos += name_len;
352
353        // Column count
354        let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
355        pos += 2;
356
357        let mut columns = Vec::with_capacity(col_count);
358        for _ in 0..col_count {
359            let col_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
360            pos += 2;
361            let col_name = String::from_utf8_lossy(&data[pos..pos + col_name_len]).into_owned();
362            pos += col_name_len;
363            let data_type = DataType::from_tag(data[pos]).ok_or_else(|| {
364                crate::error::SqlError::InvalidValue("unknown data type tag".into())
365            })?;
366            pos += 1;
367            let nullable = data[pos] != 0;
368            pos += 1;
369            let position = u16::from_le_bytes([data[pos], data[pos + 1]]);
370            pos += 2;
371            columns.push(ColumnDef {
372                name: col_name,
373                data_type,
374                nullable,
375                position,
376            });
377        }
378
379        // Primary key columns
380        let pk_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
381        pos += 2;
382        let mut primary_key_columns = Vec::with_capacity(pk_count);
383        for _ in 0..pk_count {
384            let pk_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
385            pos += 2;
386            primary_key_columns.push(pk_idx);
387        }
388        let indices = if version >= 2 && pos + 2 <= data.len() {
389            let idx_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
390            pos += 2;
391            let mut idxs = Vec::with_capacity(idx_count);
392            for _ in 0..idx_count {
393                let idx_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
394                pos += 2;
395                let idx_name = String::from_utf8_lossy(&data[pos..pos + idx_name_len]).into_owned();
396                pos += idx_name_len;
397                let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
398                pos += 2;
399                let mut cols = Vec::with_capacity(col_count);
400                for _ in 0..col_count {
401                    let col_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
402                    pos += 2;
403                    cols.push(col_idx);
404                }
405                let unique = data[pos] != 0;
406                pos += 1;
407                idxs.push(IndexDef {
408                    name: idx_name,
409                    columns: cols,
410                    unique,
411                });
412            }
413            idxs
414        } else {
415            vec![]
416        };
417        let _ = pos;
418
419        Ok(Self::new(name, columns, primary_key_columns, indices))
420    }
421
422    /// Get column index by name (case-insensitive).
423    pub fn column_index(&self, name: &str) -> Option<usize> {
424        self.columns
425            .iter()
426            .position(|c| c.name.eq_ignore_ascii_case(name))
427    }
428
429    /// Get indices of non-PK columns (columns stored in the B+ tree value).
430    pub fn non_pk_indices(&self) -> &[usize] {
431        &self.non_pk_idx_cache
432    }
433
434    /// Get the PK column indices as usize.
435    pub fn pk_indices(&self) -> &[usize] {
436        &self.pk_idx_cache
437    }
438
439    /// Get index definition by name (case-insensitive).
440    pub fn index_by_name(&self, name: &str) -> Option<&IndexDef> {
441        let lower = name.to_ascii_lowercase();
442        self.indices.iter().find(|i| i.name == lower)
443    }
444
445    /// Get the KV table name for an index.
446    pub fn index_table_name(table_name: &str, index_name: &str) -> Vec<u8> {
447        format!("__idx_{table_name}_{index_name}").into_bytes()
448    }
449}
450
451/// Result of executing a SQL statement.
452#[derive(Debug)]
453pub enum ExecutionResult {
454    RowsAffected(u64),
455    Query(QueryResult),
456    Ok,
457}
458
459/// Result of a SELECT query.
460#[derive(Debug)]
461pub struct QueryResult {
462    pub columns: Vec<String>,
463    pub rows: Vec<Vec<Value>>,
464}
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn value_ordering() {
472        assert!(Value::Null < Value::Boolean(false));
473        assert!(Value::Boolean(false) < Value::Boolean(true));
474        assert!(Value::Boolean(true) < Value::Integer(0));
475        assert!(Value::Integer(-1) < Value::Integer(0));
476        assert!(Value::Integer(0) < Value::Real(0.5));
477        assert!(Value::Real(1.0) < Value::Text("".into()));
478        assert!(Value::Text("a".into()) < Value::Text("b".into()));
479        assert!(Value::Text("z".into()) < Value::Blob(vec![]));
480        assert!(Value::Blob(vec![0]) < Value::Blob(vec![1]));
481    }
482
483    #[test]
484    fn value_numeric_mixed() {
485        assert_eq!(Value::Integer(1), Value::Real(1.0));
486        assert!(Value::Integer(1) < Value::Real(1.5));
487        assert!(Value::Real(0.5) < Value::Integer(1));
488    }
489
490    #[test]
491    fn value_display() {
492        assert_eq!(format!("{}", Value::Null), "NULL");
493        assert_eq!(format!("{}", Value::Integer(42)), "42");
494        assert_eq!(format!("{}", Value::Real(3.15)), "3.15");
495        assert_eq!(format!("{}", Value::Real(1.0)), "1.0");
496        assert_eq!(format!("{}", Value::Text("hello".into())), "hello");
497        assert_eq!(format!("{}", Value::Blob(vec![0xDE, 0xAD])), "X'DEAD'");
498        assert_eq!(format!("{}", Value::Boolean(true)), "TRUE");
499        assert_eq!(format!("{}", Value::Boolean(false)), "FALSE");
500    }
501
502    #[test]
503    fn value_coerce() {
504        assert_eq!(
505            Value::Integer(42).coerce_to(DataType::Real),
506            Some(Value::Real(42.0))
507        );
508        assert_eq!(
509            Value::Boolean(true).coerce_to(DataType::Integer),
510            Some(Value::Integer(1))
511        );
512        assert_eq!(Value::Null.coerce_to(DataType::Integer), Some(Value::Null));
513        assert_eq!(Value::Text("x".into()).coerce_to(DataType::Integer), None);
514    }
515
516    #[test]
517    fn schema_roundtrip() {
518        let schema = TableSchema::new(
519            "users".into(),
520            vec![
521                ColumnDef {
522                    name: "id".into(),
523                    data_type: DataType::Integer,
524                    nullable: false,
525                    position: 0,
526                },
527                ColumnDef {
528                    name: "name".into(),
529                    data_type: DataType::Text,
530                    nullable: true,
531                    position: 1,
532                },
533                ColumnDef {
534                    name: "active".into(),
535                    data_type: DataType::Boolean,
536                    nullable: false,
537                    position: 2,
538                },
539            ],
540            vec![0],
541            vec![],
542        );
543
544        let data = schema.serialize();
545        let restored = TableSchema::deserialize(&data).unwrap();
546
547        assert_eq!(restored.name, "users");
548        assert_eq!(restored.columns.len(), 3);
549        assert_eq!(restored.columns[0].name, "id");
550        assert_eq!(restored.columns[0].data_type, DataType::Integer);
551        assert!(!restored.columns[0].nullable);
552        assert_eq!(restored.columns[1].name, "name");
553        assert_eq!(restored.columns[1].data_type, DataType::Text);
554        assert!(restored.columns[1].nullable);
555        assert_eq!(restored.columns[2].name, "active");
556        assert_eq!(restored.columns[2].data_type, DataType::Boolean);
557        assert_eq!(restored.primary_key_columns, vec![0]);
558    }
559
560    #[test]
561    fn schema_roundtrip_with_indices() {
562        let schema = TableSchema::new(
563            "orders".into(),
564            vec![
565                ColumnDef {
566                    name: "id".into(),
567                    data_type: DataType::Integer,
568                    nullable: false,
569                    position: 0,
570                },
571                ColumnDef {
572                    name: "customer".into(),
573                    data_type: DataType::Text,
574                    nullable: false,
575                    position: 1,
576                },
577                ColumnDef {
578                    name: "amount".into(),
579                    data_type: DataType::Real,
580                    nullable: true,
581                    position: 2,
582                },
583            ],
584            vec![0],
585            vec![
586                IndexDef {
587                    name: "idx_customer".into(),
588                    columns: vec![1],
589                    unique: false,
590                },
591                IndexDef {
592                    name: "idx_amount_uniq".into(),
593                    columns: vec![2],
594                    unique: true,
595                },
596            ],
597        );
598
599        let data = schema.serialize();
600        let restored = TableSchema::deserialize(&data).unwrap();
601
602        assert_eq!(restored.indices.len(), 2);
603        assert_eq!(restored.indices[0].name, "idx_customer");
604        assert_eq!(restored.indices[0].columns, vec![1]);
605        assert!(!restored.indices[0].unique);
606        assert_eq!(restored.indices[1].name, "idx_amount_uniq");
607        assert_eq!(restored.indices[1].columns, vec![2]);
608        assert!(restored.indices[1].unique);
609    }
610
611    #[test]
612    fn schema_v1_backward_compat() {
613        let old_schema = TableSchema::new(
614            "test".into(),
615            vec![ColumnDef {
616                name: "id".into(),
617                data_type: DataType::Integer,
618                nullable: false,
619                position: 0,
620            }],
621            vec![0],
622            vec![],
623        );
624        let mut data = old_schema.serialize();
625        // Patch to v1 format: replace version byte and truncate index data
626        data[0] = 1;
627        // Remove the last 2 bytes (index count = 0)
628        data.truncate(data.len() - 2);
629
630        let restored = TableSchema::deserialize(&data).unwrap();
631        assert_eq!(restored.name, "test");
632        assert!(restored.indices.is_empty());
633    }
634
635    #[test]
636    fn data_type_display() {
637        assert_eq!(format!("{}", DataType::Integer), "INTEGER");
638        assert_eq!(format!("{}", DataType::Text), "TEXT");
639        assert_eq!(format!("{}", DataType::Boolean), "BOOLEAN");
640    }
641}