1#[derive(Debug, Clone, PartialEq)]
6pub enum ArithOp {
7 Add,
8 Sub,
9 Mul,
10 Div,
11 Mod,
12}
13
14#[derive(Debug, Clone, PartialEq)]
15pub enum Expr {
16 Literal(SqlValue),
17 Column(String),
18 BinaryOp { left: Box<Expr>, op: ArithOp, right: Box<Expr> },
19 UnaryMinus(Box<Expr>),
20 Case { whens: Vec<(WhereClause, Box<Expr>)>, else_expr: Option<Box<Expr>> },
21 DateAdd { date: Box<Expr>, days: Box<Expr> },
22 DateDiff { left: Box<Expr>, right: Box<Expr> },
23 CurrentDate,
24 CurrentTimestamp,
25 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Box<Expr>> },
26 Subquery(Box<SelectQuery>),
27 Window { func: WindowFunc, args: Vec<Expr>, over: WindowSpec },
28}
29
30impl Expr {
31 pub fn as_column(&self) -> Option<&str> {
32 if let Expr::Column(name) = self { Some(name) } else { None }
33 }
34
35 pub fn display_name(&self) -> String {
36 match self {
37 Expr::Literal(SqlValue::Int(n)) => n.to_string(),
38 Expr::Literal(SqlValue::Float(f)) => f.to_string(),
39 Expr::Literal(SqlValue::String(s)) => format!("'{}'", s),
40 Expr::Literal(SqlValue::Null) => "NULL".to_string(),
41 Expr::Literal(SqlValue::List(_)) => "list".to_string(),
42 Expr::Column(name) => name.clone(),
43 Expr::BinaryOp { left, op, right } => {
44 let op_str = match op {
45 ArithOp::Add => "+",
46 ArithOp::Sub => "-",
47 ArithOp::Mul => "*",
48 ArithOp::Div => "/",
49 ArithOp::Mod => "%",
50 };
51 format!("{} {} {}", left.display_name(), op_str, right.display_name())
52 }
53 Expr::UnaryMinus(inner) => format!("-{}", inner.display_name()),
54 Expr::Case { .. } => "CASE".to_string(),
55 Expr::DateAdd { date, days } => format!("DATE_ADD({}, {})", date.display_name(), days.display_name()),
56 Expr::DateDiff { left, right } => format!("DATEDIFF({}, {})", left.display_name(), right.display_name()),
57 Expr::CurrentDate => "CURRENT_DATE".to_string(),
58 Expr::CurrentTimestamp => "CURRENT_TIMESTAMP".to_string(),
59 Expr::Aggregate { func, arg, .. } => {
60 let func_name = match func {
61 AggFunc::Count => "COUNT",
62 AggFunc::Sum => "SUM",
63 AggFunc::Avg => "AVG",
64 AggFunc::Min => "MIN",
65 AggFunc::Max => "MAX",
66 };
67 format!("{}({})", func_name, arg)
68 }
69 Expr::Subquery(_) => "(subquery)".to_string(),
70 Expr::Window { func, args, .. } => {
71 let func_name = match func {
72 WindowFunc::RowNumber => "ROW_NUMBER",
73 WindowFunc::Rank => "RANK",
74 WindowFunc::DenseRank => "DENSE_RANK",
75 WindowFunc::Lag => "LAG",
76 WindowFunc::Lead => "LEAD",
77 WindowFunc::Agg(af) => match af {
78 AggFunc::Count => "COUNT",
79 AggFunc::Sum => "SUM",
80 AggFunc::Avg => "AVG",
81 AggFunc::Min => "MIN",
82 AggFunc::Max => "MAX",
83 },
84 };
85 if args.is_empty() {
86 format!("{}()", func_name)
87 } else {
88 let arg_strs: Vec<String> = args.iter().map(|a| a.display_name()).collect();
89 format!("{}({})", func_name, arg_strs.join(", "))
90 }
91 }
92 }
93 }
94
95 pub fn contains_aggregate(&self) -> bool {
96 match self {
97 Expr::Aggregate { .. } => true,
98 Expr::BinaryOp { left, right, .. } => {
99 left.contains_aggregate() || right.contains_aggregate()
100 }
101 Expr::UnaryMinus(inner) => inner.contains_aggregate(),
102 Expr::Case { whens, else_expr } => {
103 whens.iter().any(|(_, e)| e.contains_aggregate())
104 || else_expr.as_ref().map_or(false, |e| e.contains_aggregate())
105 }
106 Expr::Subquery(_) => false,
107 Expr::Window { .. } => false,
108 _ => false,
109 }
110 }
111
112 pub fn contains_window(&self) -> bool {
113 match self {
114 Expr::Window { .. } => true,
115 Expr::BinaryOp { left, right, .. } => {
116 left.contains_window() || right.contains_window()
117 }
118 Expr::UnaryMinus(inner) => inner.contains_window(),
119 _ => false,
120 }
121 }
122}
123
124#[derive(Debug, Clone, PartialEq)]
125pub struct OrderSpec {
126 pub column: String,
127 pub expr: Option<Expr>,
128 pub descending: bool,
129}
130
131#[derive(Debug, Clone, PartialEq)]
132pub enum CmpOp {
133 Eq,
134 Ne,
135 Lt,
136 Gt,
137 Le,
138 Ge,
139 Like,
140 NotLike,
141 In,
142 IsNull,
143 IsNotNull,
144}
145
146#[derive(Debug, Clone, PartialEq)]
147pub enum BoolOpKind {
148 And,
149 Or,
150}
151
152#[derive(Debug, Clone, PartialEq)]
153pub struct Comparison {
154 pub column: String,
155 pub op: CmpOp,
156 pub value: Option<SqlValue>,
157 pub left_expr: Option<Expr>,
158 pub right_expr: Option<Expr>,
159}
160
161#[derive(Debug, Clone, PartialEq)]
162pub struct BoolOp {
163 pub op: BoolOpKind,
164 pub left: Box<WhereClause>,
165 pub right: Box<WhereClause>,
166}
167
168#[derive(Debug, Clone, PartialEq)]
169pub enum WhereClause {
170 Comparison(Comparison),
171 BoolOp(BoolOp),
172}
173
174#[derive(Debug, Clone, PartialEq)]
175pub enum SqlValue {
176 String(String),
177 Int(i64),
178 Float(f64),
179 Null,
180 List(Vec<SqlValue>),
181}
182
183#[derive(Debug, Clone, PartialEq)]
184pub enum JoinType {
185 Inner,
186 Left,
187}
188
189#[derive(Debug, Clone, PartialEq)]
190pub struct JoinClause {
191 pub join_type: JoinType,
192 pub table: String,
193 pub alias: Option<String>,
194 pub condition: WhereClause,
195}
196
197#[derive(Debug, Clone, PartialEq)]
198pub enum AggFunc {
199 Count,
200 Sum,
201 Avg,
202 Min,
203 Max,
204}
205
206#[derive(Debug, Clone, PartialEq)]
207pub enum WindowFunc {
208 RowNumber,
209 Rank,
210 DenseRank,
211 Lag,
212 Lead,
213 Agg(AggFunc),
214}
215
216#[derive(Debug, Clone, PartialEq)]
217pub struct WindowSpec {
218 pub partition_by: Vec<String>,
219 pub order_by: Vec<OrderSpec>,
220}
221
222#[derive(Debug, Clone, PartialEq)]
223pub enum SelectExpr {
224 Column(String),
225 Aggregate { func: AggFunc, arg: String, arg_expr: Option<Expr>, alias: Option<String> },
226 Expr { expr: Expr, alias: Option<String> },
227}
228
229impl SelectExpr {
230 pub fn output_name(&self) -> String {
231 match self {
232 SelectExpr::Column(name) => name.clone(),
233 SelectExpr::Aggregate { func, arg, alias, .. } => {
234 if let Some(a) = alias {
235 a.clone()
236 } else {
237 let func_name = match func {
238 AggFunc::Count => "COUNT",
239 AggFunc::Sum => "SUM",
240 AggFunc::Avg => "AVG",
241 AggFunc::Min => "MIN",
242 AggFunc::Max => "MAX",
243 };
244 format!("{}({})", func_name, arg)
245 }
246 }
247 SelectExpr::Expr { expr, alias } => {
248 alias.clone().unwrap_or_else(|| expr.display_name())
249 }
250 }
251 }
252
253 pub fn is_aggregate(&self) -> bool {
254 match self {
255 SelectExpr::Aggregate { .. } => true,
256 SelectExpr::Expr { expr, .. } => expr.contains_aggregate(),
257 _ => false,
258 }
259 }
260}
261
262#[derive(Debug, Clone, PartialEq)]
263pub struct CteClause {
264 pub name: String,
265 pub query: Box<SelectQuery>,
266}
267
268#[derive(Debug, Clone, PartialEq)]
269pub struct SelectQuery {
270 pub columns: ColumnList,
271 pub table: String,
272 pub table_alias: Option<String>,
273 pub subquery: Option<Box<SelectQuery>>,
274 pub joins: Vec<JoinClause>,
275 pub where_clause: Option<WhereClause>,
276 pub group_by: Option<Vec<String>>,
277 pub having: Option<WhereClause>,
278 pub order_by: Option<Vec<OrderSpec>>,
279 pub limit: Option<i64>,
280 pub ctes: Vec<CteClause>,
281}
282
283#[derive(Debug, Clone, PartialEq)]
284pub enum ColumnList {
285 All,
286 Named(Vec<SelectExpr>),
287}
288
289#[derive(Debug, Clone, PartialEq)]
290pub struct InsertQuery {
291 pub table: String,
292 pub columns: Vec<String>,
293 pub values: Vec<SqlValue>,
294}
295
296#[derive(Debug, Clone, PartialEq)]
297pub struct UpdateQuery {
298 pub table: String,
299 pub assignments: Vec<(String, SqlValue)>,
300 pub where_clause: Option<WhereClause>,
301}
302
303#[derive(Debug, Clone, PartialEq)]
304pub enum DeleteMode {
305 Default,
306 Cascade,
307 Restrict,
308}
309
310#[derive(Debug, Clone, PartialEq)]
311pub struct DeleteQuery {
312 pub table: String,
313 pub where_clause: Option<WhereClause>,
314 pub mode: DeleteMode,
315}
316
317#[derive(Debug, Clone, PartialEq)]
318pub struct AlterRenameFieldQuery {
319 pub table: String,
320 pub old_name: String,
321 pub new_name: String,
322}
323
324#[derive(Debug, Clone, PartialEq)]
325pub struct AlterDropFieldQuery {
326 pub table: String,
327 pub field_name: String,
328}
329
330#[derive(Debug, Clone, PartialEq)]
331pub struct AlterMergeFieldsQuery {
332 pub table: String,
333 pub sources: Vec<String>,
334 pub into: String,
335}
336
337#[derive(Debug, Clone, PartialEq)]
338pub struct CreateViewQuery {
339 pub view_name: String,
340 pub columns: Option<Vec<String>>,
341 pub query: Box<SelectQuery>,
342}
343
344#[derive(Debug, Clone, PartialEq)]
345pub struct DropViewQuery {
346 pub view_name: String,
347}
348
349#[derive(Debug, Clone, PartialEq)]
350pub enum Statement {
351 Select(SelectQuery),
352 Insert(InsertQuery),
353 Update(UpdateQuery),
354 Delete(DeleteQuery),
355 AlterRename(AlterRenameFieldQuery),
356 AlterDrop(AlterDropFieldQuery),
357 AlterMerge(AlterMergeFieldsQuery),
358 CreateView(CreateViewQuery),
359 DropView(DropViewQuery),
360}
361
362impl Statement {
363 pub fn table_name(&self) -> &str {
364 match self {
365 Statement::Select(q) => &q.table,
366 Statement::Insert(q) => &q.table,
367 Statement::Update(q) => &q.table,
368 Statement::Delete(q) => &q.table,
369 Statement::AlterRename(q) => &q.table,
370 Statement::AlterDrop(q) => &q.table,
371 Statement::AlterMerge(q) => &q.table,
372 Statement::CreateView(q) => &q.view_name,
373 Statement::DropView(q) => &q.view_name,
374 }
375 }
376}
377
378pub fn where_clause_to_sql(clause: &WhereClause) -> String {
379 match clause {
380 WhereClause::BoolOp(bop) => {
381 let left = where_clause_to_sql(&bop.left);
382 let right = where_clause_to_sql(&bop.right);
383 let op = match bop.op {
384 BoolOpKind::And => "AND",
385 BoolOpKind::Or => "OR",
386 };
387 format!("{} {} {}", left, op, right)
388 }
389 WhereClause::Comparison(cmp) => {
390 let op_str = match cmp.op {
391 CmpOp::Eq => "=",
392 CmpOp::Ne => "!=",
393 CmpOp::Lt => "<",
394 CmpOp::Gt => ">",
395 CmpOp::Le => "<=",
396 CmpOp::Ge => ">=",
397 CmpOp::Like => "LIKE",
398 CmpOp::NotLike => "NOT LIKE",
399 CmpOp::In => "IN",
400 CmpOp::IsNull => "IS NULL",
401 CmpOp::IsNotNull => "IS NOT NULL",
402 };
403 if matches!(cmp.op, CmpOp::IsNull | CmpOp::IsNotNull) {
404 if let Some(ref expr) = cmp.left_expr {
405 return format!("{} {}", expr.display_name(), op_str);
406 }
407 return format!("{} {}", cmp.column, op_str);
408 }
409 if let (Some(ref left), Some(ref right)) = (&cmp.left_expr, &cmp.right_expr) {
410 return format!("{} {} {}", left.display_name(), op_str, right.display_name());
411 }
412 match &cmp.value {
413 Some(SqlValue::String(s)) => format!("{} {} '{}'", cmp.column, op_str, s),
414 Some(SqlValue::Int(n)) => format!("{} {} {}", cmp.column, op_str, n),
415 Some(SqlValue::Float(f)) => format!("{} {} {}", cmp.column, op_str, f),
416 Some(SqlValue::Null) => format!("{} {} NULL", cmp.column, op_str),
417 Some(SqlValue::List(items)) => {
418 let vals: Vec<String> = items.iter().map(|v| match v {
419 SqlValue::String(s) => format!("'{}'", s),
420 SqlValue::Int(n) => n.to_string(),
421 SqlValue::Float(f) => f.to_string(),
422 _ => "NULL".to_string(),
423 }).collect();
424 format!("{} {} ({})", cmp.column, op_str, vals.join(", "))
425 }
426 None => format!("{} {}", cmp.column, op_str),
427 }
428 }
429 }
430}