Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26    Subquery(Box<SelectQuery>),
27}
28
29impl Expr {
30    pub fn as_column(&self) -> Option<&str> {
31        if let Expr::Column(name) = self { Some(name) } else { None }
32    }
33
34    pub fn display_name(&self) -> String {
35        match self {
36            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
37            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
38            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
39            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
40            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
41            Expr::Column(name) => name.clone(),
42            Expr::BinaryOp { left, op, right } => {
43                let op_str = match op {
44                    ArithOp::Add => "+",
45                    ArithOp::Sub => "-",
46                    ArithOp::Mul => "*",
47                    ArithOp::Div => "/",
48                    ArithOp::Mod => "%",
49                };
50                format!("{} {} {}", left.display_name(), op_str, right.display_name())
51            }
52            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
53            Expr::Case { .. } => "CASE".to_string(),
54            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
55            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
56            Expr::CurrentDate => "CURRENT_DATE".to_string(),
57            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
58            Expr::Aggregate { func, arg, .. } => {
59                let func_name = match func {
60                    AggFunc::Count => "COUNT",
61                    AggFunc::Sum => "SUM",
62                    AggFunc::Avg => "AVG",
63                    AggFunc::Min => "MIN",
64                    AggFunc::Max => "MAX",
65                };
66                format!("{}({})", func_name, arg)
67            }
68            Expr::Subquery(_) => "(subquery)".to_string(),
69        }
70    }
71
72    pub fn contains_aggregate(&self) -> bool {
73        match self {
74            Expr::Aggregate { .. } => true,
75            Expr::BinaryOp { left, right, .. } => {
76                left.contains_aggregate() || right.contains_aggregate()
77            }
78            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
79            Expr::Case { whens, else_expr } => {
80                whens.iter().any(|(_, e)| e.contains_aggregate())
81                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
82            }
83            Expr::Subquery(_) => false,
84            _ => false,
85        }
86    }
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub struct OrderSpec {
91    pub column: String,
92    pub expr: Option<Expr>,
93    pub descending: bool,
94}
95
96#[derive(Debug, Clone, PartialEq)]
97pub enum CmpOp {
98    Eq,
99    Ne,
100    Lt,
101    Gt,
102    Le,
103    Ge,
104    Like,
105    NotLike,
106    In,
107    IsNull,
108    IsNotNull,
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum BoolOpKind {
113    And,
114    Or,
115}
116
117#[derive(Debug, Clone, PartialEq)]
118pub struct Comparison {
119    pub column: String,
120    pub op: CmpOp,
121    pub value: Option<SqlValue>,
122    pub left_expr: Option<Expr>,
123    pub right_expr: Option<Expr>,
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct BoolOp {
128    pub op: BoolOpKind,
129    pub left: Box<WhereClause>,
130    pub right: Box<WhereClause>,
131}
132
133#[derive(Debug, Clone, PartialEq)]
134pub enum WhereClause {
135    Comparison(Comparison),
136    BoolOp(BoolOp),
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub enum SqlValue {
141    String(String),
142    Int(i64),
143    Float(f64),
144    Null,
145    List(Vec<SqlValue>),
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub enum JoinType {
150    Inner,
151    Left,
152}
153
154#[derive(Debug, Clone, PartialEq)]
155pub struct JoinClause {
156    pub join_type: JoinType,
157    pub table: String,
158    pub alias: Option<String>,
159    pub condition: WhereClause,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub enum AggFunc {
164    Count,
165    Sum,
166    Avg,
167    Min,
168    Max,
169}
170
171#[derive(Debug, Clone, PartialEq)]
172pub enum SelectExpr {
173    Column(String),
174    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
175    Expr { expr: Expr, alias: Option<String> },
176}
177
178impl SelectExpr {
179    pub fn output_name(&self) -> String {
180        match self {
181            SelectExpr::Column(name) => name.clone(),
182            SelectExpr::Aggregate { func, arg, alias, .. } => {
183                if let Some(a) = alias {
184                    a.clone()
185                } else {
186                    let func_name = match func {
187                        AggFunc::Count => "COUNT",
188                        AggFunc::Sum => "SUM",
189                        AggFunc::Avg => "AVG",
190                        AggFunc::Min => "MIN",
191                        AggFunc::Max => "MAX",
192                    };
193                    format!("{}({})", func_name, arg)
194                }
195            }
196            SelectExpr::Expr { expr, alias } => {
197                alias.clone().unwrap_or_else(|| expr.display_name())
198            }
199        }
200    }
201
202    pub fn is_aggregate(&self) -> bool {
203        match self {
204            SelectExpr::Aggregate { .. } => true,
205            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
206            _ => false,
207        }
208    }
209}
210
211#[derive(Debug, Clone, PartialEq)]
212pub struct CteClause {
213    pub name: String,
214    pub query: Box<SelectQuery>,
215}
216
217#[derive(Debug, Clone, PartialEq)]
218pub struct SelectQuery {
219    pub columns: ColumnList,
220    pub table: String,
221    pub table_alias: Option<String>,
222    pub subquery: Option<Box<SelectQuery>>,
223    pub joins: Vec<JoinClause>,
224    pub where_clause: Option<WhereClause>,
225    pub group_by: Option<Vec<String>>,
226    pub having: Option<WhereClause>,
227    pub order_by: Option<Vec<OrderSpec>>,
228    pub limit: Option<i64>,
229    pub ctes: Vec<CteClause>,
230}
231
232#[derive(Debug, Clone, PartialEq)]
233pub enum ColumnList {
234    All,
235    Named(Vec<SelectExpr>),
236}
237
238#[derive(Debug, Clone, PartialEq)]
239pub struct InsertQuery {
240    pub table: String,
241    pub columns: Vec<String>,
242    pub values: Vec<SqlValue>,
243}
244
245#[derive(Debug, Clone, PartialEq)]
246pub struct UpdateQuery {
247    pub table: String,
248    pub assignments: Vec<(String, SqlValue)>,
249    pub where_clause: Option<WhereClause>,
250}
251
252#[derive(Debug, Clone, PartialEq)]
253pub enum DeleteMode {
254    Default,
255    Cascade,
256    Restrict,
257}
258
259#[derive(Debug, Clone, PartialEq)]
260pub struct DeleteQuery {
261    pub table: String,
262    pub where_clause: Option<WhereClause>,
263    pub mode: DeleteMode,
264}
265
266#[derive(Debug, Clone, PartialEq)]
267pub struct AlterRenameFieldQuery {
268    pub table: String,
269    pub old_name: String,
270    pub new_name: String,
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub struct AlterDropFieldQuery {
275    pub table: String,
276    pub field_name: String,
277}
278
279#[derive(Debug, Clone, PartialEq)]
280pub struct AlterMergeFieldsQuery {
281    pub table: String,
282    pub sources: Vec<String>,
283    pub into: String,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub struct CreateViewQuery {
288    pub view_name: String,
289    pub columns: Option<Vec<String>>,
290    pub query: Box<SelectQuery>,
291}
292
293#[derive(Debug, Clone, PartialEq)]
294pub struct DropViewQuery {
295    pub view_name: String,
296}
297
298#[derive(Debug, Clone, PartialEq)]
299pub enum Statement {
300    Select(SelectQuery),
301    Insert(InsertQuery),
302    Update(UpdateQuery),
303    Delete(DeleteQuery),
304    AlterRename(AlterRenameFieldQuery),
305    AlterDrop(AlterDropFieldQuery),
306    AlterMerge(AlterMergeFieldsQuery),
307    CreateView(CreateViewQuery),
308    DropView(DropViewQuery),
309}
310
311impl Statement {
312    pub fn table_name(&self) -> &str {
313        match self {
314            Statement::Select(q) => &q.table,
315            Statement::Insert(q) => &q.table,
316            Statement::Update(q) => &q.table,
317            Statement::Delete(q) => &q.table,
318            Statement::AlterRename(q) => &q.table,
319            Statement::AlterDrop(q) => &q.table,
320            Statement::AlterMerge(q) => &q.table,
321            Statement::CreateView(q) => &q.view_name,
322            Statement::DropView(q) => &q.view_name,
323        }
324    }
325}
326
327pub fn where_clause_to_sql(clause: &WhereClause) -> String {
328    match clause {
329        WhereClause::BoolOp(bop) => {
330            let left = where_clause_to_sql(&bop.left);
331            let right = where_clause_to_sql(&bop.right);
332            let op = match bop.op {
333                BoolOpKind::And => "AND",
334                BoolOpKind::Or => "OR",
335            };
336            format!("{} {} {}", left, op, right)
337        }
338        WhereClause::Comparison(cmp) => {
339            let op_str = match cmp.op {
340                CmpOp::Eq => "=",
341                CmpOp::Ne => "!=",
342                CmpOp::Lt => "<",
343                CmpOp::Gt => ">",
344                CmpOp::Le => "<=",
345                CmpOp::Ge => ">=",
346                CmpOp::Like => "LIKE",
347                CmpOp::NotLike => "NOT LIKE",
348                CmpOp::In => "IN",
349                CmpOp::IsNull => "IS NULL",
350                CmpOp::IsNotNull => "IS NOT NULL",
351            };
352            if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
353                if let Some(ref expr) = cmp.left_expr {
354                    return format!("{} {}", expr.display_name(), op_str);
355                }
356                return format!("{} {}", cmp.column, op_str);
357            }
358            if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
359                return format!("{} {} {}", left.display_name(), op_str, right.display_name());
360            }
361            match &cmp.value {
362                Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
363                Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
364                Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
365                Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
366                Some(SqlValue::List(items)) => {
367                    let vals: Vec<String> = items.iter().map(|v| match v {
368                        SqlValue::String(s) => format!("'{}'", s),
369                        SqlValue::Int(n) => n.to_string(),
370                        SqlValue::Float(f) => f.to_string(),
371                        _ => "NULL".to_string(),
372                    }).collect();
373                    format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
374                }
375                None => format!("{} {}", cmp.column, op_str),
376            }
377        }
378    }
379}