1#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7 Add,
8 Sub,
9 Mul,
10 Div,
11 Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16 Literal(SqlValue),
17 Column(String),
18 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19 UnaryMinus(Box<Expr>),
20 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21 DateAdd { date: Box<Expr>, days: Box<Expr> },
22 DateDiff { left: Box<Expr>, right: Box<Expr> },
23 CurrentDate,
24 CurrentTimestamp,
25 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26 Subquery(Box<SelectQuery>),
27}
28
29impl Expr {
30 pub fn as_column(&self) -> Option<&str> {
31 if let Expr::Column(name) = self { Some(name) } else { None }
32 }
33
34 pub fn display_name(&self) -> String {
35 match self {
36 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
37 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
38 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
39 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
40 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
41 Expr::Column(name) => name.clone(),
42 Expr::BinaryOp { left, op, right } => {
43 let op_str = match op {
44 ArithOp::Add => "+",
45 ArithOp::Sub => "-",
46 ArithOp::Mul => "*",
47 ArithOp::Div => "/",
48 ArithOp::Mod => "%",
49 };
50 format!("{} {} {}", left.display_name(), op_str, right.display_name())
51 }
52 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
53 Expr::Case { .. } => "CASE".to_string(),
54 Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
55 Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
56 Expr::CurrentDate => "CURRENT_DATE".to_string(),
57 Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
58 Expr::Aggregate { func, arg, .. } => {
59 let func_name = match func {
60 AggFunc::Count => "COUNT",
61 AggFunc::Sum => "SUM",
62 AggFunc::Avg => "AVG",
63 AggFunc::Min => "MIN",
64 AggFunc::Max => "MAX",
65 };
66 format!("{}({})", func_name, arg)
67 }
68 Expr::Subquery(_) => "(subquery)".to_string(),
69 }
70 }
71
72 pub fn contains_aggregate(&self) -> bool {
73 match self {
74 Expr::Aggregate { .. } => true,
75 Expr::BinaryOp { left, right, .. } => {
76 left.contains_aggregate() || right.contains_aggregate()
77 }
78 Expr::UnaryMinus(inner) => inner.contains_aggregate(),
79 Expr::Case { whens, else_expr } => {
80 whens.iter().any(|(_, e)| e.contains_aggregate())
81 || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
82 }
83 Expr::Subquery(_) => false,
84 _ => false,
85 }
86 }
87}
88
89#[derive(Debug, Clone, PartialEq)]
90pub struct OrderSpec {
91 pub column: String,
92 pub expr: Option<Expr>,
93 pub descending: bool,
94}
95
96#[derive(Debug, Clone, PartialEq)]
97pub enum CmpOp {
98 Eq,
99 Ne,
100 Lt,
101 Gt,
102 Le,
103 Ge,
104 Like,
105 NotLike,
106 In,
107 IsNull,
108 IsNotNull,
109}
110
111#[derive(Debug, Clone, PartialEq)]
112pub enum BoolOpKind {
113 And,
114 Or,
115}
116
117#[derive(Debug, Clone, PartialEq)]
118pub struct Comparison {
119 pub column: String,
120 pub op: CmpOp,
121 pub value: Option<SqlValue>,
122 pub left_expr: Option<Expr>,
123 pub right_expr: Option<Expr>,
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub struct BoolOp {
128 pub op: BoolOpKind,
129 pub left: Box<WhereClause>,
130 pub right: Box<WhereClause>,
131}
132
133#[derive(Debug, Clone, PartialEq)]
134pub enum WhereClause {
135 Comparison(Comparison),
136 BoolOp(BoolOp),
137}
138
139#[derive(Debug, Clone, PartialEq)]
140pub enum SqlValue {
141 String(String),
142 Int(i64),
143 Float(f64),
144 Null,
145 List(Vec<SqlValue>),
146}
147
148#[derive(Debug, Clone, PartialEq)]
149pub enum JoinType {
150 Inner,
151 Left,
152}
153
154#[derive(Debug, Clone, PartialEq)]
155pub struct JoinClause {
156 pub join_type: JoinType,
157 pub table: String,
158 pub alias: Option<String>,
159 pub condition: WhereClause,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub enum AggFunc {
164 Count,
165 Sum,
166 Avg,
167 Min,
168 Max,
169}
170
171#[derive(Debug, Clone, PartialEq)]
172pub enum SelectExpr {
173 Column(String),
174 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
175 Expr { expr: Expr, alias: Option<String> },
176}
177
178impl SelectExpr {
179 pub fn output_name(&self) -> String {
180 match self {
181 SelectExpr::Column(name) => name.clone(),
182 SelectExpr::Aggregate { func, arg, alias, .. } => {
183 if let Some(a) = alias {
184 a.clone()
185 } else {
186 let func_name = match func {
187 AggFunc::Count => "COUNT",
188 AggFunc::Sum => "SUM",
189 AggFunc::Avg => "AVG",
190 AggFunc::Min => "MIN",
191 AggFunc::Max => "MAX",
192 };
193 format!("{}({})", func_name, arg)
194 }
195 }
196 SelectExpr::Expr { expr, alias } => {
197 alias.clone().unwrap_or_else(|| expr.display_name())
198 }
199 }
200 }
201
202 pub fn is_aggregate(&self) -> bool {
203 match self {
204 SelectExpr::Aggregate { .. } => true,
205 SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
206 _ => false,
207 }
208 }
209}
210
211#[derive(Debug, Clone, PartialEq)]
212pub struct CteClause {
213 pub name: String,
214 pub query: Box<SelectQuery>,
215}
216
217#[derive(Debug, Clone, PartialEq)]
218pub struct SelectQuery {
219 pub columns: ColumnList,
220 pub table: String,
221 pub table_alias: Option<String>,
222 pub subquery: Option<Box<SelectQuery>>,
223 pub joins: Vec<JoinClause>,
224 pub where_clause: Option<WhereClause>,
225 pub group_by: Option<Vec<String>>,
226 pub having: Option<WhereClause>,
227 pub order_by: Option<Vec<OrderSpec>>,
228 pub limit: Option<i64>,
229 pub ctes: Vec<CteClause>,
230}
231
232#[derive(Debug, Clone, PartialEq)]
233pub enum ColumnList {
234 All,
235 Named(Vec<SelectExpr>),
236}
237
238#[derive(Debug, Clone, PartialEq)]
239pub struct InsertQuery {
240 pub table: String,
241 pub columns: Vec<String>,
242 pub values: Vec<SqlValue>,
243}
244
245#[derive(Debug, Clone, PartialEq)]
246pub struct UpdateQuery {
247 pub table: String,
248 pub assignments: Vec<(String, SqlValue)>,
249 pub where_clause: Option<WhereClause>,
250}
251
252#[derive(Debug, Clone, PartialEq)]
253pub enum DeleteMode {
254 Default,
255 Cascade,
256 Restrict,
257}
258
259#[derive(Debug, Clone, PartialEq)]
260pub struct DeleteQuery {
261 pub table: String,
262 pub where_clause: Option<WhereClause>,
263 pub mode: DeleteMode,
264}
265
266#[derive(Debug, Clone, PartialEq)]
267pub struct AlterRenameFieldQuery {
268 pub table: String,
269 pub old_name: String,
270 pub new_name: String,
271}
272
273#[derive(Debug, Clone, PartialEq)]
274pub struct AlterDropFieldQuery {
275 pub table: String,
276 pub field_name: String,
277}
278
279#[derive(Debug, Clone, PartialEq)]
280pub struct AlterMergeFieldsQuery {
281 pub table: String,
282 pub sources: Vec<String>,
283 pub into: String,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub struct CreateViewQuery {
288 pub view_name: String,
289 pub columns: Option<Vec<String>>,
290 pub query: Box<SelectQuery>,
291}
292
293#[derive(Debug, Clone, PartialEq)]
294pub struct DropViewQuery {
295 pub view_name: String,
296}
297
298#[derive(Debug, Clone, PartialEq)]
299pub enum Statement {
300 Select(SelectQuery),
301 Insert(InsertQuery),
302 Update(UpdateQuery),
303 Delete(DeleteQuery),
304 AlterRename(AlterRenameFieldQuery),
305 AlterDrop(AlterDropFieldQuery),
306 AlterMerge(AlterMergeFieldsQuery),
307 CreateView(CreateViewQuery),
308 DropView(DropViewQuery),
309}
310
311impl Statement {
312 pub fn table_name(&self) -> &str {
313 match self {
314 Statement::Select(q) => &q.table,
315 Statement::Insert(q) => &q.table,
316 Statement::Update(q) => &q.table,
317 Statement::Delete(q) => &q.table,
318 Statement::AlterRename(q) => &q.table,
319 Statement::AlterDrop(q) => &q.table,
320 Statement::AlterMerge(q) => &q.table,
321 Statement::CreateView(q) => &q.view_name,
322 Statement::DropView(q) => &q.view_name,
323 }
324 }
325}
326
327pub fn where_clause_to_sql(clause: &WhereClause) -> String {
328 match clause {
329 WhereClause::BoolOp(bop) => {
330 let left = where_clause_to_sql(&bop.left);
331 let right = where_clause_to_sql(&bop.right);
332 let op = match bop.op {
333 BoolOpKind::And => "AND",
334 BoolOpKind::Or => "OR",
335 };
336 format!("{} {} {}", left, op, right)
337 }
338 WhereClause::Comparison(cmp) => {
339 let op_str = match cmp.op {
340 CmpOp::Eq => "=",
341 CmpOp::Ne => "!=",
342 CmpOp::Lt => "<",
343 CmpOp::Gt => ">",
344 CmpOp::Le => "<=",
345 CmpOp::Ge => ">=",
346 CmpOp::Like => "LIKE",
347 CmpOp::NotLike => "NOT LIKE",
348 CmpOp::In => "IN",
349 CmpOp::IsNull => "IS NULL",
350 CmpOp::IsNotNull => "IS NOT NULL",
351 };
352 if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
353 if let Some(ref expr) = cmp.left_expr {
354 return format!("{} {}", expr.display_name(), op_str);
355 }
356 return format!("{} {}", cmp.column, op_str);
357 }
358 if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
359 return format!("{} {} {}", left.display_name(), op_str, right.display_name());
360 }
361 match &cmp.value {
362 Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
363 Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
364 Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
365 Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
366 Some(SqlValue::List(items)) => {
367 let vals: Vec<String> = items.iter().map(|v| match v {
368 SqlValue::String(s) => format!("'{}'", s),
369 SqlValue::Int(n) => n.to_string(),
370 SqlValue::Float(f) => f.to_string(),
371 _ => "NULL".to_string(),
372 }).collect();
373 format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
374 }
375 None => format!("{} {}", cmp.column, op_str),
376 }
377 }
378 }
379}