Skip to main content

citadeldb_sql/
types.rs

1use std::cmp::Ordering;
2use std::fmt;
3use std::hash::{Hash, Hasher};
4
5/// SQL data types.
6#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7pub enum DataType {
8    Null,
9    Integer,
10    Real,
11    Text,
12    Blob,
13    Boolean,
14}
15
16impl DataType {
17    pub fn type_tag(self) -> u8 {
18        match self {
19            DataType::Null => 0,
20            DataType::Blob => 1,
21            DataType::Text => 2,
22            DataType::Boolean => 3,
23            DataType::Integer => 4,
24            DataType::Real => 5,
25        }
26    }
27
28    pub fn from_tag(tag: u8) -> Option<Self> {
29        match tag {
30            0 => Some(DataType::Null),
31            1 => Some(DataType::Blob),
32            2 => Some(DataType::Text),
33            3 => Some(DataType::Boolean),
34            4 => Some(DataType::Integer),
35            5 => Some(DataType::Real),
36            _ => None,
37        }
38    }
39}
40
41impl fmt::Display for DataType {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        match self {
44            DataType::Null => write!(f, "NULL"),
45            DataType::Integer => write!(f, "INTEGER"),
46            DataType::Real => write!(f, "REAL"),
47            DataType::Text => write!(f, "TEXT"),
48            DataType::Blob => write!(f, "BLOB"),
49            DataType::Boolean => write!(f, "BOOLEAN"),
50        }
51    }
52}
53
54/// SQL value.
55#[derive(Debug, Clone)]
56pub enum Value {
57    Null,
58    Integer(i64),
59    Real(f64),
60    Text(String),
61    Blob(Vec<u8>),
62    Boolean(bool),
63}
64
65impl Value {
66    pub fn data_type(&self) -> DataType {
67        match self {
68            Value::Null => DataType::Null,
69            Value::Integer(_) => DataType::Integer,
70            Value::Real(_) => DataType::Real,
71            Value::Text(_) => DataType::Text,
72            Value::Blob(_) => DataType::Blob,
73            Value::Boolean(_) => DataType::Boolean,
74        }
75    }
76
77    pub fn is_null(&self) -> bool {
78        matches!(self, Value::Null)
79    }
80
81    /// Attempt to coerce this value to the target type.
82    pub fn coerce_to(&self, target: DataType) -> Option<Value> {
83        match (self, target) {
84            (_, DataType::Null) => Some(Value::Null),
85            (Value::Null, _) => Some(Value::Null),
86            (Value::Integer(i), DataType::Integer) => Some(Value::Integer(*i)),
87            (Value::Integer(i), DataType::Real) => Some(Value::Real(*i as f64)),
88            (Value::Real(r), DataType::Real) => Some(Value::Real(*r)),
89            (Value::Real(r), DataType::Integer) => Some(Value::Integer(*r as i64)),
90            (Value::Text(s), DataType::Text) => Some(Value::Text(s.clone())),
91            (Value::Blob(b), DataType::Blob) => Some(Value::Blob(b.clone())),
92            (Value::Boolean(b), DataType::Boolean) => Some(Value::Boolean(*b)),
93            (Value::Boolean(b), DataType::Integer) => Some(Value::Integer(if *b { 1 } else { 0 })),
94            (Value::Integer(i), DataType::Boolean) => Some(Value::Boolean(*i != 0)),
95            _ => None,
96        }
97    }
98
99    /// Numeric ordering for Integer and Real values (promotes to f64 for mixed).
100    fn numeric_cmp(&self, other: &Value) -> Option<Ordering> {
101        match (self, other) {
102            (Value::Integer(a), Value::Integer(b)) => Some(a.cmp(b)),
103            (Value::Real(a), Value::Real(b)) => a.partial_cmp(b),
104            (Value::Integer(a), Value::Real(b)) => (*a as f64).partial_cmp(b),
105            (Value::Real(a), Value::Integer(b)) => a.partial_cmp(&(*b as f64)),
106            _ => None,
107        }
108    }
109}
110
111impl PartialEq for Value {
112    fn eq(&self, other: &Self) -> bool {
113        match (self, other) {
114            (Value::Null, Value::Null) => true,
115            (Value::Integer(a), Value::Integer(b)) => a == b,
116            (Value::Real(a), Value::Real(b)) => a == b,
117            (Value::Integer(a), Value::Real(b)) => (*a as f64) == *b,
118            (Value::Real(a), Value::Integer(b)) => *a == (*b as f64),
119            (Value::Text(a), Value::Text(b)) => a == b,
120            (Value::Blob(a), Value::Blob(b)) => a == b,
121            (Value::Boolean(a), Value::Boolean(b)) => a == b,
122            _ => false,
123        }
124    }
125}
126
127impl Eq for Value {}
128
129impl Hash for Value {
130    fn hash<H: Hasher>(&self, state: &mut H) {
131        match self {
132            Value::Null => 0u8.hash(state),
133            Value::Integer(i) => {
134                // Hash via f64 bits so Integer(n) and Real(n.0) produce the same hash,
135                // matching the cross-type PartialEq contract.
136                1u8.hash(state);
137                (*i as f64).to_bits().hash(state);
138            }
139            Value::Real(r) => {
140                1u8.hash(state);
141                r.to_bits().hash(state);
142            }
143            Value::Text(s) => { 2u8.hash(state); s.hash(state); }
144            Value::Blob(b) => { 3u8.hash(state); b.hash(state); }
145            Value::Boolean(b) => { 4u8.hash(state); b.hash(state); }
146        }
147    }
148}
149
150impl PartialOrd for Value {
151    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
152        // NULL < BOOLEAN < INTEGER/REAL (numeric) < TEXT < BLOB
153        match (self, other) {
154            (Value::Null, Value::Null) => Some(Ordering::Equal),
155            (Value::Null, _) => Some(Ordering::Less),
156            (_, Value::Null) => Some(Ordering::Greater),
157
158            (Value::Boolean(a), Value::Boolean(b)) => Some(a.cmp(b)),
159            (Value::Boolean(_), _) => Some(Ordering::Less),
160            (_, Value::Boolean(_)) => Some(Ordering::Greater),
161
162            // Numeric: Integer and Real are comparable
163            (Value::Integer(_) | Value::Real(_), Value::Integer(_) | Value::Real(_)) => {
164                self.numeric_cmp(other)
165            }
166            (Value::Integer(_) | Value::Real(_), _) => Some(Ordering::Less),
167            (_, Value::Integer(_) | Value::Real(_)) => Some(Ordering::Greater),
168
169            (Value::Text(a), Value::Text(b)) => Some(a.cmp(b)),
170            (Value::Text(_), _) => Some(Ordering::Less),
171            (_, Value::Text(_)) => Some(Ordering::Greater),
172
173            (Value::Blob(a), Value::Blob(b)) => Some(a.cmp(b)),
174        }
175    }
176}
177
178impl Ord for Value {
179    fn cmp(&self, other: &Self) -> Ordering {
180        self.partial_cmp(other).unwrap_or(Ordering::Equal)
181    }
182}
183
184impl fmt::Display for Value {
185    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
186        match self {
187            Value::Null => write!(f, "NULL"),
188            Value::Integer(i) => write!(f, "{i}"),
189            Value::Real(r) => {
190                if r.fract() == 0.0 && r.is_finite() {
191                    write!(f, "{r:.1}")
192                } else {
193                    write!(f, "{r}")
194                }
195            }
196            Value::Text(s) => write!(f, "{s}"),
197            Value::Blob(b) => write!(f, "X'{}'", hex_encode(b)),
198            Value::Boolean(b) => write!(f, "{}", if *b { "TRUE" } else { "FALSE" }),
199        }
200    }
201}
202
203fn hex_encode(data: &[u8]) -> String {
204    let mut s = String::with_capacity(data.len() * 2);
205    for byte in data {
206        s.push_str(&format!("{byte:02X}"));
207    }
208    s
209}
210
211/// Column definition.
212#[derive(Debug, Clone)]
213pub struct ColumnDef {
214    pub name: String,
215    pub data_type: DataType,
216    pub nullable: bool,
217    pub position: u16,
218}
219
220/// Index definition stored as part of the table schema.
221#[derive(Debug, Clone)]
222pub struct IndexDef {
223    pub name: String,
224    pub columns: Vec<u16>,
225    pub unique: bool,
226}
227
228/// Table schema stored in the _schema table.
229#[derive(Debug, Clone)]
230pub struct TableSchema {
231    pub name: String,
232    pub columns: Vec<ColumnDef>,
233    pub primary_key_columns: Vec<u16>,
234    pub indices: Vec<IndexDef>,
235}
236
237const SCHEMA_VERSION: u8 = 2;
238
239impl TableSchema {
240    pub fn serialize(&self) -> Vec<u8> {
241        let mut buf = Vec::new();
242        buf.push(SCHEMA_VERSION);
243
244        // Table name
245        let name_bytes = self.name.as_bytes();
246        buf.extend_from_slice(&(name_bytes.len() as u16).to_le_bytes());
247        buf.extend_from_slice(name_bytes);
248
249        // Column count
250        buf.extend_from_slice(&(self.columns.len() as u16).to_le_bytes());
251
252        // Columns
253        for col in &self.columns {
254            let col_name = col.name.as_bytes();
255            buf.extend_from_slice(&(col_name.len() as u16).to_le_bytes());
256            buf.extend_from_slice(col_name);
257            buf.push(col.data_type.type_tag());
258            buf.push(if col.nullable { 1 } else { 0 });
259            buf.extend_from_slice(&col.position.to_le_bytes());
260        }
261
262        // Primary key columns
263        buf.extend_from_slice(&(self.primary_key_columns.len() as u16).to_le_bytes());
264        for &pk_idx in &self.primary_key_columns {
265            buf.extend_from_slice(&pk_idx.to_le_bytes());
266        }
267
268        // Indices (v2+)
269        buf.extend_from_slice(&(self.indices.len() as u16).to_le_bytes());
270        for idx in &self.indices {
271            let idx_name = idx.name.as_bytes();
272            buf.extend_from_slice(&(idx_name.len() as u16).to_le_bytes());
273            buf.extend_from_slice(idx_name);
274            buf.extend_from_slice(&(idx.columns.len() as u16).to_le_bytes());
275            for &col_idx in &idx.columns {
276                buf.extend_from_slice(&col_idx.to_le_bytes());
277            }
278            buf.push(if idx.unique { 1 } else { 0 });
279        }
280
281        buf
282    }
283
284    pub fn deserialize(data: &[u8]) -> crate::error::Result<Self> {
285        let mut pos = 0;
286
287        if data.is_empty() || (data[0] != SCHEMA_VERSION && data[0] != 1) {
288            return Err(crate::error::SqlError::InvalidValue(
289                "invalid schema version".into(),
290            ));
291        }
292        let version = data[0];
293        pos += 1;
294
295        // Table name
296        let name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
297        pos += 2;
298        let name = String::from_utf8_lossy(&data[pos..pos + name_len]).into_owned();
299        pos += name_len;
300
301        // Column count
302        let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
303        pos += 2;
304
305        let mut columns = Vec::with_capacity(col_count);
306        for _ in 0..col_count {
307            let col_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
308            pos += 2;
309            let col_name = String::from_utf8_lossy(&data[pos..pos + col_name_len]).into_owned();
310            pos += col_name_len;
311            let data_type = DataType::from_tag(data[pos]).ok_or_else(|| {
312                crate::error::SqlError::InvalidValue("unknown data type tag".into())
313            })?;
314            pos += 1;
315            let nullable = data[pos] != 0;
316            pos += 1;
317            let position = u16::from_le_bytes([data[pos], data[pos + 1]]);
318            pos += 2;
319            columns.push(ColumnDef {
320                name: col_name,
321                data_type,
322                nullable,
323                position,
324            });
325        }
326
327        // Primary key columns
328        let pk_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
329        pos += 2;
330        let mut primary_key_columns = Vec::with_capacity(pk_count);
331        for _ in 0..pk_count {
332            let pk_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
333            pos += 2;
334            primary_key_columns.push(pk_idx);
335        }
336        let indices = if version >= 2 && pos + 2 <= data.len() {
337            let idx_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
338            pos += 2;
339            let mut idxs = Vec::with_capacity(idx_count);
340            for _ in 0..idx_count {
341                let idx_name_len = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
342                pos += 2;
343                let idx_name = String::from_utf8_lossy(&data[pos..pos + idx_name_len]).into_owned();
344                pos += idx_name_len;
345                let col_count = u16::from_le_bytes([data[pos], data[pos + 1]]) as usize;
346                pos += 2;
347                let mut cols = Vec::with_capacity(col_count);
348                for _ in 0..col_count {
349                    let col_idx = u16::from_le_bytes([data[pos], data[pos + 1]]);
350                    pos += 2;
351                    cols.push(col_idx);
352                }
353                let unique = data[pos] != 0;
354                pos += 1;
355                idxs.push(IndexDef { name: idx_name, columns: cols, unique });
356            }
357            idxs
358        } else {
359            vec![]
360        };
361        let _ = pos;
362
363        Ok(Self {
364            name,
365            columns,
366            primary_key_columns,
367            indices,
368        })
369    }
370
371    /// Get column index by name (case-insensitive).
372    pub fn column_index(&self, name: &str) -> Option<usize> {
373        let lower = name.to_ascii_lowercase();
374        self.columns.iter().position(|c| c.name.to_ascii_lowercase() == lower)
375    }
376
377    /// Get indices of non-PK columns (columns stored in the B+ tree value).
378    pub fn non_pk_indices(&self) -> Vec<usize> {
379        (0..self.columns.len())
380            .filter(|i| !self.primary_key_columns.contains(&(*i as u16)))
381            .collect()
382    }
383
384    /// Get the PK column indices as usize.
385    pub fn pk_indices(&self) -> Vec<usize> {
386        self.primary_key_columns.iter().map(|&i| i as usize).collect()
387    }
388
389    /// Get index definition by name (case-insensitive).
390    pub fn index_by_name(&self, name: &str) -> Option<&IndexDef> {
391        let lower = name.to_ascii_lowercase();
392        self.indices.iter().find(|i| i.name == lower)
393    }
394
395    /// Get the KV table name for an index.
396    pub fn index_table_name(table_name: &str, index_name: &str) -> Vec<u8> {
397        format!("__idx_{table_name}_{index_name}").into_bytes()
398    }
399}
400
401/// Result of executing a SQL statement.
402#[derive(Debug)]
403pub enum ExecutionResult {
404    RowsAffected(u64),
405    Query(QueryResult),
406    Ok,
407}
408
409/// Result of a SELECT query.
410#[derive(Debug)]
411pub struct QueryResult {
412    pub columns: Vec<String>,
413    pub rows: Vec<Vec<Value>>,
414}
415
416#[cfg(test)]
417mod tests {
418    use super::*;
419
420    #[test]
421    fn value_ordering() {
422        assert!(Value::Null < Value::Boolean(false));
423        assert!(Value::Boolean(false) < Value::Boolean(true));
424        assert!(Value::Boolean(true) < Value::Integer(0));
425        assert!(Value::Integer(-1) < Value::Integer(0));
426        assert!(Value::Integer(0) < Value::Real(0.5));
427        assert!(Value::Real(1.0) < Value::Text("".into()));
428        assert!(Value::Text("a".into()) < Value::Text("b".into()));
429        assert!(Value::Text("z".into()) < Value::Blob(vec![]));
430        assert!(Value::Blob(vec![0]) < Value::Blob(vec![1]));
431    }
432
433    #[test]
434    fn value_numeric_mixed() {
435        assert_eq!(Value::Integer(1), Value::Real(1.0));
436        assert!(Value::Integer(1) < Value::Real(1.5));
437        assert!(Value::Real(0.5) < Value::Integer(1));
438    }
439
440    #[test]
441    fn value_display() {
442        assert_eq!(format!("{}", Value::Null), "NULL");
443        assert_eq!(format!("{}", Value::Integer(42)), "42");
444        assert_eq!(format!("{}", Value::Real(3.14)), "3.14");
445        assert_eq!(format!("{}", Value::Real(1.0)), "1.0");
446        assert_eq!(format!("{}", Value::Text("hello".into())), "hello");
447        assert_eq!(format!("{}", Value::Blob(vec![0xDE, 0xAD])), "X'DEAD'");
448        assert_eq!(format!("{}", Value::Boolean(true)), "TRUE");
449        assert_eq!(format!("{}", Value::Boolean(false)), "FALSE");
450    }
451
452    #[test]
453    fn value_coerce() {
454        assert_eq!(
455            Value::Integer(42).coerce_to(DataType::Real),
456            Some(Value::Real(42.0))
457        );
458        assert_eq!(
459            Value::Boolean(true).coerce_to(DataType::Integer),
460            Some(Value::Integer(1))
461        );
462        assert_eq!(
463            Value::Null.coerce_to(DataType::Integer),
464            Some(Value::Null)
465        );
466        assert_eq!(
467            Value::Text("x".into()).coerce_to(DataType::Integer),
468            None
469        );
470    }
471
472    #[test]
473    fn schema_roundtrip() {
474        let schema = TableSchema {
475            name: "users".into(),
476            columns: vec![
477                ColumnDef { name: "id".into(), data_type: DataType::Integer, nullable: false, position: 0 },
478                ColumnDef { name: "name".into(), data_type: DataType::Text, nullable: true, position: 1 },
479                ColumnDef { name: "active".into(), data_type: DataType::Boolean, nullable: false, position: 2 },
480            ],
481            primary_key_columns: vec![0],
482            indices: vec![],
483        };
484
485        let data = schema.serialize();
486        let restored = TableSchema::deserialize(&data).unwrap();
487
488        assert_eq!(restored.name, "users");
489        assert_eq!(restored.columns.len(), 3);
490        assert_eq!(restored.columns[0].name, "id");
491        assert_eq!(restored.columns[0].data_type, DataType::Integer);
492        assert!(!restored.columns[0].nullable);
493        assert_eq!(restored.columns[1].name, "name");
494        assert_eq!(restored.columns[1].data_type, DataType::Text);
495        assert!(restored.columns[1].nullable);
496        assert_eq!(restored.columns[2].name, "active");
497        assert_eq!(restored.columns[2].data_type, DataType::Boolean);
498        assert_eq!(restored.primary_key_columns, vec![0]);
499    }
500
501    #[test]
502    fn schema_roundtrip_with_indices() {
503        let schema = TableSchema {
504            name: "orders".into(),
505            columns: vec![
506                ColumnDef { name: "id".into(), data_type: DataType::Integer, nullable: false, position: 0 },
507                ColumnDef { name: "customer".into(), data_type: DataType::Text, nullable: false, position: 1 },
508                ColumnDef { name: "amount".into(), data_type: DataType::Real, nullable: true, position: 2 },
509            ],
510            primary_key_columns: vec![0],
511            indices: vec![
512                IndexDef { name: "idx_customer".into(), columns: vec![1], unique: false },
513                IndexDef { name: "idx_amount_uniq".into(), columns: vec![2], unique: true },
514            ],
515        };
516
517        let data = schema.serialize();
518        let restored = TableSchema::deserialize(&data).unwrap();
519
520        assert_eq!(restored.indices.len(), 2);
521        assert_eq!(restored.indices[0].name, "idx_customer");
522        assert_eq!(restored.indices[0].columns, vec![1]);
523        assert!(!restored.indices[0].unique);
524        assert_eq!(restored.indices[1].name, "idx_amount_uniq");
525        assert_eq!(restored.indices[1].columns, vec![2]);
526        assert!(restored.indices[1].unique);
527    }
528
529    #[test]
530    fn schema_v1_backward_compat() {
531        let old_schema = TableSchema {
532            name: "test".into(),
533            columns: vec![
534                ColumnDef { name: "id".into(), data_type: DataType::Integer, nullable: false, position: 0 },
535            ],
536            primary_key_columns: vec![0],
537            indices: vec![],
538        };
539        let mut data = old_schema.serialize();
540        // Patch to v1 format: replace version byte and truncate index data
541        data[0] = 1;
542        // Remove the last 2 bytes (index count = 0)
543        data.truncate(data.len() - 2);
544
545        let restored = TableSchema::deserialize(&data).unwrap();
546        assert_eq!(restored.name, "test");
547        assert!(restored.indices.is_empty());
548    }
549
550    #[test]
551    fn data_type_display() {
552        assert_eq!(format!("{}", DataType::Integer), "INTEGER");
553        assert_eq!(format!("{}", DataType::Text), "TEXT");
554        assert_eq!(format!("{}", DataType::Boolean), "BOOLEAN");
555    }
556}