Skip to main content

qcraft_core/ast/
expr.rs

1use super::common::{FieldRef, OrderByDef};
2use super::conditions::Conditions;
3use super::custom::{CustomBinaryOp, CustomExpr};
4use super::query::QueryStmt;
5use super::value::Value;
6
7/// An expression in a SQL statement.
8#[derive(Debug, Clone)]
9pub enum Expr {
10    /// Literal value.
11    Value(Value),
12
13    /// Column reference.
14    Field(FieldRef),
15
16    /// Binary operation: `left op right`.
17    Binary {
18        left: Box<Expr>,
19        op: BinaryOp,
20        right: Box<Expr>,
21    },
22
23    /// Unary operation: `-expr`, `NOT expr`.
24    Unary { op: UnaryOp, expr: Box<Expr> },
25
26    /// Function call: `name(args...)`.
27    Func { name: String, args: Vec<Expr> },
28
29    /// Aggregate function: `COUNT(expr)`, `SUM(DISTINCT expr) FILTER (WHERE ...)`.
30    Aggregate(AggregationDef),
31
32    /// Type cast: `expr::type` (PG) or `CAST(expr AS type)`.
33    Cast { expr: Box<Expr>, to_type: String },
34
35    /// CASE WHEN ... THEN ... ELSE ... END.
36    Case(CaseDef),
37
38    /// Window function: `expr OVER (PARTITION BY ... ORDER BY ... frame)`.
39    Window(WindowDef),
40
41    /// EXISTS (subquery).
42    Exists(Box<QueryStmt>),
43
44    /// Scalar subquery.
45    SubQuery(Box<QueryStmt>),
46
47    /// ARRAY(subquery).
48    ArraySubQuery(Box<QueryStmt>),
49
50    /// Collation override: `expr COLLATE "name"`.
51    Collate { expr: Box<Expr>, collation: String },
52
53    /// Raw SQL with parameters (escape hatch).
54    Raw { sql: String, params: Vec<Value> },
55
56    /// User-defined expression (extension point).
57    Custom(Box<dyn CustomExpr>),
58}
59
60impl Expr {
61    /// Column reference: `table.field`.
62    pub fn field(table: &str, name: &str) -> Self {
63        Expr::Field(FieldRef::new(table, name))
64    }
65
66    /// Literal value.
67    pub fn value(val: impl Into<Value>) -> Self {
68        Expr::Value(val.into())
69    }
70
71    /// Raw SQL expression (no parameters).
72    pub fn raw(sql: impl Into<String>) -> Self {
73        Expr::Raw {
74            sql: sql.into(),
75            params: vec![],
76        }
77    }
78
79    /// Function call: `name(args...)`.
80    pub fn func(name: impl Into<String>, args: Vec<Expr>) -> Self {
81        Expr::Func {
82            name: name.into(),
83            args,
84        }
85    }
86
87    /// Type cast: `CAST(expr AS to_type)`.
88    pub fn cast(expr: Expr, to_type: impl Into<String>) -> Self {
89        Expr::Cast {
90            expr: Box::new(expr),
91            to_type: to_type.into(),
92        }
93    }
94
95    /// COUNT(expr).
96    pub fn count(expr: Expr) -> Self {
97        Expr::Aggregate(AggregationDef {
98            name: "COUNT".into(),
99            expression: Some(Box::new(expr)),
100            distinct: false,
101            filter: None,
102            args: None,
103            order_by: None,
104        })
105    }
106
107    /// COUNT(*).
108    pub fn count_all() -> Self {
109        Expr::Aggregate(AggregationDef {
110            name: "COUNT".into(),
111            expression: None,
112            distinct: false,
113            filter: None,
114            args: None,
115            order_by: None,
116        })
117    }
118
119    /// SUM(expr).
120    pub fn sum(expr: Expr) -> Self {
121        Expr::Aggregate(AggregationDef {
122            name: "SUM".into(),
123            expression: Some(Box::new(expr)),
124            distinct: false,
125            filter: None,
126            args: None,
127            order_by: None,
128        })
129    }
130
131    /// AVG(expr).
132    pub fn avg(expr: Expr) -> Self {
133        Expr::Aggregate(AggregationDef {
134            name: "AVG".into(),
135            expression: Some(Box::new(expr)),
136            distinct: false,
137            filter: None,
138            args: None,
139            order_by: None,
140        })
141    }
142
143    /// MIN(expr).
144    pub fn min(expr: Expr) -> Self {
145        Expr::Aggregate(AggregationDef {
146            name: "MIN".into(),
147            expression: Some(Box::new(expr)),
148            distinct: false,
149            filter: None,
150            args: None,
151            order_by: None,
152        })
153    }
154
155    /// MAX(expr).
156    pub fn max(expr: Expr) -> Self {
157        Expr::Aggregate(AggregationDef {
158            name: "MAX".into(),
159            expression: Some(Box::new(expr)),
160            distinct: false,
161            filter: None,
162            args: None,
163            order_by: None,
164        })
165    }
166
167    /// EXISTS (subquery).
168    pub fn exists(query: QueryStmt) -> Self {
169        Expr::Exists(Box::new(query))
170    }
171
172    /// Scalar subquery.
173    pub fn subquery(query: QueryStmt) -> Self {
174        Expr::SubQuery(Box::new(query))
175    }
176
177    /// Collation override: `expr COLLATE "name"`.
178    pub fn collate(self, collation: impl Into<String>) -> Self {
179        Expr::Collate {
180            expr: Box::new(self),
181            collation: collation.into(),
182        }
183    }
184}
185
186impl From<Value> for Expr {
187    fn from(v: Value) -> Self {
188        Expr::Value(v)
189    }
190}
191
192impl From<FieldRef> for Expr {
193    fn from(f: FieldRef) -> Self {
194        Expr::Field(f)
195    }
196}
197
198/// Binary operators.
199#[derive(Debug, Clone)]
200pub enum BinaryOp {
201    Add,
202    Sub,
203    Mul,
204    Div,
205    Mod,
206    BitwiseAnd,
207    BitwiseOr,
208    ShiftLeft,
209    ShiftRight,
210    Concat,
211
212    /// User-defined binary operator (extension point).
213    Custom(Box<dyn CustomBinaryOp>),
214}
215
216/// Unary operators.
217#[derive(Debug, Clone, Copy, PartialEq, Eq)]
218pub enum UnaryOp {
219    Neg,
220    Not,
221    BitwiseNot,
222}
223
224/// Aggregate function definition.
225#[derive(Debug, Clone)]
226pub struct AggregationDef {
227    pub name: String,
228    pub expression: Option<Box<Expr>>,
229    pub distinct: bool,
230    pub filter: Option<Conditions>,
231    pub args: Option<Vec<Expr>>,
232    pub order_by: Option<Vec<OrderByDef>>,
233}
234
235impl AggregationDef {
236    pub fn new(name: impl Into<String>, expr: Expr) -> Self {
237        Self {
238            name: name.into(),
239            expression: Some(Box::new(expr)),
240            distinct: false,
241            filter: None,
242            args: None,
243            order_by: None,
244        }
245    }
246
247    pub fn count_all() -> Self {
248        Self {
249            name: "COUNT".into(),
250            expression: None,
251            distinct: false,
252            filter: None,
253            args: None,
254            order_by: None,
255        }
256    }
257
258    pub fn distinct(mut self) -> Self {
259        self.distinct = true;
260        self
261    }
262
263    pub fn filter(mut self, cond: Conditions) -> Self {
264        self.filter = Some(cond);
265        self
266    }
267
268    pub fn order_by(mut self, order: Vec<OrderByDef>) -> Self {
269        self.order_by = Some(order);
270        self
271    }
272}
273
274/// CASE expression.
275#[derive(Debug, Clone)]
276pub struct CaseDef {
277    pub cases: Vec<WhenClause>,
278    pub default: Option<Box<Expr>>,
279}
280
281/// WHEN condition THEN result.
282#[derive(Debug, Clone)]
283pub struct WhenClause {
284    pub condition: Conditions,
285    pub result: Expr,
286}
287
288/// Window function definition.
289#[derive(Debug, Clone)]
290pub struct WindowDef {
291    pub expression: Box<Expr>,
292    pub partition_by: Option<Vec<Expr>>,
293    pub order_by: Option<Vec<OrderByDef>>,
294    pub frame: Option<WindowFrameDef>,
295}
296
297/// Window frame specification.
298#[derive(Debug, Clone)]
299pub struct WindowFrameDef {
300    pub frame_type: WindowFrameType,
301    pub start: WindowFrameBound,
302    pub end: Option<WindowFrameBound>,
303}
304
305/// Window frame type.
306#[derive(Debug, Clone, Copy, PartialEq, Eq)]
307pub enum WindowFrameType {
308    Rows,
309    Range,
310    Groups,
311}
312
313/// Window frame bound.
314#[derive(Debug, Clone, PartialEq, Eq)]
315pub enum WindowFrameBound {
316    CurrentRow,
317    Preceding(Option<u64>),
318    Following(Option<u64>),
319}