Skip to main content

cynos_query/ast/
expr.rs

1//! Expression AST definitions.
2
3use alloc::boxed::Box;
4use alloc::string::String;
5use alloc::vec::Vec;
6use cynos_core::Value;
7
8/// Reference to a column in a table.
9#[derive(Clone, Debug, PartialEq, Eq, Hash)]
10pub struct ColumnRef {
11    /// Table name (or alias).
12    pub table: String,
13    /// Column name.
14    pub column: String,
15    /// Column index in the table schema.
16    pub index: usize,
17}
18
19impl ColumnRef {
20    /// Creates a new column reference.
21    pub fn new(table: impl Into<String>, column: impl Into<String>, index: usize) -> Self {
22        Self {
23            table: table.into(),
24            column: column.into(),
25            index,
26        }
27    }
28
29    /// Returns the normalized name (table.column).
30    pub fn normalized_name(&self) -> String {
31        alloc::format!("{}.{}", self.table, self.column)
32    }
33}
34
35/// Binary operators.
36#[derive(Clone, Copy, Debug, PartialEq, Eq)]
37pub enum BinaryOp {
38    // Comparison
39    Eq,
40    Ne,
41    Lt,
42    Le,
43    Gt,
44    Ge,
45    // Logical
46    And,
47    Or,
48    // Arithmetic
49    Add,
50    Sub,
51    Mul,
52    Div,
53    Mod,
54    // String/Pattern
55    Like,
56    // Set
57    In,
58    Between,
59}
60
61/// Unary operators.
62#[derive(Clone, Copy, Debug, PartialEq, Eq)]
63pub enum UnaryOp {
64    Not,
65    Neg,
66    IsNull,
67    IsNotNull,
68}
69
70/// Aggregate functions.
71#[derive(Clone, Copy, Debug, PartialEq, Eq)]
72pub enum AggregateFunc {
73    Count,
74    Sum,
75    Avg,
76    Min,
77    Max,
78    Distinct,
79    StdDev,
80    GeoMean,
81}
82
83/// Sort order.
84#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
85pub enum SortOrder {
86    #[default]
87    Asc,
88    Desc,
89}
90
91/// Expression AST node.
92#[derive(Clone, Debug)]
93pub enum Expr {
94    /// Column reference.
95    Column(ColumnRef),
96    /// Literal value.
97    Literal(Value),
98    /// Binary operation.
99    BinaryOp {
100        left: Box<Expr>,
101        op: BinaryOp,
102        right: Box<Expr>,
103    },
104    /// Unary operation.
105    UnaryOp { op: UnaryOp, expr: Box<Expr> },
106    /// Function call.
107    Function { name: String, args: Vec<Expr> },
108    /// Aggregate function.
109    Aggregate {
110        func: AggregateFunc,
111        expr: Option<Box<Expr>>,
112        distinct: bool,
113    },
114    /// BETWEEN expression.
115    Between {
116        expr: Box<Expr>,
117        low: Box<Expr>,
118        high: Box<Expr>,
119    },
120    /// NOT BETWEEN expression.
121    NotBetween {
122        expr: Box<Expr>,
123        low: Box<Expr>,
124        high: Box<Expr>,
125    },
126    /// IN expression.
127    In {
128        expr: Box<Expr>,
129        list: Vec<Expr>,
130    },
131    /// NOT IN expression.
132    NotIn {
133        expr: Box<Expr>,
134        list: Vec<Expr>,
135    },
136    /// LIKE expression.
137    Like {
138        expr: Box<Expr>,
139        pattern: String,
140    },
141    /// NOT LIKE expression.
142    NotLike {
143        expr: Box<Expr>,
144        pattern: String,
145    },
146    /// MATCH (regex) expression.
147    Match {
148        expr: Box<Expr>,
149        pattern: String,
150    },
151    /// NOT MATCH (regex) expression.
152    NotMatch {
153        expr: Box<Expr>,
154        pattern: String,
155    },
156}
157
158impl Expr {
159    /// Creates a column reference expression.
160    pub fn column(table: impl Into<String>, column: impl Into<String>, index: usize) -> Self {
161        Expr::Column(ColumnRef::new(table, column, index))
162    }
163
164    /// Creates a literal expression.
165    pub fn literal(value: impl Into<Value>) -> Self {
166        Expr::Literal(value.into())
167    }
168
169    /// Creates an equality expression.
170    pub fn eq(left: Expr, right: Expr) -> Self {
171        Expr::BinaryOp {
172            left: Box::new(left),
173            op: BinaryOp::Eq,
174            right: Box::new(right),
175        }
176    }
177
178    /// Creates a not-equal expression.
179    pub fn ne(left: Expr, right: Expr) -> Self {
180        Expr::BinaryOp {
181            left: Box::new(left),
182            op: BinaryOp::Ne,
183            right: Box::new(right),
184        }
185    }
186
187    /// Creates a less-than expression.
188    pub fn lt(left: Expr, right: Expr) -> Self {
189        Expr::BinaryOp {
190            left: Box::new(left),
191            op: BinaryOp::Lt,
192            right: Box::new(right),
193        }
194    }
195
196    /// Creates a less-than-or-equal expression.
197    pub fn le(left: Expr, right: Expr) -> Self {
198        Expr::BinaryOp {
199            left: Box::new(left),
200            op: BinaryOp::Le,
201            right: Box::new(right),
202        }
203    }
204
205    /// Creates a greater-than expression.
206    pub fn gt(left: Expr, right: Expr) -> Self {
207        Expr::BinaryOp {
208            left: Box::new(left),
209            op: BinaryOp::Gt,
210            right: Box::new(right),
211        }
212    }
213
214    /// Creates a greater-than-or-equal expression.
215    pub fn ge(left: Expr, right: Expr) -> Self {
216        Expr::BinaryOp {
217            left: Box::new(left),
218            op: BinaryOp::Ge,
219            right: Box::new(right),
220        }
221    }
222
223    /// Creates an AND expression.
224    pub fn and(left: Expr, right: Expr) -> Self {
225        Expr::BinaryOp {
226            left: Box::new(left),
227            op: BinaryOp::And,
228            right: Box::new(right),
229        }
230    }
231
232    /// Creates an OR expression.
233    pub fn or(left: Expr, right: Expr) -> Self {
234        Expr::BinaryOp {
235            left: Box::new(left),
236            op: BinaryOp::Or,
237            right: Box::new(right),
238        }
239    }
240
241    /// Creates a NOT expression.
242    pub fn not(expr: Expr) -> Self {
243        Expr::UnaryOp {
244            op: UnaryOp::Not,
245            expr: Box::new(expr),
246        }
247    }
248
249    /// Creates an IS NULL expression.
250    pub fn is_null(expr: Expr) -> Self {
251        Expr::UnaryOp {
252            op: UnaryOp::IsNull,
253            expr: Box::new(expr),
254        }
255    }
256
257    /// Creates an IS NOT NULL expression.
258    pub fn is_not_null(expr: Expr) -> Self {
259        Expr::UnaryOp {
260            op: UnaryOp::IsNotNull,
261            expr: Box::new(expr),
262        }
263    }
264
265    /// Creates a COUNT(*) aggregate.
266    pub fn count_star() -> Self {
267        Expr::Aggregate {
268            func: AggregateFunc::Count,
269            expr: None,
270            distinct: false,
271        }
272    }
273
274    /// Creates a COUNT(expr) aggregate.
275    pub fn count(expr: Expr) -> Self {
276        Expr::Aggregate {
277            func: AggregateFunc::Count,
278            expr: Some(Box::new(expr)),
279            distinct: false,
280        }
281    }
282
283    /// Creates a SUM aggregate.
284    pub fn sum(expr: Expr) -> Self {
285        Expr::Aggregate {
286            func: AggregateFunc::Sum,
287            expr: Some(Box::new(expr)),
288            distinct: false,
289        }
290    }
291
292    /// Creates an AVG aggregate.
293    pub fn avg(expr: Expr) -> Self {
294        Expr::Aggregate {
295            func: AggregateFunc::Avg,
296            expr: Some(Box::new(expr)),
297            distinct: false,
298        }
299    }
300
301    /// Creates a MIN aggregate.
302    pub fn min(expr: Expr) -> Self {
303        Expr::Aggregate {
304            func: AggregateFunc::Min,
305            expr: Some(Box::new(expr)),
306            distinct: false,
307        }
308    }
309
310    /// Creates a MAX aggregate.
311    pub fn max(expr: Expr) -> Self {
312        Expr::Aggregate {
313            func: AggregateFunc::Max,
314            expr: Some(Box::new(expr)),
315            distinct: false,
316        }
317    }
318
319    /// Creates a greater-than-or-equal expression.
320    pub fn gte(left: Expr, right: Expr) -> Self {
321        Expr::BinaryOp {
322            left: Box::new(left),
323            op: BinaryOp::Ge,
324            right: Box::new(right),
325        }
326    }
327
328    /// Creates a less-than-or-equal expression.
329    pub fn lte(left: Expr, right: Expr) -> Self {
330        Expr::BinaryOp {
331            left: Box::new(left),
332            op: BinaryOp::Le,
333            right: Box::new(right),
334        }
335    }
336
337    /// Creates a BETWEEN expression.
338    pub fn between(expr: Expr, low: Expr, high: Expr) -> Self {
339        Expr::Between {
340            expr: Box::new(expr),
341            low: Box::new(low),
342            high: Box::new(high),
343        }
344    }
345
346    /// Creates a NOT BETWEEN expression.
347    pub fn not_between(expr: Expr, low: Expr, high: Expr) -> Self {
348        Expr::NotBetween {
349            expr: Box::new(expr),
350            low: Box::new(low),
351            high: Box::new(high),
352        }
353    }
354
355    /// Creates an IN expression.
356    pub fn in_list(expr: Expr, values: Vec<Value>) -> Self {
357        Expr::In {
358            expr: Box::new(expr),
359            list: values.into_iter().map(Expr::Literal).collect(),
360        }
361    }
362
363    /// Creates a NOT IN expression.
364    pub fn not_in_list(expr: Expr, values: Vec<Value>) -> Self {
365        Expr::NotIn {
366            expr: Box::new(expr),
367            list: values.into_iter().map(Expr::Literal).collect(),
368        }
369    }
370
371    /// Creates a LIKE expression.
372    pub fn like(expr: Expr, pattern: &str) -> Self {
373        Expr::Like {
374            expr: Box::new(expr),
375            pattern: pattern.into(),
376        }
377    }
378
379    /// Creates a NOT LIKE expression.
380    pub fn not_like(expr: Expr, pattern: &str) -> Self {
381        Expr::NotLike {
382            expr: Box::new(expr),
383            pattern: pattern.into(),
384        }
385    }
386
387    /// Creates a MATCH (regex) expression.
388    pub fn regex_match(expr: Expr, pattern: &str) -> Self {
389        Expr::Match {
390            expr: Box::new(expr),
391            pattern: pattern.into(),
392        }
393    }
394
395    /// Creates a NOT MATCH (regex) expression.
396    pub fn not_regex_match(expr: Expr, pattern: &str) -> Self {
397        Expr::NotMatch {
398            expr: Box::new(expr),
399            pattern: pattern.into(),
400        }
401    }
402
403    /// Creates a JSONB path equality expression.
404    pub fn jsonb_path_eq(expr: Expr, path: &str, value: Value) -> Self {
405        // Simplified: treat as function call
406        Expr::Function {
407            name: "jsonb_path_eq".into(),
408            args: alloc::vec![expr, Expr::literal(path), Expr::Literal(value)],
409        }
410    }
411
412    /// Creates a JSONB contains expression.
413    pub fn jsonb_contains(expr: Expr, path: &str) -> Self {
414        Expr::Function {
415            name: "jsonb_contains".into(),
416            args: alloc::vec![expr, Expr::literal(path)],
417        }
418    }
419
420    /// Creates a JSONB exists expression.
421    pub fn jsonb_exists(expr: Expr, path: &str) -> Self {
422        Expr::Function {
423            name: "jsonb_exists".into(),
424            args: alloc::vec![expr, Expr::literal(path)],
425        }
426    }
427
428    /// Checks if this is an equi-join condition (column = column).
429    pub fn is_equi_join(&self) -> bool {
430        matches!(
431            self,
432            Expr::BinaryOp {
433                op: BinaryOp::Eq,
434                left,
435                right
436            } if matches!(left.as_ref(), Expr::Column(_)) && matches!(right.as_ref(), Expr::Column(_))
437        )
438    }
439
440    /// Checks if this is a range join condition (>, <, >=, <=).
441    pub fn is_range_join(&self) -> bool {
442        matches!(
443            self,
444            Expr::BinaryOp {
445                op: BinaryOp::Lt | BinaryOp::Le | BinaryOp::Gt | BinaryOp::Ge,
446                left,
447                right
448            } if matches!(left.as_ref(), Expr::Column(_)) && matches!(right.as_ref(), Expr::Column(_))
449        )
450    }
451}
452
453#[cfg(test)]
454mod tests {
455    use super::*;
456
457    #[test]
458    fn test_column_ref() {
459        let col = ColumnRef::new("users", "id", 0);
460        assert_eq!(col.table, "users");
461        assert_eq!(col.column, "id");
462        assert_eq!(col.index, 0);
463        assert_eq!(col.normalized_name(), "users.id");
464    }
465
466    #[test]
467    fn test_expr_builders() {
468        let col = Expr::column("t", "c", 0);
469        assert!(matches!(col, Expr::Column(_)));
470
471        let lit = Expr::literal(42i64);
472        assert!(matches!(lit, Expr::Literal(Value::Int64(42))));
473
474        let eq = Expr::eq(Expr::column("t", "a", 0), Expr::column("t", "b", 1));
475        assert!(matches!(eq, Expr::BinaryOp { op: BinaryOp::Eq, .. }));
476    }
477
478    #[test]
479    fn test_is_equi_join() {
480        let equi = Expr::eq(Expr::column("a", "id", 0), Expr::column("b", "id", 0));
481        assert!(equi.is_equi_join());
482
483        let non_equi = Expr::eq(Expr::column("a", "id", 0), Expr::literal(1i64));
484        assert!(!non_equi.is_equi_join());
485
486        let range = Expr::gt(Expr::column("a", "id", 0), Expr::column("b", "id", 0));
487        assert!(!range.is_equi_join());
488        assert!(range.is_range_join());
489    }
490}