Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26}
27
28impl Expr {
29    pub fn as_column(&self) -> Option<&str> {
30        if let Expr::Column(name) = self { Some(name) } else { None }
31    }
32
33    pub fn display_name(&self) -> String {
34        match self {
35            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
36            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
37            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
38            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
39            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
40            Expr::Column(name) => name.clone(),
41            Expr::BinaryOp { left, op, right } => {
42                let op_str = match op {
43                    ArithOp::Add => "+",
44                    ArithOp::Sub => "-",
45                    ArithOp::Mul => "*",
46                    ArithOp::Div => "/",
47                    ArithOp::Mod => "%",
48                };
49                format!("{} {} {}", left.display_name(), op_str, right.display_name())
50            }
51            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
52            Expr::Case { .. } => "CASE".to_string(),
53            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
54            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
55            Expr::CurrentDate => "CURRENT_DATE".to_string(),
56            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
57            Expr::Aggregate { func, arg, .. } => {
58                let func_name = match func {
59                    AggFunc::Count => "COUNT",
60                    AggFunc::Sum => "SUM",
61                    AggFunc::Avg => "AVG",
62                    AggFunc::Min => "MIN",
63                    AggFunc::Max => "MAX",
64                };
65                format!("{}({})", func_name, arg)
66            }
67        }
68    }
69
70    pub fn contains_aggregate(&self) -> bool {
71        match self {
72            Expr::Aggregate { .. } => true,
73            Expr::BinaryOp { left, right, .. } => {
74                left.contains_aggregate() || right.contains_aggregate()
75            }
76            Expr::UnaryMinus(inner) => inner.contains_aggregate(),
77            Expr::Case { whens, else_expr } => {
78                whens.iter().any(|(_, e)| e.contains_aggregate())
79                    || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
80            }
81            _ => false,
82        }
83    }
84}
85
86#[derive(Debug, Clone, PartialEq)]
87pub struct OrderSpec {
88    pub column: String,
89    pub expr: Option<Expr>,
90    pub descending: bool,
91}
92
93#[derive(Debug, Clone, PartialEq)]
94pub struct Comparison {
95    pub column: String,
96    pub op: String,
97    pub value: Option<SqlValue>,
98    pub left_expr: Option<Expr>,
99    pub right_expr: Option<Expr>,
100}
101
102#[derive(Debug, Clone, PartialEq)]
103pub struct BoolOp {
104    pub op: String,
105    pub left: Box<WhereClause>,
106    pub right: Box<WhereClause>,
107}
108
109#[derive(Debug, Clone, PartialEq)]
110pub enum WhereClause {
111    Comparison(Comparison),
112    BoolOp(BoolOp),
113}
114
115#[derive(Debug, Clone, PartialEq)]
116pub enum SqlValue {
117    String(String),
118    Int(i64),
119    Float(f64),
120    Null,
121    List(Vec<SqlValue>),
122}
123
124#[derive(Debug, Clone, PartialEq)]
125pub struct JoinClause {
126    pub table: String,
127    pub alias: Option<String>,
128    pub left_col: String,
129    pub right_col: String,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub enum AggFunc {
134    Count,
135    Sum,
136    Avg,
137    Min,
138    Max,
139}
140
141#[derive(Debug, Clone, PartialEq)]
142pub enum SelectExpr {
143    Column(String),
144    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
145    Expr { expr: Expr, alias: Option<String> },
146}
147
148impl SelectExpr {
149    pub fn output_name(&self) -> String {
150        match self {
151            SelectExpr::Column(name) => name.clone(),
152            SelectExpr::Aggregate { func, arg, alias, .. } => {
153                if let Some(a) = alias {
154                    a.clone()
155                } else {
156                    let func_name = match func {
157                        AggFunc::Count => "COUNT",
158                        AggFunc::Sum => "SUM",
159                        AggFunc::Avg => "AVG",
160                        AggFunc::Min => "MIN",
161                        AggFunc::Max => "MAX",
162                    };
163                    format!("{}({})", func_name, arg)
164                }
165            }
166            SelectExpr::Expr { expr, alias } => {
167                alias.clone().unwrap_or_else(|| expr.display_name())
168            }
169        }
170    }
171
172    pub fn is_aggregate(&self) -> bool {
173        match self {
174            SelectExpr::Aggregate { .. } => true,
175            SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
176            _ => false,
177        }
178    }
179}
180
181#[derive(Debug, Clone, PartialEq)]
182pub struct SelectQuery {
183    pub columns: ColumnList,
184    pub table: String,
185    pub table_alias: Option<String>,
186    pub subquery: Option<Box<SelectQuery>>,
187    pub joins: Vec<JoinClause>,
188    pub where_clause: Option<WhereClause>,
189    pub group_by: Option<Vec<String>>,
190    pub having: Option<WhereClause>,
191    pub order_by: Option<Vec<OrderSpec>>,
192    pub limit: Option<i64>,
193}
194
195#[derive(Debug, Clone, PartialEq)]
196pub enum ColumnList {
197    All,
198    Named(Vec<SelectExpr>),
199}
200
201#[derive(Debug, Clone, PartialEq)]
202pub struct InsertQuery {
203    pub table: String,
204    pub columns: Vec<String>,
205    pub values: Vec<SqlValue>,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub struct UpdateQuery {
210    pub table: String,
211    pub assignments: Vec<(String, SqlValue)>,
212    pub where_clause: Option<WhereClause>,
213}
214
215#[derive(Debug, Clone, PartialEq)]
216pub struct DeleteQuery {
217    pub table: String,
218    pub where_clause: Option<WhereClause>,
219}
220
221#[derive(Debug, Clone, PartialEq)]
222pub struct AlterRenameFieldQuery {
223    pub table: String,
224    pub old_name: String,
225    pub new_name: String,
226}
227
228#[derive(Debug, Clone, PartialEq)]
229pub struct AlterDropFieldQuery {
230    pub table: String,
231    pub field_name: String,
232}
233
234#[derive(Debug, Clone, PartialEq)]
235pub struct AlterMergeFieldsQuery {
236    pub table: String,
237    pub sources: Vec<String>,
238    pub into: String,
239}
240
241#[derive(Debug, Clone, PartialEq)]
242pub struct CreateViewQuery {
243    pub view_name: String,
244    pub columns: Option<Vec<String>>,
245    pub query: Box<SelectQuery>,
246}
247
248#[derive(Debug, Clone, PartialEq)]
249pub struct DropViewQuery {
250    pub view_name: String,
251}
252
253#[derive(Debug, Clone, PartialEq)]
254pub enum Statement {
255    Select(SelectQuery),
256    Insert(InsertQuery),
257    Update(UpdateQuery),
258    Delete(DeleteQuery),
259    AlterRename(AlterRenameFieldQuery),
260    AlterDrop(AlterDropFieldQuery),
261    AlterMerge(AlterMergeFieldsQuery),
262    CreateView(CreateViewQuery),
263    DropView(DropViewQuery),
264}
265
266impl Statement {
267    pub fn table_name(&self) -> &str {
268        match self {
269            Statement::Select(q) => &q.table,
270            Statement::Insert(q) => &q.table,
271            Statement::Update(q) => &q.table,
272            Statement::Delete(q) => &q.table,
273            Statement::AlterRename(q) => &q.table,
274            Statement::AlterDrop(q) => &q.table,
275            Statement::AlterMerge(q) => &q.table,
276            Statement::CreateView(q) => &q.view_name,
277            Statement::DropView(q) => &q.view_name,
278        }
279    }
280}