1#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7 Add,
8 Sub,
9 Mul,
10 Div,
11 Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16 Literal(SqlValue),
17 Column(String),
18 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19 UnaryMinus(Box<Expr>),
20 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21 DateAdd { date: Box<Expr>, days: Box<Expr> },
22 DateDiff { left: Box<Expr>, right: Box<Expr> },
23 CurrentDate,
24 CurrentTimestamp,
25 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26 Subquery(Box<SelectQuery>),
27 Window { func: WindowFunc, args: Vec<Expr>, over: WindowSpec },
28}
29
30impl Expr {
31 pub fn as_column(&self) -> Option<&str> {
32 if let Expr::Column(name) = self { Some(name) } else { None }
33 }
34
35 pub fn display_name(&self) -> String {
36 match self {
37 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40 Expr::Literal(SqlValue::Bool(b)) => b.to_string(),
41 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
42 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
43 Expr::Column(name) => name.clone(),
44 Expr::BinaryOp { left, op, right } => {
45 let op_str = match op {
46 ArithOp::Add => "+",
47 ArithOp::Sub => "-",
48 ArithOp::Mul => "*",
49 ArithOp::Div => "/",
50 ArithOp::Mod => "%",
51 };
52 format!("{} {} {}", left.display_name(), op_str, right.display_name())
53 }
54 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
55 Expr::Case { .. } => "CASE".to_string(),
56 Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
57 Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
58 Expr::CurrentDate => "CURRENT_DATE".to_string(),
59 Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
60 Expr::Aggregate { func, arg, .. } => {
61 let func_name = match func {
62 AggFunc::Count => "COUNT",
63 AggFunc::Sum => "SUM",
64 AggFunc::Avg => "AVG",
65 AggFunc::Min => "MIN",
66 AggFunc::Max => "MAX",
67 };
68 format!("{}({})", func_name, arg)
69 }
70 Expr::Subquery(_) => "(subquery)".to_string(),
71 Expr::Window { func, args, .. } => {
72 let func_name = match func {
73 WindowFunc::RowNumber => "ROW_NUMBER",
74 WindowFunc::Rank => "RANK",
75 WindowFunc::DenseRank => "DENSE_RANK",
76 WindowFunc::Lag => "LAG",
77 WindowFunc::Lead => "LEAD",
78 WindowFunc::Agg(af) => match af {
79 AggFunc::Count => "COUNT",
80 AggFunc::Sum => "SUM",
81 AggFunc::Avg => "AVG",
82 AggFunc::Min => "MIN",
83 AggFunc::Max => "MAX",
84 },
85 };
86 if args.is_empty() {
87 format!("{}()", func_name)
88 } else {
89 let arg_strs: Vec<String> = args.iter().map(|a| a.display_name()).collect();
90 format!("{}({})", func_name, arg_strs.join(", "))
91 }
92 }
93 }
94 }
95
96 pub fn contains_aggregate(&self) -> bool {
97 match self {
98 Expr::Aggregate { .. } => true,
99 Expr::BinaryOp { left, right, .. } => {
100 left.contains_aggregate() || right.contains_aggregate()
101 }
102 Expr::UnaryMinus(inner) => inner.contains_aggregate(),
103 Expr::Case { whens, else_expr } => {
104 whens.iter().any(|(_, e)| e.contains_aggregate())
105 || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
106 }
107 Expr::Subquery(_) => false,
108 Expr::Window { .. } => false,
109 _ => false,
110 }
111 }
112
113 pub fn contains_window(&self) -> bool {
114 match self {
115 Expr::Window { .. } => true,
116 Expr::BinaryOp { left, right, .. } => {
117 left.contains_window() || right.contains_window()
118 }
119 Expr::UnaryMinus(inner) => inner.contains_window(),
120 _ => false,
121 }
122 }
123}
124
125#[derive(Debug, Clone, PartialEq)]
126pub struct OrderSpec {
127 pub column: String,
128 pub expr: Option<Expr>,
129 pub descending: bool,
130}
131
132#[derive(Debug, Clone, PartialEq)]
133pub enum CmpOp {
134 Eq,
135 Ne,
136 Lt,
137 Gt,
138 Le,
139 Ge,
140 Like,
141 NotLike,
142 In,
143 IsNull,
144 IsNotNull,
145}
146
147#[derive(Debug, Clone, PartialEq)]
148pub enum BoolOpKind {
149 And,
150 Or,
151}
152
153#[derive(Debug, Clone, PartialEq)]
154pub struct Comparison {
155 pub column: String,
156 pub op: CmpOp,
157 pub value: Option<SqlValue>,
158 pub left_expr: Option<Expr>,
159 pub right_expr: Option<Expr>,
160}
161
162#[derive(Debug, Clone, PartialEq)]
163pub struct BoolOp {
164 pub op: BoolOpKind,
165 pub left: Box<WhereClause>,
166 pub right: Box<WhereClause>,
167}
168
169#[derive(Debug, Clone, PartialEq)]
170pub enum WhereClause {
171 Comparison(Comparison),
172 BoolOp(BoolOp),
173}
174
175#[derive(Debug, Clone, PartialEq)]
176pub enum SqlValue {
177 String(String),
178 Int(i64),
179 Float(f64),
180 Bool(bool),
181 Null,
182 List(Vec<SqlValue>),
183}
184
185#[derive(Debug, Clone, PartialEq)]
186pub enum JoinType {
187 Inner,
188 Left,
189}
190
191#[derive(Debug, Clone, PartialEq)]
192pub struct JoinClause {
193 pub join_type: JoinType,
194 pub table: String,
195 pub alias: Option<String>,
196 pub condition: WhereClause,
197}
198
199#[derive(Debug, Clone, PartialEq)]
200pub enum AggFunc {
201 Count,
202 Sum,
203 Avg,
204 Min,
205 Max,
206}
207
208#[derive(Debug, Clone, PartialEq)]
209pub enum WindowFunc {
210 RowNumber,
211 Rank,
212 DenseRank,
213 Lag,
214 Lead,
215 Agg(AggFunc),
216}
217
218#[derive(Debug, Clone, PartialEq)]
219pub struct WindowSpec {
220 pub partition_by: Vec<String>,
221 pub order_by: Vec<OrderSpec>,
222}
223
224#[derive(Debug, Clone, PartialEq)]
225pub enum SelectExpr {
226 Column(String),
227 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
228 Expr { expr: Expr, alias: Option<String> },
229}
230
231impl SelectExpr {
232 pub fn output_name(&self) -> String {
233 match self {
234 SelectExpr::Column(name) => name.clone(),
235 SelectExpr::Aggregate { func, arg, alias, .. } => {
236 if let Some(a) = alias {
237 a.clone()
238 } else {
239 let func_name = match func {
240 AggFunc::Count => "COUNT",
241 AggFunc::Sum => "SUM",
242 AggFunc::Avg => "AVG",
243 AggFunc::Min => "MIN",
244 AggFunc::Max => "MAX",
245 };
246 format!("{}({})", func_name, arg)
247 }
248 }
249 SelectExpr::Expr { expr, alias } => {
250 alias.clone().unwrap_or_else(|| expr.display_name())
251 }
252 }
253 }
254
255 pub fn is_aggregate(&self) -> bool {
256 match self {
257 SelectExpr::Aggregate { .. } => true,
258 SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
259 _ => false,
260 }
261 }
262}
263
264#[derive(Debug, Clone, PartialEq)]
265pub struct CteClause {
266 pub name: String,
267 pub query: Box<SelectQuery>,
268}
269
270#[derive(Debug, Clone, PartialEq)]
271pub struct SelectQuery {
272 pub distinct: bool,
273 pub columns: ColumnList,
274 pub table: String,
275 pub table_alias: Option<String>,
276 pub subquery: Option<Box<SelectQuery>>,
277 pub joins: Vec<JoinClause>,
278 pub where_clause: Option<WhereClause>,
279 pub group_by: Option<Vec<String>>,
280 pub having: Option<WhereClause>,
281 pub order_by: Option<Vec<OrderSpec>>,
282 pub limit: Option<i64>,
283 pub ctes: Vec<CteClause>,
284}
285
286#[derive(Debug, Clone, PartialEq)]
287pub enum ColumnList {
288 All,
289 Named(Vec<SelectExpr>),
290}
291
292#[derive(Debug, Clone, PartialEq)]
293pub struct InsertQuery {
294 pub table: String,
295 pub columns: Vec<String>,
296 pub values: Vec<SqlValue>,
297}
298
299#[derive(Debug, Clone, PartialEq)]
300pub struct UpdateQuery {
301 pub table: String,
302 pub assignments: Vec<(String, SqlValue)>,
303 pub where_clause: Option<WhereClause>,
304}
305
306#[derive(Debug, Clone, PartialEq)]
307pub enum DeleteMode {
308 Default,
309 Cascade,
310 Restrict,
311}
312
313#[derive(Debug, Clone, PartialEq)]
314pub struct DeleteQuery {
315 pub table: String,
316 pub where_clause: Option<WhereClause>,
317 pub mode: DeleteMode,
318}
319
320#[derive(Debug, Clone, PartialEq)]
321pub struct AlterRenameFieldQuery {
322 pub table: String,
323 pub old_name: String,
324 pub new_name: String,
325}
326
327#[derive(Debug, Clone, PartialEq)]
328pub struct AlterDropFieldQuery {
329 pub table: String,
330 pub field_name: String,
331}
332
333#[derive(Debug, Clone, PartialEq)]
334pub struct AlterMergeFieldsQuery {
335 pub table: String,
336 pub sources: Vec<String>,
337 pub into: String,
338}
339
340#[derive(Debug, Clone, PartialEq)]
341pub struct CreateViewQuery {
342 pub view_name: String,
343 pub columns: Option<Vec<String>>,
344 pub query: Box<SelectQuery>,
345}
346
347#[derive(Debug, Clone, PartialEq)]
348pub struct DropViewQuery {
349 pub view_name: String,
350}
351
352#[derive(Debug, Clone, PartialEq)]
353pub enum Statement {
354 Select(SelectQuery),
355 Insert(InsertQuery),
356 Update(UpdateQuery),
357 Delete(DeleteQuery),
358 AlterRename(AlterRenameFieldQuery),
359 AlterDrop(AlterDropFieldQuery),
360 AlterMerge(AlterMergeFieldsQuery),
361 CreateView(CreateViewQuery),
362 DropView(DropViewQuery),
363}
364
365impl Statement {
366 pub fn table_name(&self) -> &str {
367 match self {
368 Statement::Select(q) => &q.table,
369 Statement::Insert(q) => &q.table,
370 Statement::Update(q) => &q.table,
371 Statement::Delete(q) => &q.table,
372 Statement::AlterRename(q) => &q.table,
373 Statement::AlterDrop(q) => &q.table,
374 Statement::AlterMerge(q) => &q.table,
375 Statement::CreateView(q) => &q.view_name,
376 Statement::DropView(q) => &q.view_name,
377 }
378 }
379}
380
381pub fn where_clause_to_sql(clause: &WhereClause) -> String {
382 match clause {
383 WhereClause::BoolOp(bop) => {
384 let left = where_clause_to_sql(&bop.left);
385 let right = where_clause_to_sql(&bop.right);
386 let op = match bop.op {
387 BoolOpKind::And => "AND",
388 BoolOpKind::Or => "OR",
389 };
390 format!("{} {} {}", left, op, right)
391 }
392 WhereClause::Comparison(cmp) => {
393 let op_str = match cmp.op {
394 CmpOp::Eq => "=",
395 CmpOp::Ne => "!=",
396 CmpOp::Lt => "<",
397 CmpOp::Gt => ">",
398 CmpOp::Le => "<=",
399 CmpOp::Ge => ">=",
400 CmpOp::Like => "LIKE",
401 CmpOp::NotLike => "NOT LIKE",
402 CmpOp::In => "IN",
403 CmpOp::IsNull => "IS NULL",
404 CmpOp::IsNotNull => "IS NOT NULL",
405 };
406 if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
407 if let Some(ref expr) = cmp.left_expr {
408 return format!("{} {}", expr.display_name(), op_str);
409 }
410 return format!("{} {}", cmp.column, op_str);
411 }
412 if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
413 return format!("{} {} {}", left.display_name(), op_str, right.display_name());
414 }
415 match &cmp.value {
416 Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
417 Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
418 Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
419 Some(SqlValue::Bool(b)) => format!("{} {} {}", cmp.column, op_str, b),
420 Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
421 Some(SqlValue::List(items)) => {
422 let vals: Vec<String> = items.iter().map(|v| match v {
423 SqlValue::String(s) => format!("'{}'", s),
424 SqlValue::Int(n) => n.to_string(),
425 SqlValue::Float(f) => f.to_string(),
426 _ => "NULL".to_string(),
427 }).collect();
428 format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
429 }
430 None => format!("{} {}", cmp.column, op_str),
431 }
432 }
433 }
434}