Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26    Subquery(Box<SelectQuery>),
27    Window { func: WindowFunc, args: Vec<Expr>, over: WindowSpec },
28}
29
30impl Expr {
31    pub fn as_column(&self) -> Option<&str> {
32        if let Expr::Column(name) = self { Some(name) } else { None }
33    }
34
35    pub fn display_name(&self) -> String {
36        match self {
37            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40            Expr::Literal(SqlValue::Bool(b)) => b.to_string(),
41            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
42            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
43            Expr::Column(name) => name.clone(),
44            Expr::BinaryOp { left, op, right } => {
45                let op_str = match op {
46                    ArithOp::Add => "+",
47                    ArithOp::Sub => "-",
48                    ArithOp::Mul => "*",
49                    ArithOp::Div => "/",
50                    ArithOp::Mod => "%",
51                };
52                format!("{} {} {}", left.display_name(), op_str, right.display_name())
53            }
54            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
55            Expr::Case { .. } => "CASE".to_string(),
56            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
57            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
58            Expr::CurrentDate => "CURRENT_DATE".to_string(),
59            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
60            Expr::Aggregate { func, arg, .. } => {
61                let func_name = match func {
62                    AggFunc::Count => "COUNT",
63                    AggFunc::Sum => "SUM",
64                    AggFunc::Avg => "AVG",
65                    AggFunc::Min => "MIN",
66                    AggFunc::Max => "MAX",
67                };
68                format!("{}({})", func_name, arg)
69            }
70            Expr::Subquery(_) => "(subquery)".to_string(),
71            Expr::Window { func, args, .. } => {
72                let func_name = match func {
73                    WindowFunc::RowNumber => "ROW_NUMBER",
74                    WindowFunc::Rank => "RANK",
75                    WindowFunc::DenseRank => "DENSE_RANK",
76                    WindowFunc::Lag => "LAG",
77                    WindowFunc::Lead => "LEAD",
78                    WindowFunc::Agg(af) => match af {
79                        AggFunc::Count => "COUNT",
80                        AggFunc::Sum => "SUM",
81                        AggFunc::Avg => "AVG",
82                        AggFunc::Min => "MIN",
83                        AggFunc::Max => "MAX",
84                    },
85                };
86                if args.is_empty() {
87                    format!("{}()", func_name)
88                } else {
89                    let arg_strs: Vec<String> = args.iter().map(|a| a.display_name()).collect();
90                    format!("{}({})", func_name, arg_strs.join(", "))
91                }
92            }
93        }
94    }
95
96    pub fn contains_aggregate(&self) -> bool {
97        match self {
98            Expr::Aggregate { .. } => true,
99            Expr::BinaryOp { left, right, .. } => {
100                left.contains_aggregate() || right.contains_aggregate()
101            }
102            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
103            Expr::Case { whens, else_expr } => {
104                whens.iter().any(|(_, e)| e.contains_aggregate())
105                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
106            }
107            Expr::Subquery(_) => false,
108            Expr::Window { .. } => false,
109            _ => false,
110        }
111    }
112
113    pub fn contains_window(&self) -> bool {
114        match self {
115            Expr::Window { .. } => true,
116            Expr::BinaryOp { left, right, .. } => {
117                left.contains_window() || right.contains_window()
118            }
119            Expr::UnaryMinus(inner) => inner.contains_window(),
120            _ => false,
121        }
122    }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126pub struct OrderSpec {
127    pub column: String,
128    pub expr: Option<Expr>,
129    pub descending: bool,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub enum CmpOp {
134    Eq,
135    Ne,
136    Lt,
137    Gt,
138    Le,
139    Ge,
140    Like,
141    NotLike,
142    In,
143    IsNull,
144    IsNotNull,
145}
146
147#[derive(Debug, Clone, PartialEq)]
148pub enum BoolOpKind {
149    And,
150    Or,
151}
152
153#[derive(Debug, Clone, PartialEq)]
154pub struct Comparison {
155    pub column: String,
156    pub op: CmpOp,
157    pub value: Option<SqlValue>,
158    pub left_expr: Option<Expr>,
159    pub right_expr: Option<Expr>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct BoolOp {
164    pub op: BoolOpKind,
165    pub left: Box<WhereClause>,
166    pub right: Box<WhereClause>,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub enum WhereClause {
171    Comparison(Comparison),
172    BoolOp(BoolOp),
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub enum SqlValue {
177    String(String),
178    Int(i64),
179    Float(f64),
180    Bool(bool),
181    Null,
182    List(Vec<SqlValue>),
183}
184
185#[derive(Debug, Clone, PartialEq)]
186pub enum JoinType {
187    Inner,
188    Left,
189}
190
191#[derive(Debug, Clone, PartialEq)]
192pub struct JoinClause {
193    pub join_type: JoinType,
194    pub table: String,
195    pub alias: Option<String>,
196    pub condition: WhereClause,
197}
198
199#[derive(Debug, Clone, PartialEq)]
200pub enum AggFunc {
201    Count,
202    Sum,
203    Avg,
204    Min,
205    Max,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub enum WindowFunc {
210    RowNumber,
211    Rank,
212    DenseRank,
213    Lag,
214    Lead,
215    Agg(AggFunc),
216}
217
218#[derive(Debug, Clone, PartialEq)]
219pub struct WindowSpec {
220    pub partition_by: Vec<String>,
221    pub order_by: Vec<OrderSpec>,
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum SelectExpr {
226    Column(String),
227    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
228    Expr { expr: Expr, alias: Option<String> },
229}
230
231impl SelectExpr {
232    pub fn output_name(&self) -> String {
233        match self {
234            SelectExpr::Column(name) => name.clone(),
235            SelectExpr::Aggregate { func, arg, alias, .. } => {
236                if let Some(a) = alias {
237                    a.clone()
238                } else {
239                    let func_name = match func {
240                        AggFunc::Count => "COUNT",
241                        AggFunc::Sum => "SUM",
242                        AggFunc::Avg => "AVG",
243                        AggFunc::Min => "MIN",
244                        AggFunc::Max => "MAX",
245                    };
246                    format!("{}({})", func_name, arg)
247                }
248            }
249            SelectExpr::Expr { expr, alias } => {
250                alias.clone().unwrap_or_else(|| expr.display_name())
251            }
252        }
253    }
254
255    pub fn is_aggregate(&self) -> bool {
256        match self {
257            SelectExpr::Aggregate { .. } => true,
258            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
259            _ => false,
260        }
261    }
262}
263
264#[derive(Debug, Clone, PartialEq)]
265pub struct CteClause {
266    pub name: String,
267    pub query: Box<SelectQuery>,
268}
269
270#[derive(Debug, Clone, PartialEq)]
271pub struct SelectQuery {
272    pub columns: ColumnList,
273    pub table: String,
274    pub table_alias: Option<String>,
275    pub subquery: Option<Box<SelectQuery>>,
276    pub joins: Vec<JoinClause>,
277    pub where_clause: Option<WhereClause>,
278    pub group_by: Option<Vec<String>>,
279    pub having: Option<WhereClause>,
280    pub order_by: Option<Vec<OrderSpec>>,
281    pub limit: Option<i64>,
282    pub ctes: Vec<CteClause>,
283}
284
285#[derive(Debug, Clone, PartialEq)]
286pub enum ColumnList {
287    All,
288    Named(Vec<SelectExpr>),
289}
290
291#[derive(Debug, Clone, PartialEq)]
292pub struct InsertQuery {
293    pub table: String,
294    pub columns: Vec<String>,
295    pub values: Vec<SqlValue>,
296}
297
298#[derive(Debug, Clone, PartialEq)]
299pub struct UpdateQuery {
300    pub table: String,
301    pub assignments: Vec<(String, SqlValue)>,
302    pub where_clause: Option<WhereClause>,
303}
304
305#[derive(Debug, Clone, PartialEq)]
306pub enum DeleteMode {
307    Default,
308    Cascade,
309    Restrict,
310}
311
312#[derive(Debug, Clone, PartialEq)]
313pub struct DeleteQuery {
314    pub table: String,
315    pub where_clause: Option<WhereClause>,
316    pub mode: DeleteMode,
317}
318
319#[derive(Debug, Clone, PartialEq)]
320pub struct AlterRenameFieldQuery {
321    pub table: String,
322    pub old_name: String,
323    pub new_name: String,
324}
325
326#[derive(Debug, Clone, PartialEq)]
327pub struct AlterDropFieldQuery {
328    pub table: String,
329    pub field_name: String,
330}
331
332#[derive(Debug, Clone, PartialEq)]
333pub struct AlterMergeFieldsQuery {
334    pub table: String,
335    pub sources: Vec<String>,
336    pub into: String,
337}
338
339#[derive(Debug, Clone, PartialEq)]
340pub struct CreateViewQuery {
341    pub view_name: String,
342    pub columns: Option<Vec<String>>,
343    pub query: Box<SelectQuery>,
344}
345
346#[derive(Debug, Clone, PartialEq)]
347pub struct DropViewQuery {
348    pub view_name: String,
349}
350
351#[derive(Debug, Clone, PartialEq)]
352pub enum Statement {
353    Select(SelectQuery),
354    Insert(InsertQuery),
355    Update(UpdateQuery),
356    Delete(DeleteQuery),
357    AlterRename(AlterRenameFieldQuery),
358    AlterDrop(AlterDropFieldQuery),
359    AlterMerge(AlterMergeFieldsQuery),
360    CreateView(CreateViewQuery),
361    DropView(DropViewQuery),
362}
363
364impl Statement {
365    pub fn table_name(&self) -> &str {
366        match self {
367            Statement::Select(q) => &q.table,
368            Statement::Insert(q) => &q.table,
369            Statement::Update(q) => &q.table,
370            Statement::Delete(q) => &q.table,
371            Statement::AlterRename(q) => &q.table,
372            Statement::AlterDrop(q) => &q.table,
373            Statement::AlterMerge(q) => &q.table,
374            Statement::CreateView(q) => &q.view_name,
375            Statement::DropView(q) => &q.view_name,
376        }
377    }
378}
379
380pub fn where_clause_to_sql(clause: &WhereClause) -> String {
381    match clause {
382        WhereClause::BoolOp(bop) => {
383            let left = where_clause_to_sql(&bop.left);
384            let right = where_clause_to_sql(&bop.right);
385            let op = match bop.op {
386                BoolOpKind::And => "AND",
387                BoolOpKind::Or => "OR",
388            };
389            format!("{} {} {}", left, op, right)
390        }
391        WhereClause::Comparison(cmp) => {
392            let op_str = match cmp.op {
393                CmpOp::Eq => "=",
394                CmpOp::Ne => "!=",
395                CmpOp::Lt => "<",
396                CmpOp::Gt => ">",
397                CmpOp::Le => "<=",
398                CmpOp::Ge => ">=",
399                CmpOp::Like => "LIKE",
400                CmpOp::NotLike => "NOT LIKE",
401                CmpOp::In => "IN",
402                CmpOp::IsNull => "IS NULL",
403                CmpOp::IsNotNull => "IS NOT NULL",
404            };
405            if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
406                if let Some(ref expr) = cmp.left_expr {
407                    return format!("{} {}", expr.display_name(), op_str);
408                }
409                return format!("{} {}", cmp.column, op_str);
410            }
411            if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
412                return format!("{} {} {}", left.display_name(), op_str, right.display_name());
413            }
414            match &cmp.value {
415                Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
416                Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
417                Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
418                Some(SqlValue::Bool(b)) => format!("{} {} {}", cmp.column, op_str, b),
419                Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
420                Some(SqlValue::List(items)) => {
421                    let vals: Vec<String> = items.iter().map(|v| match v {
422                        SqlValue::String(s) => format!("'{}'", s),
423                        SqlValue::Int(n) => n.to_string(),
424                        SqlValue::Float(f) => f.to_string(),
425                        _ => "NULL".to_string(),
426                    }).collect();
427                    format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
428                }
429                None => format!("{} {}", cmp.column, op_str),
430            }
431        }
432    }
433}