1#[derive(Debug, Clone)]
2pub enum Statement {
3 With(WithStatement),
4 Select(SelectStatement),
5 SetOp {
6 left: Box<Statement>,
7 op: SetOperator,
8 right: Box<Statement>,
9 },
10 Explain(Box<Statement>),
11 CreateTable(CreateTableStatement),
12 DropTable(DropTableStatement),
13 Truncate(TruncateStatement),
14 Analyze(AnalyzeStatement),
15 Insert(InsertStatement),
16 Update(UpdateStatement),
17 Delete(DeleteStatement),
18}
19
20#[derive(Debug, Clone)]
21pub struct CreateTableStatement {
22 pub name: String,
23 pub if_not_exists: bool,
24 pub columns: Vec<ColumnDef>,
25}
26
27#[derive(Debug, Clone)]
28pub struct ColumnDef {
29 pub name: String,
30 pub data_type: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct DropTableStatement {
35 pub name: String,
36 pub if_exists: bool,
37}
38
39#[derive(Debug, Clone)]
40pub struct TruncateStatement {
41 pub table: String,
42}
43
44#[derive(Debug, Clone)]
45pub struct AnalyzeStatement {
46 pub table: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct InsertStatement {
51 pub table: String,
52 pub columns: Vec<String>,
53 pub source: InsertSource,
54 pub returning: Vec<SelectItem>,
55}
56
57#[derive(Debug, Clone)]
58pub enum InsertSource {
59 Values(Vec<Vec<Expr>>),
60 Query(Box<Statement>),
61 DefaultValues,
62}
63
64#[derive(Debug, Clone)]
65pub struct UpdateStatement {
66 pub table: String,
67 pub assignments: Vec<Assignment>,
68 pub selection: Option<Expr>,
69 pub returning: Vec<SelectItem>,
70}
71
72#[derive(Debug, Clone)]
73pub struct Assignment {
74 pub column: String,
75 pub value: Expr,
76}
77
78#[derive(Debug, Clone)]
79pub struct DeleteStatement {
80 pub table: String,
81 pub selection: Option<Expr>,
82 pub returning: Vec<SelectItem>,
83}
84
85#[derive(Debug, Clone)]
86pub struct SelectStatement {
87 pub distinct: bool,
88 pub distinct_on: Vec<Expr>,
89 pub projection: Vec<SelectItem>,
90 pub from: Option<TableRef>,
91 pub selection: Option<Expr>,
92 pub group_by: Vec<Expr>,
93 pub having: Option<Expr>,
94 pub order_by: Vec<OrderByExpr>,
95 pub limit: Option<u64>,
96 pub offset: Option<u64>,
97}
98
99#[derive(Debug, Clone)]
100pub struct WithStatement {
101 pub ctes: Vec<Cte>,
102 pub recursive: bool,
103 pub statement: Box<Statement>,
104}
105
106#[derive(Debug, Clone)]
107pub struct Cte {
108 pub name: String,
109 pub columns: Vec<String>,
110 pub query: Box<Statement>,
111}
112
113#[derive(Debug, Clone, Copy)]
114pub enum SetOperator {
115 Union,
116 UnionAll,
117 Intersect,
118 IntersectAll,
119 Except,
120 ExceptAll,
121}
122
123#[derive(Debug, Clone)]
124pub struct TableRef {
125 pub factor: TableFactor,
126 pub alias: Option<String>,
127 pub column_aliases: Vec<String>,
128 pub joins: Vec<Join>,
129}
130
131#[derive(Debug, Clone)]
132pub enum TableFactor {
133 Table { name: String },
134 Derived { query: Box<Statement> },
135}
136
137#[derive(Debug, Clone)]
138pub struct Join {
139 pub join_type: JoinType,
140 pub right: TableRef,
141 pub on: Expr,
142}
143
144#[derive(Debug, Clone, Copy)]
145pub enum JoinType {
146 Inner,
147 Left,
148 Right,
149 Full,
150}
151
152#[derive(Debug, Clone)]
153pub struct SelectItem {
154 pub expr: Expr,
155 pub alias: Option<String>,
156}
157
158#[derive(Debug, Clone)]
159pub struct OrderByExpr {
160 pub expr: Expr,
161 pub asc: bool,
162 pub nulls_first: Option<bool>,
163}
164
165#[derive(Debug, Clone)]
166pub enum Expr {
167 Identifier(String),
168 Literal(Literal),
169 BinaryOp {
170 left: Box<Expr>,
171 op: BinaryOperator,
172 right: Box<Expr>,
173 },
174 IsNull {
175 expr: Box<Expr>,
176 negated: bool,
177 },
178 UnaryOp {
179 op: UnaryOperator,
180 expr: Box<Expr>,
181 },
182 FunctionCall {
183 name: String,
184 args: Vec<Expr>,
185 },
186 WindowFunction {
187 function: Box<Expr>,
188 spec: WindowSpec,
189 },
190 Subquery(Box<SelectStatement>),
191 Exists(Box<SelectStatement>),
192 InSubquery {
193 expr: Box<Expr>,
194 subquery: Box<SelectStatement>,
195 },
196 Case {
197 operand: Option<Box<Expr>>,
198 when_then: Vec<(Expr, Expr)>,
199 else_expr: Option<Box<Expr>>,
200 },
201 Wildcard,
202}
203
204#[derive(Debug, Clone)]
205pub enum Literal {
206 String(String),
207 Number(f64),
208 Bool(bool),
209}
210
211#[derive(Debug, Clone, Copy, PartialEq)]
212pub enum BinaryOperator {
213 Eq,
214 NotEq,
215 Lt,
216 LtEq,
217 Gt,
218 GtEq,
219 And,
220 Or,
221 Add,
222 Sub,
223 Mul,
224 Div,
225}
226
227#[derive(Debug, Clone, Copy, PartialEq)]
228pub enum UnaryOperator {
229 Not,
230 Neg,
231}
232
233#[derive(Debug, Clone)]
234pub struct WindowSpec {
235 pub partition_by: Vec<Expr>,
236 pub order_by: Vec<OrderByExpr>,
237}
238
239impl Expr {
240 pub fn to_sql(&self) -> String {
241 match self {
242 Expr::Identifier(name) => name.clone(),
243 Expr::Literal(Literal::String(value)) => format!("'{}'", value),
244 Expr::Literal(Literal::Number(value)) => value.to_string(),
245 Expr::Literal(Literal::Bool(value)) => {
246 if *value { "true".to_string() } else { "false".to_string() }
247 }
248 Expr::BinaryOp { left, op, right } => {
249 let op_str = match op {
250 BinaryOperator::Eq => "=",
251 BinaryOperator::NotEq => "!=",
252 BinaryOperator::Lt => "<",
253 BinaryOperator::LtEq => "<=",
254 BinaryOperator::Gt => ">",
255 BinaryOperator::GtEq => ">=",
256 BinaryOperator::And => "and",
257 BinaryOperator::Or => "or",
258 BinaryOperator::Add => "+",
259 BinaryOperator::Sub => "-",
260 BinaryOperator::Mul => "*",
261 BinaryOperator::Div => "/",
262 };
263 format!("{} {} {}", left.to_sql(), op_str, right.to_sql())
264 }
265 Expr::IsNull { expr, negated } => {
266 if *negated {
267 format!("{} is not null", expr.to_sql())
268 } else {
269 format!("{} is null", expr.to_sql())
270 }
271 }
272 Expr::UnaryOp { op, expr } => match op {
273 UnaryOperator::Not => format!("not {}", expr.to_sql()),
274 UnaryOperator::Neg => format!("-{}", expr.to_sql()),
275 },
276 Expr::FunctionCall { name, args } => {
277 let args_sql = args.iter().map(|arg| arg.to_sql()).collect::<Vec<_>>();
278 format!("{}({})", name, args_sql.join(", "))
279 }
280 Expr::WindowFunction { function, spec } => {
281 let mut clauses = Vec::new();
282 if !spec.partition_by.is_empty() {
283 let partition = spec
284 .partition_by
285 .iter()
286 .map(|expr| expr.to_sql())
287 .collect::<Vec<_>>()
288 .join(", ");
289 clauses.push(format!("partition by {partition}"));
290 }
291 if !spec.order_by.is_empty() {
292 let order = spec
293 .order_by
294 .iter()
295 .map(|item| {
296 let dir = if item.asc { "asc" } else { "desc" };
297 let mut rendered = format!("{} {dir}", item.expr.to_sql());
298 if let Some(nulls_first) = item.nulls_first {
299 if nulls_first {
300 rendered.push_str(" nulls first");
301 } else {
302 rendered.push_str(" nulls last");
303 }
304 }
305 rendered
306 })
307 .collect::<Vec<_>>()
308 .join(", ");
309 clauses.push(format!("order by {order}"));
310 }
311 format!("{} over ({})", function.to_sql(), clauses.join(" "))
312 }
313 Expr::Exists(select) => format!("exists ({})", select_to_sql(select)),
314 Expr::InSubquery { expr, subquery } => {
315 format!("{} in ({})", expr.to_sql(), select_to_sql(subquery))
316 }
317 Expr::Subquery(select) => format!("({})", select_to_sql(select)),
318 Expr::Case {
319 operand,
320 when_then,
321 else_expr,
322 } => {
323 let mut output = String::from("case");
324 if let Some(expr) = operand {
325 output.push(' ');
326 output.push_str(&expr.to_sql());
327 }
328 for (when_expr, then_expr) in when_then {
329 output.push_str(" when ");
330 output.push_str(&when_expr.to_sql());
331 output.push_str(" then ");
332 output.push_str(&then_expr.to_sql());
333 }
334 if let Some(expr) = else_expr {
335 output.push_str(" else ");
336 output.push_str(&expr.to_sql());
337 }
338 output.push_str(" end");
339 output
340 }
341 Expr::Wildcard => "*".to_string(),
342 }
343 }
344
345 pub fn structural_eq(&self, other: &Expr) -> bool {
346 const FLOAT_EPSILON: f64 = 1e-9;
347 match (self, other) {
348 (Expr::Identifier(left), Expr::Identifier(right)) => left == right,
349 (Expr::Literal(left), Expr::Literal(right)) => match (left, right) {
350 (Literal::String(left), Literal::String(right)) => left == right,
351 (Literal::Number(left), Literal::Number(right)) => {
352 if left.is_nan() || right.is_nan() {
353 false
354 } else if left.is_infinite() || right.is_infinite() {
355 left == right
356 } else {
357 (left - right).abs() <= FLOAT_EPSILON
358 }
359 }
360 (Literal::Bool(left), Literal::Bool(right)) => left == right,
361 _ => false,
362 },
363 (Expr::UnaryOp { op: left_op, expr: left }, Expr::UnaryOp { op: right_op, expr: right }) => {
364 left_op == right_op && left.structural_eq(right)
365 }
366 (
367 Expr::BinaryOp { left: left_lhs, op: left_op, right: left_rhs },
368 Expr::BinaryOp { left: right_lhs, op: right_op, right: right_rhs },
369 ) => left_op == right_op
370 && left_lhs.structural_eq(right_lhs)
371 && left_rhs.structural_eq(right_rhs),
372 (
373 Expr::IsNull { expr: left, negated: left_negated },
374 Expr::IsNull { expr: right, negated: right_negated },
375 ) => left_negated == right_negated && left.structural_eq(right),
376 (
377 Expr::FunctionCall { name: left_name, args: left_args },
378 Expr::FunctionCall { name: right_name, args: right_args },
379 ) => left_name == right_name
380 && left_args.len() == right_args.len()
381 && left_args
382 .iter()
383 .zip(right_args.iter())
384 .all(|(left, right)| left.structural_eq(right)),
385 (
386 Expr::WindowFunction { function: left_func, spec: left_spec },
387 Expr::WindowFunction { function: right_func, spec: right_spec },
388 ) => left_func.structural_eq(right_func)
389 && left_spec.partition_by.len() == right_spec.partition_by.len()
390 && left_spec
391 .partition_by
392 .iter()
393 .zip(right_spec.partition_by.iter())
394 .all(|(left, right)| left.structural_eq(right))
395 && left_spec.order_by.len() == right_spec.order_by.len()
396 && left_spec
397 .order_by
398 .iter()
399 .zip(right_spec.order_by.iter())
400 .all(|(left, right)| {
401 left.asc == right.asc
402 && left.nulls_first == right.nulls_first
403 && left.expr.structural_eq(&right.expr)
404 }),
405 (Expr::Subquery(left), Expr::Subquery(right)) => select_to_sql(left) == select_to_sql(right),
406 (Expr::Exists(left), Expr::Exists(right)) => select_to_sql(left) == select_to_sql(right),
407 (
408 Expr::InSubquery { expr: left_expr, subquery: left_subquery },
409 Expr::InSubquery { expr: right_expr, subquery: right_subquery },
410 ) => left_expr.structural_eq(right_expr)
411 && select_to_sql(left_subquery) == select_to_sql(right_subquery),
412 (
413 Expr::Case { operand: left_operand, when_then: left_when_then, else_expr: left_else },
414 Expr::Case { operand: right_operand, when_then: right_when_then, else_expr: right_else },
415 ) => left_operand
416 .as_ref()
417 .zip(right_operand.as_ref())
418 .map(|(left, right)| left.structural_eq(right))
419 .unwrap_or(left_operand.is_none() && right_operand.is_none())
420 && left_when_then.len() == right_when_then.len()
421 && left_when_then.iter().zip(right_when_then.iter()).all(|(left, right)| {
422 left.0.structural_eq(&right.0) && left.1.structural_eq(&right.1)
423 })
424 && left_else
425 .as_ref()
426 .zip(right_else.as_ref())
427 .map(|(left, right)| left.structural_eq(right))
428 .unwrap_or(left_else.is_none() && right_else.is_none()),
429 (Expr::Wildcard, Expr::Wildcard) => true,
430 _ => false,
431 }
432 }
433
434 pub fn normalize(&self) -> Expr {
435 let normalized = match self {
436 Expr::BinaryOp { left, op, right } => {
437 let left_norm = left.normalize();
438 let right_norm = right.normalize();
439 if matches!(op, BinaryOperator::And | BinaryOperator::Or) {
440 if left_norm.to_sql() > right_norm.to_sql() {
441 return Expr::BinaryOp {
442 left: Box::new(right_norm),
443 op: *op,
444 right: Box::new(left_norm),
445 };
446 }
447 }
448 Expr::BinaryOp {
449 left: Box::new(left_norm),
450 op: *op,
451 right: Box::new(right_norm),
452 }
453 }
454 Expr::IsNull { expr, negated } => Expr::IsNull {
455 expr: Box::new(expr.normalize()),
456 negated: *negated,
457 },
458 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
459 op: *op,
460 expr: Box::new(expr.normalize()),
461 },
462 Expr::FunctionCall { name, args } => Expr::FunctionCall {
463 name: name.clone(),
464 args: args.iter().map(|arg| arg.normalize()).collect(),
465 },
466 Expr::Case {
467 operand,
468 when_then,
469 else_expr,
470 } => Expr::Case {
471 operand: operand.as_ref().map(|expr| Box::new(expr.normalize())),
472 when_then: when_then
473 .iter()
474 .map(|(when_expr, then_expr)| (when_expr.normalize(), then_expr.normalize()))
475 .collect(),
476 else_expr: else_expr.as_ref().map(|expr| Box::new(expr.normalize())),
477 },
478 Expr::WindowFunction { function, spec } => Expr::WindowFunction {
479 function: Box::new(function.normalize()),
480 spec: spec.clone(),
481 },
482 Expr::Exists(select) => Expr::Exists(Box::new(normalize_select_inner(select))),
483 Expr::InSubquery { expr, subquery } => Expr::InSubquery {
484 expr: Box::new(expr.normalize()),
485 subquery: Box::new(normalize_select_inner(subquery)),
486 },
487 Expr::Subquery(select) => Expr::Subquery(Box::new(normalize_select_inner(select))),
488 other => other.clone(),
489 };
490 rewrite_strong_expr(normalized)
491 }
492}
493
494fn rewrite_strong_expr(expr: Expr) -> Expr {
495 match expr {
496 Expr::UnaryOp {
497 op: UnaryOperator::Not,
498 expr,
499 } => match *expr {
500 Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(!value)),
501 Expr::UnaryOp {
502 op: UnaryOperator::Not,
503 expr,
504 } => *expr,
505 Expr::IsNull { expr, negated } => Expr::IsNull {
506 expr,
507 negated: !negated,
508 },
509 Expr::BinaryOp { left, op, right } => match op {
510 BinaryOperator::Eq => Expr::BinaryOp {
511 left,
512 op: BinaryOperator::NotEq,
513 right,
514 },
515 BinaryOperator::NotEq => Expr::BinaryOp {
516 left,
517 op: BinaryOperator::Eq,
518 right,
519 },
520 BinaryOperator::Lt => Expr::BinaryOp {
521 left,
522 op: BinaryOperator::GtEq,
523 right,
524 },
525 BinaryOperator::LtEq => Expr::BinaryOp {
526 left,
527 op: BinaryOperator::Gt,
528 right,
529 },
530 BinaryOperator::Gt => Expr::BinaryOp {
531 left,
532 op: BinaryOperator::LtEq,
533 right,
534 },
535 BinaryOperator::GtEq => Expr::BinaryOp {
536 left,
537 op: BinaryOperator::Lt,
538 right,
539 },
540 BinaryOperator::And => Expr::BinaryOp {
541 left: Box::new(Expr::UnaryOp {
542 op: UnaryOperator::Not,
543 expr: left,
544 }),
545 op: BinaryOperator::Or,
546 right: Box::new(Expr::UnaryOp {
547 op: UnaryOperator::Not,
548 expr: right,
549 }),
550 },
551 BinaryOperator::Or => Expr::BinaryOp {
552 left: Box::new(Expr::UnaryOp {
553 op: UnaryOperator::Not,
554 expr: left,
555 }),
556 op: BinaryOperator::And,
557 right: Box::new(Expr::UnaryOp {
558 op: UnaryOperator::Not,
559 expr: right,
560 }),
561 },
562 _ => Expr::UnaryOp {
563 op: UnaryOperator::Not,
564 expr: Box::new(Expr::BinaryOp { left, op, right }),
565 },
566 },
567 other => Expr::UnaryOp {
568 op: UnaryOperator::Not,
569 expr: Box::new(other),
570 },
571 },
572 Expr::UnaryOp {
573 op: UnaryOperator::Neg,
574 expr,
575 } => match *expr {
576 Expr::Literal(Literal::Number(value)) => {
577 Expr::Literal(Literal::Number(-value))
578 }
579 other => Expr::UnaryOp {
580 op: UnaryOperator::Neg,
581 expr: Box::new(other),
582 },
583 },
584 Expr::BinaryOp { left, op, right } => {
585 if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
586 return *left;
587 }
588 let same_expr = left.structural_eq(&right);
589 if same_expr {
590 return match op {
591 BinaryOperator::Eq | BinaryOperator::LtEq | BinaryOperator::GtEq => {
592 Expr::Literal(Literal::Bool(true))
593 }
594 BinaryOperator::NotEq | BinaryOperator::Lt | BinaryOperator::Gt => {
595 Expr::Literal(Literal::Bool(false))
596 }
597 _ => Expr::BinaryOp { left, op, right },
598 };
599 }
600 match (*left, op, *right) {
601 (Expr::Literal(Literal::Bool(a)), BinaryOperator::And, Expr::Literal(Literal::Bool(b))) => {
602 Expr::Literal(Literal::Bool(a && b))
603 }
604 (Expr::Literal(Literal::Bool(a)), BinaryOperator::Or, Expr::Literal(Literal::Bool(b))) => {
605 Expr::Literal(Literal::Bool(a || b))
606 }
607 (Expr::Literal(Literal::Bool(a)), BinaryOperator::And, other) => {
608 if a { other } else { Expr::Literal(Literal::Bool(false)) }
609 }
610 (other, BinaryOperator::And, Expr::Literal(Literal::Bool(b))) => {
611 if b { other } else { Expr::Literal(Literal::Bool(false)) }
612 }
613 (Expr::Literal(Literal::Bool(a)), BinaryOperator::Or, other) => {
614 if a { Expr::Literal(Literal::Bool(true)) } else { other }
615 }
616 (other, BinaryOperator::Or, Expr::Literal(Literal::Bool(b))) => {
617 if b { Expr::Literal(Literal::Bool(true)) } else { other }
618 }
619 (Expr::Literal(Literal::Number(a)), BinaryOperator::Eq, Expr::Literal(Literal::Number(b))) => {
620 Expr::Literal(Literal::Bool(a == b))
621 }
622 (Expr::Literal(Literal::Number(a)), BinaryOperator::NotEq, Expr::Literal(Literal::Number(b))) => {
623 Expr::Literal(Literal::Bool(a != b))
624 }
625 (Expr::Literal(Literal::Number(a)), BinaryOperator::Lt, Expr::Literal(Literal::Number(b))) => {
626 Expr::Literal(Literal::Bool(a < b))
627 }
628 (Expr::Literal(Literal::Number(a)), BinaryOperator::LtEq, Expr::Literal(Literal::Number(b))) => {
629 Expr::Literal(Literal::Bool(a <= b))
630 }
631 (Expr::Literal(Literal::Number(a)), BinaryOperator::Gt, Expr::Literal(Literal::Number(b))) => {
632 Expr::Literal(Literal::Bool(a > b))
633 }
634 (Expr::Literal(Literal::Number(a)), BinaryOperator::GtEq, Expr::Literal(Literal::Number(b))) => {
635 Expr::Literal(Literal::Bool(a >= b))
636 }
637 (Expr::Literal(Literal::String(a)), BinaryOperator::Eq, Expr::Literal(Literal::String(b))) => {
638 Expr::Literal(Literal::Bool(a == b))
639 }
640 (Expr::Literal(Literal::String(a)), BinaryOperator::NotEq, Expr::Literal(Literal::String(b))) => {
641 Expr::Literal(Literal::Bool(a != b))
642 }
643 (Expr::Literal(Literal::Bool(a)), BinaryOperator::Eq, Expr::Literal(Literal::Bool(b))) => {
644 Expr::Literal(Literal::Bool(a == b))
645 }
646 (Expr::Literal(Literal::Bool(a)), BinaryOperator::NotEq, Expr::Literal(Literal::Bool(b))) => {
647 Expr::Literal(Literal::Bool(a != b))
648 }
649 (Expr::Literal(Literal::Number(a)), BinaryOperator::Add, Expr::Literal(Literal::Number(b))) => {
650 Expr::Literal(Literal::Number(a + b))
651 }
652 (Expr::Literal(Literal::Number(a)), BinaryOperator::Sub, Expr::Literal(Literal::Number(b))) => {
653 Expr::Literal(Literal::Number(a - b))
654 }
655 (Expr::Literal(Literal::Number(a)), BinaryOperator::Mul, Expr::Literal(Literal::Number(b))) => {
656 Expr::Literal(Literal::Number(a * b))
657 }
658 (Expr::Literal(Literal::Number(a)), BinaryOperator::Div, Expr::Literal(Literal::Number(b))) => {
659 if b == 0.0 {
660 Expr::BinaryOp {
661 left: Box::new(Expr::Literal(Literal::Number(a))),
662 op: BinaryOperator::Div,
663 right: Box::new(Expr::Literal(Literal::Number(b))),
664 }
665 } else {
666 Expr::Literal(Literal::Number(a / b))
667 }
668 }
669 (left, op, right) => Expr::BinaryOp {
670 left: Box::new(left),
671 op,
672 right: Box::new(right),
673 },
674 }
675 }
676 other => other,
677 }
678}
679
680pub fn normalize_statement(statement: &Statement) -> Statement {
681 match statement {
682 Statement::With(stmt) => Statement::With(WithStatement {
683 ctes: stmt
684 .ctes
685 .iter()
686 .map(|cte| Cte {
687 name: cte.name.clone(),
688 columns: cte.columns.clone(),
689 query: Box::new(normalize_statement(&cte.query)),
690 })
691 .collect(),
692 recursive: stmt.recursive,
693 statement: Box::new(normalize_statement(&stmt.statement)),
694 }),
695 Statement::Select(select) => Statement::Select(normalize_select(select)),
696 Statement::SetOp { left, op, right } => Statement::SetOp {
697 left: Box::new(normalize_statement(left)),
698 op: *op,
699 right: Box::new(normalize_statement(right)),
700 },
701 Statement::Explain(inner) => Statement::Explain(Box::new(normalize_statement(inner))),
702 Statement::CreateTable(stmt) => Statement::CreateTable(stmt.clone()),
703 Statement::DropTable(stmt) => Statement::DropTable(stmt.clone()),
704 Statement::Truncate(stmt) => Statement::Truncate(stmt.clone()),
705 Statement::Analyze(stmt) => Statement::Analyze(stmt.clone()),
706 Statement::Insert(stmt) => Statement::Insert(InsertStatement {
707 table: stmt.table.clone(),
708 columns: stmt.columns.clone(),
709 source: match &stmt.source {
710 InsertSource::Values(values) => InsertSource::Values(
711 values
712 .iter()
713 .map(|row| row.iter().map(|expr| expr.normalize()).collect())
714 .collect(),
715 ),
716 InsertSource::Query(statement) => {
717 InsertSource::Query(Box::new(normalize_statement(statement)))
718 }
719 InsertSource::DefaultValues => InsertSource::DefaultValues,
720 },
721 returning: stmt
722 .returning
723 .iter()
724 .map(|item| SelectItem {
725 expr: item.expr.normalize(),
726 alias: item.alias.clone(),
727 })
728 .collect(),
729 }),
730 Statement::Update(stmt) => Statement::Update(UpdateStatement {
731 table: stmt.table.clone(),
732 assignments: stmt
733 .assignments
734 .iter()
735 .map(|assign| Assignment {
736 column: assign.column.clone(),
737 value: assign.value.normalize(),
738 })
739 .collect(),
740 selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
741 returning: stmt
742 .returning
743 .iter()
744 .map(|item| SelectItem {
745 expr: item.expr.normalize(),
746 alias: item.alias.clone(),
747 })
748 .collect(),
749 }),
750 Statement::Delete(stmt) => Statement::Delete(DeleteStatement {
751 table: stmt.table.clone(),
752 selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
753 returning: stmt
754 .returning
755 .iter()
756 .map(|item| SelectItem {
757 expr: item.expr.normalize(),
758 alias: item.alias.clone(),
759 })
760 .collect(),
761 }),
762 }
763}
764
765fn normalize_select(select: &SelectStatement) -> SelectStatement {
766 SelectStatement {
767 distinct: select.distinct,
768 distinct_on: select.distinct_on.iter().map(|expr| expr.normalize()).collect(),
769 projection: select
770 .projection
771 .iter()
772 .map(|item| SelectItem {
773 expr: item.expr.normalize(),
774 alias: item.alias.clone(),
775 })
776 .collect(),
777 from: select.from.as_ref().map(normalize_table_ref),
778 selection: select.selection.as_ref().map(|expr| expr.normalize()),
779 group_by: select.group_by.iter().map(|expr| expr.normalize()).collect(),
780 having: select.having.as_ref().map(|expr| expr.normalize()),
781 order_by: select
782 .order_by
783 .iter()
784 .map(|order| OrderByExpr {
785 expr: order.expr.normalize(),
786 asc: order.asc,
787 nulls_first: order.nulls_first,
788 })
789 .collect(),
790 limit: select.limit,
791 offset: select.offset,
792 }
793}
794
795fn normalize_select_inner(select: &SelectStatement) -> SelectStatement {
796 normalize_select(select)
797}
798
799fn normalize_table_ref(table: &TableRef) -> TableRef {
800 TableRef {
801 factor: match &table.factor {
802 TableFactor::Table { name } => TableFactor::Table { name: name.clone() },
803 TableFactor::Derived { query } => {
804 TableFactor::Derived { query: Box::new(normalize_statement(query)) }
805 }
806 },
807 alias: table.alias.clone(),
808 column_aliases: table.column_aliases.clone(),
809 joins: table
810 .joins
811 .iter()
812 .map(|join| Join {
813 join_type: join.join_type,
814 right: normalize_table_ref(&join.right),
815 on: join.on.normalize(),
816 })
817 .collect(),
818 }
819}
820
821fn select_to_sql(select: &SelectStatement) -> String {
822 let mut output = String::from("select ");
823 if select.distinct {
824 output.push_str("distinct ");
825 if !select.distinct_on.is_empty() {
826 let distinct_on = select
827 .distinct_on
828 .iter()
829 .map(|expr| expr.to_sql())
830 .collect::<Vec<_>>()
831 .join(", ");
832 output.push_str("on (");
833 output.push_str(&distinct_on);
834 output.push_str(") ");
835 }
836 }
837 let projection = select
838 .projection
839 .iter()
840 .map(|item| item.expr.to_sql())
841 .collect::<Vec<_>>()
842 .join(", ");
843 output.push_str(&projection);
844 if let Some(from) = &select.from {
845 output.push_str(" from ");
846 output.push_str(&table_ref_to_sql(from));
847 }
848 if let Some(selection) = &select.selection {
849 output.push_str(" where ");
850 output.push_str(&selection.to_sql());
851 }
852 if !select.group_by.is_empty() {
853 let group_by = select
854 .group_by
855 .iter()
856 .map(|expr| expr.to_sql())
857 .collect::<Vec<_>>()
858 .join(", ");
859 output.push_str(" group by ");
860 output.push_str(&group_by);
861 }
862 if let Some(having) = &select.having {
863 output.push_str(" having ");
864 output.push_str(&having.to_sql());
865 }
866 if !select.order_by.is_empty() {
867 let order_by = select
868 .order_by
869 .iter()
870 .map(|item| {
871 let mut rendered = item.expr.to_sql();
872 rendered.push(' ');
873 rendered.push_str(if item.asc { "asc" } else { "desc" });
874 if let Some(nulls_first) = item.nulls_first {
875 rendered.push_str(" nulls ");
876 rendered.push_str(if nulls_first { "first" } else { "last" });
877 }
878 rendered
879 })
880 .collect::<Vec<_>>()
881 .join(", ");
882 output.push_str(" order by ");
883 output.push_str(&order_by);
884 }
885 if let Some(limit) = select.limit {
886 output.push_str(" limit ");
887 output.push_str(&limit.to_string());
888 }
889 if let Some(offset) = select.offset {
890 output.push_str(" offset ");
891 output.push_str(&offset.to_string());
892 }
893 output
894}
895
896fn table_ref_to_sql(table: &TableRef) -> String {
897 let mut output = match &table.factor {
898 TableFactor::Table { name } => name.clone(),
899 TableFactor::Derived { query } => format!("({})", statement_to_sql(query)),
900 };
901 if let Some(alias) = &table.alias {
902 output.push_str(" as ");
903 output.push_str(alias);
904 if !table.column_aliases.is_empty() {
905 output.push_str(" (");
906 output.push_str(&table.column_aliases.join(", "));
907 output.push(')');
908 }
909 }
910 for join in &table.joins {
911 let join_type = match join.join_type {
912 JoinType::Inner => "join",
913 JoinType::Left => "left join",
914 JoinType::Right => "right join",
915 JoinType::Full => "full join",
916 };
917 output.push(' ');
918 output.push_str(join_type);
919 output.push(' ');
920 output.push_str(&table_ref_to_sql(&join.right));
921 output.push_str(" on ");
922 output.push_str(&join.on.to_sql());
923 }
924 output
925}
926
927pub fn statement_to_sql(statement: &Statement) -> String {
928 match statement {
929 Statement::Select(select) => select_to_sql(select),
930 Statement::SetOp { left, op, right } => {
931 let left_sql = statement_to_sql(left);
932 let right_sql = statement_to_sql(right);
933 let op_str = match op {
934 SetOperator::Union => "union",
935 SetOperator::UnionAll => "union all",
936 SetOperator::Intersect => "intersect",
937 SetOperator::IntersectAll => "intersect all",
938 SetOperator::Except => "except",
939 SetOperator::ExceptAll => "except all",
940 };
941 format!("{left_sql} {op_str} {right_sql}")
942 }
943 Statement::With(with_stmt) => {
944 let ctes = with_stmt
945 .ctes
946 .iter()
947 .map(|cte| {
948 let mut name = cte.name.clone();
949 if !cte.columns.is_empty() {
950 let cols = cte.columns.join(", ");
951 name.push_str(" (");
952 name.push_str(&cols);
953 name.push(')');
954 }
955 format!("{name} as ({})", statement_to_sql(&cte.query))
956 })
957 .collect::<Vec<_>>()
958 .join(", ");
959 let keyword = if with_stmt.recursive {
960 "with recursive"
961 } else {
962 "with"
963 };
964 format!("{keyword} {ctes} {}", statement_to_sql(&with_stmt.statement))
965 }
966 Statement::Explain(inner) => format!("explain {}", statement_to_sql(inner)),
967 Statement::CreateTable(stmt) => {
968 if stmt.if_not_exists {
969 if stmt.columns.is_empty() {
970 format!("create table if not exists {}", stmt.name)
971 } else {
972 let columns = stmt
973 .columns
974 .iter()
975 .map(|col| format!("{} {}", col.name, col.data_type))
976 .collect::<Vec<_>>()
977 .join(", ");
978 format!("create table if not exists {} ({})", stmt.name, columns)
979 }
980 } else {
981 if stmt.columns.is_empty() {
982 format!("create table {}", stmt.name)
983 } else {
984 let columns = stmt
985 .columns
986 .iter()
987 .map(|col| format!("{} {}", col.name, col.data_type))
988 .collect::<Vec<_>>()
989 .join(", ");
990 format!("create table {} ({})", stmt.name, columns)
991 }
992 }
993 }
994 Statement::DropTable(stmt) => {
995 if stmt.if_exists {
996 format!("drop table if exists {}", stmt.name)
997 } else {
998 format!("drop table {}", stmt.name)
999 }
1000 }
1001 Statement::Truncate(stmt) => format!("truncate table {}", stmt.table),
1002 Statement::Analyze(stmt) => format!("analyze {}", stmt.table),
1003 Statement::Insert(stmt) => {
1004 let mut output = format!("insert into {}", stmt.table);
1005 if !stmt.columns.is_empty() {
1006 output.push_str(" (");
1007 output.push_str(&stmt.columns.join(", "));
1008 output.push(')');
1009 }
1010 match &stmt.source {
1011 InsertSource::DefaultValues => {
1012 output.push_str(" default values");
1013 }
1014 InsertSource::Values(values) => {
1015 let rows = values
1016 .iter()
1017 .map(|row| {
1018 let values = row
1019 .iter()
1020 .map(|expr| expr.to_sql())
1021 .collect::<Vec<_>>()
1022 .join(", ");
1023 format!("({values})")
1024 })
1025 .collect::<Vec<_>>()
1026 .join(", ");
1027 output.push_str(" values ");
1028 output.push_str(&rows);
1029 }
1030 InsertSource::Query(statement) => {
1031 output.push(' ');
1032 output.push_str(&statement_to_sql(statement));
1033 }
1034 }
1035 if !stmt.returning.is_empty() {
1036 let returning = stmt
1037 .returning
1038 .iter()
1039 .map(|item| item.expr.to_sql())
1040 .collect::<Vec<_>>()
1041 .join(", ");
1042 output.push_str(" returning ");
1043 output.push_str(&returning);
1044 }
1045 output
1046 }
1047 Statement::Update(stmt) => {
1048 let mut output = format!("update {} set ", stmt.table);
1049 let assignments = stmt
1050 .assignments
1051 .iter()
1052 .map(|assign| format!("{} = {}", assign.column, assign.value.to_sql()))
1053 .collect::<Vec<_>>()
1054 .join(", ");
1055 output.push_str(&assignments);
1056 if let Some(selection) = &stmt.selection {
1057 output.push_str(" where ");
1058 output.push_str(&selection.to_sql());
1059 }
1060 if !stmt.returning.is_empty() {
1061 let returning = stmt
1062 .returning
1063 .iter()
1064 .map(|item| item.expr.to_sql())
1065 .collect::<Vec<_>>()
1066 .join(", ");
1067 output.push_str(" returning ");
1068 output.push_str(&returning);
1069 }
1070 output
1071 }
1072 Statement::Delete(stmt) => {
1073 let mut output = format!("delete from {}", stmt.table);
1074 if let Some(selection) = &stmt.selection {
1075 output.push_str(" where ");
1076 output.push_str(&selection.to_sql());
1077 }
1078 if !stmt.returning.is_empty() {
1079 let returning = stmt
1080 .returning
1081 .iter()
1082 .map(|item| item.expr.to_sql())
1083 .collect::<Vec<_>>()
1084 .join(", ");
1085 output.push_str(" returning ");
1086 output.push_str(&returning);
1087 }
1088 output
1089 }
1090 }
1091}
1092
1093#[cfg(test)]
1094mod tests {
1095 use super::{BinaryOperator, Expr, Literal, UnaryOperator};
1096
1097 #[test]
1098 fn normalize_commutative_predicate() {
1099 let expr = Expr::BinaryOp {
1100 left: Box::new(Expr::Identifier("b".to_string())),
1101 op: BinaryOperator::And,
1102 right: Box::new(Expr::Identifier("a".to_string())),
1103 };
1104 let normalized = expr.normalize();
1105 let Expr::BinaryOp { left, right, .. } = normalized else {
1106 panic!("expected binary op");
1107 };
1108 assert!(matches!(left.as_ref(), Expr::Identifier(name) if name == "a"));
1109 assert!(matches!(right.as_ref(), Expr::Identifier(name) if name == "b"));
1110 }
1111
1112 #[test]
1113 fn normalize_not_comparison() {
1114 let expr = Expr::UnaryOp {
1115 op: UnaryOperator::Not,
1116 expr: Box::new(Expr::BinaryOp {
1117 left: Box::new(Expr::Identifier("a".to_string())),
1118 op: BinaryOperator::Eq,
1119 right: Box::new(Expr::Identifier("b".to_string())),
1120 }),
1121 };
1122 let normalized = expr.normalize();
1123 let Expr::BinaryOp { op, .. } = normalized else {
1124 panic!("expected binary op");
1125 };
1126 assert!(matches!(op, BinaryOperator::NotEq));
1127 }
1128
1129 #[test]
1130 fn normalize_double_not() {
1131 let expr = Expr::UnaryOp {
1132 op: UnaryOperator::Not,
1133 expr: Box::new(Expr::UnaryOp {
1134 op: UnaryOperator::Not,
1135 expr: Box::new(Expr::Identifier("flag".to_string())),
1136 }),
1137 };
1138 let normalized = expr.normalize();
1139 assert!(matches!(normalized, Expr::Identifier(_)));
1140 }
1141
1142 #[test]
1143 fn normalize_constant_fold() {
1144 let expr = Expr::BinaryOp {
1145 left: Box::new(Expr::Literal(Literal::Number(2.0))),
1146 op: BinaryOperator::Mul,
1147 right: Box::new(Expr::Literal(Literal::Number(4.0))),
1148 };
1149 let normalized = expr.normalize();
1150 match normalized {
1151 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 8.0),
1152 other => panic!("expected literal, got {other:?}"),
1153 }
1154 }
1155
1156 #[test]
1157 fn normalize_boolean_identities() {
1158 let expr = Expr::BinaryOp {
1159 left: Box::new(Expr::Identifier("flag".to_string())),
1160 op: BinaryOperator::And,
1161 right: Box::new(Expr::Literal(Literal::Bool(true))),
1162 };
1163 let normalized = expr.normalize();
1164 assert!(matches!(normalized, Expr::Identifier(_)));
1165 }
1166
1167 #[test]
1168 fn normalize_duplicate_and_or() {
1169 let expr = Expr::BinaryOp {
1170 left: Box::new(Expr::Identifier("a".to_string())),
1171 op: BinaryOperator::Or,
1172 right: Box::new(Expr::Identifier("a".to_string())),
1173 };
1174 let normalized = expr.normalize();
1175 assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1176 let expr = Expr::BinaryOp {
1177 left: Box::new(Expr::Identifier("a".to_string())),
1178 op: BinaryOperator::And,
1179 right: Box::new(Expr::Identifier("a".to_string())),
1180 };
1181 let normalized = expr.normalize();
1182 assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1183 }
1184
1185 #[test]
1186 fn normalize_self_comparison() {
1187 let expr = Expr::BinaryOp {
1188 left: Box::new(Expr::Identifier("a".to_string())),
1189 op: BinaryOperator::Eq,
1190 right: Box::new(Expr::Identifier("a".to_string())),
1191 };
1192 let normalized = expr.normalize();
1193 assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1194 let expr = Expr::BinaryOp {
1195 left: Box::new(Expr::Identifier("a".to_string())),
1196 op: BinaryOperator::NotEq,
1197 right: Box::new(Expr::Identifier("a".to_string())),
1198 };
1199 let normalized = expr.normalize();
1200 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1201 let expr = Expr::BinaryOp {
1202 left: Box::new(Expr::Identifier("a".to_string())),
1203 op: BinaryOperator::Lt,
1204 right: Box::new(Expr::Identifier("a".to_string())),
1205 };
1206 let normalized = expr.normalize();
1207 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1208 let expr = Expr::BinaryOp {
1209 left: Box::new(Expr::Identifier("a".to_string())),
1210 op: BinaryOperator::LtEq,
1211 right: Box::new(Expr::Identifier("a".to_string())),
1212 };
1213 let normalized = expr.normalize();
1214 assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1215 }
1216
1217 #[test]
1218 fn normalize_not_literal() {
1219 let expr = Expr::UnaryOp {
1220 op: UnaryOperator::Not,
1221 expr: Box::new(Expr::Literal(Literal::Bool(true))),
1222 };
1223 let normalized = expr.normalize();
1224 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1225 }
1226}