Skip to main content

citadel_sql/
types.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4
5/// SQL data types.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum DataType {
8    Null,
9    Integer,
10    Real,
11    Text,
12    Blob,
13    Boolean,
14}
15
16impl DataType {
17    pub fn type_tag(self) -> u8 {
18        match self {
19            DataType::Null => 0,
20            DataType::Blob => 1,
21            DataType::Text => 2,
22            DataType::Boolean => 3,
23            DataType::Integer => 4,
24            DataType::Real => 5,
25        }
26    }
27
28    pub fn from_tag(tag: u8) -> Option<Self> {
29        match tag {
30            0 => Some(DataType::Null),
31            1 => Some(DataType::Blob),
32            2 => Some(DataType::Text),
33            3 => Some(DataType::Boolean),
34            4 => Some(DataType::Integer),
35            5 => Some(DataType::Real),
36            _ => None,
37        }
38    }
39}
40
41impl fmt::Display for DataType {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            DataType::Null => write!(f, "NULL"),
45            DataType::Integer => write!(f, "INTEGER"),
46            DataType::Real => write!(f, "REAL"),
47            DataType::Text => write!(f, "TEXT"),
48            DataType::Blob => write!(f, "BLOB"),
49            DataType::Boolean => write!(f, "BOOLEAN"),
50        }
51    }
52}
53
54/// SQL value.
55#[derive(Debug, Clone)]
56pub enum Value {
57    Null,
58    Integer(i64),
59    Real(f64),
60    Text(String),
61    Blob(Vec<u8>),
62    Boolean(bool),
63}
64
65impl Value {
66    pub fn data_type(&self) -> DataType {
67        match self {
68            Value::Null => DataType::Null,
69            Value::Integer(_) => DataType::Integer,
70            Value::Real(_) => DataType::Real,
71            Value::Text(_) => DataType::Text,
72            Value::Blob(_) => DataType::Blob,
73            Value::Boolean(_) => DataType::Boolean,
74        }
75    }
76
77    pub fn is_null(&self) -> bool {
78        matches!(self, Value::Null)
79    }
80
81    /// Attempt to coerce this value to the target type.
82    pub fn coerce_to(&self, target: DataType) -> Option<Value> {
83        match (self, target) {
84            (_, DataType::Null) => Some(Value::Null),
85            (Value::Null, _) => Some(Value::Null),
86            (Value::Integer(i), DataType::Integer) => Some(Value::Integer(*i)),
87            (Value::Integer(i), DataType::Real) => Some(Value::Real(*i as f64)),
88            (Value::Real(r), DataType::Real) => Some(Value::Real(*r)),
89            (Value::Real(r), DataType::Integer) => Some(Value::Integer(*r as i64)),
90            (Value::Text(s), DataType::Text) => Some(Value::Text(s.clone())),
91            (Value::Blob(b), DataType::Blob) => Some(Value::Blob(b.clone())),
92            (Value::Boolean(b), DataType::Boolean) => Some(Value::Boolean(*b)),
93            (Value::Boolean(b), DataType::Integer) => Some(Value::Integer(if *b { 1 } else { 0 })),
94            (Value::Integer(i), DataType::Boolean) => Some(Value::Boolean(*i != 0)),
95            _ => None,
96        }
97    }
98
99    /// Numeric ordering for Integer and Real values (promotes to f64 for mixed).
100    fn numeric_cmp(&self, other: &Value) -> Option<Ordering> {
101        match (self, other) {
102            (Value::Integer(a), Value::Integer(b)) => Some(a.cmp(b)),
103            (Value::Real(a), Value::Real(b)) => a.partial_cmp(b),
104            (Value::Integer(a), Value::Real(b)) => (*a as f64).partial_cmp(b),
105            (Value::Real(a), Value::Integer(b)) => a.partial_cmp(&(*b as f64)),
106            _ => None,
107        }
108    }
109}
110
111impl PartialEq for Value {
112    fn eq(&self, other: &Self) -> bool {
113        match (self, other) {
114            (Value::Null, Value::Null) => true,
115            (Value::Integer(a), Value::Integer(b)) => a == b,
116            (Value::Real(a), Value::Real(b)) => a == b,
117            (Value::Integer(a), Value::Real(b)) => (*a as f64) == *b,
118            (Value::Real(a), Value::Integer(b)) => *a == (*b as f64),
119            (Value::Text(a), Value::Text(b)) => a == b,
120            (Value::Blob(a), Value::Blob(b)) => a == b,
121            (Value::Boolean(a), Value::Boolean(b)) => a == b,
122            _ => false,
123        }
124    }
125}
126
127impl Eq for Value {}
128
129impl Hash for Value {
130    fn hash<H: Hasher>(&self, state: &mut H) {
131        match self {
132            Value::Null => 0u8.hash(state),
133            Value::Integer(i) => {
134                // Hash via f64 bits so Integer(n) and Real(n.0) produce the same hash,
135                // matching the cross-type PartialEq contract.
136                1u8.hash(state);
137                (*i as f64).to_bits().hash(state);
138            }
139            Value::Real(r) => {
140                1u8.hash(state);
141                r.to_bits().hash(state);
142            }
143            Value::Text(s) => {
144                2u8.hash(state);
145                s.hash(state);
146            }
147            Value::Blob(b) => {
148                3u8.hash(state);
149                b.hash(state);
150            }
151            Value::Boolean(b) => {
152                4u8.hash(state);
153                b.hash(state);
154            }
155        }
156    }
157}
158
159impl PartialOrd for Value {
160    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
161        Some(self.cmp(other))
162    }
163}
164
165impl Ord for Value {
166    fn cmp(&self, other: &Self) -> Ordering {
167        // NULL < BOOLEAN < INTEGER/REAL (numeric) < TEXT < BLOB
168        match (self, other) {
169            (Value::Null, Value::Null) => Ordering::Equal,
170            (Value::Null, _) => Ordering::Less,
171            (_, Value::Null) => Ordering::Greater,
172
173            (Value::Boolean(a), Value::Boolean(b)) => a.cmp(b),
174            (Value::Boolean(_), _) => Ordering::Less,
175            (_, Value::Boolean(_)) => Ordering::Greater,
176
177            // Numeric: Integer and Real are comparable
178            (Value::Integer(_) | Value::Real(_), Value::Integer(_) | Value::Real(_)) => {
179                self.numeric_cmp(other).unwrap_or(Ordering::Equal)
180            }
181            (Value::Integer(_) | Value::Real(_), _) => Ordering::Less,
182            (_, Value::Integer(_) | Value::Real(_)) => Ordering::Greater,
183
184            (Value::Text(a), Value::Text(b)) => a.cmp(b),
185            (Value::Text(_), _) => Ordering::Less,
186            (_, Value::Text(_)) => Ordering::Greater,
187
188            (Value::Blob(a), Value::Blob(b)) => a.cmp(b),
189        }
190    }
191}
192
193impl fmt::Display for Value {
194    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195        match self {
196            Value::Null => write!(f, "NULL"),
197            Value::Integer(i) => write!(f, "{i}"),
198            Value::Real(r) => {
199                if r.fract() == 0.0 && r.is_finite() {
200                    write!(f, "{r:.1}")
201                } else {
202                    write!(f, "{r}")
203                }
204            }
205            Value::Text(s) => write!(f, "{s}"),
206            Value::Blob(b) => write!(f, "X'{}'", hex_encode(b)),
207            Value::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
208        }
209    }
210}
211
212fn hex_encode(data: &[u8]) -> String {
213    let mut s = String::with_capacity(data.len() * 2);
214    for byte in data {
215        s.push_str(&format!("{byte:02X}"));
216    }
217    s
218}
219
220/// Column definition.
221#[derive(Debug, Clone)]
222pub struct ColumnDef {
223    pub name: String,
224    pub data_type: DataType,
225    pub nullable: bool,
226    pub position: u16,
227}
228
229/// Index definition stored as part of the table schema.
230#[derive(Debug, Clone)]
231pub struct IndexDef {
232    pub name: String,
233    pub columns: Vec<u16>,
234    pub unique: bool,
235}
236
237/// Table schema stored in the _schema table.
238#[derive(Debug, Clone)]
239pub struct TableSchema {
240    pub name: String,
241    pub columns: Vec<ColumnDef>,
242    pub primary_key_columns: Vec<u16>,
243    pub indices: Vec<IndexDef>,
244}
245
246const SCHEMA_VERSION: u8 = 2;
247
248impl TableSchema {
249    pub fn serialize(&self) -> Vec<u8> {
250        let mut buf = Vec::new();
251        buf.push(SCHEMA_VERSION);
252
253        // Table name
254        let name_bytes = self.name.as_bytes();
255        buf.extend_from_slice(&(name_bytes.len() as u16).to_le_bytes());
256        buf.extend_from_slice(name_bytes);
257
258        // Column count
259        buf.extend_from_slice(&(self.columns.len() as u16).to_le_bytes());
260
261        // Columns
262        for col in &self.columns {
263            let col_name = col.name.as_bytes();
264            buf.extend_from_slice(&(col_name.len() as u16).to_le_bytes());
265            buf.extend_from_slice(col_name);
266            buf.push(col.data_type.type_tag());
267            buf.push(if col.nullable { 1 } else { 0 });
268            buf.extend_from_slice(&col.position.to_le_bytes());
269        }
270
271        // Primary key columns
272        buf.extend_from_slice(&(self.primary_key_columns.len() as u16).to_le_bytes());
273        for &pk_idx in &self.primary_key_columns {
274            buf.extend_from_slice(&pk_idx.to_le_bytes());
275        }
276
277        // Indices (v2+)
278        buf.extend_from_slice(&(self.indices.len() as u16).to_le_bytes());
279        for idx in &self.indices {
280            let idx_name = idx.name.as_bytes();
281            buf.extend_from_slice(&(idx_name.len() as u16).to_le_bytes());
282            buf.extend_from_slice(idx_name);
283            buf.extend_from_slice(&(idx.columns.len() as u16).to_le_bytes());
284            for &col_idx in &idx.columns {
285                buf.extend_from_slice(&col_idx.to_le_bytes());
286            }
287            buf.push(if idx.unique { 1 } else { 0 });
288        }
289
290        buf
291    }
292
293    pub fn deserialize(data: &[u8]) -> crate::error::Result<Self> {
294        let mut pos = 0;
295
296        if data.is_empty() || (data[0] != SCHEMA_VERSION && data[0] != 1) {
297            return Err(crate::error::SqlError::InvalidValue(
298                "invalid schema version".into(),
299            ));
300        }
301        let version = data[0];
302        pos += 1;
303
304        // Table name
305        let name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
306        pos += 2;
307        let name = String::from_utf8_lossy(&data[pos..pos + name_len]).into_owned();
308        pos += name_len;
309
310        // Column count
311        let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
312        pos += 2;
313
314        let mut columns = Vec::with_capacity(col_count);
315        for _ in 0..col_count {
316            let col_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
317            pos += 2;
318            let col_name = String::from_utf8_lossy(&data[pos..pos + col_name_len]).into_owned();
319            pos += col_name_len;
320            let data_type = DataType::from_tag(data[pos]).ok_or_else(|| {
321                crate::error::SqlError::InvalidValue("unknown data type tag".into())
322            })?;
323            pos += 1;
324            let nullable = data[pos] != 0;
325            pos += 1;
326            let position = u16::from_le_bytes([data[pos], data[pos + 1]]);
327            pos += 2;
328            columns.push(ColumnDef {
329                name: col_name,
330                data_type,
331                nullable,
332                position,
333            });
334        }
335
336        // Primary key columns
337        let pk_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
338        pos += 2;
339        let mut primary_key_columns = Vec::with_capacity(pk_count);
340        for _ in 0..pk_count {
341            let pk_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
342            pos += 2;
343            primary_key_columns.push(pk_idx);
344        }
345        let indices = if version >= 2 && pos + 2 <= data.len() {
346            let idx_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
347            pos += 2;
348            let mut idxs = Vec::with_capacity(idx_count);
349            for _ in 0..idx_count {
350                let idx_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
351                pos += 2;
352                let idx_name = String::from_utf8_lossy(&data[pos..pos + idx_name_len]).into_owned();
353                pos += idx_name_len;
354                let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
355                pos += 2;
356                let mut cols = Vec::with_capacity(col_count);
357                for _ in 0..col_count {
358                    let col_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
359                    pos += 2;
360                    cols.push(col_idx);
361                }
362                let unique = data[pos] != 0;
363                pos += 1;
364                idxs.push(IndexDef {
365                    name: idx_name,
366                    columns: cols,
367                    unique,
368                });
369            }
370            idxs
371        } else {
372            vec![]
373        };
374        let _ = pos;
375
376        Ok(Self {
377            name,
378            columns,
379            primary_key_columns,
380            indices,
381        })
382    }
383
384    /// Get column index by name (case-insensitive).
385    pub fn column_index(&self, name: &str) -> Option<usize> {
386        let lower = name.to_ascii_lowercase();
387        self.columns
388            .iter()
389            .position(|c| c.name.to_ascii_lowercase() == lower)
390    }
391
392    /// Get indices of non-PK columns (columns stored in the B+ tree value).
393    pub fn non_pk_indices(&self) -> Vec<usize> {
394        (0..self.columns.len())
395            .filter(|i| !self.primary_key_columns.contains(&(*i as u16)))
396            .collect()
397    }
398
399    /// Get the PK column indices as usize.
400    pub fn pk_indices(&self) -> Vec<usize> {
401        self.primary_key_columns
402            .iter()
403            .map(|&i| i as usize)
404            .collect()
405    }
406
407    /// Get index definition by name (case-insensitive).
408    pub fn index_by_name(&self, name: &str) -> Option<&IndexDef> {
409        let lower = name.to_ascii_lowercase();
410        self.indices.iter().find(|i| i.name == lower)
411    }
412
413    /// Get the KV table name for an index.
414    pub fn index_table_name(table_name: &str, index_name: &str) -> Vec<u8> {
415        format!("__idx_{table_name}_{index_name}").into_bytes()
416    }
417}
418
419/// Result of executing a SQL statement.
420#[derive(Debug)]
421pub enum ExecutionResult {
422    RowsAffected(u64),
423    Query(QueryResult),
424    Ok,
425}
426
427/// Result of a SELECT query.
428#[derive(Debug)]
429pub struct QueryResult {
430    pub columns: Vec<String>,
431    pub rows: Vec<Vec<Value>>,
432}
433
434#[cfg(test)]
435mod tests {
436    use super::*;
437
438    #[test]
439    fn value_ordering() {
440        assert!(Value::Null < Value::Boolean(false));
441        assert!(Value::Boolean(false) < Value::Boolean(true));
442        assert!(Value::Boolean(true) < Value::Integer(0));
443        assert!(Value::Integer(-1) < Value::Integer(0));
444        assert!(Value::Integer(0) < Value::Real(0.5));
445        assert!(Value::Real(1.0) < Value::Text("".into()));
446        assert!(Value::Text("a".into()) < Value::Text("b".into()));
447        assert!(Value::Text("z".into()) < Value::Blob(vec![]));
448        assert!(Value::Blob(vec![0]) < Value::Blob(vec![1]));
449    }
450
451    #[test]
452    fn value_numeric_mixed() {
453        assert_eq!(Value::Integer(1), Value::Real(1.0));
454        assert!(Value::Integer(1) < Value::Real(1.5));
455        assert!(Value::Real(0.5) < Value::Integer(1));
456    }
457
458    #[test]
459    fn value_display() {
460        assert_eq!(format!("{}", Value::Null), "NULL");
461        assert_eq!(format!("{}", Value::Integer(42)), "42");
462        assert_eq!(format!("{}", Value::Real(3.15)), "3.15");
463        assert_eq!(format!("{}", Value::Real(1.0)), "1.0");
464        assert_eq!(format!("{}", Value::Text("hello".into())), "hello");
465        assert_eq!(format!("{}", Value::Blob(vec![0xDE, 0xAD])), "X'DEAD'");
466        assert_eq!(format!("{}", Value::Boolean(true)), "TRUE");
467        assert_eq!(format!("{}", Value::Boolean(false)), "FALSE");
468    }
469
470    #[test]
471    fn value_coerce() {
472        assert_eq!(
473            Value::Integer(42).coerce_to(DataType::Real),
474            Some(Value::Real(42.0))
475        );
476        assert_eq!(
477            Value::Boolean(true).coerce_to(DataType::Integer),
478            Some(Value::Integer(1))
479        );
480        assert_eq!(Value::Null.coerce_to(DataType::Integer), Some(Value::Null));
481        assert_eq!(Value::Text("x".into()).coerce_to(DataType::Integer), None);
482    }
483
484    #[test]
485    fn schema_roundtrip() {
486        let schema = TableSchema {
487            name: "users".into(),
488            columns: vec![
489                ColumnDef {
490                    name: "id".into(),
491                    data_type: DataType::Integer,
492                    nullable: false,
493                    position: 0,
494                },
495                ColumnDef {
496                    name: "name".into(),
497                    data_type: DataType::Text,
498                    nullable: true,
499                    position: 1,
500                },
501                ColumnDef {
502                    name: "active".into(),
503                    data_type: DataType::Boolean,
504                    nullable: false,
505                    position: 2,
506                },
507            ],
508            primary_key_columns: vec![0],
509            indices: vec![],
510        };
511
512        let data = schema.serialize();
513        let restored = TableSchema::deserialize(&data).unwrap();
514
515        assert_eq!(restored.name, "users");
516        assert_eq!(restored.columns.len(), 3);
517        assert_eq!(restored.columns[0].name, "id");
518        assert_eq!(restored.columns[0].data_type, DataType::Integer);
519        assert!(!restored.columns[0].nullable);
520        assert_eq!(restored.columns[1].name, "name");
521        assert_eq!(restored.columns[1].data_type, DataType::Text);
522        assert!(restored.columns[1].nullable);
523        assert_eq!(restored.columns[2].name, "active");
524        assert_eq!(restored.columns[2].data_type, DataType::Boolean);
525        assert_eq!(restored.primary_key_columns, vec![0]);
526    }
527
528    #[test]
529    fn schema_roundtrip_with_indices() {
530        let schema = TableSchema {
531            name: "orders".into(),
532            columns: vec![
533                ColumnDef {
534                    name: "id".into(),
535                    data_type: DataType::Integer,
536                    nullable: false,
537                    position: 0,
538                },
539                ColumnDef {
540                    name: "customer".into(),
541                    data_type: DataType::Text,
542                    nullable: false,
543                    position: 1,
544                },
545                ColumnDef {
546                    name: "amount".into(),
547                    data_type: DataType::Real,
548                    nullable: true,
549                    position: 2,
550                },
551            ],
552            primary_key_columns: vec![0],
553            indices: vec![
554                IndexDef {
555                    name: "idx_customer".into(),
556                    columns: vec![1],
557                    unique: false,
558                },
559                IndexDef {
560                    name: "idx_amount_uniq".into(),
561                    columns: vec![2],
562                    unique: true,
563                },
564            ],
565        };
566
567        let data = schema.serialize();
568        let restored = TableSchema::deserialize(&data).unwrap();
569
570        assert_eq!(restored.indices.len(), 2);
571        assert_eq!(restored.indices[0].name, "idx_customer");
572        assert_eq!(restored.indices[0].columns, vec![1]);
573        assert!(!restored.indices[0].unique);
574        assert_eq!(restored.indices[1].name, "idx_amount_uniq");
575        assert_eq!(restored.indices[1].columns, vec![2]);
576        assert!(restored.indices[1].unique);
577    }
578
579    #[test]
580    fn schema_v1_backward_compat() {
581        let old_schema = TableSchema {
582            name: "test".into(),
583            columns: vec![ColumnDef {
584                name: "id".into(),
585                data_type: DataType::Integer,
586                nullable: false,
587                position: 0,
588            }],
589            primary_key_columns: vec![0],
590            indices: vec![],
591        };
592        let mut data = old_schema.serialize();
593        // Patch to v1 format: replace version byte and truncate index data
594        data[0] = 1;
595        // Remove the last 2 bytes (index count = 0)
596        data.truncate(data.len() - 2);
597
598        let restored = TableSchema::deserialize(&data).unwrap();
599        assert_eq!(restored.name, "test");
600        assert!(restored.indices.is_empty());
601    }
602
603    #[test]
604    fn data_type_display() {
605        assert_eq!(format!("{}", DataType::Integer), "INTEGER");
606        assert_eq!(format!("{}", DataType::Text), "TEXT");
607        assert_eq!(format!("{}", DataType::Boolean), "BOOLEAN");
608    }
609}