Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26    Subquery(Box<SelectQuery>),
27    Window { func: WindowFunc, args: Vec<Expr>, over: WindowSpec },
28}
29
30impl Expr {
31    pub fn as_column(&self) -> Option<&str> {
32        if let Expr::Column(name) = self { Some(name) } else { None }
33    }
34
35    pub fn display_name(&self) -> String {
36        match self {
37            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40            Expr::Literal(SqlValue::Bool(b)) => b.to_string(),
41            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
42            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
43            Expr::Column(name) => name.clone(),
44            Expr::BinaryOp { left, op, right } => {
45                let op_str = match op {
46                    ArithOp::Add => "+",
47                    ArithOp::Sub => "-",
48                    ArithOp::Mul => "*",
49                    ArithOp::Div => "/",
50                    ArithOp::Mod => "%",
51                };
52                format!("{} {} {}", left.display_name(), op_str, right.display_name())
53            }
54            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
55            Expr::Case { .. } => "CASE".to_string(),
56            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
57            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
58            Expr::CurrentDate => "CURRENT_DATE".to_string(),
59            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
60            Expr::Aggregate { func, arg, .. } => {
61                let func_name = match func {
62                    AggFunc::Count => "COUNT",
63                    AggFunc::Sum => "SUM",
64                    AggFunc::Avg => "AVG",
65                    AggFunc::Min => "MIN",
66                    AggFunc::Max => "MAX",
67                };
68                format!("{}({})", func_name, arg)
69            }
70            Expr::Subquery(_) => "(subquery)".to_string(),
71            Expr::Window { func, args, .. } => {
72                let func_name = match func {
73                    WindowFunc::RowNumber => "ROW_NUMBER",
74                    WindowFunc::Rank => "RANK",
75                    WindowFunc::DenseRank => "DENSE_RANK",
76                    WindowFunc::Lag => "LAG",
77                    WindowFunc::Lead => "LEAD",
78                    WindowFunc::Agg(af) => match af {
79                        AggFunc::Count => "COUNT",
80                        AggFunc::Sum => "SUM",
81                        AggFunc::Avg => "AVG",
82                        AggFunc::Min => "MIN",
83                        AggFunc::Max => "MAX",
84                    },
85                };
86                if args.is_empty() {
87                    format!("{}()", func_name)
88                } else {
89                    let arg_strs: Vec<String> = args.iter().map(|a| a.display_name()).collect();
90                    format!("{}({})", func_name, arg_strs.join(", "))
91                }
92            }
93        }
94    }
95
96    pub fn contains_aggregate(&self) -> bool {
97        match self {
98            Expr::Aggregate { .. } => true,
99            Expr::BinaryOp { left, right, .. } => {
100                left.contains_aggregate() || right.contains_aggregate()
101            }
102            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
103            Expr::Case { whens, else_expr } => {
104                whens.iter().any(|(_, e)| e.contains_aggregate())
105                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
106            }
107            Expr::Subquery(_) => false,
108            Expr::Window { .. } => false,
109            _ => false,
110        }
111    }
112
113    pub fn contains_window(&self) -> bool {
114        match self {
115            Expr::Window { .. } => true,
116            Expr::BinaryOp { left, right, .. } => {
117                left.contains_window() || right.contains_window()
118            }
119            Expr::UnaryMinus(inner) => inner.contains_window(),
120            _ => false,
121        }
122    }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126pub struct OrderSpec {
127    pub column: String,
128    pub expr: Option<Expr>,
129    pub descending: bool,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub enum CmpOp {
134    Eq,
135    Ne,
136    Lt,
137    Gt,
138    Le,
139    Ge,
140    Like,
141    NotLike,
142    In,
143    IsNull,
144    IsNotNull,
145}
146
147#[derive(Debug, Clone, PartialEq)]
148pub enum BoolOpKind {
149    And,
150    Or,
151}
152
153#[derive(Debug, Clone, PartialEq)]
154pub struct Comparison {
155    pub column: String,
156    pub op: CmpOp,
157    pub value: Option<SqlValue>,
158    pub left_expr: Option<Expr>,
159    pub right_expr: Option<Expr>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct BoolOp {
164    pub op: BoolOpKind,
165    pub left: Box<WhereClause>,
166    pub right: Box<WhereClause>,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub enum WhereClause {
171    Comparison(Comparison),
172    BoolOp(BoolOp),
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub enum SqlValue {
177    String(String),
178    Int(i64),
179    Float(f64),
180    Bool(bool),
181    Null,
182    List(Vec<SqlValue>),
183}
184
185#[derive(Debug, Clone, PartialEq)]
186pub enum JoinType {
187    Inner,
188    Left,
189}
190
191#[derive(Debug, Clone, PartialEq)]
192pub struct JoinClause {
193    pub join_type: JoinType,
194    pub table: String,
195    pub alias: Option<String>,
196    pub condition: WhereClause,
197}
198
199#[derive(Debug, Clone, PartialEq)]
200pub enum AggFunc {
201    Count,
202    Sum,
203    Avg,
204    Min,
205    Max,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub enum WindowFunc {
210    RowNumber,
211    Rank,
212    DenseRank,
213    Lag,
214    Lead,
215    Agg(AggFunc),
216}
217
218#[derive(Debug, Clone, PartialEq)]
219pub struct WindowSpec {
220    pub partition_by: Vec<String>,
221    pub order_by: Vec<OrderSpec>,
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum SelectExpr {
226    Column(String),
227    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
228    Expr { expr: Expr, alias: Option<String> },
229}
230
231impl SelectExpr {
232    pub fn output_name(&self) -> String {
233        match self {
234            SelectExpr::Column(name) => name.clone(),
235            SelectExpr::Aggregate { func, arg, alias, .. } => {
236                if let Some(a) = alias {
237                    a.clone()
238                } else {
239                    let func_name = match func {
240                        AggFunc::Count => "COUNT",
241                        AggFunc::Sum => "SUM",
242                        AggFunc::Avg => "AVG",
243                        AggFunc::Min => "MIN",
244                        AggFunc::Max => "MAX",
245                    };
246                    format!("{}({})", func_name, arg)
247                }
248            }
249            SelectExpr::Expr { expr, alias } => {
250                alias.clone().unwrap_or_else(|| expr.display_name())
251            }
252        }
253    }
254
255    pub fn is_aggregate(&self) -> bool {
256        match self {
257            SelectExpr::Aggregate { .. } => true,
258            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
259            _ => false,
260        }
261    }
262}
263
264#[derive(Debug, Clone, PartialEq)]
265pub struct CteClause {
266    pub name: String,
267    pub query: Box<SelectQuery>,
268}
269
270#[derive(Debug, Clone, PartialEq)]
271pub struct SelectQuery {
272    pub distinct: bool,
273    pub columns: ColumnList,
274    pub table: String,
275    pub table_alias: Option<String>,
276    pub subquery: Option<Box<SelectQuery>>,
277    pub joins: Vec<JoinClause>,
278    pub where_clause: Option<WhereClause>,
279    pub group_by: Option<Vec<String>>,
280    pub having: Option<WhereClause>,
281    pub order_by: Option<Vec<OrderSpec>>,
282    pub limit: Option<i64>,
283    pub ctes: Vec<CteClause>,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub enum ColumnList {
288    All,
289    Named(Vec<SelectExpr>),
290}
291
292#[derive(Debug, Clone, PartialEq)]
293pub struct InsertQuery {
294    pub table: String,
295    pub columns: Vec<String>,
296    pub values: Vec<SqlValue>,
297}
298
299#[derive(Debug, Clone, PartialEq)]
300pub struct UpdateQuery {
301    pub table: String,
302    pub assignments: Vec<(String, SqlValue)>,
303    pub where_clause: Option<WhereClause>,
304}
305
306#[derive(Debug, Clone, PartialEq)]
307pub enum DeleteMode {
308    Default,
309    Cascade,
310    Restrict,
311}
312
313#[derive(Debug, Clone, PartialEq)]
314pub struct DeleteQuery {
315    pub table: String,
316    pub where_clause: Option<WhereClause>,
317    pub mode: DeleteMode,
318}
319
320#[derive(Debug, Clone, PartialEq)]
321pub struct AlterRenameFieldQuery {
322    pub table: String,
323    pub old_name: String,
324    pub new_name: String,
325}
326
327#[derive(Debug, Clone, PartialEq)]
328pub struct AlterDropFieldQuery {
329    pub table: String,
330    pub field_name: String,
331}
332
333#[derive(Debug, Clone, PartialEq)]
334pub struct AlterMergeFieldsQuery {
335    pub table: String,
336    pub sources: Vec<String>,
337    pub into: String,
338}
339
340#[derive(Debug, Clone, PartialEq)]
341pub struct CreateViewQuery {
342    pub view_name: String,
343    pub columns: Option<Vec<String>>,
344    pub query: Box<SelectQuery>,
345}
346
347#[derive(Debug, Clone, PartialEq)]
348pub struct DropViewQuery {
349    pub view_name: String,
350}
351
352#[derive(Debug, Clone, PartialEq)]
353pub enum Statement {
354    Select(SelectQuery),
355    Insert(InsertQuery),
356    Update(UpdateQuery),
357    Delete(DeleteQuery),
358    AlterRename(AlterRenameFieldQuery),
359    AlterDrop(AlterDropFieldQuery),
360    AlterMerge(AlterMergeFieldsQuery),
361    CreateView(CreateViewQuery),
362    DropView(DropViewQuery),
363}
364
365impl Statement {
366    pub fn table_name(&self) -> &str {
367        match self {
368            Statement::Select(q) => &q.table,
369            Statement::Insert(q) => &q.table,
370            Statement::Update(q) => &q.table,
371            Statement::Delete(q) => &q.table,
372            Statement::AlterRename(q) => &q.table,
373            Statement::AlterDrop(q) => &q.table,
374            Statement::AlterMerge(q) => &q.table,
375            Statement::CreateView(q) => &q.view_name,
376            Statement::DropView(q) => &q.view_name,
377        }
378    }
379}
380
381pub fn where_clause_to_sql(clause: &WhereClause) -> String {
382    match clause {
383        WhereClause::BoolOp(bop) => {
384            let left = where_clause_to_sql(&bop.left);
385            let right = where_clause_to_sql(&bop.right);
386            let op = match bop.op {
387                BoolOpKind::And => "AND",
388                BoolOpKind::Or => "OR",
389            };
390            format!("{} {} {}", left, op, right)
391        }
392        WhereClause::Comparison(cmp) => {
393            let op_str = match cmp.op {
394                CmpOp::Eq => "=",
395                CmpOp::Ne => "!=",
396                CmpOp::Lt => "<",
397                CmpOp::Gt => ">",
398                CmpOp::Le => "<=",
399                CmpOp::Ge => ">=",
400                CmpOp::Like => "LIKE",
401                CmpOp::NotLike => "NOT LIKE",
402                CmpOp::In => "IN",
403                CmpOp::IsNull => "IS NULL",
404                CmpOp::IsNotNull => "IS NOT NULL",
405            };
406            if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
407                if let Some(ref expr) = cmp.left_expr {
408                    return format!("{} {}", expr.display_name(), op_str);
409                }
410                return format!("{} {}", cmp.column, op_str);
411            }
412            if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
413                return format!("{} {} {}", left.display_name(), op_str, right.display_name());
414            }
415            match &cmp.value {
416                Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
417                Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
418                Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
419                Some(SqlValue::Bool(b)) => format!("{} {} {}", cmp.column, op_str, b),
420                Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
421                Some(SqlValue::List(items)) => {
422                    let vals: Vec<String> = items.iter().map(|v| match v {
423                        SqlValue::String(s) => format!("'{}'", s),
424                        SqlValue::Int(n) => n.to_string(),
425                        SqlValue::Float(f) => f.to_string(),
426                        _ => "NULL".to_string(),
427                    }).collect();
428                    format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
429                }
430                None => format!("{} {}", cmp.column, op_str),
431            }
432        }
433    }
434}