Skip to main content

mdql_core/
query_ast.rs

1//! AST types for the MDQL SQL subset.
2
3// ── AST nodes ──────────────────────────────────────────────────────────────
4
5#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7    Add,
8    Sub,
9    Mul,
10    Div,
11    Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16    Literal(SqlValue),
17    Column(String),
18    BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19    UnaryMinus(Box<Expr>),
20    Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21    DateAdd { date: Box<Expr>, days: Box<Expr> },
22    DateDiff { left: Box<Expr>, right: Box<Expr> },
23    CurrentDate,
24    CurrentTimestamp,
25}
26
27impl Expr {
28    pub fn as_column(&self) -> Option<&str> {
29        if let Expr::Column(name) = self { Some(name) } else { None }
30    }
31
32    pub fn display_name(&self) -> String {
33        match self {
34            Expr::Literal(SqlValue::Int(n)) => n.to_string(),
35            Expr::Literal(SqlValue::Float(f)) => f.to_string(),
36            Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
37            Expr::Literal(SqlValue::Null) => "NULL".to_string(),
38            Expr::Literal(SqlValue::List(_)) => "list".to_string(),
39            Expr::Column(name) => name.clone(),
40            Expr::BinaryOp { left, op, right } => {
41                let op_str = match op {
42                    ArithOp::Add => "+",
43                    ArithOp::Sub => "-",
44                    ArithOp::Mul => "*",
45                    ArithOp::Div => "/",
46                    ArithOp::Mod => "%",
47                };
48                format!("{} {} {}", left.display_name(), op_str, right.display_name())
49            }
50            Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
51            Expr::Case { .. } => "CASE".to_string(),
52            Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
53            Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
54            Expr::CurrentDate => "CURRENT_DATE".to_string(),
55            Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
56        }
57    }
58}
59
60#[derive(Debug, Clone, PartialEq)]
61pub struct OrderSpec {
62    pub column: String,
63    pub expr: Option<Expr>,
64    pub descending: bool,
65}
66
67#[derive(Debug, Clone, PartialEq)]
68pub struct Comparison {
69    pub column: String,
70    pub op: String,
71    pub value: Option<SqlValue>,
72    pub left_expr: Option<Expr>,
73    pub right_expr: Option<Expr>,
74}
75
76#[derive(Debug, Clone, PartialEq)]
77pub struct BoolOp {
78    pub op: String,
79    pub left: Box<WhereClause>,
80    pub right: Box<WhereClause>,
81}
82
83#[derive(Debug, Clone, PartialEq)]
84pub enum WhereClause {
85    Comparison(Comparison),
86    BoolOp(BoolOp),
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub enum SqlValue {
91    String(String),
92    Int(i64),
93    Float(f64),
94    Null,
95    List(Vec<SqlValue>),
96}
97
98#[derive(Debug, Clone, PartialEq)]
99pub struct JoinClause {
100    pub table: String,
101    pub alias: Option<String>,
102    pub left_col: String,
103    pub right_col: String,
104}
105
106#[derive(Debug, Clone, PartialEq)]
107pub enum AggFunc {
108    Count,
109    Sum,
110    Avg,
111    Min,
112    Max,
113}
114
115#[derive(Debug, Clone, PartialEq)]
116pub enum SelectExpr {
117    Column(String),
118    Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
119    Expr { expr: Expr, alias: Option<String> },
120}
121
122impl SelectExpr {
123    pub fn output_name(&self) -> String {
124        match self {
125            SelectExpr::Column(name) => name.clone(),
126            SelectExpr::Aggregate { func, arg, alias, .. } => {
127                if let Some(a) = alias {
128                    a.clone()
129                } else {
130                    let func_name = match func {
131                        AggFunc::Count => "COUNT",
132                        AggFunc::Sum => "SUM",
133                        AggFunc::Avg => "AVG",
134                        AggFunc::Min => "MIN",
135                        AggFunc::Max => "MAX",
136                    };
137                    format!("{}({})", func_name, arg)
138                }
139            }
140            SelectExpr::Expr { expr, alias } => {
141                alias.clone().unwrap_or_else(|| expr.display_name())
142            }
143        }
144    }
145
146    pub fn is_aggregate(&self) -> bool {
147        matches!(self, SelectExpr::Aggregate { .. })
148    }
149}
150
151#[derive(Debug, Clone, PartialEq)]
152pub struct SelectQuery {
153    pub columns: ColumnList,
154    pub table: String,
155    pub table_alias: Option<String>,
156    pub joins: Vec<JoinClause>,
157    pub where_clause: Option<WhereClause>,
158    pub group_by: Option<Vec<String>>,
159    pub having: Option<WhereClause>,
160    pub order_by: Option<Vec<OrderSpec>>,
161    pub limit: Option<i64>,
162}
163
164#[derive(Debug, Clone, PartialEq)]
165pub enum ColumnList {
166    All,
167    Named(Vec<SelectExpr>),
168}
169
170#[derive(Debug, Clone, PartialEq)]
171pub struct InsertQuery {
172    pub table: String,
173    pub columns: Vec<String>,
174    pub values: Vec<SqlValue>,
175}
176
177#[derive(Debug, Clone, PartialEq)]
178pub struct UpdateQuery {
179    pub table: String,
180    pub assignments: Vec<(String, SqlValue)>,
181    pub where_clause: Option<WhereClause>,
182}
183
184#[derive(Debug, Clone, PartialEq)]
185pub struct DeleteQuery {
186    pub table: String,
187    pub where_clause: Option<WhereClause>,
188}
189
190#[derive(Debug, Clone, PartialEq)]
191pub struct AlterRenameFieldQuery {
192    pub table: String,
193    pub old_name: String,
194    pub new_name: String,
195}
196
197#[derive(Debug, Clone, PartialEq)]
198pub struct AlterDropFieldQuery {
199    pub table: String,
200    pub field_name: String,
201}
202
203#[derive(Debug, Clone, PartialEq)]
204pub struct AlterMergeFieldsQuery {
205    pub table: String,
206    pub sources: Vec<String>,
207    pub into: String,
208}
209
210#[derive(Debug, Clone, PartialEq)]
211pub struct CreateViewQuery {
212    pub view_name: String,
213    pub columns: Option<Vec<String>>,
214    pub query: Box<SelectQuery>,
215}
216
217#[derive(Debug, Clone, PartialEq)]
218pub struct DropViewQuery {
219    pub view_name: String,
220}
221
222#[derive(Debug, Clone, PartialEq)]
223pub enum Statement {
224    Select(SelectQuery),
225    Insert(InsertQuery),
226    Update(UpdateQuery),
227    Delete(DeleteQuery),
228    AlterRename(AlterRenameFieldQuery),
229    AlterDrop(AlterDropFieldQuery),
230    AlterMerge(AlterMergeFieldsQuery),
231    CreateView(CreateViewQuery),
232    DropView(DropViewQuery),
233}
234
235impl Statement {
236    pub fn table_name(&self) -> &str {
237        match self {
238            Statement::Select(q) => &q.table,
239            Statement::Insert(q) => &q.table,
240            Statement::Update(q) => &q.table,
241            Statement::Delete(q) => &q.table,
242            Statement::AlterRename(q) => &q.table,
243            Statement::AlterDrop(q) => &q.table,
244            Statement::AlterMerge(q) => &q.table,
245            Statement::CreateView(q) => &q.view_name,
246            Statement::DropView(q) => &q.view_name,
247        }
248    }
249}