1#[derive(Debug, Clone)]
2pub enum Statement {
3 With(WithStatement),
4 Select(SelectStatement),
5 SetOp {
6 left: Box<Statement>,
7 op: SetOperator,
8 right: Box<Statement>,
9 },
10 Explain(Box<Statement>),
11 CreateTable(CreateTableStatement),
12 DropTable(DropTableStatement),
13 Truncate(TruncateStatement),
14 Analyze(AnalyzeStatement),
15 Insert(InsertStatement),
16 Update(UpdateStatement),
17 Delete(DeleteStatement),
18}
19
20#[derive(Debug, Clone)]
21pub struct CreateTableStatement {
22 pub name: String,
23 pub if_not_exists: bool,
24 pub columns: Vec<ColumnDef>,
25}
26
27#[derive(Debug, Clone)]
28pub struct ColumnDef {
29 pub name: String,
30 pub data_type: String,
31}
32
33#[derive(Debug, Clone)]
34pub struct DropTableStatement {
35 pub name: String,
36 pub if_exists: bool,
37}
38
39#[derive(Debug, Clone)]
40pub struct TruncateStatement {
41 pub table: String,
42}
43
44#[derive(Debug, Clone)]
45pub struct AnalyzeStatement {
46 pub table: String,
47}
48
49#[derive(Debug, Clone)]
50pub struct InsertStatement {
51 pub table: String,
52 pub columns: Vec<String>,
53 pub source: InsertSource,
54 pub returning: Vec<SelectItem>,
55}
56
57#[derive(Debug, Clone)]
58pub enum InsertSource {
59 Values(Vec<Vec<Expr>>),
60 Query(Box<Statement>),
61 DefaultValues,
62}
63
64#[derive(Debug, Clone)]
65pub struct UpdateStatement {
66 pub table: String,
67 pub assignments: Vec<Assignment>,
68 pub selection: Option<Expr>,
69 pub returning: Vec<SelectItem>,
70}
71
72#[derive(Debug, Clone)]
73pub struct Assignment {
74 pub column: String,
75 pub value: Expr,
76}
77
78#[derive(Debug, Clone)]
79pub struct DeleteStatement {
80 pub table: String,
81 pub selection: Option<Expr>,
82 pub returning: Vec<SelectItem>,
83}
84
85#[derive(Debug, Clone)]
86pub struct SelectStatement {
87 pub distinct: bool,
88 pub distinct_on: Vec<Expr>,
89 pub projection: Vec<SelectItem>,
90 pub from: Option<TableRef>,
91 pub selection: Option<Expr>,
92 pub group_by: Vec<Expr>,
93 pub having: Option<Expr>,
94 pub qualify: Option<Expr>,
95 pub order_by: Vec<OrderByExpr>,
96 pub limit: Option<u64>,
97 pub offset: Option<u64>,
98}
99
100#[derive(Debug, Clone)]
101pub struct WithStatement {
102 pub ctes: Vec<Cte>,
103 pub recursive: bool,
104 pub statement: Box<Statement>,
105}
106
107#[derive(Debug, Clone)]
108pub struct Cte {
109 pub name: String,
110 pub columns: Vec<String>,
111 pub query: Box<Statement>,
112}
113
114#[derive(Debug, Clone, Copy)]
115pub enum SetOperator {
116 Union,
117 UnionAll,
118 Intersect,
119 IntersectAll,
120 Except,
121 ExceptAll,
122}
123
124#[derive(Debug, Clone)]
125pub struct TableRef {
126 pub factor: TableFactor,
127 pub alias: Option<String>,
128 pub column_aliases: Vec<String>,
129 pub joins: Vec<Join>,
130}
131
132#[derive(Debug, Clone)]
133pub enum TableFactor {
134 Table { name: String },
135 Derived { query: Box<Statement> },
136}
137
138#[derive(Debug, Clone)]
139pub struct Join {
140 pub join_type: JoinType,
141 pub right: TableRef,
142 pub on: Expr,
143}
144
145#[derive(Debug, Clone, Copy)]
146pub enum JoinType {
147 Inner,
148 Left,
149 Right,
150 Full,
151}
152
153#[derive(Debug, Clone)]
154pub struct SelectItem {
155 pub expr: Expr,
156 pub alias: Option<String>,
157}
158
159#[derive(Debug, Clone)]
160pub struct OrderByExpr {
161 pub expr: Expr,
162 pub asc: bool,
163 pub nulls_first: Option<bool>,
164}
165
166#[derive(Debug, Clone)]
167pub enum Expr {
168 Identifier(String),
169 Literal(Literal),
170 BinaryOp {
171 left: Box<Expr>,
172 op: BinaryOperator,
173 right: Box<Expr>,
174 },
175 IsNull {
176 expr: Box<Expr>,
177 negated: bool,
178 },
179 UnaryOp {
180 op: UnaryOperator,
181 expr: Box<Expr>,
182 },
183 FunctionCall {
184 name: String,
185 args: Vec<Expr>,
186 },
187 WindowFunction {
188 function: Box<Expr>,
189 spec: WindowSpec,
190 },
191 Subquery(Box<SelectStatement>),
192 Exists(Box<SelectStatement>),
193 InSubquery {
194 expr: Box<Expr>,
195 subquery: Box<SelectStatement>,
196 },
197 Case {
198 operand: Option<Box<Expr>>,
199 when_then: Vec<(Expr, Expr)>,
200 else_expr: Option<Box<Expr>>,
201 },
202 Wildcard,
203}
204
205#[derive(Debug, Clone)]
206pub enum Literal {
207 String(String),
208 Number(f64),
209 Bool(bool),
210}
211
212#[derive(Debug, Clone, Copy, PartialEq)]
213pub enum BinaryOperator {
214 Eq,
215 NotEq,
216 Lt,
217 LtEq,
218 Gt,
219 GtEq,
220 And,
221 Or,
222 Add,
223 Sub,
224 Mul,
225 Div,
226}
227
228#[derive(Debug, Clone, Copy, PartialEq)]
229pub enum UnaryOperator {
230 Not,
231 Neg,
232}
233
234#[derive(Debug, Clone)]
235pub struct WindowSpec {
236 pub partition_by: Vec<Expr>,
237 pub order_by: Vec<OrderByExpr>,
238 pub frame: Option<WindowFrame>,
239}
240
241#[derive(Debug, Clone, Copy, PartialEq)]
242pub enum WindowFrameKind {
243 Rows,
244 Range,
245 Groups,
246}
247
248#[derive(Debug, Clone)]
249pub enum WindowFrameBound {
250 UnboundedPreceding,
251 Preceding(Box<Expr>),
252 CurrentRow,
253 Following(Box<Expr>),
254 UnboundedFollowing,
255}
256
257#[derive(Debug, Clone)]
258pub struct WindowFrame {
259 pub kind: WindowFrameKind,
260 pub start: WindowFrameBound,
261 pub end: Option<WindowFrameBound>,
262}
263
264impl Expr {
265 pub fn to_sql(&self) -> String {
266 match self {
267 Expr::Identifier(name) => name.clone(),
268 Expr::Literal(Literal::String(value)) => format!("'{}'", value),
269 Expr::Literal(Literal::Number(value)) => value.to_string(),
270 Expr::Literal(Literal::Bool(value)) => {
271 if *value {
272 "true".to_string()
273 } else {
274 "false".to_string()
275 }
276 }
277 Expr::BinaryOp { left, op, right } => {
278 let op_str = match op {
279 BinaryOperator::Eq => "=",
280 BinaryOperator::NotEq => "!=",
281 BinaryOperator::Lt => "<",
282 BinaryOperator::LtEq => "<=",
283 BinaryOperator::Gt => ">",
284 BinaryOperator::GtEq => ">=",
285 BinaryOperator::And => "and",
286 BinaryOperator::Or => "or",
287 BinaryOperator::Add => "+",
288 BinaryOperator::Sub => "-",
289 BinaryOperator::Mul => "*",
290 BinaryOperator::Div => "/",
291 };
292 format!("{} {} {}", left.to_sql(), op_str, right.to_sql())
293 }
294 Expr::IsNull { expr, negated } => {
295 if *negated {
296 format!("{} is not null", expr.to_sql())
297 } else {
298 format!("{} is null", expr.to_sql())
299 }
300 }
301 Expr::UnaryOp { op, expr } => match op {
302 UnaryOperator::Not => format!("not {}", expr.to_sql()),
303 UnaryOperator::Neg => format!("-{}", expr.to_sql()),
304 },
305 Expr::FunctionCall { name, args } => {
306 let args_sql = args.iter().map(|arg| arg.to_sql()).collect::<Vec<_>>();
307 format!("{}({})", name, args_sql.join(", "))
308 }
309 Expr::WindowFunction { function, spec } => {
310 let mut clauses = Vec::new();
311 if !spec.partition_by.is_empty() {
312 let partition = spec
313 .partition_by
314 .iter()
315 .map(|expr| expr.to_sql())
316 .collect::<Vec<_>>()
317 .join(", ");
318 clauses.push(format!("partition by {partition}"));
319 }
320 if !spec.order_by.is_empty() {
321 let order = spec
322 .order_by
323 .iter()
324 .map(|item| {
325 let dir = if item.asc { "asc" } else { "desc" };
326 let mut rendered = format!("{} {dir}", item.expr.to_sql());
327 if let Some(nulls_first) = item.nulls_first {
328 if nulls_first {
329 rendered.push_str(" nulls first");
330 } else {
331 rendered.push_str(" nulls last");
332 }
333 }
334 rendered
335 })
336 .collect::<Vec<_>>()
337 .join(", ");
338 clauses.push(format!("order by {order}"));
339 }
340 if let Some(frame) = &spec.frame {
341 clauses.push(frame_to_sql(frame));
342 }
343 format!("{} over ({})", function.to_sql(), clauses.join(" "))
344 }
345 Expr::Exists(select) => format!("exists ({})", select_to_sql(select)),
346 Expr::InSubquery { expr, subquery } => {
347 format!("{} in ({})", expr.to_sql(), select_to_sql(subquery))
348 }
349 Expr::Subquery(select) => format!("({})", select_to_sql(select)),
350 Expr::Case {
351 operand,
352 when_then,
353 else_expr,
354 } => {
355 let mut output = String::from("case");
356 if let Some(expr) = operand {
357 output.push(' ');
358 output.push_str(&expr.to_sql());
359 }
360 for (when_expr, then_expr) in when_then {
361 output.push_str(" when ");
362 output.push_str(&when_expr.to_sql());
363 output.push_str(" then ");
364 output.push_str(&then_expr.to_sql());
365 }
366 if let Some(expr) = else_expr {
367 output.push_str(" else ");
368 output.push_str(&expr.to_sql());
369 }
370 output.push_str(" end");
371 output
372 }
373 Expr::Wildcard => "*".to_string(),
374 }
375 }
376
377 pub fn structural_eq(&self, other: &Expr) -> bool {
378 const FLOAT_EPSILON: f64 = 1e-9;
379 match (self, other) {
380 (Expr::Identifier(left), Expr::Identifier(right)) => left == right,
381 (Expr::Literal(left), Expr::Literal(right)) => match (left, right) {
382 (Literal::String(left), Literal::String(right)) => left == right,
383 (Literal::Number(left), Literal::Number(right)) => {
384 if left.is_nan() || right.is_nan() {
385 false
386 } else if left.is_infinite() || right.is_infinite() {
387 left == right
388 } else {
389 (left - right).abs() <= FLOAT_EPSILON
390 }
391 }
392 (Literal::Bool(left), Literal::Bool(right)) => left == right,
393 _ => false,
394 },
395 (
396 Expr::UnaryOp {
397 op: left_op,
398 expr: left,
399 },
400 Expr::UnaryOp {
401 op: right_op,
402 expr: right,
403 },
404 ) => left_op == right_op && left.structural_eq(right),
405 (
406 Expr::BinaryOp {
407 left: left_lhs,
408 op: left_op,
409 right: left_rhs,
410 },
411 Expr::BinaryOp {
412 left: right_lhs,
413 op: right_op,
414 right: right_rhs,
415 },
416 ) => {
417 left_op == right_op
418 && left_lhs.structural_eq(right_lhs)
419 && left_rhs.structural_eq(right_rhs)
420 }
421 (
422 Expr::IsNull {
423 expr: left,
424 negated: left_negated,
425 },
426 Expr::IsNull {
427 expr: right,
428 negated: right_negated,
429 },
430 ) => left_negated == right_negated && left.structural_eq(right),
431 (
432 Expr::FunctionCall {
433 name: left_name,
434 args: left_args,
435 },
436 Expr::FunctionCall {
437 name: right_name,
438 args: right_args,
439 },
440 ) => {
441 left_name == right_name
442 && left_args.len() == right_args.len()
443 && left_args
444 .iter()
445 .zip(right_args.iter())
446 .all(|(left, right)| left.structural_eq(right))
447 }
448 (
449 Expr::WindowFunction {
450 function: left_func,
451 spec: left_spec,
452 },
453 Expr::WindowFunction {
454 function: right_func,
455 spec: right_spec,
456 },
457 ) => {
458 left_func.structural_eq(right_func)
459 && left_spec.partition_by.len() == right_spec.partition_by.len()
460 && left_spec
461 .partition_by
462 .iter()
463 .zip(right_spec.partition_by.iter())
464 .all(|(left, right)| left.structural_eq(right))
465 && left_spec.order_by.len() == right_spec.order_by.len()
466 && left_spec
467 .order_by
468 .iter()
469 .zip(right_spec.order_by.iter())
470 .all(|(left, right)| {
471 left.asc == right.asc
472 && left.nulls_first == right.nulls_first
473 && left.expr.structural_eq(&right.expr)
474 })
475 && window_frame_eq(&left_spec.frame, &right_spec.frame)
476 }
477 (Expr::Subquery(left), Expr::Subquery(right)) => {
478 select_to_sql(left) == select_to_sql(right)
479 }
480 (Expr::Exists(left), Expr::Exists(right)) => {
481 select_to_sql(left) == select_to_sql(right)
482 }
483 (
484 Expr::InSubquery {
485 expr: left_expr,
486 subquery: left_subquery,
487 },
488 Expr::InSubquery {
489 expr: right_expr,
490 subquery: right_subquery,
491 },
492 ) => {
493 left_expr.structural_eq(right_expr)
494 && select_to_sql(left_subquery) == select_to_sql(right_subquery)
495 }
496 (
497 Expr::Case {
498 operand: left_operand,
499 when_then: left_when_then,
500 else_expr: left_else,
501 },
502 Expr::Case {
503 operand: right_operand,
504 when_then: right_when_then,
505 else_expr: right_else,
506 },
507 ) => {
508 left_operand
509 .as_ref()
510 .zip(right_operand.as_ref())
511 .map(|(left, right)| left.structural_eq(right))
512 .unwrap_or(left_operand.is_none() && right_operand.is_none())
513 && left_when_then.len() == right_when_then.len()
514 && left_when_then
515 .iter()
516 .zip(right_when_then.iter())
517 .all(|(left, right)| {
518 left.0.structural_eq(&right.0) && left.1.structural_eq(&right.1)
519 })
520 && left_else
521 .as_ref()
522 .zip(right_else.as_ref())
523 .map(|(left, right)| left.structural_eq(right))
524 .unwrap_or(left_else.is_none() && right_else.is_none())
525 }
526 (Expr::Wildcard, Expr::Wildcard) => true,
527 _ => false,
528 }
529 }
530
531 pub fn normalize(&self) -> Expr {
532 let normalized = match self {
533 Expr::BinaryOp { left, op, right } => {
534 let left_norm = left.normalize();
535 let right_norm = right.normalize();
536 if matches!(op, BinaryOperator::And | BinaryOperator::Or) {
537 if left_norm.to_sql() > right_norm.to_sql() {
538 return Expr::BinaryOp {
539 left: Box::new(right_norm),
540 op: *op,
541 right: Box::new(left_norm),
542 };
543 }
544 }
545 Expr::BinaryOp {
546 left: Box::new(left_norm),
547 op: *op,
548 right: Box::new(right_norm),
549 }
550 }
551 Expr::IsNull { expr, negated } => Expr::IsNull {
552 expr: Box::new(expr.normalize()),
553 negated: *negated,
554 },
555 Expr::UnaryOp { op, expr } => Expr::UnaryOp {
556 op: *op,
557 expr: Box::new(expr.normalize()),
558 },
559 Expr::FunctionCall { name, args } => Expr::FunctionCall {
560 name: name.clone(),
561 args: args.iter().map(|arg| arg.normalize()).collect(),
562 },
563 Expr::Case {
564 operand,
565 when_then,
566 else_expr,
567 } => Expr::Case {
568 operand: operand.as_ref().map(|expr| Box::new(expr.normalize())),
569 when_then: when_then
570 .iter()
571 .map(|(when_expr, then_expr)| (when_expr.normalize(), then_expr.normalize()))
572 .collect(),
573 else_expr: else_expr.as_ref().map(|expr| Box::new(expr.normalize())),
574 },
575 Expr::WindowFunction { function, spec } => Expr::WindowFunction {
576 function: Box::new(function.normalize()),
577 spec: WindowSpec {
578 partition_by: spec
579 .partition_by
580 .iter()
581 .map(|expr| expr.normalize())
582 .collect(),
583 order_by: spec
584 .order_by
585 .iter()
586 .map(|order| OrderByExpr {
587 expr: order.expr.normalize(),
588 asc: order.asc,
589 nulls_first: order.nulls_first,
590 })
591 .collect(),
592 frame: spec.frame.as_ref().map(|frame| WindowFrame {
593 kind: frame.kind,
594 start: normalize_frame_bound(&frame.start),
595 end: frame.end.as_ref().map(normalize_frame_bound),
596 }),
597 },
598 },
599 Expr::Exists(select) => Expr::Exists(Box::new(normalize_select_inner(select))),
600 Expr::InSubquery { expr, subquery } => Expr::InSubquery {
601 expr: Box::new(expr.normalize()),
602 subquery: Box::new(normalize_select_inner(subquery)),
603 },
604 Expr::Subquery(select) => Expr::Subquery(Box::new(normalize_select_inner(select))),
605 other => other.clone(),
606 };
607 rewrite_strong_expr(normalized)
608 }
609}
610
611fn rewrite_strong_expr(expr: Expr) -> Expr {
612 match expr {
613 Expr::UnaryOp {
614 op: UnaryOperator::Not,
615 expr,
616 } => match *expr {
617 Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(!value)),
618 Expr::UnaryOp {
619 op: UnaryOperator::Not,
620 expr,
621 } => *expr,
622 Expr::IsNull { expr, negated } => Expr::IsNull {
623 expr,
624 negated: !negated,
625 },
626 Expr::BinaryOp { left, op, right } => match op {
627 BinaryOperator::Eq => Expr::BinaryOp {
628 left,
629 op: BinaryOperator::NotEq,
630 right,
631 },
632 BinaryOperator::NotEq => Expr::BinaryOp {
633 left,
634 op: BinaryOperator::Eq,
635 right,
636 },
637 BinaryOperator::Lt => Expr::BinaryOp {
638 left,
639 op: BinaryOperator::GtEq,
640 right,
641 },
642 BinaryOperator::LtEq => Expr::BinaryOp {
643 left,
644 op: BinaryOperator::Gt,
645 right,
646 },
647 BinaryOperator::Gt => Expr::BinaryOp {
648 left,
649 op: BinaryOperator::LtEq,
650 right,
651 },
652 BinaryOperator::GtEq => Expr::BinaryOp {
653 left,
654 op: BinaryOperator::Lt,
655 right,
656 },
657 BinaryOperator::And => Expr::BinaryOp {
658 left: Box::new(Expr::UnaryOp {
659 op: UnaryOperator::Not,
660 expr: left,
661 }),
662 op: BinaryOperator::Or,
663 right: Box::new(Expr::UnaryOp {
664 op: UnaryOperator::Not,
665 expr: right,
666 }),
667 },
668 BinaryOperator::Or => Expr::BinaryOp {
669 left: Box::new(Expr::UnaryOp {
670 op: UnaryOperator::Not,
671 expr: left,
672 }),
673 op: BinaryOperator::And,
674 right: Box::new(Expr::UnaryOp {
675 op: UnaryOperator::Not,
676 expr: right,
677 }),
678 },
679 _ => Expr::UnaryOp {
680 op: UnaryOperator::Not,
681 expr: Box::new(Expr::BinaryOp { left, op, right }),
682 },
683 },
684 other => Expr::UnaryOp {
685 op: UnaryOperator::Not,
686 expr: Box::new(other),
687 },
688 },
689 Expr::UnaryOp {
690 op: UnaryOperator::Neg,
691 expr,
692 } => match *expr {
693 Expr::Literal(Literal::Number(value)) => Expr::Literal(Literal::Number(-value)),
694 other => Expr::UnaryOp {
695 op: UnaryOperator::Neg,
696 expr: Box::new(other),
697 },
698 },
699 Expr::BinaryOp { left, op, right } => {
700 if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right)
701 {
702 return *left;
703 }
704 let same_expr = left.structural_eq(&right);
705 if same_expr {
706 return match op {
707 BinaryOperator::Eq | BinaryOperator::LtEq | BinaryOperator::GtEq => {
708 Expr::Literal(Literal::Bool(true))
709 }
710 BinaryOperator::NotEq | BinaryOperator::Lt | BinaryOperator::Gt => {
711 Expr::Literal(Literal::Bool(false))
712 }
713 _ => Expr::BinaryOp { left, op, right },
714 };
715 }
716 match (*left, op, *right) {
717 (
718 Expr::Literal(Literal::Bool(a)),
719 BinaryOperator::And,
720 Expr::Literal(Literal::Bool(b)),
721 ) => Expr::Literal(Literal::Bool(a && b)),
722 (
723 Expr::Literal(Literal::Bool(a)),
724 BinaryOperator::Or,
725 Expr::Literal(Literal::Bool(b)),
726 ) => Expr::Literal(Literal::Bool(a || b)),
727 (Expr::Literal(Literal::Bool(a)), BinaryOperator::And, other) => {
728 if a {
729 other
730 } else {
731 Expr::Literal(Literal::Bool(false))
732 }
733 }
734 (other, BinaryOperator::And, Expr::Literal(Literal::Bool(b))) => {
735 if b {
736 other
737 } else {
738 Expr::Literal(Literal::Bool(false))
739 }
740 }
741 (Expr::Literal(Literal::Bool(a)), BinaryOperator::Or, other) => {
742 if a {
743 Expr::Literal(Literal::Bool(true))
744 } else {
745 other
746 }
747 }
748 (other, BinaryOperator::Or, Expr::Literal(Literal::Bool(b))) => {
749 if b {
750 Expr::Literal(Literal::Bool(true))
751 } else {
752 other
753 }
754 }
755 (
756 Expr::Literal(Literal::Number(a)),
757 BinaryOperator::Eq,
758 Expr::Literal(Literal::Number(b)),
759 ) => Expr::Literal(Literal::Bool(a == b)),
760 (
761 Expr::Literal(Literal::Number(a)),
762 BinaryOperator::NotEq,
763 Expr::Literal(Literal::Number(b)),
764 ) => Expr::Literal(Literal::Bool(a != b)),
765 (
766 Expr::Literal(Literal::Number(a)),
767 BinaryOperator::Lt,
768 Expr::Literal(Literal::Number(b)),
769 ) => Expr::Literal(Literal::Bool(a < b)),
770 (
771 Expr::Literal(Literal::Number(a)),
772 BinaryOperator::LtEq,
773 Expr::Literal(Literal::Number(b)),
774 ) => Expr::Literal(Literal::Bool(a <= b)),
775 (
776 Expr::Literal(Literal::Number(a)),
777 BinaryOperator::Gt,
778 Expr::Literal(Literal::Number(b)),
779 ) => Expr::Literal(Literal::Bool(a > b)),
780 (
781 Expr::Literal(Literal::Number(a)),
782 BinaryOperator::GtEq,
783 Expr::Literal(Literal::Number(b)),
784 ) => Expr::Literal(Literal::Bool(a >= b)),
785 (
786 Expr::Literal(Literal::String(a)),
787 BinaryOperator::Eq,
788 Expr::Literal(Literal::String(b)),
789 ) => Expr::Literal(Literal::Bool(a == b)),
790 (
791 Expr::Literal(Literal::String(a)),
792 BinaryOperator::NotEq,
793 Expr::Literal(Literal::String(b)),
794 ) => Expr::Literal(Literal::Bool(a != b)),
795 (
796 Expr::Literal(Literal::Bool(a)),
797 BinaryOperator::Eq,
798 Expr::Literal(Literal::Bool(b)),
799 ) => Expr::Literal(Literal::Bool(a == b)),
800 (
801 Expr::Literal(Literal::Bool(a)),
802 BinaryOperator::NotEq,
803 Expr::Literal(Literal::Bool(b)),
804 ) => Expr::Literal(Literal::Bool(a != b)),
805 (
806 Expr::Literal(Literal::Number(a)),
807 BinaryOperator::Add,
808 Expr::Literal(Literal::Number(b)),
809 ) => Expr::Literal(Literal::Number(a + b)),
810 (
811 Expr::Literal(Literal::Number(a)),
812 BinaryOperator::Sub,
813 Expr::Literal(Literal::Number(b)),
814 ) => Expr::Literal(Literal::Number(a - b)),
815 (
816 Expr::Literal(Literal::Number(a)),
817 BinaryOperator::Mul,
818 Expr::Literal(Literal::Number(b)),
819 ) => Expr::Literal(Literal::Number(a * b)),
820 (
821 Expr::Literal(Literal::Number(a)),
822 BinaryOperator::Div,
823 Expr::Literal(Literal::Number(b)),
824 ) => {
825 if b == 0.0 {
826 Expr::BinaryOp {
827 left: Box::new(Expr::Literal(Literal::Number(a))),
828 op: BinaryOperator::Div,
829 right: Box::new(Expr::Literal(Literal::Number(b))),
830 }
831 } else {
832 Expr::Literal(Literal::Number(a / b))
833 }
834 }
835 (left, op, right) => Expr::BinaryOp {
836 left: Box::new(left),
837 op,
838 right: Box::new(right),
839 },
840 }
841 }
842 other => other,
843 }
844}
845
846pub fn normalize_statement(statement: &Statement) -> Statement {
847 match statement {
848 Statement::With(stmt) => Statement::With(WithStatement {
849 ctes: stmt
850 .ctes
851 .iter()
852 .map(|cte| Cte {
853 name: cte.name.clone(),
854 columns: cte.columns.clone(),
855 query: Box::new(normalize_statement(&cte.query)),
856 })
857 .collect(),
858 recursive: stmt.recursive,
859 statement: Box::new(normalize_statement(&stmt.statement)),
860 }),
861 Statement::Select(select) => Statement::Select(normalize_select(select)),
862 Statement::SetOp { left, op, right } => Statement::SetOp {
863 left: Box::new(normalize_statement(left)),
864 op: *op,
865 right: Box::new(normalize_statement(right)),
866 },
867 Statement::Explain(inner) => Statement::Explain(Box::new(normalize_statement(inner))),
868 Statement::CreateTable(stmt) => Statement::CreateTable(stmt.clone()),
869 Statement::DropTable(stmt) => Statement::DropTable(stmt.clone()),
870 Statement::Truncate(stmt) => Statement::Truncate(stmt.clone()),
871 Statement::Analyze(stmt) => Statement::Analyze(stmt.clone()),
872 Statement::Insert(stmt) => Statement::Insert(InsertStatement {
873 table: stmt.table.clone(),
874 columns: stmt.columns.clone(),
875 source: match &stmt.source {
876 InsertSource::Values(values) => InsertSource::Values(
877 values
878 .iter()
879 .map(|row| row.iter().map(|expr| expr.normalize()).collect())
880 .collect(),
881 ),
882 InsertSource::Query(statement) => {
883 InsertSource::Query(Box::new(normalize_statement(statement)))
884 }
885 InsertSource::DefaultValues => InsertSource::DefaultValues,
886 },
887 returning: stmt
888 .returning
889 .iter()
890 .map(|item| SelectItem {
891 expr: item.expr.normalize(),
892 alias: item.alias.clone(),
893 })
894 .collect(),
895 }),
896 Statement::Update(stmt) => Statement::Update(UpdateStatement {
897 table: stmt.table.clone(),
898 assignments: stmt
899 .assignments
900 .iter()
901 .map(|assign| Assignment {
902 column: assign.column.clone(),
903 value: assign.value.normalize(),
904 })
905 .collect(),
906 selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
907 returning: stmt
908 .returning
909 .iter()
910 .map(|item| SelectItem {
911 expr: item.expr.normalize(),
912 alias: item.alias.clone(),
913 })
914 .collect(),
915 }),
916 Statement::Delete(stmt) => Statement::Delete(DeleteStatement {
917 table: stmt.table.clone(),
918 selection: stmt.selection.as_ref().map(|expr| expr.normalize()),
919 returning: stmt
920 .returning
921 .iter()
922 .map(|item| SelectItem {
923 expr: item.expr.normalize(),
924 alias: item.alias.clone(),
925 })
926 .collect(),
927 }),
928 }
929}
930
931fn normalize_select(select: &SelectStatement) -> SelectStatement {
932 SelectStatement {
933 distinct: select.distinct,
934 distinct_on: select
935 .distinct_on
936 .iter()
937 .map(|expr| expr.normalize())
938 .collect(),
939 projection: select
940 .projection
941 .iter()
942 .map(|item| SelectItem {
943 expr: item.expr.normalize(),
944 alias: item.alias.clone(),
945 })
946 .collect(),
947 from: select.from.as_ref().map(normalize_table_ref),
948 selection: select.selection.as_ref().map(|expr| expr.normalize()),
949 group_by: select
950 .group_by
951 .iter()
952 .map(|expr| expr.normalize())
953 .collect(),
954 having: select.having.as_ref().map(|expr| expr.normalize()),
955 qualify: select.qualify.as_ref().map(|expr| expr.normalize()),
956 order_by: select
957 .order_by
958 .iter()
959 .map(|order| OrderByExpr {
960 expr: order.expr.normalize(),
961 asc: order.asc,
962 nulls_first: order.nulls_first,
963 })
964 .collect(),
965 limit: select.limit,
966 offset: select.offset,
967 }
968}
969
970fn normalize_select_inner(select: &SelectStatement) -> SelectStatement {
971 normalize_select(select)
972}
973
974fn normalize_table_ref(table: &TableRef) -> TableRef {
975 TableRef {
976 factor: match &table.factor {
977 TableFactor::Table { name } => TableFactor::Table { name: name.clone() },
978 TableFactor::Derived { query } => TableFactor::Derived {
979 query: Box::new(normalize_statement(query)),
980 },
981 },
982 alias: table.alias.clone(),
983 column_aliases: table.column_aliases.clone(),
984 joins: table
985 .joins
986 .iter()
987 .map(|join| Join {
988 join_type: join.join_type,
989 right: normalize_table_ref(&join.right),
990 on: join.on.normalize(),
991 })
992 .collect(),
993 }
994}
995
996fn select_to_sql(select: &SelectStatement) -> String {
997 let mut output = String::from("select ");
998 if select.distinct {
999 output.push_str("distinct ");
1000 if !select.distinct_on.is_empty() {
1001 let distinct_on = select
1002 .distinct_on
1003 .iter()
1004 .map(|expr| expr.to_sql())
1005 .collect::<Vec<_>>()
1006 .join(", ");
1007 output.push_str("on (");
1008 output.push_str(&distinct_on);
1009 output.push_str(") ");
1010 }
1011 }
1012 let projection = select
1013 .projection
1014 .iter()
1015 .map(|item| item.expr.to_sql())
1016 .collect::<Vec<_>>()
1017 .join(", ");
1018 output.push_str(&projection);
1019 if let Some(from) = &select.from {
1020 output.push_str(" from ");
1021 output.push_str(&table_ref_to_sql(from));
1022 }
1023 if let Some(selection) = &select.selection {
1024 output.push_str(" where ");
1025 output.push_str(&selection.to_sql());
1026 }
1027 if !select.group_by.is_empty() {
1028 let group_by = select
1029 .group_by
1030 .iter()
1031 .map(|expr| expr.to_sql())
1032 .collect::<Vec<_>>()
1033 .join(", ");
1034 output.push_str(" group by ");
1035 output.push_str(&group_by);
1036 }
1037 if let Some(having) = &select.having {
1038 output.push_str(" having ");
1039 output.push_str(&having.to_sql());
1040 }
1041 if let Some(qualify) = &select.qualify {
1042 output.push_str(" qualify ");
1043 output.push_str(&qualify.to_sql());
1044 }
1045 if !select.order_by.is_empty() {
1046 let order_by = select
1047 .order_by
1048 .iter()
1049 .map(|item| {
1050 let mut rendered = item.expr.to_sql();
1051 rendered.push(' ');
1052 rendered.push_str(if item.asc { "asc" } else { "desc" });
1053 if let Some(nulls_first) = item.nulls_first {
1054 rendered.push_str(" nulls ");
1055 rendered.push_str(if nulls_first { "first" } else { "last" });
1056 }
1057 rendered
1058 })
1059 .collect::<Vec<_>>()
1060 .join(", ");
1061 output.push_str(" order by ");
1062 output.push_str(&order_by);
1063 }
1064 if let Some(limit) = select.limit {
1065 output.push_str(" limit ");
1066 output.push_str(&limit.to_string());
1067 }
1068 if let Some(offset) = select.offset {
1069 output.push_str(" offset ");
1070 output.push_str(&offset.to_string());
1071 }
1072 output
1073}
1074
1075fn frame_to_sql(frame: &WindowFrame) -> String {
1076 let kind = match frame.kind {
1077 WindowFrameKind::Rows => "rows",
1078 WindowFrameKind::Range => "range",
1079 WindowFrameKind::Groups => "groups",
1080 };
1081 let start = frame_bound_to_sql(&frame.start);
1082 if let Some(end) = &frame.end {
1083 format!("{kind} between {start} and {}", frame_bound_to_sql(end))
1084 } else {
1085 format!("{kind} {start}")
1086 }
1087}
1088
1089fn frame_bound_to_sql(bound: &WindowFrameBound) -> String {
1090 match bound {
1091 WindowFrameBound::UnboundedPreceding => "unbounded preceding".to_string(),
1092 WindowFrameBound::Preceding(expr) => format!("{} preceding", expr.to_sql()),
1093 WindowFrameBound::CurrentRow => "current row".to_string(),
1094 WindowFrameBound::Following(expr) => format!("{} following", expr.to_sql()),
1095 WindowFrameBound::UnboundedFollowing => "unbounded following".to_string(),
1096 }
1097}
1098
1099fn normalize_frame_bound(bound: &WindowFrameBound) -> WindowFrameBound {
1100 match bound {
1101 WindowFrameBound::UnboundedPreceding => WindowFrameBound::UnboundedPreceding,
1102 WindowFrameBound::Preceding(expr) => {
1103 WindowFrameBound::Preceding(Box::new(expr.normalize()))
1104 }
1105 WindowFrameBound::CurrentRow => WindowFrameBound::CurrentRow,
1106 WindowFrameBound::Following(expr) => {
1107 WindowFrameBound::Following(Box::new(expr.normalize()))
1108 }
1109 WindowFrameBound::UnboundedFollowing => WindowFrameBound::UnboundedFollowing,
1110 }
1111}
1112
1113fn window_frame_eq(left: &Option<WindowFrame>, right: &Option<WindowFrame>) -> bool {
1114 match (left, right) {
1115 (None, None) => true,
1116 (Some(left), Some(right)) => {
1117 left.kind == right.kind
1118 && frame_bound_eq(&left.start, &right.start)
1119 && match (&left.end, &right.end) {
1120 (None, None) => true,
1121 (Some(a), Some(b)) => frame_bound_eq(a, b),
1122 _ => false,
1123 }
1124 }
1125 _ => false,
1126 }
1127}
1128
1129fn frame_bound_eq(left: &WindowFrameBound, right: &WindowFrameBound) -> bool {
1130 match (left, right) {
1131 (WindowFrameBound::UnboundedPreceding, WindowFrameBound::UnboundedPreceding) => true,
1132 (WindowFrameBound::CurrentRow, WindowFrameBound::CurrentRow) => true,
1133 (WindowFrameBound::UnboundedFollowing, WindowFrameBound::UnboundedFollowing) => true,
1134 (WindowFrameBound::Preceding(a), WindowFrameBound::Preceding(b)) => a.structural_eq(b),
1135 (WindowFrameBound::Following(a), WindowFrameBound::Following(b)) => a.structural_eq(b),
1136 _ => false,
1137 }
1138}
1139
1140fn table_ref_to_sql(table: &TableRef) -> String {
1141 let mut output = match &table.factor {
1142 TableFactor::Table { name } => name.clone(),
1143 TableFactor::Derived { query } => format!("({})", statement_to_sql(query)),
1144 };
1145 if let Some(alias) = &table.alias {
1146 output.push_str(" as ");
1147 output.push_str(alias);
1148 if !table.column_aliases.is_empty() {
1149 output.push_str(" (");
1150 output.push_str(&table.column_aliases.join(", "));
1151 output.push(')');
1152 }
1153 }
1154 for join in &table.joins {
1155 let join_type = match join.join_type {
1156 JoinType::Inner => "join",
1157 JoinType::Left => "left join",
1158 JoinType::Right => "right join",
1159 JoinType::Full => "full join",
1160 };
1161 output.push(' ');
1162 output.push_str(join_type);
1163 output.push(' ');
1164 output.push_str(&table_ref_to_sql(&join.right));
1165 output.push_str(" on ");
1166 output.push_str(&join.on.to_sql());
1167 }
1168 output
1169}
1170
1171pub fn statement_to_sql(statement: &Statement) -> String {
1172 match statement {
1173 Statement::Select(select) => select_to_sql(select),
1174 Statement::SetOp { left, op, right } => {
1175 let left_sql = statement_to_sql(left);
1176 let right_sql = statement_to_sql(right);
1177 let op_str = match op {
1178 SetOperator::Union => "union",
1179 SetOperator::UnionAll => "union all",
1180 SetOperator::Intersect => "intersect",
1181 SetOperator::IntersectAll => "intersect all",
1182 SetOperator::Except => "except",
1183 SetOperator::ExceptAll => "except all",
1184 };
1185 format!("{left_sql} {op_str} {right_sql}")
1186 }
1187 Statement::With(with_stmt) => {
1188 let ctes = with_stmt
1189 .ctes
1190 .iter()
1191 .map(|cte| {
1192 let mut name = cte.name.clone();
1193 if !cte.columns.is_empty() {
1194 let cols = cte.columns.join(", ");
1195 name.push_str(" (");
1196 name.push_str(&cols);
1197 name.push(')');
1198 }
1199 format!("{name} as ({})", statement_to_sql(&cte.query))
1200 })
1201 .collect::<Vec<_>>()
1202 .join(", ");
1203 let keyword = if with_stmt.recursive {
1204 "with recursive"
1205 } else {
1206 "with"
1207 };
1208 format!(
1209 "{keyword} {ctes} {}",
1210 statement_to_sql(&with_stmt.statement)
1211 )
1212 }
1213 Statement::Explain(inner) => format!("explain {}", statement_to_sql(inner)),
1214 Statement::CreateTable(stmt) => {
1215 if stmt.if_not_exists {
1216 if stmt.columns.is_empty() {
1217 format!("create table if not exists {}", stmt.name)
1218 } else {
1219 let columns = stmt
1220 .columns
1221 .iter()
1222 .map(|col| format!("{} {}", col.name, col.data_type))
1223 .collect::<Vec<_>>()
1224 .join(", ");
1225 format!("create table if not exists {} ({})", stmt.name, columns)
1226 }
1227 } else {
1228 if stmt.columns.is_empty() {
1229 format!("create table {}", stmt.name)
1230 } else {
1231 let columns = stmt
1232 .columns
1233 .iter()
1234 .map(|col| format!("{} {}", col.name, col.data_type))
1235 .collect::<Vec<_>>()
1236 .join(", ");
1237 format!("create table {} ({})", stmt.name, columns)
1238 }
1239 }
1240 }
1241 Statement::DropTable(stmt) => {
1242 if stmt.if_exists {
1243 format!("drop table if exists {}", stmt.name)
1244 } else {
1245 format!("drop table {}", stmt.name)
1246 }
1247 }
1248 Statement::Truncate(stmt) => format!("truncate table {}", stmt.table),
1249 Statement::Analyze(stmt) => format!("analyze {}", stmt.table),
1250 Statement::Insert(stmt) => {
1251 let mut output = format!("insert into {}", stmt.table);
1252 if !stmt.columns.is_empty() {
1253 output.push_str(" (");
1254 output.push_str(&stmt.columns.join(", "));
1255 output.push(')');
1256 }
1257 match &stmt.source {
1258 InsertSource::DefaultValues => {
1259 output.push_str(" default values");
1260 }
1261 InsertSource::Values(values) => {
1262 let rows = values
1263 .iter()
1264 .map(|row| {
1265 let values = row
1266 .iter()
1267 .map(|expr| expr.to_sql())
1268 .collect::<Vec<_>>()
1269 .join(", ");
1270 format!("({values})")
1271 })
1272 .collect::<Vec<_>>()
1273 .join(", ");
1274 output.push_str(" values ");
1275 output.push_str(&rows);
1276 }
1277 InsertSource::Query(statement) => {
1278 output.push(' ');
1279 output.push_str(&statement_to_sql(statement));
1280 }
1281 }
1282 if !stmt.returning.is_empty() {
1283 let returning = stmt
1284 .returning
1285 .iter()
1286 .map(|item| item.expr.to_sql())
1287 .collect::<Vec<_>>()
1288 .join(", ");
1289 output.push_str(" returning ");
1290 output.push_str(&returning);
1291 }
1292 output
1293 }
1294 Statement::Update(stmt) => {
1295 let mut output = format!("update {} set ", stmt.table);
1296 let assignments = stmt
1297 .assignments
1298 .iter()
1299 .map(|assign| format!("{} = {}", assign.column, assign.value.to_sql()))
1300 .collect::<Vec<_>>()
1301 .join(", ");
1302 output.push_str(&assignments);
1303 if let Some(selection) = &stmt.selection {
1304 output.push_str(" where ");
1305 output.push_str(&selection.to_sql());
1306 }
1307 if !stmt.returning.is_empty() {
1308 let returning = stmt
1309 .returning
1310 .iter()
1311 .map(|item| item.expr.to_sql())
1312 .collect::<Vec<_>>()
1313 .join(", ");
1314 output.push_str(" returning ");
1315 output.push_str(&returning);
1316 }
1317 output
1318 }
1319 Statement::Delete(stmt) => {
1320 let mut output = format!("delete from {}", stmt.table);
1321 if let Some(selection) = &stmt.selection {
1322 output.push_str(" where ");
1323 output.push_str(&selection.to_sql());
1324 }
1325 if !stmt.returning.is_empty() {
1326 let returning = stmt
1327 .returning
1328 .iter()
1329 .map(|item| item.expr.to_sql())
1330 .collect::<Vec<_>>()
1331 .join(", ");
1332 output.push_str(" returning ");
1333 output.push_str(&returning);
1334 }
1335 output
1336 }
1337 }
1338}
1339
1340#[cfg(test)]
1341mod tests {
1342 use super::{BinaryOperator, Expr, Literal, UnaryOperator};
1343
1344 #[test]
1345 fn normalize_commutative_predicate() {
1346 let expr = Expr::BinaryOp {
1347 left: Box::new(Expr::Identifier("b".to_string())),
1348 op: BinaryOperator::And,
1349 right: Box::new(Expr::Identifier("a".to_string())),
1350 };
1351 let normalized = expr.normalize();
1352 let Expr::BinaryOp { left, right, .. } = normalized else {
1353 panic!("expected binary op");
1354 };
1355 assert!(matches!(left.as_ref(), Expr::Identifier(name) if name == "a"));
1356 assert!(matches!(right.as_ref(), Expr::Identifier(name) if name == "b"));
1357 }
1358
1359 #[test]
1360 fn normalize_not_comparison() {
1361 let expr = Expr::UnaryOp {
1362 op: UnaryOperator::Not,
1363 expr: Box::new(Expr::BinaryOp {
1364 left: Box::new(Expr::Identifier("a".to_string())),
1365 op: BinaryOperator::Eq,
1366 right: Box::new(Expr::Identifier("b".to_string())),
1367 }),
1368 };
1369 let normalized = expr.normalize();
1370 let Expr::BinaryOp { op, .. } = normalized else {
1371 panic!("expected binary op");
1372 };
1373 assert!(matches!(op, BinaryOperator::NotEq));
1374 }
1375
1376 #[test]
1377 fn normalize_double_not() {
1378 let expr = Expr::UnaryOp {
1379 op: UnaryOperator::Not,
1380 expr: Box::new(Expr::UnaryOp {
1381 op: UnaryOperator::Not,
1382 expr: Box::new(Expr::Identifier("flag".to_string())),
1383 }),
1384 };
1385 let normalized = expr.normalize();
1386 assert!(matches!(normalized, Expr::Identifier(_)));
1387 }
1388
1389 #[test]
1390 fn normalize_constant_fold() {
1391 let expr = Expr::BinaryOp {
1392 left: Box::new(Expr::Literal(Literal::Number(2.0))),
1393 op: BinaryOperator::Mul,
1394 right: Box::new(Expr::Literal(Literal::Number(4.0))),
1395 };
1396 let normalized = expr.normalize();
1397 match normalized {
1398 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 8.0),
1399 other => panic!("expected literal, got {other:?}"),
1400 }
1401 }
1402
1403 #[test]
1404 fn normalize_boolean_identities() {
1405 let expr = Expr::BinaryOp {
1406 left: Box::new(Expr::Identifier("flag".to_string())),
1407 op: BinaryOperator::And,
1408 right: Box::new(Expr::Literal(Literal::Bool(true))),
1409 };
1410 let normalized = expr.normalize();
1411 assert!(matches!(normalized, Expr::Identifier(_)));
1412 }
1413
1414 #[test]
1415 fn normalize_duplicate_and_or() {
1416 let expr = Expr::BinaryOp {
1417 left: Box::new(Expr::Identifier("a".to_string())),
1418 op: BinaryOperator::Or,
1419 right: Box::new(Expr::Identifier("a".to_string())),
1420 };
1421 let normalized = expr.normalize();
1422 assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1423 let expr = Expr::BinaryOp {
1424 left: Box::new(Expr::Identifier("a".to_string())),
1425 op: BinaryOperator::And,
1426 right: Box::new(Expr::Identifier("a".to_string())),
1427 };
1428 let normalized = expr.normalize();
1429 assert!(matches!(normalized, Expr::Identifier(name) if name == "a"));
1430 }
1431
1432 #[test]
1433 fn normalize_self_comparison() {
1434 let expr = Expr::BinaryOp {
1435 left: Box::new(Expr::Identifier("a".to_string())),
1436 op: BinaryOperator::Eq,
1437 right: Box::new(Expr::Identifier("a".to_string())),
1438 };
1439 let normalized = expr.normalize();
1440 assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1441 let expr = Expr::BinaryOp {
1442 left: Box::new(Expr::Identifier("a".to_string())),
1443 op: BinaryOperator::NotEq,
1444 right: Box::new(Expr::Identifier("a".to_string())),
1445 };
1446 let normalized = expr.normalize();
1447 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1448 let expr = Expr::BinaryOp {
1449 left: Box::new(Expr::Identifier("a".to_string())),
1450 op: BinaryOperator::Lt,
1451 right: Box::new(Expr::Identifier("a".to_string())),
1452 };
1453 let normalized = expr.normalize();
1454 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1455 let expr = Expr::BinaryOp {
1456 left: Box::new(Expr::Identifier("a".to_string())),
1457 op: BinaryOperator::LtEq,
1458 right: Box::new(Expr::Identifier("a".to_string())),
1459 };
1460 let normalized = expr.normalize();
1461 assert!(matches!(normalized, Expr::Literal(Literal::Bool(true))));
1462 }
1463
1464 #[test]
1465 fn normalize_not_literal() {
1466 let expr = Expr::UnaryOp {
1467 op: UnaryOperator::Not,
1468 expr: Box::new(Expr::Literal(Literal::Bool(true))),
1469 };
1470 let normalized = expr.normalize();
1471 assert!(matches!(normalized, Expr::Literal(Literal::Bool(false))));
1472 }
1473}