1use crate::dialects::DialectType;
14use crate::expressions::{
15 BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16 UnaryOp,
17};
18
19pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21 let mut simplifier = Simplifier::new(dialect);
22 simplifier.simplify(expression)
23}
24
25pub fn always_true(expr: &Expression) -> bool {
27 match expr {
28 Expression::Boolean(b) => b.value,
29 Expression::Literal(Literal::Number(n)) => {
30 if let Ok(num) = n.parse::<f64>() {
32 num != 0.0
33 } else {
34 false
35 }
36 }
37 _ => false,
38 }
39}
40
41pub fn is_boolean_true(expr: &Expression) -> bool {
43 matches!(expr, Expression::Boolean(b) if b.value)
44}
45
46pub fn is_boolean_false(expr: &Expression) -> bool {
48 matches!(expr, Expression::Boolean(b) if !b.value)
49}
50
51pub fn always_false(expr: &Expression) -> bool {
53 is_false(expr) || is_null(expr) || is_zero(expr)
54}
55
56pub fn is_false(expr: &Expression) -> bool {
58 matches!(expr, Expression::Boolean(b) if !b.value)
59}
60
61pub fn is_null(expr: &Expression) -> bool {
63 matches!(expr, Expression::Null(_))
64}
65
66pub fn is_zero(expr: &Expression) -> bool {
68 match expr {
69 Expression::Literal(Literal::Number(n)) => {
70 if let Ok(num) = n.parse::<f64>() {
71 num == 0.0
72 } else {
73 false
74 }
75 }
76 _ => false,
77 }
78}
79
80pub fn is_complement(a: &Expression, b: &Expression) -> bool {
82 if let Expression::Not(not_op) = b {
83 ¬_op.this == a
84 } else {
85 false
86 }
87}
88
89pub fn bool_true() -> Expression {
91 Expression::Boolean(BooleanLiteral { value: true })
92}
93
94pub fn bool_false() -> Expression {
96 Expression::Boolean(BooleanLiteral { value: false })
97}
98
99pub fn null() -> Expression {
101 Expression::Null(Null)
102}
103
104pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
106 let result = match op {
107 "=" | "==" => a == b,
108 "!=" | "<>" => a != b,
109 ">" => a > b,
110 ">=" => a >= b,
111 "<" => a < b,
112 "<=" => a <= b,
113 _ => return None,
114 };
115 Some(if result { bool_true() } else { bool_false() })
116}
117
118pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
120 let result = match op {
121 "=" | "==" => a == b,
122 "!=" | "<>" => a != b,
123 ">" => a > b,
124 ">=" => a >= b,
125 "<" => a < b,
126 "<=" => a <= b,
127 _ => return None,
128 };
129 Some(if result { bool_true() } else { bool_false() })
130}
131
132pub struct Simplifier {
134 _dialect: Option<DialectType>,
135 max_iterations: usize,
136}
137
138impl Simplifier {
139 pub fn new(dialect: Option<DialectType>) -> Self {
141 Self {
142 _dialect: dialect,
143 max_iterations: 100,
144 }
145 }
146
147 pub fn simplify(&mut self, expression: Expression) -> Expression {
149 let mut current = expression;
151 for _ in 0..self.max_iterations {
152 let simplified = self.simplify_once(current.clone());
153 if expressions_equal(&simplified, ¤t) {
154 return simplified;
155 }
156 current = simplified;
157 }
158 current
159 }
160
161 fn simplify_once(&mut self, expression: Expression) -> Expression {
163 match expression {
164 Expression::And(op) => self.simplify_and(*op),
166 Expression::Or(op) => self.simplify_or(*op),
167
168 Expression::Not(op) => self.simplify_not(*op),
170
171 Expression::Add(op) => self.simplify_add(*op),
173 Expression::Sub(op) => self.simplify_sub(*op),
174 Expression::Mul(op) => self.simplify_mul(*op),
175 Expression::Div(op) => self.simplify_div(*op),
176
177 Expression::Eq(op) => self.simplify_comparison(*op, "="),
179 Expression::Neq(op) => self.simplify_comparison(*op, "!="),
180 Expression::Gt(op) => self.simplify_comparison(*op, ">"),
181 Expression::Gte(op) => self.simplify_comparison(*op, ">="),
182 Expression::Lt(op) => self.simplify_comparison(*op, "<"),
183 Expression::Lte(op) => self.simplify_comparison(*op, "<="),
184
185 Expression::Neg(op) => self.simplify_neg(*op),
187
188 Expression::Case(case) => self.simplify_case(*case),
190
191 Expression::Concat(op) => self.simplify_concat(*op),
193 Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
194
195 Expression::Paren(paren) => self.simplify_paren(*paren),
197
198 Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
200 Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
201
202 other => self.simplify_children(other),
204 }
205 }
206
207 fn simplify_and(&mut self, op: BinaryOp) -> Expression {
209 let left = self.simplify_once(op.left);
210 let right = self.simplify_once(op.right);
211
212 if is_boolean_false(&left) || is_boolean_false(&right) {
215 return bool_false();
216 }
217
218 if is_zero(&left) || is_zero(&right) {
221 return bool_false();
222 }
223
224 if (is_null(&left) && is_null(&right))
228 || (is_null(&left) && is_boolean_true(&right))
229 || (is_boolean_true(&left) && is_null(&right))
230 {
231 return null();
232 }
233
234 if is_boolean_true(&left) {
236 return right;
237 }
238
239 if is_boolean_true(&right) {
241 return left;
242 }
243
244 if is_complement(&left, &right) || is_complement(&right, &left) {
246 return bool_false();
247 }
248
249 if expressions_equal(&left, &right) {
251 return left;
252 }
253
254 absorb_and_eliminate_and(left, right)
258 }
259
260 fn simplify_or(&mut self, op: BinaryOp) -> Expression {
262 let left = self.simplify_once(op.left);
263 let right = self.simplify_once(op.right);
264
265 if is_boolean_true(&left) {
267 return bool_true();
268 }
269
270 if is_boolean_true(&right) {
272 return bool_true();
273 }
274
275 if (is_null(&left) && is_null(&right))
279 || (is_null(&left) && is_boolean_false(&right))
280 || (is_boolean_false(&left) && is_null(&right))
281 {
282 return null();
283 }
284
285 if is_boolean_false(&left) {
287 return right;
288 }
289
290 if is_boolean_false(&right) {
292 return left;
293 }
294
295 if expressions_equal(&left, &right) {
297 return left;
298 }
299
300 absorb_and_eliminate_or(left, right)
304 }
305
306 fn simplify_not(&mut self, op: UnaryOp) -> Expression {
308 match &op.this {
311 Expression::Eq(inner_op) => {
313 let left = self.simplify_once(inner_op.left.clone());
314 let right = self.simplify_once(inner_op.right.clone());
315 return Expression::Neq(Box::new(BinaryOp::new(left, right)));
316 }
317 Expression::Neq(inner_op) => {
319 let left = self.simplify_once(inner_op.left.clone());
320 let right = self.simplify_once(inner_op.right.clone());
321 return Expression::Eq(Box::new(BinaryOp::new(left, right)));
322 }
323 Expression::Gt(inner_op) => {
325 let left = self.simplify_once(inner_op.left.clone());
326 let right = self.simplify_once(inner_op.right.clone());
327 return Expression::Lte(Box::new(BinaryOp::new(left, right)));
328 }
329 Expression::Gte(inner_op) => {
331 let left = self.simplify_once(inner_op.left.clone());
332 let right = self.simplify_once(inner_op.right.clone());
333 return Expression::Lt(Box::new(BinaryOp::new(left, right)));
334 }
335 Expression::Lt(inner_op) => {
337 let left = self.simplify_once(inner_op.left.clone());
338 let right = self.simplify_once(inner_op.right.clone());
339 return Expression::Gte(Box::new(BinaryOp::new(left, right)));
340 }
341 Expression::Lte(inner_op) => {
343 let left = self.simplify_once(inner_op.left.clone());
344 let right = self.simplify_once(inner_op.right.clone());
345 return Expression::Gt(Box::new(BinaryOp::new(left, right)));
346 }
347 _ => {}
348 }
349
350 let inner = self.simplify_once(op.this);
352
353 if is_null(&inner) {
355 return null();
356 }
357
358 if is_boolean_true(&inner) {
360 return bool_false();
361 }
362
363 if is_boolean_false(&inner) {
365 return bool_true();
366 }
367
368 if let Expression::Not(inner_not) = &inner {
370 return inner_not.this.clone();
371 }
372
373 Expression::Not(Box::new(UnaryOp {
374 this: inner,
375 inferred_type: None,
376 }))
377 }
378
379 fn simplify_add(&mut self, op: BinaryOp) -> Expression {
381 let left = self.simplify_once(op.left);
382 let right = self.simplify_once(op.right);
383
384 if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
386 return Expression::Literal(Literal::Number((a + b).to_string()));
387 }
388
389 if is_zero(&right) {
391 return left;
392 }
393
394 if is_zero(&left) {
396 return right;
397 }
398
399 Expression::Add(Box::new(BinaryOp::new(left, right)))
400 }
401
402 fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
404 let left = self.simplify_once(op.left);
405 let right = self.simplify_once(op.right);
406
407 if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
409 return Expression::Literal(Literal::Number((a - b).to_string()));
410 }
411
412 if is_zero(&right) {
414 return left;
415 }
416
417 if expressions_equal(&left, &right) {
419 if let Expression::Literal(Literal::Number(_)) = &left {
420 return Expression::Literal(Literal::Number("0".to_string()));
421 }
422 }
423
424 Expression::Sub(Box::new(BinaryOp::new(left, right)))
425 }
426
427 fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
429 let left = self.simplify_once(op.left);
430 let right = self.simplify_once(op.right);
431
432 if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
434 return Expression::Literal(Literal::Number((a * b).to_string()));
435 }
436
437 if is_zero(&right) {
439 return Expression::Literal(Literal::Number("0".to_string()));
440 }
441
442 if is_zero(&left) {
444 return Expression::Literal(Literal::Number("0".to_string()));
445 }
446
447 if is_one(&right) {
449 return left;
450 }
451
452 if is_one(&left) {
454 return right;
455 }
456
457 Expression::Mul(Box::new(BinaryOp::new(left, right)))
458 }
459
460 fn simplify_div(&mut self, op: BinaryOp) -> Expression {
462 let left = self.simplify_once(op.left);
463 let right = self.simplify_once(op.right);
464
465 if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
467 if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
469 return Expression::Literal(Literal::Number((a / b).to_string()));
470 }
471 }
472
473 if is_zero(&left) && !is_zero(&right) {
475 return Expression::Literal(Literal::Number("0".to_string()));
476 }
477
478 if is_one(&right) {
480 return left;
481 }
482
483 Expression::Div(Box::new(BinaryOp::new(left, right)))
484 }
485
486 fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
488 let inner = self.simplify_once(op.this);
489
490 if let Expression::Neg(inner_neg) = inner {
492 return inner_neg.this;
493 }
494
495 if let Some(n) = get_number(&inner) {
497 return Expression::Literal(Literal::Number((-n).to_string()));
498 }
499
500 Expression::Neg(Box::new(UnaryOp {
501 this: inner,
502 inferred_type: None,
503 }))
504 }
505
506 fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
508 let left = self.simplify_once(op.left);
509 let right = self.simplify_once(op.right);
510
511 if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
513 if let Some(result) = eval_boolean_nums(operator, a, b) {
514 return result;
515 }
516 }
517
518 if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
520 if let Some(result) = eval_boolean_strings(operator, &a, &b) {
521 return result;
522 }
523 }
524
525 if operator == "=" {
527 if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
528 return simplified;
529 }
530 }
531
532 let new_op = BinaryOp::new(left, right);
534
535 match operator {
536 "=" => Expression::Eq(Box::new(new_op)),
537 "!=" | "<>" => Expression::Neq(Box::new(new_op)),
538 ">" => Expression::Gt(Box::new(new_op)),
539 ">=" => Expression::Gte(Box::new(new_op)),
540 "<" => Expression::Lt(Box::new(new_op)),
541 "<=" => Expression::Lte(Box::new(new_op)),
542 _ => Expression::Eq(Box::new(new_op)),
543 }
544 }
545
546 fn simplify_case(&mut self, case: Case) -> Expression {
548 let mut new_whens = Vec::new();
549
550 for (cond, then_expr) in case.whens {
551 let simplified_cond = self.simplify_once(cond);
552
553 if always_true(&simplified_cond) {
555 return self.simplify_once(then_expr);
556 }
557
558 if always_false(&simplified_cond) {
560 continue;
561 }
562
563 new_whens.push((simplified_cond, self.simplify_once(then_expr)));
564 }
565
566 if new_whens.is_empty() {
568 return case
569 .else_
570 .map(|e| self.simplify_once(e))
571 .unwrap_or_else(null);
572 }
573
574 Expression::Case(Box::new(Case {
575 operand: case.operand.map(|e| self.simplify_once(e)),
576 whens: new_whens,
577 else_: case.else_.map(|e| self.simplify_once(e)),
578 comments: Vec::new(),
579 inferred_type: None,
580 }))
581 }
582
583 fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
591 let left = self.simplify_once(op.left);
592 let right = self.simplify_once(op.right);
593
594 if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
596 return Expression::Literal(Literal::String(format!("{}{}", a, b)));
597 }
598
599 if let Some(s) = get_string(&left) {
601 if s.is_empty() {
602 return right;
603 }
604 }
605
606 if let Some(s) = get_string(&right) {
608 if s.is_empty() {
609 return left;
610 }
611 }
612
613 if is_null(&left) || is_null(&right) {
615 return null();
616 }
617
618 Expression::Concat(Box::new(BinaryOp::new(left, right)))
619 }
620
621 fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
628 let separator = self.simplify_once(concat_ws.separator);
629
630 if is_null(&separator) {
632 return null();
633 }
634
635 let expressions: Vec<Expression> = concat_ws
636 .expressions
637 .into_iter()
638 .map(|e| self.simplify_once(e))
639 .filter(|e| !is_null(e)) .collect();
641
642 if expressions.is_empty() {
644 return Expression::Literal(Literal::String(String::new()));
645 }
646
647 if let Some(sep) = get_string(&separator) {
649 let all_strings: Option<Vec<String>> =
650 expressions.iter().map(|e| get_string(e)).collect();
651
652 if let Some(strings) = all_strings {
653 return Expression::Literal(Literal::String(strings.join(&sep)));
654 }
655 }
656
657 Expression::ConcatWs(Box::new(ConcatWs {
659 separator,
660 expressions,
661 }))
662 }
663
664 fn simplify_paren(&mut self, paren: Paren) -> Expression {
670 let inner = self.simplify_once(paren.this);
671
672 match &inner {
675 Expression::Literal(_)
676 | Expression::Boolean(_)
677 | Expression::Null(_)
678 | Expression::Column(_)
679 | Expression::Paren(_) => inner,
680 _ => Expression::Paren(Box::new(Paren {
682 this: inner,
683 trailing_comments: paren.trailing_comments,
684 })),
685 }
686 }
687
688 fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
693 let inner = self.simplify_once(dt.this);
694
695 Expression::DateTrunc(Box::new(DateTruncFunc {
698 this: inner,
699 unit: dt.unit,
700 }))
701 }
702
703 fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
710 let right_val = get_number(&right)?;
712
713 match left {
715 Expression::Add(ref op) => {
716 if let Some(c) = get_number(&op.right) {
718 let new_right =
719 Expression::Literal(Literal::Number((right_val - c).to_string()));
720 return Some(Expression::Eq(Box::new(BinaryOp::new(
721 op.left.clone(),
722 new_right,
723 ))));
724 }
725 if let Some(c) = get_number(&op.left) {
727 let new_right =
728 Expression::Literal(Literal::Number((right_val - c).to_string()));
729 return Some(Expression::Eq(Box::new(BinaryOp::new(
730 op.right.clone(),
731 new_right,
732 ))));
733 }
734 }
735 Expression::Sub(ref op) => {
736 if let Some(c) = get_number(&op.right) {
738 let new_right =
739 Expression::Literal(Literal::Number((right_val + c).to_string()));
740 return Some(Expression::Eq(Box::new(BinaryOp::new(
741 op.left.clone(),
742 new_right,
743 ))));
744 }
745 if let Some(c) = get_number(&op.left) {
747 let new_right =
748 Expression::Literal(Literal::Number((c - right_val).to_string()));
749 return Some(Expression::Eq(Box::new(BinaryOp::new(
750 op.right.clone(),
751 new_right,
752 ))));
753 }
754 }
755 Expression::Mul(ref op) => {
756 if let Some(c) = get_number(&op.right) {
758 if c != 0.0 && right_val % c == 0.0 {
759 let new_right =
760 Expression::Literal(Literal::Number((right_val / c).to_string()));
761 return Some(Expression::Eq(Box::new(BinaryOp::new(
762 op.left.clone(),
763 new_right,
764 ))));
765 }
766 }
767 if let Some(c) = get_number(&op.left) {
769 if c != 0.0 && right_val % c == 0.0 {
770 let new_right =
771 Expression::Literal(Literal::Number((right_val / c).to_string()));
772 return Some(Expression::Eq(Box::new(BinaryOp::new(
773 op.right.clone(),
774 new_right,
775 ))));
776 }
777 }
778 }
779 _ => {}
780 }
781
782 None
783 }
784
785 fn simplify_children(&mut self, expr: Expression) -> Expression {
787 match expr {
790 Expression::Alias(mut alias) => {
791 alias.this = self.simplify_once(alias.this);
792 Expression::Alias(alias)
793 }
794 Expression::Between(mut between) => {
795 between.this = self.simplify_once(between.this);
796 between.low = self.simplify_once(between.low);
797 between.high = self.simplify_once(between.high);
798 Expression::Between(between)
799 }
800 Expression::In(mut in_expr) => {
801 in_expr.this = self.simplify_once(in_expr.this);
802 in_expr.expressions = in_expr
803 .expressions
804 .into_iter()
805 .map(|e| self.simplify_once(e))
806 .collect();
807 Expression::In(in_expr)
808 }
809 Expression::Function(mut func) => {
810 func.args = func
811 .args
812 .into_iter()
813 .map(|e| self.simplify_once(e))
814 .collect();
815 Expression::Function(func)
816 }
817 other => other,
819 }
820 }
821}
822
823fn is_one(expr: &Expression) -> bool {
825 match expr {
826 Expression::Literal(Literal::Number(n)) => {
827 if let Ok(num) = n.parse::<f64>() {
828 num == 1.0
829 } else {
830 false
831 }
832 }
833 _ => false,
834 }
835}
836
837fn get_number(expr: &Expression) -> Option<f64> {
839 match expr {
840 Expression::Literal(Literal::Number(n)) => n.parse().ok(),
841 _ => None,
842 }
843}
844
845fn get_string(expr: &Expression) -> Option<String> {
847 match expr {
848 Expression::Literal(Literal::String(s)) => Some(s.clone()),
849 _ => None,
850 }
851}
852
853fn expressions_equal(a: &Expression, b: &Expression) -> bool {
856 format!("{:?}", a) == format!("{:?}", b)
859}
860
861fn flatten_and(expr: &Expression) -> Vec<Expression> {
864 match expr {
865 Expression::And(op) => {
866 let mut result = flatten_and(&op.left);
867 result.extend(flatten_and(&op.right));
868 result
869 }
870 other => vec![other.clone()],
871 }
872}
873
874fn flatten_or(expr: &Expression) -> Vec<Expression> {
877 match expr {
878 Expression::Or(op) => {
879 let mut result = flatten_or(&op.left);
880 result.extend(flatten_or(&op.right));
881 result
882 }
883 other => vec![other.clone()],
884 }
885}
886
887fn rebuild_and(operands: Vec<Expression>) -> Expression {
889 if operands.is_empty() {
890 return bool_true(); }
892 let mut result = operands.into_iter();
893 let first = result.next().unwrap();
894 result.fold(first, |acc, op| {
895 Expression::And(Box::new(BinaryOp::new(acc, op)))
896 })
897}
898
899fn rebuild_or(operands: Vec<Expression>) -> Expression {
901 if operands.is_empty() {
902 return bool_false(); }
904 let mut result = operands.into_iter();
905 let first = result.next().unwrap();
906 result.fold(first, |acc, op| {
907 Expression::Or(Box::new(BinaryOp::new(acc, op)))
908 })
909}
910
911fn get_not_inner(expr: &Expression) -> Option<&Expression> {
913 match expr {
914 Expression::Not(op) => Some(&op.this),
915 _ => None,
916 }
917}
918
919pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
928 let left_ops = flatten_and(&left);
930 let right_ops = flatten_and(&right);
931 let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
932
933 let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
935
936 let mut result_ops: Vec<Expression> = Vec::new();
937 let mut absorbed = std::collections::HashSet::new();
938
939 for (i, op) in all_ops.iter().enumerate() {
940 let op_str = gen(op);
941
942 if absorbed.contains(&op_str) {
944 continue;
945 }
946
947 if let Expression::Or(_) = op {
949 let or_operands = flatten_or(op);
950
951 let absorbed_by_existing = or_operands.iter().any(|or_op| {
954 let or_op_str = gen(or_op);
955 all_ops
957 .iter()
958 .enumerate()
959 .any(|(j, other)| i != j && gen(other) == or_op_str)
960 });
961
962 if absorbed_by_existing {
963 absorbed.insert(op_str);
965 continue;
966 }
967
968 let mut remaining_or_ops: Vec<Expression> = Vec::new();
971 let mut had_complement_absorption = false;
972
973 for or_op in or_operands {
974 let complement_str = if let Some(inner) = get_not_inner(&or_op) {
975 gen(inner)
977 } else {
978 format!("NOT {}", gen(&or_op))
980 };
981
982 let has_complement = all_ops
984 .iter()
985 .enumerate()
986 .any(|(j, other)| i != j && gen(other) == complement_str)
987 || op_strings.contains(&complement_str);
988
989 if has_complement {
990 had_complement_absorption = true;
994 } else {
996 remaining_or_ops.push(or_op);
997 }
998 }
999
1000 if had_complement_absorption {
1001 if remaining_or_ops.is_empty() {
1002 absorbed.insert(op_str);
1005 continue;
1006 } else if remaining_or_ops.len() == 1 {
1007 result_ops.push(remaining_or_ops.into_iter().next().unwrap());
1009 absorbed.insert(op_str);
1010 continue;
1011 } else {
1012 result_ops.push(rebuild_or(remaining_or_ops));
1014 absorbed.insert(op_str);
1015 continue;
1016 }
1017 }
1018 }
1019
1020 result_ops.push(op.clone());
1021 }
1022
1023 let mut seen = std::collections::HashSet::new();
1025 result_ops.retain(|op| seen.insert(gen(op)));
1026
1027 if result_ops.is_empty() {
1028 bool_true()
1029 } else {
1030 rebuild_and(result_ops)
1031 }
1032}
1033
1034pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1043 let left_ops = flatten_or(&left);
1045 let right_ops = flatten_or(&right);
1046 let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1047
1048 let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1050
1051 let mut result_ops: Vec<Expression> = Vec::new();
1052 let mut absorbed = std::collections::HashSet::new();
1053
1054 for (i, op) in all_ops.iter().enumerate() {
1055 let op_str = gen(op);
1056
1057 if absorbed.contains(&op_str) {
1059 continue;
1060 }
1061
1062 if let Expression::And(_) = op {
1064 let and_operands = flatten_and(op);
1065
1066 let absorbed_by_existing = and_operands.iter().any(|and_op| {
1069 let and_op_str = gen(and_op);
1070 all_ops
1072 .iter()
1073 .enumerate()
1074 .any(|(j, other)| i != j && gen(other) == and_op_str)
1075 });
1076
1077 if absorbed_by_existing {
1078 absorbed.insert(op_str);
1080 continue;
1081 }
1082
1083 let mut remaining_and_ops: Vec<Expression> = Vec::new();
1086 let mut had_complement_absorption = false;
1087
1088 for and_op in and_operands {
1089 let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1090 gen(inner)
1092 } else {
1093 format!("NOT {}", gen(&and_op))
1095 };
1096
1097 let has_complement = all_ops
1099 .iter()
1100 .enumerate()
1101 .any(|(j, other)| i != j && gen(other) == complement_str)
1102 || op_strings.contains(&complement_str);
1103
1104 if has_complement {
1105 had_complement_absorption = true;
1108 } else {
1110 remaining_and_ops.push(and_op);
1111 }
1112 }
1113
1114 if had_complement_absorption {
1115 if remaining_and_ops.is_empty() {
1116 absorbed.insert(op_str);
1119 continue;
1120 } else if remaining_and_ops.len() == 1 {
1121 result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1123 absorbed.insert(op_str);
1124 continue;
1125 } else {
1126 result_ops.push(rebuild_and(remaining_and_ops));
1128 absorbed.insert(op_str);
1129 continue;
1130 }
1131 }
1132 }
1133
1134 result_ops.push(op.clone());
1135 }
1136
1137 let mut seen = std::collections::HashSet::new();
1139 result_ops.retain(|op| seen.insert(gen(op)));
1140
1141 if result_ops.is_empty() {
1142 bool_false()
1143 } else {
1144 rebuild_or(result_ops)
1145 }
1146}
1147
1148pub fn gen(expr: &Expression) -> String {
1150 match expr {
1151 Expression::Literal(lit) => match lit {
1152 Literal::String(s) => format!("'{}'", s),
1153 Literal::Number(n) => n.clone(),
1154 _ => format!("{:?}", lit),
1155 },
1156 Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1157 Expression::Null(_) => "NULL".to_string(),
1158 Expression::Column(col) => {
1159 if let Some(ref table) = col.table {
1160 format!("{}.{}", table.name, col.name.name)
1161 } else {
1162 col.name.name.clone()
1163 }
1164 }
1165 Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1166 Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1167 Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1168 Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1169 Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1170 Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1171 Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1172 Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1173 Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1174 Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1175 Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1176 Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1177 Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1178 Expression::Function(f) => {
1179 let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1180 format!("{}({})", f.name.to_uppercase(), args.join(", "))
1181 }
1182 _ => format!("{:?}", expr),
1183 }
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188 use super::*;
1189
1190 fn make_int(val: i64) -> Expression {
1191 Expression::Literal(Literal::Number(val.to_string()))
1192 }
1193
1194 fn make_string(val: &str) -> Expression {
1195 Expression::Literal(Literal::String(val.to_string()))
1196 }
1197
1198 fn make_bool(val: bool) -> Expression {
1199 Expression::Boolean(BooleanLiteral { value: val })
1200 }
1201
1202 fn make_column(name: &str) -> Expression {
1203 use crate::expressions::{Column, Identifier};
1204 Expression::Column(Column {
1205 name: Identifier::new(name),
1206 table: None,
1207 join_mark: false,
1208 trailing_comments: vec![],
1209 span: None,
1210 inferred_type: None,
1211 })
1212 }
1213
1214 #[test]
1215 fn test_always_true_false() {
1216 assert!(always_true(&make_bool(true)));
1217 assert!(!always_true(&make_bool(false)));
1218 assert!(always_true(&make_int(1)));
1219 assert!(!always_true(&make_int(0)));
1220
1221 assert!(always_false(&make_bool(false)));
1222 assert!(!always_false(&make_bool(true)));
1223 assert!(always_false(&null()));
1224 assert!(always_false(&make_int(0)));
1225 }
1226
1227 #[test]
1228 fn test_simplify_and_with_true() {
1229 let mut simplifier = Simplifier::new(None);
1230
1231 let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1233 let result = simplifier.simplify(expr);
1234 assert!(always_true(&result));
1235
1236 let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1238 let result = simplifier.simplify(expr);
1239 assert!(always_false(&result));
1240
1241 let x = make_int(42);
1243 let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1244 let result = simplifier.simplify(expr);
1245 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1246 }
1247
1248 #[test]
1249 fn test_simplify_or_with_false() {
1250 let mut simplifier = Simplifier::new(None);
1251
1252 let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1254 let result = simplifier.simplify(expr);
1255 assert!(always_false(&result));
1256
1257 let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1259 let result = simplifier.simplify(expr);
1260 assert!(always_true(&result));
1261
1262 let x = make_int(42);
1264 let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1265 let result = simplifier.simplify(expr);
1266 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1267 }
1268
1269 #[test]
1270 fn test_simplify_not() {
1271 let mut simplifier = Simplifier::new(None);
1272
1273 let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1275 let result = simplifier.simplify(expr);
1276 assert!(is_false(&result));
1277
1278 let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1280 let result = simplifier.simplify(expr);
1281 assert!(always_true(&result));
1282
1283 let x = make_int(42);
1285 let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1286 let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1287 let result = simplifier.simplify(expr);
1288 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1289 }
1290
1291 #[test]
1292 fn test_simplify_demorgan_comparison() {
1293 let mut simplifier = Simplifier::new(None);
1294
1295 let a = make_column("a");
1297 let b = make_column("b");
1298 let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1299 let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1300 let result = simplifier.simplify(expr);
1301 assert!(matches!(result, Expression::Neq(_)));
1302
1303 let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1305 let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1306 let result = simplifier.simplify(expr);
1307 assert!(matches!(result, Expression::Lte(_)));
1308 }
1309
1310 #[test]
1311 fn test_constant_folding_add() {
1312 let mut simplifier = Simplifier::new(None);
1313
1314 let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1316 let result = simplifier.simplify(expr);
1317 assert_eq!(get_number(&result), Some(3.0));
1318
1319 let x = make_int(42);
1321 let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1322 let result = simplifier.simplify(expr);
1323 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1324 }
1325
1326 #[test]
1327 fn test_constant_folding_mul() {
1328 let mut simplifier = Simplifier::new(None);
1329
1330 let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1332 let result = simplifier.simplify(expr);
1333 assert_eq!(get_number(&result), Some(12.0));
1334
1335 let x = make_int(42);
1337 let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1338 let result = simplifier.simplify(expr);
1339 assert_eq!(get_number(&result), Some(0.0));
1340
1341 let x = make_int(42);
1343 let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1344 let result = simplifier.simplify(expr);
1345 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1346 }
1347
1348 #[test]
1349 fn test_constant_folding_comparison() {
1350 let mut simplifier = Simplifier::new(None);
1351
1352 let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1354 let result = simplifier.simplify(expr);
1355 assert!(always_true(&result));
1356
1357 let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1359 let result = simplifier.simplify(expr);
1360 assert!(is_false(&result));
1361
1362 let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1364 let result = simplifier.simplify(expr);
1365 assert!(always_true(&result));
1366
1367 let expr = Expression::Eq(Box::new(BinaryOp::new(
1369 make_string("abc"),
1370 make_string("abc"),
1371 )));
1372 let result = simplifier.simplify(expr);
1373 assert!(always_true(&result));
1374 }
1375
1376 #[test]
1377 fn test_simplify_negation() {
1378 let mut simplifier = Simplifier::new(None);
1379
1380 let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1382 let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1383 let result = simplifier.simplify(expr);
1384 assert_eq!(get_number(&result), Some(5.0));
1385
1386 let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1388 let result = simplifier.simplify(expr);
1389 assert_eq!(get_number(&result), Some(-3.0));
1390 }
1391
1392 #[test]
1393 fn test_gen_simple() {
1394 assert_eq!(gen(&make_int(42)), "42");
1395 assert_eq!(gen(&make_string("hello")), "'hello'");
1396 assert_eq!(gen(&make_bool(true)), "TRUE");
1397 assert_eq!(gen(&make_bool(false)), "FALSE");
1398 assert_eq!(gen(&null()), "NULL");
1399 }
1400
1401 #[test]
1402 fn test_gen_operations() {
1403 let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1404 assert_eq!(gen(&add), "1 + 2");
1405
1406 let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1407 assert_eq!(gen(&and), "(TRUE AND FALSE)");
1408 }
1409
1410 #[test]
1411 fn test_complement_elimination() {
1412 let mut simplifier = Simplifier::new(None);
1413
1414 let x = make_int(42);
1416 let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1417 let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1418 let result = simplifier.simplify(expr);
1419 assert!(is_false(&result));
1420 }
1421
1422 #[test]
1423 fn test_idempotent() {
1424 let mut simplifier = Simplifier::new(None);
1425
1426 let x = make_int(42);
1428 let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1429 let result = simplifier.simplify(expr);
1430 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1431
1432 let x = make_int(42);
1434 let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1435 let result = simplifier.simplify(expr);
1436 assert_eq!(format!("{:?}", result), format!("{:?}", x));
1437 }
1438
1439 #[test]
1440 fn test_absorption_and() {
1441 let mut simplifier = Simplifier::new(None);
1442
1443 let a = make_column("a");
1445 let b = make_column("b");
1446 let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1447 let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1448 let result = simplifier.simplify(expr);
1449 assert_eq!(gen(&result), gen(&a));
1451 }
1452
1453 #[test]
1454 fn test_absorption_or() {
1455 let mut simplifier = Simplifier::new(None);
1456
1457 let a = make_column("a");
1459 let b = make_column("b");
1460 let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1461 let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1462 let result = simplifier.simplify(expr);
1463 assert_eq!(gen(&result), gen(&a));
1465 }
1466
1467 #[test]
1468 fn test_absorption_with_complement_and() {
1469 let mut simplifier = Simplifier::new(None);
1470
1471 let a = make_column("a");
1473 let b = make_column("b");
1474 let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1475 let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1476 let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1477 let result = simplifier.simplify(expr);
1478 let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1480 assert_eq!(gen(&result), gen(&expected));
1481 }
1482
1483 #[test]
1484 fn test_absorption_with_complement_or() {
1485 let mut simplifier = Simplifier::new(None);
1486
1487 let a = make_column("a");
1489 let b = make_column("b");
1490 let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1491 let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1492 let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1493 let result = simplifier.simplify(expr);
1494 let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1496 assert_eq!(gen(&result), gen(&expected));
1497 }
1498
1499 #[test]
1500 fn test_flatten_and() {
1501 let a = make_column("a");
1503 let b = make_column("b");
1504 let c = make_column("c");
1505 let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1506 let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1507 let flattened = flatten_and(&expr);
1508 assert_eq!(flattened.len(), 3);
1509 assert_eq!(gen(&flattened[0]), "a");
1510 assert_eq!(gen(&flattened[1]), "b");
1511 assert_eq!(gen(&flattened[2]), "c");
1512 }
1513
1514 #[test]
1515 fn test_flatten_or() {
1516 let a = make_column("a");
1518 let b = make_column("b");
1519 let c = make_column("c");
1520 let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1521 let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1522 let flattened = flatten_or(&expr);
1523 assert_eq!(flattened.len(), 3);
1524 assert_eq!(gen(&flattened[0]), "a");
1525 assert_eq!(gen(&flattened[1]), "b");
1526 assert_eq!(gen(&flattened[2]), "c");
1527 }
1528
1529 #[test]
1530 fn test_simplify_concat() {
1531 let mut simplifier = Simplifier::new(None);
1532
1533 let expr = Expression::Concat(Box::new(BinaryOp::new(
1535 make_string("hello"),
1536 make_string("world"),
1537 )));
1538 let result = simplifier.simplify(expr);
1539 assert_eq!(get_string(&result), Some("helloworld".to_string()));
1540
1541 let x = make_string("test");
1543 let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1544 let result = simplifier.simplify(expr);
1545 assert_eq!(get_string(&result), Some("test".to_string()));
1546
1547 let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1549 let result = simplifier.simplify(expr);
1550 assert_eq!(get_string(&result), Some("test".to_string()));
1551
1552 let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1554 let result = simplifier.simplify(expr);
1555 assert!(is_null(&result));
1556 }
1557
1558 #[test]
1559 fn test_simplify_concat_ws() {
1560 let mut simplifier = Simplifier::new(None);
1561
1562 let expr = Expression::ConcatWs(Box::new(ConcatWs {
1564 separator: make_string(","),
1565 expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1566 }));
1567 let result = simplifier.simplify(expr);
1568 assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1569
1570 let expr = Expression::ConcatWs(Box::new(ConcatWs {
1572 separator: null(),
1573 expressions: vec![make_string("a"), make_string("b")],
1574 }));
1575 let result = simplifier.simplify(expr);
1576 assert!(is_null(&result));
1577
1578 let expr = Expression::ConcatWs(Box::new(ConcatWs {
1580 separator: make_string(","),
1581 expressions: vec![],
1582 }));
1583 let result = simplifier.simplify(expr);
1584 assert_eq!(get_string(&result), Some("".to_string()));
1585
1586 let expr = Expression::ConcatWs(Box::new(ConcatWs {
1588 separator: make_string("-"),
1589 expressions: vec![make_string("a"), null(), make_string("b")],
1590 }));
1591 let result = simplifier.simplify(expr);
1592 assert_eq!(get_string(&result), Some("a-b".to_string()));
1593 }
1594
1595 #[test]
1596 fn test_simplify_paren() {
1597 let mut simplifier = Simplifier::new(None);
1598
1599 let expr = Expression::Paren(Box::new(Paren {
1601 this: make_int(42),
1602 trailing_comments: vec![],
1603 }));
1604 let result = simplifier.simplify(expr);
1605 assert_eq!(get_number(&result), Some(42.0));
1606
1607 let expr = Expression::Paren(Box::new(Paren {
1609 this: make_bool(true),
1610 trailing_comments: vec![],
1611 }));
1612 let result = simplifier.simplify(expr);
1613 assert!(is_boolean_true(&result));
1614
1615 let expr = Expression::Paren(Box::new(Paren {
1617 this: null(),
1618 trailing_comments: vec![],
1619 }));
1620 let result = simplifier.simplify(expr);
1621 assert!(is_null(&result));
1622
1623 let inner_paren = Expression::Paren(Box::new(Paren {
1625 this: make_int(10),
1626 trailing_comments: vec![],
1627 }));
1628 let expr = Expression::Paren(Box::new(Paren {
1629 this: inner_paren,
1630 trailing_comments: vec![],
1631 }));
1632 let result = simplifier.simplify(expr);
1633 assert_eq!(get_number(&result), Some(10.0));
1634 }
1635
1636 #[test]
1637 fn test_simplify_equality_solve() {
1638 let mut simplifier = Simplifier::new(None);
1639
1640 let x = make_column("x");
1642 let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1643 let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1644 let result = simplifier.simplify(expr);
1645 if let Expression::Eq(op) = &result {
1647 assert_eq!(gen(&op.left), "x");
1648 assert_eq!(get_number(&op.right), Some(2.0));
1649 } else {
1650 panic!("Expected Eq expression");
1651 }
1652
1653 let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1655 let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1656 let result = simplifier.simplify(expr);
1657 if let Expression::Eq(op) = &result {
1658 assert_eq!(gen(&op.left), "x");
1659 assert_eq!(get_number(&op.right), Some(4.0));
1660 } else {
1661 panic!("Expected Eq expression");
1662 }
1663
1664 let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1666 let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1667 let result = simplifier.simplify(expr);
1668 if let Expression::Eq(op) = &result {
1669 assert_eq!(gen(&op.left), "x");
1670 assert_eq!(get_number(&op.right), Some(3.0));
1671 } else {
1672 panic!("Expected Eq expression");
1673 }
1674
1675 let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1677 let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1678 let result = simplifier.simplify(expr);
1679 if let Expression::Eq(op) = &result {
1680 assert_eq!(gen(&op.left), "x");
1681 assert_eq!(get_number(&op.right), Some(2.0));
1682 } else {
1683 panic!("Expected Eq expression");
1684 }
1685 }
1686
1687 #[test]
1688 fn test_simplify_datetrunc() {
1689 use crate::expressions::DateTimeField;
1690 let mut simplifier = Simplifier::new(None);
1691
1692 let x = make_column("x");
1694 let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1695 this: x.clone(),
1696 unit: DateTimeField::Day,
1697 }));
1698 let result = simplifier.simplify(expr);
1699 if let Expression::DateTrunc(dt) = &result {
1700 assert_eq!(gen(&dt.this), "x");
1701 assert_eq!(dt.unit, DateTimeField::Day);
1702 } else {
1703 panic!("Expected DateTrunc expression");
1704 }
1705 }
1706}