Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26    Subquery(Box<SelectQuery>),
27    Window { func: WindowFunc, args: Vec<Expr>, over: WindowSpec },
28}
29
30impl Expr {
31    pub fn as_column(&self) -> Option<&str> {
32        if let Expr::Column(name) = self { Some(name) } else { None }
33    }
34
35    pub fn display_name(&self) -> String {
36        match self {
37            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
41            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
42            Expr::Column(name) => name.clone(),
43            Expr::BinaryOp { left, op, right } => {
44                let op_str = match op {
45                    ArithOp::Add => "+",
46                    ArithOp::Sub => "-",
47                    ArithOp::Mul => "*",
48                    ArithOp::Div => "/",
49                    ArithOp::Mod => "%",
50                };
51                format!("{} {} {}", left.display_name(), op_str, right.display_name())
52            }
53            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
54            Expr::Case { .. } => "CASE".to_string(),
55            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
56            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
57            Expr::CurrentDate => "CURRENT_DATE".to_string(),
58            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
59            Expr::Aggregate { func, arg, .. } => {
60                let func_name = match func {
61                    AggFunc::Count => "COUNT",
62                    AggFunc::Sum => "SUM",
63                    AggFunc::Avg => "AVG",
64                    AggFunc::Min => "MIN",
65                    AggFunc::Max => "MAX",
66                };
67                format!("{}({})", func_name, arg)
68            }
69            Expr::Subquery(_) => "(subquery)".to_string(),
70            Expr::Window { func, args, .. } => {
71                let func_name = match func {
72                    WindowFunc::RowNumber => "ROW_NUMBER",
73                    WindowFunc::Rank => "RANK",
74                    WindowFunc::DenseRank => "DENSE_RANK",
75                    WindowFunc::Lag => "LAG",
76                    WindowFunc::Lead => "LEAD",
77                    WindowFunc::Agg(af) => match af {
78                        AggFunc::Count => "COUNT",
79                        AggFunc::Sum => "SUM",
80                        AggFunc::Avg => "AVG",
81                        AggFunc::Min => "MIN",
82                        AggFunc::Max => "MAX",
83                    },
84                };
85                if args.is_empty() {
86                    format!("{}()", func_name)
87                } else {
88                    let arg_strs: Vec<String> = args.iter().map(|a| a.display_name()).collect();
89                    format!("{}({})", func_name, arg_strs.join(", "))
90                }
91            }
92        }
93    }
94
95    pub fn contains_aggregate(&self) -> bool {
96        match self {
97            Expr::Aggregate { .. } => true,
98            Expr::BinaryOp { left, right, .. } => {
99                left.contains_aggregate() || right.contains_aggregate()
100            }
101            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
102            Expr::Case { whens, else_expr } => {
103                whens.iter().any(|(_, e)| e.contains_aggregate())
104                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
105            }
106            Expr::Subquery(_) => false,
107            Expr::Window { .. } => false,
108            _ => false,
109        }
110    }
111
112    pub fn contains_window(&self) -> bool {
113        match self {
114            Expr::Window { .. } => true,
115            Expr::BinaryOp { left, right, .. } => {
116                left.contains_window() || right.contains_window()
117            }
118            Expr::UnaryMinus(inner) => inner.contains_window(),
119            _ => false,
120        }
121    }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125pub struct OrderSpec {
126    pub column: String,
127    pub expr: Option<Expr>,
128    pub descending: bool,
129}
130
131#[derive(Debug, Clone, PartialEq)]
132pub enum CmpOp {
133    Eq,
134    Ne,
135    Lt,
136    Gt,
137    Le,
138    Ge,
139    Like,
140    NotLike,
141    In,
142    IsNull,
143    IsNotNull,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum BoolOpKind {
148    And,
149    Or,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub struct Comparison {
154    pub column: String,
155    pub op: CmpOp,
156    pub value: Option<SqlValue>,
157    pub left_expr: Option<Expr>,
158    pub right_expr: Option<Expr>,
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct BoolOp {
163    pub op: BoolOpKind,
164    pub left: Box<WhereClause>,
165    pub right: Box<WhereClause>,
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub enum WhereClause {
170    Comparison(Comparison),
171    BoolOp(BoolOp),
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub enum SqlValue {
176    String(String),
177    Int(i64),
178    Float(f64),
179    Null,
180    List(Vec<SqlValue>),
181}
182
183#[derive(Debug, Clone, PartialEq)]
184pub enum JoinType {
185    Inner,
186    Left,
187}
188
189#[derive(Debug, Clone, PartialEq)]
190pub struct JoinClause {
191    pub join_type: JoinType,
192    pub table: String,
193    pub alias: Option<String>,
194    pub condition: WhereClause,
195}
196
197#[derive(Debug, Clone, PartialEq)]
198pub enum AggFunc {
199    Count,
200    Sum,
201    Avg,
202    Min,
203    Max,
204}
205
206#[derive(Debug, Clone, PartialEq)]
207pub enum WindowFunc {
208    RowNumber,
209    Rank,
210    DenseRank,
211    Lag,
212    Lead,
213    Agg(AggFunc),
214}
215
216#[derive(Debug, Clone, PartialEq)]
217pub struct WindowSpec {
218    pub partition_by: Vec<String>,
219    pub order_by: Vec<OrderSpec>,
220}
221
222#[derive(Debug, Clone, PartialEq)]
223pub enum SelectExpr {
224    Column(String),
225    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
226    Expr { expr: Expr, alias: Option<String> },
227}
228
229impl SelectExpr {
230    pub fn output_name(&self) -> String {
231        match self {
232            SelectExpr::Column(name) => name.clone(),
233            SelectExpr::Aggregate { func, arg, alias, .. } => {
234                if let Some(a) = alias {
235                    a.clone()
236                } else {
237                    let func_name = match func {
238                        AggFunc::Count => "COUNT",
239                        AggFunc::Sum => "SUM",
240                        AggFunc::Avg => "AVG",
241                        AggFunc::Min => "MIN",
242                        AggFunc::Max => "MAX",
243                    };
244                    format!("{}({})", func_name, arg)
245                }
246            }
247            SelectExpr::Expr { expr, alias } => {
248                alias.clone().unwrap_or_else(|| expr.display_name())
249            }
250        }
251    }
252
253    pub fn is_aggregate(&self) -> bool {
254        match self {
255            SelectExpr::Aggregate { .. } => true,
256            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
257            _ => false,
258        }
259    }
260}
261
262#[derive(Debug, Clone, PartialEq)]
263pub struct CteClause {
264    pub name: String,
265    pub query: Box<SelectQuery>,
266}
267
268#[derive(Debug, Clone, PartialEq)]
269pub struct SelectQuery {
270    pub columns: ColumnList,
271    pub table: String,
272    pub table_alias: Option<String>,
273    pub subquery: Option<Box<SelectQuery>>,
274    pub joins: Vec<JoinClause>,
275    pub where_clause: Option<WhereClause>,
276    pub group_by: Option<Vec<String>>,
277    pub having: Option<WhereClause>,
278    pub order_by: Option<Vec<OrderSpec>>,
279    pub limit: Option<i64>,
280    pub ctes: Vec<CteClause>,
281}
282
283#[derive(Debug, Clone, PartialEq)]
284pub enum ColumnList {
285    All,
286    Named(Vec<SelectExpr>),
287}
288
289#[derive(Debug, Clone, PartialEq)]
290pub struct InsertQuery {
291    pub table: String,
292    pub columns: Vec<String>,
293    pub values: Vec<SqlValue>,
294}
295
296#[derive(Debug, Clone, PartialEq)]
297pub struct UpdateQuery {
298    pub table: String,
299    pub assignments: Vec<(String, SqlValue)>,
300    pub where_clause: Option<WhereClause>,
301}
302
303#[derive(Debug, Clone, PartialEq)]
304pub enum DeleteMode {
305    Default,
306    Cascade,
307    Restrict,
308}
309
310#[derive(Debug, Clone, PartialEq)]
311pub struct DeleteQuery {
312    pub table: String,
313    pub where_clause: Option<WhereClause>,
314    pub mode: DeleteMode,
315}
316
317#[derive(Debug, Clone, PartialEq)]
318pub struct AlterRenameFieldQuery {
319    pub table: String,
320    pub old_name: String,
321    pub new_name: String,
322}
323
324#[derive(Debug, Clone, PartialEq)]
325pub struct AlterDropFieldQuery {
326    pub table: String,
327    pub field_name: String,
328}
329
330#[derive(Debug, Clone, PartialEq)]
331pub struct AlterMergeFieldsQuery {
332    pub table: String,
333    pub sources: Vec<String>,
334    pub into: String,
335}
336
337#[derive(Debug, Clone, PartialEq)]
338pub struct CreateViewQuery {
339    pub view_name: String,
340    pub columns: Option<Vec<String>>,
341    pub query: Box<SelectQuery>,
342}
343
344#[derive(Debug, Clone, PartialEq)]
345pub struct DropViewQuery {
346    pub view_name: String,
347}
348
349#[derive(Debug, Clone, PartialEq)]
350pub enum Statement {
351    Select(SelectQuery),
352    Insert(InsertQuery),
353    Update(UpdateQuery),
354    Delete(DeleteQuery),
355    AlterRename(AlterRenameFieldQuery),
356    AlterDrop(AlterDropFieldQuery),
357    AlterMerge(AlterMergeFieldsQuery),
358    CreateView(CreateViewQuery),
359    DropView(DropViewQuery),
360}
361
362impl Statement {
363    pub fn table_name(&self) -> &str {
364        match self {
365            Statement::Select(q) => &q.table,
366            Statement::Insert(q) => &q.table,
367            Statement::Update(q) => &q.table,
368            Statement::Delete(q) => &q.table,
369            Statement::AlterRename(q) => &q.table,
370            Statement::AlterDrop(q) => &q.table,
371            Statement::AlterMerge(q) => &q.table,
372            Statement::CreateView(q) => &q.view_name,
373            Statement::DropView(q) => &q.view_name,
374        }
375    }
376}
377
378pub fn where_clause_to_sql(clause: &WhereClause) -> String {
379    match clause {
380        WhereClause::BoolOp(bop) => {
381            let left = where_clause_to_sql(&bop.left);
382            let right = where_clause_to_sql(&bop.right);
383            let op = match bop.op {
384                BoolOpKind::And => "AND",
385                BoolOpKind::Or => "OR",
386            };
387            format!("{} {} {}", left, op, right)
388        }
389        WhereClause::Comparison(cmp) => {
390            let op_str = match cmp.op {
391                CmpOp::Eq => "=",
392                CmpOp::Ne => "!=",
393                CmpOp::Lt => "<",
394                CmpOp::Gt => ">",
395                CmpOp::Le => "<=",
396                CmpOp::Ge => ">=",
397                CmpOp::Like => "LIKE",
398                CmpOp::NotLike => "NOT LIKE",
399                CmpOp::In => "IN",
400                CmpOp::IsNull => "IS NULL",
401                CmpOp::IsNotNull => "IS NOT NULL",
402            };
403            if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
404                if let Some(ref expr) = cmp.left_expr {
405                    return format!("{} {}", expr.display_name(), op_str);
406                }
407                return format!("{} {}", cmp.column, op_str);
408            }
409            if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
410                return format!("{} {} {}", left.display_name(), op_str, right.display_name());
411            }
412            match &cmp.value {
413                Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
414                Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
415                Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
416                Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
417                Some(SqlValue::List(items)) => {
418                    let vals: Vec<String> = items.iter().map(|v| match v {
419                        SqlValue::String(s) => format!("'{}'", s),
420                        SqlValue::Int(n) => n.to_string(),
421                        SqlValue::Float(f) => f.to_string(),
422                        _ => "NULL".to_string(),
423                    }).collect();
424                    format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
425                }
426                None => format!("{} {}", cmp.column, op_str),
427            }
428        }
429    }
430}