Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26}
27
28impl Expr {
29    pub fn as_column(&self) -> Option<&str> {
30        if let Expr::Column(name) = self { Some(name) } else { None }
31    }
32
33    pub fn display_name(&self) -> String {
34        match self {
35            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
36            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
37            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
38            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
39            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
40            Expr::Column(name) => name.clone(),
41            Expr::BinaryOp { left, op, right } => {
42                let op_str = match op {
43                    ArithOp::Add => "+",
44                    ArithOp::Sub => "-",
45                    ArithOp::Mul => "*",
46                    ArithOp::Div => "/",
47                    ArithOp::Mod => "%",
48                };
49                format!("{} {} {}", left.display_name(), op_str, right.display_name())
50            }
51            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
52            Expr::Case { .. } => "CASE".to_string(),
53            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
54            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
55            Expr::CurrentDate => "CURRENT_DATE".to_string(),
56            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
57            Expr::Aggregate { func, arg, .. } => {
58                let func_name = match func {
59                    AggFunc::Count => "COUNT",
60                    AggFunc::Sum => "SUM",
61                    AggFunc::Avg => "AVG",
62                    AggFunc::Min => "MIN",
63                    AggFunc::Max => "MAX",
64                };
65                format!("{}({})", func_name, arg)
66            }
67        }
68    }
69
70    pub fn contains_aggregate(&self) -> bool {
71        match self {
72            Expr::Aggregate { .. } => true,
73            Expr::BinaryOp { left, right, .. } => {
74                left.contains_aggregate() || right.contains_aggregate()
75            }
76            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
77            Expr::Case { whens, else_expr } => {
78                whens.iter().any(|(_, e)| e.contains_aggregate())
79                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
80            }
81            _ => false,
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub struct OrderSpec {
88    pub column: String,
89    pub expr: Option<Expr>,
90    pub descending: bool,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub enum CmpOp {
95    Eq,
96    Ne,
97    Lt,
98    Gt,
99    Le,
100    Ge,
101    Like,
102    NotLike,
103    In,
104    IsNull,
105    IsNotNull,
106}
107
108#[derive(Debug, Clone, PartialEq)]
109pub enum BoolOpKind {
110    And,
111    Or,
112}
113
114#[derive(Debug, Clone, PartialEq)]
115pub struct Comparison {
116    pub column: String,
117    pub op: CmpOp,
118    pub value: Option<SqlValue>,
119    pub left_expr: Option<Expr>,
120    pub right_expr: Option<Expr>,
121}
122
123#[derive(Debug, Clone, PartialEq)]
124pub struct BoolOp {
125    pub op: BoolOpKind,
126    pub left: Box<WhereClause>,
127    pub right: Box<WhereClause>,
128}
129
130#[derive(Debug, Clone, PartialEq)]
131pub enum WhereClause {
132    Comparison(Comparison),
133    BoolOp(BoolOp),
134}
135
136#[derive(Debug, Clone, PartialEq)]
137pub enum SqlValue {
138    String(String),
139    Int(i64),
140    Float(f64),
141    Null,
142    List(Vec<SqlValue>),
143}
144
145#[derive(Debug, Clone, PartialEq)]
146pub enum JoinType {
147    Inner,
148    Left,
149}
150
151#[derive(Debug, Clone, PartialEq)]
152pub struct JoinClause {
153    pub join_type: JoinType,
154    pub table: String,
155    pub alias: Option<String>,
156    pub condition: WhereClause,
157}
158
159#[derive(Debug, Clone, PartialEq)]
160pub enum AggFunc {
161    Count,
162    Sum,
163    Avg,
164    Min,
165    Max,
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub enum SelectExpr {
170    Column(String),
171    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
172    Expr { expr: Expr, alias: Option<String> },
173}
174
175impl SelectExpr {
176    pub fn output_name(&self) -> String {
177        match self {
178            SelectExpr::Column(name) => name.clone(),
179            SelectExpr::Aggregate { func, arg, alias, .. } => {
180                if let Some(a) = alias {
181                    a.clone()
182                } else {
183                    let func_name = match func {
184                        AggFunc::Count => "COUNT",
185                        AggFunc::Sum => "SUM",
186                        AggFunc::Avg => "AVG",
187                        AggFunc::Min => "MIN",
188                        AggFunc::Max => "MAX",
189                    };
190                    format!("{}({})", func_name, arg)
191                }
192            }
193            SelectExpr::Expr { expr, alias } => {
194                alias.clone().unwrap_or_else(|| expr.display_name())
195            }
196        }
197    }
198
199    pub fn is_aggregate(&self) -> bool {
200        match self {
201            SelectExpr::Aggregate { .. } => true,
202            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
203            _ => false,
204        }
205    }
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct CteClause {
210    pub name: String,
211    pub query: Box<SelectQuery>,
212}
213
214#[derive(Debug, Clone, PartialEq)]
215pub struct SelectQuery {
216    pub columns: ColumnList,
217    pub table: String,
218    pub table_alias: Option<String>,
219    pub subquery: Option<Box<SelectQuery>>,
220    pub joins: Vec<JoinClause>,
221    pub where_clause: Option<WhereClause>,
222    pub group_by: Option<Vec<String>>,
223    pub having: Option<WhereClause>,
224    pub order_by: Option<Vec<OrderSpec>>,
225    pub limit: Option<i64>,
226    pub ctes: Vec<CteClause>,
227}
228
229#[derive(Debug, Clone, PartialEq)]
230pub enum ColumnList {
231    All,
232    Named(Vec<SelectExpr>),
233}
234
235#[derive(Debug, Clone, PartialEq)]
236pub struct InsertQuery {
237    pub table: String,
238    pub columns: Vec<String>,
239    pub values: Vec<SqlValue>,
240}
241
242#[derive(Debug, Clone, PartialEq)]
243pub struct UpdateQuery {
244    pub table: String,
245    pub assignments: Vec<(String, SqlValue)>,
246    pub where_clause: Option<WhereClause>,
247}
248
249#[derive(Debug, Clone, PartialEq)]
250pub enum DeleteMode {
251    Default,
252    Cascade,
253    Restrict,
254}
255
256#[derive(Debug, Clone, PartialEq)]
257pub struct DeleteQuery {
258    pub table: String,
259    pub where_clause: Option<WhereClause>,
260    pub mode: DeleteMode,
261}
262
263#[derive(Debug, Clone, PartialEq)]
264pub struct AlterRenameFieldQuery {
265    pub table: String,
266    pub old_name: String,
267    pub new_name: String,
268}
269
270#[derive(Debug, Clone, PartialEq)]
271pub struct AlterDropFieldQuery {
272    pub table: String,
273    pub field_name: String,
274}
275
276#[derive(Debug, Clone, PartialEq)]
277pub struct AlterMergeFieldsQuery {
278    pub table: String,
279    pub sources: Vec<String>,
280    pub into: String,
281}
282
283#[derive(Debug, Clone, PartialEq)]
284pub struct CreateViewQuery {
285    pub view_name: String,
286    pub columns: Option<Vec<String>>,
287    pub query: Box<SelectQuery>,
288}
289
290#[derive(Debug, Clone, PartialEq)]
291pub struct DropViewQuery {
292    pub view_name: String,
293}
294
295#[derive(Debug, Clone, PartialEq)]
296pub enum Statement {
297    Select(SelectQuery),
298    Insert(InsertQuery),
299    Update(UpdateQuery),
300    Delete(DeleteQuery),
301    AlterRename(AlterRenameFieldQuery),
302    AlterDrop(AlterDropFieldQuery),
303    AlterMerge(AlterMergeFieldsQuery),
304    CreateView(CreateViewQuery),
305    DropView(DropViewQuery),
306}
307
308impl Statement {
309    pub fn table_name(&self) -> &str {
310        match self {
311            Statement::Select(q) => &q.table,
312            Statement::Insert(q) => &q.table,
313            Statement::Update(q) => &q.table,
314            Statement::Delete(q) => &q.table,
315            Statement::AlterRename(q) => &q.table,
316            Statement::AlterDrop(q) => &q.table,
317            Statement::AlterMerge(q) => &q.table,
318            Statement::CreateView(q) => &q.view_name,
319            Statement::DropView(q) => &q.view_name,
320        }
321    }
322}
323
324pub fn where_clause_to_sql(clause: &WhereClause) -> String {
325    match clause {
326        WhereClause::BoolOp(bop) => {
327            let left = where_clause_to_sql(&bop.left);
328            let right = where_clause_to_sql(&bop.right);
329            let op = match bop.op {
330                BoolOpKind::And => "AND",
331                BoolOpKind::Or => "OR",
332            };
333            format!("{} {} {}", left, op, right)
334        }
335        WhereClause::Comparison(cmp) => {
336            let op_str = match cmp.op {
337                CmpOp::Eq => "=",
338                CmpOp::Ne => "!=",
339                CmpOp::Lt => "<",
340                CmpOp::Gt => ">",
341                CmpOp::Le => "<=",
342                CmpOp::Ge => ">=",
343                CmpOp::Like => "LIKE",
344                CmpOp::NotLike => "NOT LIKE",
345                CmpOp::In => "IN",
346                CmpOp::IsNull => "IS NULL",
347                CmpOp::IsNotNull => "IS NOT NULL",
348            };
349            if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
350                if let Some(ref expr) = cmp.left_expr {
351                    return format!("{} {}", expr.display_name(), op_str);
352                }
353                return format!("{} {}", cmp.column, op_str);
354            }
355            if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
356                return format!("{} {} {}", left.display_name(), op_str, right.display_name());
357            }
358            match &cmp.value {
359                Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
360                Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
361                Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
362                Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
363                Some(SqlValue::List(items)) => {
364                    let vals: Vec<String> = items.iter().map(|v| match v {
365                        SqlValue::String(s) => format!("'{}'", s),
366                        SqlValue::Int(n) => n.to_string(),
367                        SqlValue::Float(f) => f.to_string(),
368                        _ => "NULL".to_string(),
369                    }).collect();
370                    format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
371                }
372                None => format!("{} {}", cmp.column, op_str),
373            }
374        }
375    }
376}