Skip to main content

polyglot_sql/optimizer/
simplify.rs

1//! Expression Simplification
2//!
3//! This module provides boolean and expression simplification for SQL AST nodes.
4//! It applies various algebraic transformations to simplify expressions:
5//! - De Morgan's laws (NOT (A AND B) -> NOT A OR NOT B)
6//! - Constant folding (1 + 2 -> 3)
7//! - Boolean absorption (A AND (A OR B) -> A)
8//! - Complement removal (A AND NOT A -> FALSE)
9//! - Connector flattening (A AND (B AND C) -> A AND B AND C)
10//!
11//! Based on SQLGlot's optimizer/simplify.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{
15    BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16    UnaryOp,
17};
18
19/// Main entry point for expression simplification
20pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21    let mut simplifier = Simplifier::new(dialect);
22    simplifier.simplify(expression)
23}
24
25/// Check if expression is always true
26pub fn always_true(expr: &Expression) -> bool {
27    match expr {
28        Expression::Boolean(b) => b.value,
29        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
30            let Literal::Number(n) = lit.as_ref() else {
31                unreachable!()
32            };
33            // Non-zero numbers are truthy
34            if let Ok(num) = n.parse::<f64>() {
35                num != 0.0
36            } else {
37                false
38            }
39        }
40        _ => false,
41    }
42}
43
44/// Check if expression is a boolean TRUE literal (not just truthy)
45pub fn is_boolean_true(expr: &Expression) -> bool {
46    matches!(expr, Expression::Boolean(b) if b.value)
47}
48
49/// Check if expression is a boolean FALSE literal (not just falsy)
50pub fn is_boolean_false(expr: &Expression) -> bool {
51    matches!(expr, Expression::Boolean(b) if !b.value)
52}
53
54/// Check if expression is always false
55pub fn always_false(expr: &Expression) -> bool {
56    is_false(expr) || is_null(expr) || is_zero(expr)
57}
58
59/// Check if expression is boolean FALSE
60pub fn is_false(expr: &Expression) -> bool {
61    matches!(expr, Expression::Boolean(b) if !b.value)
62}
63
64/// Check if expression is NULL
65pub fn is_null(expr: &Expression) -> bool {
66    matches!(expr, Expression::Null(_))
67}
68
69/// Check if expression is zero
70pub fn is_zero(expr: &Expression) -> bool {
71    match expr {
72        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
73            let Literal::Number(n) = lit.as_ref() else {
74                unreachable!()
75            };
76            if let Ok(num) = n.parse::<f64>() {
77                num == 0.0
78            } else {
79                false
80            }
81        }
82        _ => false,
83    }
84}
85
86/// Check if b is the complement of a (i.e., b = NOT a)
87pub fn is_complement(a: &Expression, b: &Expression) -> bool {
88    if let Expression::Not(not_op) = b {
89        &not_op.this == a
90    } else {
91        false
92    }
93}
94
95/// Create a TRUE boolean literal
96pub fn bool_true() -> Expression {
97    Expression::Boolean(BooleanLiteral { value: true })
98}
99
100/// Create a FALSE boolean literal
101pub fn bool_false() -> Expression {
102    Expression::Boolean(BooleanLiteral { value: false })
103}
104
105/// Create a NULL expression
106pub fn null() -> Expression {
107    Expression::Null(Null)
108}
109
110/// Evaluate a boolean comparison between two numbers
111pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
112    let result = match op {
113        "=" | "==" => a == b,
114        "!=" | "<>" => a != b,
115        ">" => a > b,
116        ">=" => a >= b,
117        "<" => a < b,
118        "<=" => a <= b,
119        _ => return None,
120    };
121    Some(if result { bool_true() } else { bool_false() })
122}
123
124/// Evaluate a boolean comparison between two strings
125pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
126    let result = match op {
127        "=" | "==" => a == b,
128        "!=" | "<>" => a != b,
129        ">" => a > b,
130        ">=" => a >= b,
131        "<" => a < b,
132        "<=" => a <= b,
133        _ => return None,
134    };
135    Some(if result { bool_true() } else { bool_false() })
136}
137
138/// Expression simplifier
139pub struct Simplifier {
140    _dialect: Option<DialectType>,
141    max_iterations: usize,
142}
143
144impl Simplifier {
145    /// Create a new simplifier
146    pub fn new(dialect: Option<DialectType>) -> Self {
147        Self {
148            _dialect: dialect,
149            max_iterations: 100,
150        }
151    }
152
153    /// Simplify an expression
154    pub fn simplify(&mut self, expression: Expression) -> Expression {
155        // Apply simplifications until no more changes (or max iterations)
156        let mut current = expression;
157        for _ in 0..self.max_iterations {
158            let simplified = self.simplify_once(current.clone());
159            if expressions_equal(&simplified, &current) {
160                return simplified;
161            }
162            current = simplified;
163        }
164        current
165    }
166
167    /// Apply one round of simplifications
168    fn simplify_once(&mut self, expression: Expression) -> Expression {
169        match expression {
170            // Binary logical operations
171            Expression::And(op) => self.simplify_and(*op),
172            Expression::Or(op) => self.simplify_or(*op),
173
174            // NOT operation - De Morgan's laws
175            Expression::Not(op) => self.simplify_not(*op),
176
177            // Arithmetic operations - constant folding
178            Expression::Add(op) => self.simplify_add(*op),
179            Expression::Sub(op) => self.simplify_sub(*op),
180            Expression::Mul(op) => self.simplify_mul(*op),
181            Expression::Div(op) => self.simplify_div(*op),
182
183            // Comparison operations
184            Expression::Eq(op) => self.simplify_comparison(*op, "="),
185            Expression::Neq(op) => self.simplify_comparison(*op, "!="),
186            Expression::Gt(op) => self.simplify_comparison(*op, ">"),
187            Expression::Gte(op) => self.simplify_comparison(*op, ">="),
188            Expression::Lt(op) => self.simplify_comparison(*op, "<"),
189            Expression::Lte(op) => self.simplify_comparison(*op, "<="),
190
191            // Negation
192            Expression::Neg(op) => self.simplify_neg(*op),
193
194            // CASE expression
195            Expression::Case(case) => self.simplify_case(*case),
196
197            // String concatenation
198            Expression::Concat(op) => self.simplify_concat(*op),
199            Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
200
201            // Parentheses - remove if unnecessary
202            Expression::Paren(paren) => self.simplify_paren(*paren),
203
204            // Date truncation
205            Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
206            Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
207
208            // Recursively simplify children for other expressions
209            other => self.simplify_children(other),
210        }
211    }
212
213    /// Simplify AND operation
214    fn simplify_and(&mut self, op: BinaryOp) -> Expression {
215        let left = self.simplify_once(op.left);
216        let right = self.simplify_once(op.right);
217
218        // FALSE AND x -> FALSE
219        // x AND FALSE -> FALSE
220        if is_boolean_false(&left) || is_boolean_false(&right) {
221            return bool_false();
222        }
223
224        // 0 AND x -> FALSE (in boolean context)
225        // x AND 0 -> FALSE
226        if is_zero(&left) || is_zero(&right) {
227            return bool_false();
228        }
229
230        // NULL AND NULL -> NULL
231        // NULL AND TRUE -> NULL
232        // TRUE AND NULL -> NULL
233        if (is_null(&left) && is_null(&right))
234            || (is_null(&left) && is_boolean_true(&right))
235            || (is_boolean_true(&left) && is_null(&right))
236        {
237            return null();
238        }
239
240        // TRUE AND x -> x (only when left is actually boolean TRUE)
241        if is_boolean_true(&left) {
242            return right;
243        }
244
245        // x AND TRUE -> x (only when right is actually boolean TRUE)
246        if is_boolean_true(&right) {
247            return left;
248        }
249
250        // A AND NOT A -> FALSE (complement elimination)
251        if is_complement(&left, &right) || is_complement(&right, &left) {
252            return bool_false();
253        }
254
255        // A AND A -> A (idempotent)
256        if expressions_equal(&left, &right) {
257            return left;
258        }
259
260        // Apply absorption rules
261        // A AND (A OR B) -> A
262        // A AND (NOT A OR B) -> A AND B
263        absorb_and_eliminate_and(left, right)
264    }
265
266    /// Simplify OR operation
267    fn simplify_or(&mut self, op: BinaryOp) -> Expression {
268        let left = self.simplify_once(op.left);
269        let right = self.simplify_once(op.right);
270
271        // TRUE OR x -> TRUE (only when left is actually boolean TRUE)
272        if is_boolean_true(&left) {
273            return bool_true();
274        }
275
276        // x OR TRUE -> TRUE (only when right is actually boolean TRUE)
277        if is_boolean_true(&right) {
278            return bool_true();
279        }
280
281        // NULL OR NULL -> NULL
282        // NULL OR FALSE -> NULL
283        // FALSE OR NULL -> NULL
284        if (is_null(&left) && is_null(&right))
285            || (is_null(&left) && is_boolean_false(&right))
286            || (is_boolean_false(&left) && is_null(&right))
287        {
288            return null();
289        }
290
291        // FALSE OR x -> x (only when left is actually boolean FALSE)
292        if is_boolean_false(&left) {
293            return right;
294        }
295
296        // x OR FALSE -> x (only when right is actually boolean FALSE)
297        if is_boolean_false(&right) {
298            return left;
299        }
300
301        // A OR A -> A (idempotent)
302        if expressions_equal(&left, &right) {
303            return left;
304        }
305
306        // Apply absorption rules
307        // A OR (A AND B) -> A
308        // A OR (NOT A AND B) -> A OR B
309        absorb_and_eliminate_or(left, right)
310    }
311
312    /// Simplify NOT operation (De Morgan's laws)
313    fn simplify_not(&mut self, op: UnaryOp) -> Expression {
314        // Check for De Morgan's laws BEFORE simplifying inner expression
315        // This prevents constant folding from eliminating the comparison operator
316        match &op.this {
317            // NOT (a = b) -> a != b
318            Expression::Eq(inner_op) => {
319                let left = self.simplify_once(inner_op.left.clone());
320                let right = self.simplify_once(inner_op.right.clone());
321                return Expression::Neq(Box::new(BinaryOp::new(left, right)));
322            }
323            // NOT (a != b) -> a = b
324            Expression::Neq(inner_op) => {
325                let left = self.simplify_once(inner_op.left.clone());
326                let right = self.simplify_once(inner_op.right.clone());
327                return Expression::Eq(Box::new(BinaryOp::new(left, right)));
328            }
329            // NOT (a > b) -> a <= b
330            Expression::Gt(inner_op) => {
331                let left = self.simplify_once(inner_op.left.clone());
332                let right = self.simplify_once(inner_op.right.clone());
333                return Expression::Lte(Box::new(BinaryOp::new(left, right)));
334            }
335            // NOT (a >= b) -> a < b
336            Expression::Gte(inner_op) => {
337                let left = self.simplify_once(inner_op.left.clone());
338                let right = self.simplify_once(inner_op.right.clone());
339                return Expression::Lt(Box::new(BinaryOp::new(left, right)));
340            }
341            // NOT (a < b) -> a >= b
342            Expression::Lt(inner_op) => {
343                let left = self.simplify_once(inner_op.left.clone());
344                let right = self.simplify_once(inner_op.right.clone());
345                return Expression::Gte(Box::new(BinaryOp::new(left, right)));
346            }
347            // NOT (a <= b) -> a > b
348            Expression::Lte(inner_op) => {
349                let left = self.simplify_once(inner_op.left.clone());
350                let right = self.simplify_once(inner_op.right.clone());
351                return Expression::Gt(Box::new(BinaryOp::new(left, right)));
352            }
353            _ => {}
354        }
355
356        // Now simplify the inner expression for other patterns
357        let inner = self.simplify_once(op.this);
358
359        // NOT NULL -> NULL (with TRUE for SQL semantics)
360        if is_null(&inner) {
361            return null();
362        }
363
364        // NOT TRUE -> FALSE (only for boolean TRUE literal)
365        if is_boolean_true(&inner) {
366            return bool_false();
367        }
368
369        // NOT FALSE -> TRUE (only for boolean FALSE literal)
370        if is_boolean_false(&inner) {
371            return bool_true();
372        }
373
374        // NOT NOT x -> x (double negation elimination)
375        if let Expression::Not(inner_not) = &inner {
376            return inner_not.this.clone();
377        }
378
379        Expression::Not(Box::new(UnaryOp {
380            this: inner,
381            inferred_type: None,
382        }))
383    }
384
385    /// Simplify addition (constant folding)
386    fn simplify_add(&mut self, op: BinaryOp) -> Expression {
387        let left = self.simplify_once(op.left);
388        let right = self.simplify_once(op.right);
389
390        // Try constant folding for numbers
391        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
392            return Expression::Literal(Box::new(Literal::Number((a + b).to_string())));
393        }
394
395        // x + 0 -> x
396        if is_zero(&right) {
397            return left;
398        }
399
400        // 0 + x -> x
401        if is_zero(&left) {
402            return right;
403        }
404
405        Expression::Add(Box::new(BinaryOp::new(left, right)))
406    }
407
408    /// Simplify subtraction (constant folding)
409    fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
410        let left = self.simplify_once(op.left);
411        let right = self.simplify_once(op.right);
412
413        // Try constant folding for numbers
414        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
415            return Expression::Literal(Box::new(Literal::Number((a - b).to_string())));
416        }
417
418        // x - 0 -> x
419        if is_zero(&right) {
420            return left;
421        }
422
423        // x - x -> 0 (only for literals/constants)
424        if expressions_equal(&left, &right) {
425            if let Expression::Literal(lit) = &left {
426                if let Literal::Number(_) = lit.as_ref() {
427                    return Expression::Literal(Box::new(Literal::Number("0".to_string())));
428                }
429            }
430        }
431
432        Expression::Sub(Box::new(BinaryOp::new(left, right)))
433    }
434
435    /// Simplify multiplication (constant folding)
436    fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
437        let left = self.simplify_once(op.left);
438        let right = self.simplify_once(op.right);
439
440        // Try constant folding for numbers
441        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
442            return Expression::Literal(Box::new(Literal::Number((a * b).to_string())));
443        }
444
445        // x * 0 -> 0
446        if is_zero(&right) {
447            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
448        }
449
450        // 0 * x -> 0
451        if is_zero(&left) {
452            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
453        }
454
455        // x * 1 -> x
456        if is_one(&right) {
457            return left;
458        }
459
460        // 1 * x -> x
461        if is_one(&left) {
462            return right;
463        }
464
465        Expression::Mul(Box::new(BinaryOp::new(left, right)))
466    }
467
468    /// Simplify division (constant folding)
469    fn simplify_div(&mut self, op: BinaryOp) -> Expression {
470        let left = self.simplify_once(op.left);
471        let right = self.simplify_once(op.right);
472
473        // Try constant folding for numbers (but not integer division)
474        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
475            // Only fold if both are floats to avoid integer division issues
476            if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
477                return Expression::Literal(Box::new(Literal::Number((a / b).to_string())));
478            }
479        }
480
481        // 0 / x -> 0 (when x != 0)
482        if is_zero(&left) && !is_zero(&right) {
483            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
484        }
485
486        // x / 1 -> x
487        if is_one(&right) {
488            return left;
489        }
490
491        Expression::Div(Box::new(BinaryOp::new(left, right)))
492    }
493
494    /// Simplify negation
495    fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
496        let inner = self.simplify_once(op.this);
497
498        // -(-x) -> x (double negation)
499        if let Expression::Neg(inner_neg) = inner {
500            return inner_neg.this;
501        }
502
503        // -(number) -> -number
504        if let Some(n) = get_number(&inner) {
505            return Expression::Literal(Box::new(Literal::Number((-n).to_string())));
506        }
507
508        Expression::Neg(Box::new(UnaryOp {
509            this: inner,
510            inferred_type: None,
511        }))
512    }
513
514    /// Simplify comparison operations (constant folding)
515    fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
516        let left = self.simplify_once(op.left);
517        let right = self.simplify_once(op.right);
518
519        // Try constant folding for numbers
520        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
521            if let Some(result) = eval_boolean_nums(operator, a, b) {
522                return result;
523            }
524        }
525
526        // Try constant folding for strings
527        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
528            if let Some(result) = eval_boolean_strings(operator, &a, &b) {
529                return result;
530            }
531        }
532
533        // For equality, try to solve simple equations (x + 1 = 3 -> x = 2)
534        if operator == "=" {
535            if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
536                return simplified;
537            }
538        }
539
540        // Reconstruct the comparison
541        let new_op = BinaryOp::new(left, right);
542
543        match operator {
544            "=" => Expression::Eq(Box::new(new_op)),
545            "!=" | "<>" => Expression::Neq(Box::new(new_op)),
546            ">" => Expression::Gt(Box::new(new_op)),
547            ">=" => Expression::Gte(Box::new(new_op)),
548            "<" => Expression::Lt(Box::new(new_op)),
549            "<=" => Expression::Lte(Box::new(new_op)),
550            _ => Expression::Eq(Box::new(new_op)),
551        }
552    }
553
554    /// Simplify CASE expression
555    fn simplify_case(&mut self, case: Case) -> Expression {
556        let mut new_whens = Vec::new();
557
558        for (cond, then_expr) in case.whens {
559            let simplified_cond = self.simplify_once(cond);
560
561            // If condition is always true, return the THEN expression
562            if always_true(&simplified_cond) {
563                return self.simplify_once(then_expr);
564            }
565
566            // If condition is always false, skip this WHEN clause
567            if always_false(&simplified_cond) {
568                continue;
569            }
570
571            new_whens.push((simplified_cond, self.simplify_once(then_expr)));
572        }
573
574        // If no WHEN clauses remain, return the ELSE expression (or NULL)
575        if new_whens.is_empty() {
576            return case
577                .else_
578                .map(|e| self.simplify_once(e))
579                .unwrap_or_else(null);
580        }
581
582        Expression::Case(Box::new(Case {
583            operand: case.operand.map(|e| self.simplify_once(e)),
584            whens: new_whens,
585            else_: case.else_.map(|e| self.simplify_once(e)),
586            comments: Vec::new(),
587            inferred_type: None,
588        }))
589    }
590
591    /// Simplify string concatenation (Concat is || operator)
592    ///
593    /// Folds adjacent string literals:
594    /// - 'a' || 'b' -> 'ab'
595    /// - 'a' || 'b' || 'c' -> 'abc'
596    /// - '' || x -> x
597    /// - x || '' -> x
598    fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
599        let left = self.simplify_once(op.left);
600        let right = self.simplify_once(op.right);
601
602        // Fold two string literals: 'a' || 'b' -> 'ab'
603        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
604            return Expression::Literal(Box::new(Literal::String(format!("{}{}", a, b))));
605        }
606
607        // '' || x -> x
608        if let Some(s) = get_string(&left) {
609            if s.is_empty() {
610                return right;
611            }
612        }
613
614        // x || '' -> x
615        if let Some(s) = get_string(&right) {
616            if s.is_empty() {
617                return left;
618            }
619        }
620
621        // NULL || x -> NULL, x || NULL -> NULL (SQL string concat semantics)
622        if is_null(&left) || is_null(&right) {
623            return null();
624        }
625
626        Expression::Concat(Box::new(BinaryOp::new(left, right)))
627    }
628
629    /// Simplify CONCAT_WS function
630    ///
631    /// CONCAT_WS(sep, a, b, c) -> concatenates with separator, skipping NULLs
632    /// - CONCAT_WS(',', 'a', 'b') -> 'a,b' (when all are literals)
633    /// - CONCAT_WS(',', 'a', NULL, 'b') -> 'a,b' (NULLs are skipped)
634    /// - CONCAT_WS(NULL, ...) -> NULL
635    fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
636        let separator = self.simplify_once(concat_ws.separator);
637
638        // If separator is NULL, result is NULL
639        if is_null(&separator) {
640            return null();
641        }
642
643        let expressions: Vec<Expression> = concat_ws
644            .expressions
645            .into_iter()
646            .map(|e| self.simplify_once(e))
647            .filter(|e| !is_null(e)) // Skip NULL values
648            .collect();
649
650        // If no expressions remain, return empty string
651        if expressions.is_empty() {
652            return Expression::Literal(Box::new(Literal::String(String::new())));
653        }
654
655        // Try to fold if all are string literals
656        if let Some(sep) = get_string(&separator) {
657            let all_strings: Option<Vec<String>> =
658                expressions.iter().map(|e| get_string(e)).collect();
659
660            if let Some(strings) = all_strings {
661                return Expression::Literal(Box::new(Literal::String(strings.join(&sep))));
662            }
663        }
664
665        // Return simplified CONCAT_WS
666        Expression::ConcatWs(Box::new(ConcatWs {
667            separator,
668            expressions,
669        }))
670    }
671
672    /// Simplify parentheses
673    ///
674    /// Remove unnecessary parentheses:
675    /// - (x) -> x when x is a literal, column, or already parenthesized
676    /// - ((x)) -> (x) -> x (recursive simplification)
677    fn simplify_paren(&mut self, paren: Paren) -> Expression {
678        let inner = self.simplify_once(paren.this);
679
680        // If inner is a literal, column, boolean, null, or already parenthesized,
681        // we can remove the parentheses
682        match &inner {
683            Expression::Literal(_)
684            | Expression::Boolean(_)
685            | Expression::Null(_)
686            | Expression::Column(_)
687            | Expression::Paren(_) => inner,
688            // For other expressions, keep the parentheses
689            _ => Expression::Paren(Box::new(Paren {
690                this: inner,
691                trailing_comments: paren.trailing_comments,
692            })),
693        }
694    }
695
696    /// Simplify DATE_TRUNC and TIMESTAMP_TRUNC
697    ///
698    /// Currently just simplifies children and passes through.
699    /// Future: could fold DATE_TRUNC('day', '2024-01-15') -> '2024-01-15'
700    fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
701        let inner = self.simplify_once(dt.this);
702
703        // For now, just return with simplified inner expression
704        // A more advanced implementation would fold constant date/timestamps
705        Expression::DateTrunc(Box::new(DateTruncFunc {
706            this: inner,
707            unit: dt.unit,
708        }))
709    }
710
711    /// Simplify equality with arithmetic (solve simple equations)
712    ///
713    /// - x + 1 = 3 -> x = 2
714    /// - x - 1 = 3 -> x = 4
715    /// - x * 2 = 6 -> x = 3 (only when divisible)
716    /// - 1 + x = 3 -> x = 2 (commutative)
717    fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
718        // Only works when right side is a constant
719        let right_val = get_number(&right)?;
720
721        // Check if left side is arithmetic with one constant
722        match left {
723            Expression::Add(ref op) => {
724                // x + c = r -> x = r - c
725                if let Some(c) = get_number(&op.right) {
726                    let new_right =
727                        Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
728                    return Some(Expression::Eq(Box::new(BinaryOp::new(
729                        op.left.clone(),
730                        new_right,
731                    ))));
732                }
733                // c + x = r -> x = r - c
734                if let Some(c) = get_number(&op.left) {
735                    let new_right =
736                        Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
737                    return Some(Expression::Eq(Box::new(BinaryOp::new(
738                        op.right.clone(),
739                        new_right,
740                    ))));
741                }
742            }
743            Expression::Sub(ref op) => {
744                // x - c = r -> x = r + c
745                if let Some(c) = get_number(&op.right) {
746                    let new_right =
747                        Expression::Literal(Box::new(Literal::Number((right_val + c).to_string())));
748                    return Some(Expression::Eq(Box::new(BinaryOp::new(
749                        op.left.clone(),
750                        new_right,
751                    ))));
752                }
753                // c - x = r -> x = c - r
754                if let Some(c) = get_number(&op.left) {
755                    let new_right =
756                        Expression::Literal(Box::new(Literal::Number((c - right_val).to_string())));
757                    return Some(Expression::Eq(Box::new(BinaryOp::new(
758                        op.right.clone(),
759                        new_right,
760                    ))));
761                }
762            }
763            Expression::Mul(ref op) => {
764                // x * c = r -> x = r / c (only for non-zero c and when divisible)
765                if let Some(c) = get_number(&op.right) {
766                    if c != 0.0 && right_val % c == 0.0 {
767                        let new_right = Expression::Literal(Box::new(Literal::Number(
768                            (right_val / c).to_string(),
769                        )));
770                        return Some(Expression::Eq(Box::new(BinaryOp::new(
771                            op.left.clone(),
772                            new_right,
773                        ))));
774                    }
775                }
776                // c * x = r -> x = r / c
777                if let Some(c) = get_number(&op.left) {
778                    if c != 0.0 && right_val % c == 0.0 {
779                        let new_right = Expression::Literal(Box::new(Literal::Number(
780                            (right_val / c).to_string(),
781                        )));
782                        return Some(Expression::Eq(Box::new(BinaryOp::new(
783                            op.right.clone(),
784                            new_right,
785                        ))));
786                    }
787                }
788            }
789            _ => {}
790        }
791
792        None
793    }
794
795    /// Recursively simplify children of an expression
796    fn simplify_children(&mut self, expr: Expression) -> Expression {
797        // For expressions we don't have specific simplification rules for,
798        // we still want to simplify their children
799        match expr {
800            Expression::Alias(mut alias) => {
801                alias.this = self.simplify_once(alias.this);
802                Expression::Alias(alias)
803            }
804            Expression::Between(mut between) => {
805                between.this = self.simplify_once(between.this);
806                between.low = self.simplify_once(between.low);
807                between.high = self.simplify_once(between.high);
808                Expression::Between(between)
809            }
810            Expression::In(mut in_expr) => {
811                in_expr.this = self.simplify_once(in_expr.this);
812                in_expr.expressions = in_expr
813                    .expressions
814                    .into_iter()
815                    .map(|e| self.simplify_once(e))
816                    .collect();
817                Expression::In(in_expr)
818            }
819            Expression::Function(mut func) => {
820                func.args = func
821                    .args
822                    .into_iter()
823                    .map(|e| self.simplify_once(e))
824                    .collect();
825                Expression::Function(func)
826            }
827            // For other expressions, return as-is for now
828            other => other,
829        }
830    }
831}
832
833/// Check if expression equals 1
834fn is_one(expr: &Expression) -> bool {
835    match expr {
836        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
837            let Literal::Number(n) = lit.as_ref() else {
838                unreachable!()
839            };
840            if let Ok(num) = n.parse::<f64>() {
841                num == 1.0
842            } else {
843                false
844            }
845        }
846        _ => false,
847    }
848}
849
850/// Get numeric value from expression if it's a number literal
851fn get_number(expr: &Expression) -> Option<f64> {
852    match expr {
853        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
854            let Literal::Number(n) = lit.as_ref() else {
855                unreachable!()
856            };
857            n.parse().ok()
858        }
859        _ => None,
860    }
861}
862
863/// Get string value from expression if it's a string literal
864fn get_string(expr: &Expression) -> Option<String> {
865    match expr {
866        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::String(_)) => {
867            let Literal::String(s) = lit.as_ref() else {
868                unreachable!()
869            };
870            Some(s.clone())
871        }
872        _ => None,
873    }
874}
875
876/// Check if two expressions are structurally equal
877/// This is a simplified comparison - a full implementation would need deep comparison
878fn expressions_equal(a: &Expression, b: &Expression) -> bool {
879    // For now, use Debug representation for comparison
880    // A proper implementation would do structural comparison
881    format!("{:?}", a) == format!("{:?}", b)
882}
883
884/// Flatten nested AND expressions into a list of operands
885/// e.g., (A AND (B AND C)) -> [A, B, C]
886fn flatten_and(expr: &Expression) -> Vec<Expression> {
887    match expr {
888        Expression::And(op) => {
889            let mut result = flatten_and(&op.left);
890            result.extend(flatten_and(&op.right));
891            result
892        }
893        other => vec![other.clone()],
894    }
895}
896
897/// Flatten nested OR expressions into a list of operands
898/// e.g., (A OR (B OR C)) -> [A, B, C]
899fn flatten_or(expr: &Expression) -> Vec<Expression> {
900    match expr {
901        Expression::Or(op) => {
902            let mut result = flatten_or(&op.left);
903            result.extend(flatten_or(&op.right));
904            result
905        }
906        other => vec![other.clone()],
907    }
908}
909
910/// Rebuild an AND expression from a list of operands
911fn rebuild_and(operands: Vec<Expression>) -> Expression {
912    if operands.is_empty() {
913        return bool_true(); // Empty AND is TRUE
914    }
915    let mut result = operands.into_iter();
916    let first = result.next().unwrap();
917    result.fold(first, |acc, op| {
918        Expression::And(Box::new(BinaryOp::new(acc, op)))
919    })
920}
921
922/// Rebuild an OR expression from a list of operands
923fn rebuild_or(operands: Vec<Expression>) -> Expression {
924    if operands.is_empty() {
925        return bool_false(); // Empty OR is FALSE
926    }
927    let mut result = operands.into_iter();
928    let first = result.next().unwrap();
929    result.fold(first, |acc, op| {
930        Expression::Or(Box::new(BinaryOp::new(acc, op)))
931    })
932}
933
934/// Get the inner expression of a NOT, if it is one
935fn get_not_inner(expr: &Expression) -> Option<&Expression> {
936    match expr {
937        Expression::Not(op) => Some(&op.this),
938        _ => None,
939    }
940}
941
942/// Apply Boolean absorption and elimination rules to an AND expression
943///
944/// Absorption:
945///   A AND (A OR B) -> A
946///   A AND (NOT A OR B) -> A AND B
947///
948/// Elimination:
949///   (A OR B) AND (A OR NOT B) -> A
950pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
951    // Flatten both sides
952    let left_ops = flatten_and(&left);
953    let right_ops = flatten_and(&right);
954    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
955
956    // Build a set of string representations for quick lookup
957    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
958
959    let mut result_ops: Vec<Expression> = Vec::new();
960    let mut absorbed = std::collections::HashSet::new();
961
962    for (i, op) in all_ops.iter().enumerate() {
963        let op_str = gen(op);
964
965        // Skip if already absorbed
966        if absorbed.contains(&op_str) {
967            continue;
968        }
969
970        // Check if this is an OR expression (potential absorption target)
971        if let Expression::Or(_) = op {
972            let or_operands = flatten_or(op);
973
974            // Absorption: A AND (A OR B) -> A
975            // Check if any OR operand is already in our AND operands
976            let absorbed_by_existing = or_operands.iter().any(|or_op| {
977                let or_op_str = gen(or_op);
978                // Check if this OR operand exists in other AND operands (not this OR itself)
979                all_ops
980                    .iter()
981                    .enumerate()
982                    .any(|(j, other)| i != j && gen(other) == or_op_str)
983            });
984
985            if absorbed_by_existing {
986                // This OR is absorbed, skip it
987                absorbed.insert(op_str);
988                continue;
989            }
990
991            // Absorption with complement: A AND (NOT A OR B) -> A AND B
992            // Check if any OR operand's complement is in our AND operands
993            let mut remaining_or_ops: Vec<Expression> = Vec::new();
994            let mut had_complement_absorption = false;
995
996            for or_op in or_operands {
997                let complement_str = if let Some(inner) = get_not_inner(&or_op) {
998                    // or_op is NOT X, complement is X
999                    gen(inner)
1000                } else {
1001                    // or_op is X, complement is NOT X
1002                    format!("NOT {}", gen(&or_op))
1003                };
1004
1005                // Check if complement exists in our AND operands
1006                let has_complement = all_ops
1007                    .iter()
1008                    .enumerate()
1009                    .any(|(j, other)| i != j && gen(other) == complement_str)
1010                    || op_strings.contains(&complement_str);
1011
1012                if has_complement {
1013                    // This OR operand's complement exists, so this term becomes TRUE in AND context
1014                    // NOT A OR B, where A exists, becomes TRUE OR B (when A is true) or B (when A is false)
1015                    // Actually: A AND (NOT A OR B) -> A AND B, so we drop NOT A from the OR
1016                    had_complement_absorption = true;
1017                    // Drop this operand from OR
1018                } else {
1019                    remaining_or_ops.push(or_op);
1020                }
1021            }
1022
1023            if had_complement_absorption {
1024                if remaining_or_ops.is_empty() {
1025                    // All OR operands were absorbed, the OR becomes TRUE
1026                    // A AND TRUE -> A, so we just skip adding this
1027                    absorbed.insert(op_str);
1028                    continue;
1029                } else if remaining_or_ops.len() == 1 {
1030                    // Single remaining operand
1031                    result_ops.push(remaining_or_ops.into_iter().next().unwrap());
1032                    absorbed.insert(op_str);
1033                    continue;
1034                } else {
1035                    // Rebuild the OR with remaining operands
1036                    result_ops.push(rebuild_or(remaining_or_ops));
1037                    absorbed.insert(op_str);
1038                    continue;
1039                }
1040            }
1041        }
1042
1043        result_ops.push(op.clone());
1044    }
1045
1046    // Deduplicate
1047    let mut seen = std::collections::HashSet::new();
1048    result_ops.retain(|op| seen.insert(gen(op)));
1049
1050    if result_ops.is_empty() {
1051        bool_true()
1052    } else {
1053        rebuild_and(result_ops)
1054    }
1055}
1056
1057/// Apply Boolean absorption and elimination rules to an OR expression
1058///
1059/// Absorption:
1060///   A OR (A AND B) -> A
1061///   A OR (NOT A AND B) -> A OR B
1062///
1063/// Elimination:
1064///   (A AND B) OR (A AND NOT B) -> A
1065pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1066    // Flatten both sides
1067    let left_ops = flatten_or(&left);
1068    let right_ops = flatten_or(&right);
1069    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1070
1071    // Build a set of string representations for quick lookup
1072    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1073
1074    let mut result_ops: Vec<Expression> = Vec::new();
1075    let mut absorbed = std::collections::HashSet::new();
1076
1077    for (i, op) in all_ops.iter().enumerate() {
1078        let op_str = gen(op);
1079
1080        // Skip if already absorbed
1081        if absorbed.contains(&op_str) {
1082            continue;
1083        }
1084
1085        // Check if this is an AND expression (potential absorption target)
1086        if let Expression::And(_) = op {
1087            let and_operands = flatten_and(op);
1088
1089            // Absorption: A OR (A AND B) -> A
1090            // Check if any AND operand is already in our OR operands
1091            let absorbed_by_existing = and_operands.iter().any(|and_op| {
1092                let and_op_str = gen(and_op);
1093                // Check if this AND operand exists in other OR operands (not this AND itself)
1094                all_ops
1095                    .iter()
1096                    .enumerate()
1097                    .any(|(j, other)| i != j && gen(other) == and_op_str)
1098            });
1099
1100            if absorbed_by_existing {
1101                // This AND is absorbed, skip it
1102                absorbed.insert(op_str);
1103                continue;
1104            }
1105
1106            // Absorption with complement: A OR (NOT A AND B) -> A OR B
1107            // Check if any AND operand's complement is in our OR operands
1108            let mut remaining_and_ops: Vec<Expression> = Vec::new();
1109            let mut had_complement_absorption = false;
1110
1111            for and_op in and_operands {
1112                let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1113                    // and_op is NOT X, complement is X
1114                    gen(inner)
1115                } else {
1116                    // and_op is X, complement is NOT X
1117                    format!("NOT {}", gen(&and_op))
1118                };
1119
1120                // Check if complement exists in our OR operands
1121                let has_complement = all_ops
1122                    .iter()
1123                    .enumerate()
1124                    .any(|(j, other)| i != j && gen(other) == complement_str)
1125                    || op_strings.contains(&complement_str);
1126
1127                if has_complement {
1128                    // This AND operand's complement exists, so this term becomes FALSE in OR context
1129                    // A OR (NOT A AND B) -> A OR B, so we drop NOT A from the AND
1130                    had_complement_absorption = true;
1131                    // Drop this operand from AND
1132                } else {
1133                    remaining_and_ops.push(and_op);
1134                }
1135            }
1136
1137            if had_complement_absorption {
1138                if remaining_and_ops.is_empty() {
1139                    // All AND operands were absorbed, the AND becomes FALSE
1140                    // A OR FALSE -> A, so we just skip adding this
1141                    absorbed.insert(op_str);
1142                    continue;
1143                } else if remaining_and_ops.len() == 1 {
1144                    // Single remaining operand
1145                    result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1146                    absorbed.insert(op_str);
1147                    continue;
1148                } else {
1149                    // Rebuild the AND with remaining operands
1150                    result_ops.push(rebuild_and(remaining_and_ops));
1151                    absorbed.insert(op_str);
1152                    continue;
1153                }
1154            }
1155        }
1156
1157        result_ops.push(op.clone());
1158    }
1159
1160    // Deduplicate
1161    let mut seen = std::collections::HashSet::new();
1162    result_ops.retain(|op| seen.insert(gen(op)));
1163
1164    if result_ops.is_empty() {
1165        bool_false()
1166    } else {
1167        rebuild_or(result_ops)
1168    }
1169}
1170
1171/// Generate a simple string representation of an expression for sorting/deduping
1172pub fn gen(expr: &Expression) -> String {
1173    match expr {
1174        Expression::Literal(lit) => match lit.as_ref() {
1175            Literal::String(s) => format!("'{}'", s),
1176            Literal::Number(n) => n.clone(),
1177            _ => format!("{:?}", lit),
1178        },
1179        Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1180        Expression::Null(_) => "NULL".to_string(),
1181        Expression::Column(col) => {
1182            if let Some(ref table) = col.table {
1183                format!("{}.{}", table.name, col.name.name)
1184            } else {
1185                col.name.name.clone()
1186            }
1187        }
1188        Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1189        Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1190        Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1191        Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1192        Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1193        Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1194        Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1195        Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1196        Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1197        Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1198        Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1199        Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1200        Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1201        Expression::Function(f) => {
1202            let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1203            format!("{}({})", f.name.to_uppercase(), args.join(", "))
1204        }
1205        _ => format!("{:?}", expr),
1206    }
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211    use super::*;
1212
1213    fn make_int(val: i64) -> Expression {
1214        Expression::Literal(Box::new(Literal::Number(val.to_string())))
1215    }
1216
1217    fn make_string(val: &str) -> Expression {
1218        Expression::Literal(Box::new(Literal::String(val.to_string())))
1219    }
1220
1221    fn make_bool(val: bool) -> Expression {
1222        Expression::Boolean(BooleanLiteral { value: val })
1223    }
1224
1225    fn make_column(name: &str) -> Expression {
1226        use crate::expressions::{Column, Identifier};
1227        Expression::boxed_column(Column {
1228            name: Identifier::new(name),
1229            table: None,
1230            join_mark: false,
1231            trailing_comments: vec![],
1232            span: None,
1233            inferred_type: None,
1234        })
1235    }
1236
1237    #[test]
1238    fn test_always_true_false() {
1239        assert!(always_true(&make_bool(true)));
1240        assert!(!always_true(&make_bool(false)));
1241        assert!(always_true(&make_int(1)));
1242        assert!(!always_true(&make_int(0)));
1243
1244        assert!(always_false(&make_bool(false)));
1245        assert!(!always_false(&make_bool(true)));
1246        assert!(always_false(&null()));
1247        assert!(always_false(&make_int(0)));
1248    }
1249
1250    #[test]
1251    fn test_simplify_and_with_true() {
1252        let mut simplifier = Simplifier::new(None);
1253
1254        // TRUE AND TRUE -> TRUE
1255        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1256        let result = simplifier.simplify(expr);
1257        assert!(always_true(&result));
1258
1259        // TRUE AND FALSE -> FALSE
1260        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1261        let result = simplifier.simplify(expr);
1262        assert!(always_false(&result));
1263
1264        // TRUE AND x -> x
1265        let x = make_int(42);
1266        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1267        let result = simplifier.simplify(expr);
1268        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1269    }
1270
1271    #[test]
1272    fn test_simplify_or_with_false() {
1273        let mut simplifier = Simplifier::new(None);
1274
1275        // FALSE OR FALSE -> FALSE
1276        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1277        let result = simplifier.simplify(expr);
1278        assert!(always_false(&result));
1279
1280        // FALSE OR TRUE -> TRUE
1281        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1282        let result = simplifier.simplify(expr);
1283        assert!(always_true(&result));
1284
1285        // FALSE OR x -> x
1286        let x = make_int(42);
1287        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1288        let result = simplifier.simplify(expr);
1289        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1290    }
1291
1292    #[test]
1293    fn test_simplify_not() {
1294        let mut simplifier = Simplifier::new(None);
1295
1296        // NOT TRUE -> FALSE
1297        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1298        let result = simplifier.simplify(expr);
1299        assert!(is_false(&result));
1300
1301        // NOT FALSE -> TRUE
1302        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1303        let result = simplifier.simplify(expr);
1304        assert!(always_true(&result));
1305
1306        // NOT NOT x -> x
1307        let x = make_int(42);
1308        let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1309        let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1310        let result = simplifier.simplify(expr);
1311        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1312    }
1313
1314    #[test]
1315    fn test_simplify_demorgan_comparison() {
1316        let mut simplifier = Simplifier::new(None);
1317
1318        // NOT (a = b) -> a != b (using columns to avoid constant folding)
1319        let a = make_column("a");
1320        let b = make_column("b");
1321        let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1322        let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1323        let result = simplifier.simplify(expr);
1324        assert!(matches!(result, Expression::Neq(_)));
1325
1326        // NOT (a > b) -> a <= b
1327        let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1328        let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1329        let result = simplifier.simplify(expr);
1330        assert!(matches!(result, Expression::Lte(_)));
1331    }
1332
1333    #[test]
1334    fn test_constant_folding_add() {
1335        let mut simplifier = Simplifier::new(None);
1336
1337        // 1 + 2 -> 3
1338        let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1339        let result = simplifier.simplify(expr);
1340        assert_eq!(get_number(&result), Some(3.0));
1341
1342        // x + 0 -> x
1343        let x = make_int(42);
1344        let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1345        let result = simplifier.simplify(expr);
1346        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1347    }
1348
1349    #[test]
1350    fn test_constant_folding_mul() {
1351        let mut simplifier = Simplifier::new(None);
1352
1353        // 3 * 4 -> 12
1354        let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1355        let result = simplifier.simplify(expr);
1356        assert_eq!(get_number(&result), Some(12.0));
1357
1358        // x * 0 -> 0
1359        let x = make_int(42);
1360        let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1361        let result = simplifier.simplify(expr);
1362        assert_eq!(get_number(&result), Some(0.0));
1363
1364        // x * 1 -> x
1365        let x = make_int(42);
1366        let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1367        let result = simplifier.simplify(expr);
1368        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1369    }
1370
1371    #[test]
1372    fn test_constant_folding_comparison() {
1373        let mut simplifier = Simplifier::new(None);
1374
1375        // 1 = 1 -> TRUE
1376        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1377        let result = simplifier.simplify(expr);
1378        assert!(always_true(&result));
1379
1380        // 1 = 2 -> FALSE
1381        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1382        let result = simplifier.simplify(expr);
1383        assert!(is_false(&result));
1384
1385        // 3 > 2 -> TRUE
1386        let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1387        let result = simplifier.simplify(expr);
1388        assert!(always_true(&result));
1389
1390        // 'a' = 'a' -> TRUE
1391        let expr = Expression::Eq(Box::new(BinaryOp::new(
1392            make_string("abc"),
1393            make_string("abc"),
1394        )));
1395        let result = simplifier.simplify(expr);
1396        assert!(always_true(&result));
1397    }
1398
1399    #[test]
1400    fn test_simplify_negation() {
1401        let mut simplifier = Simplifier::new(None);
1402
1403        // -(-5) -> 5
1404        let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1405        let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1406        let result = simplifier.simplify(expr);
1407        assert_eq!(get_number(&result), Some(5.0));
1408
1409        // -(3) -> -3
1410        let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1411        let result = simplifier.simplify(expr);
1412        assert_eq!(get_number(&result), Some(-3.0));
1413    }
1414
1415    #[test]
1416    fn test_gen_simple() {
1417        assert_eq!(gen(&make_int(42)), "42");
1418        assert_eq!(gen(&make_string("hello")), "'hello'");
1419        assert_eq!(gen(&make_bool(true)), "TRUE");
1420        assert_eq!(gen(&make_bool(false)), "FALSE");
1421        assert_eq!(gen(&null()), "NULL");
1422    }
1423
1424    #[test]
1425    fn test_gen_operations() {
1426        let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1427        assert_eq!(gen(&add), "1 + 2");
1428
1429        let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1430        assert_eq!(gen(&and), "(TRUE AND FALSE)");
1431    }
1432
1433    #[test]
1434    fn test_complement_elimination() {
1435        let mut simplifier = Simplifier::new(None);
1436
1437        // x AND NOT x -> FALSE
1438        let x = make_int(42);
1439        let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1440        let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1441        let result = simplifier.simplify(expr);
1442        assert!(is_false(&result));
1443    }
1444
1445    #[test]
1446    fn test_idempotent() {
1447        let mut simplifier = Simplifier::new(None);
1448
1449        // x AND x -> x
1450        let x = make_int(42);
1451        let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1452        let result = simplifier.simplify(expr);
1453        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1454
1455        // x OR x -> x
1456        let x = make_int(42);
1457        let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1458        let result = simplifier.simplify(expr);
1459        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1460    }
1461
1462    #[test]
1463    fn test_absorption_and() {
1464        let mut simplifier = Simplifier::new(None);
1465
1466        // A AND (A OR B) -> A
1467        let a = make_column("a");
1468        let b = make_column("b");
1469        let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1470        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1471        let result = simplifier.simplify(expr);
1472        // Result should be just A
1473        assert_eq!(gen(&result), gen(&a));
1474    }
1475
1476    #[test]
1477    fn test_absorption_or() {
1478        let mut simplifier = Simplifier::new(None);
1479
1480        // A OR (A AND B) -> A
1481        let a = make_column("a");
1482        let b = make_column("b");
1483        let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1484        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1485        let result = simplifier.simplify(expr);
1486        // Result should be just A
1487        assert_eq!(gen(&result), gen(&a));
1488    }
1489
1490    #[test]
1491    fn test_absorption_with_complement_and() {
1492        let mut simplifier = Simplifier::new(None);
1493
1494        // A AND (NOT A OR B) -> A AND B
1495        let a = make_column("a");
1496        let b = make_column("b");
1497        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1498        let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1499        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1500        let result = simplifier.simplify(expr);
1501        // Result should be A AND B
1502        let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1503        assert_eq!(gen(&result), gen(&expected));
1504    }
1505
1506    #[test]
1507    fn test_absorption_with_complement_or() {
1508        let mut simplifier = Simplifier::new(None);
1509
1510        // A OR (NOT A AND B) -> A OR B
1511        let a = make_column("a");
1512        let b = make_column("b");
1513        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1514        let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1515        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1516        let result = simplifier.simplify(expr);
1517        // Result should be A OR B
1518        let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1519        assert_eq!(gen(&result), gen(&expected));
1520    }
1521
1522    #[test]
1523    fn test_flatten_and() {
1524        // (A AND (B AND C)) should flatten to [A, B, C]
1525        let a = make_column("a");
1526        let b = make_column("b");
1527        let c = make_column("c");
1528        let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1529        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1530        let flattened = flatten_and(&expr);
1531        assert_eq!(flattened.len(), 3);
1532        assert_eq!(gen(&flattened[0]), "a");
1533        assert_eq!(gen(&flattened[1]), "b");
1534        assert_eq!(gen(&flattened[2]), "c");
1535    }
1536
1537    #[test]
1538    fn test_flatten_or() {
1539        // (A OR (B OR C)) should flatten to [A, B, C]
1540        let a = make_column("a");
1541        let b = make_column("b");
1542        let c = make_column("c");
1543        let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1544        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1545        let flattened = flatten_or(&expr);
1546        assert_eq!(flattened.len(), 3);
1547        assert_eq!(gen(&flattened[0]), "a");
1548        assert_eq!(gen(&flattened[1]), "b");
1549        assert_eq!(gen(&flattened[2]), "c");
1550    }
1551
1552    #[test]
1553    fn test_simplify_concat() {
1554        let mut simplifier = Simplifier::new(None);
1555
1556        // 'a' || 'b' -> 'ab'
1557        let expr = Expression::Concat(Box::new(BinaryOp::new(
1558            make_string("hello"),
1559            make_string("world"),
1560        )));
1561        let result = simplifier.simplify(expr);
1562        assert_eq!(get_string(&result), Some("helloworld".to_string()));
1563
1564        // '' || x -> x
1565        let x = make_string("test");
1566        let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1567        let result = simplifier.simplify(expr);
1568        assert_eq!(get_string(&result), Some("test".to_string()));
1569
1570        // x || '' -> x
1571        let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1572        let result = simplifier.simplify(expr);
1573        assert_eq!(get_string(&result), Some("test".to_string()));
1574
1575        // NULL || x -> NULL
1576        let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1577        let result = simplifier.simplify(expr);
1578        assert!(is_null(&result));
1579    }
1580
1581    #[test]
1582    fn test_simplify_concat_ws() {
1583        let mut simplifier = Simplifier::new(None);
1584
1585        // CONCAT_WS(',', 'a', 'b', 'c') -> 'a,b,c'
1586        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1587            separator: make_string(","),
1588            expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1589        }));
1590        let result = simplifier.simplify(expr);
1591        assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1592
1593        // CONCAT_WS with NULL separator -> NULL
1594        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1595            separator: null(),
1596            expressions: vec![make_string("a"), make_string("b")],
1597        }));
1598        let result = simplifier.simplify(expr);
1599        assert!(is_null(&result));
1600
1601        // CONCAT_WS with empty expressions -> ''
1602        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1603            separator: make_string(","),
1604            expressions: vec![],
1605        }));
1606        let result = simplifier.simplify(expr);
1607        assert_eq!(get_string(&result), Some("".to_string()));
1608
1609        // CONCAT_WS skips NULLs
1610        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1611            separator: make_string("-"),
1612            expressions: vec![make_string("a"), null(), make_string("b")],
1613        }));
1614        let result = simplifier.simplify(expr);
1615        assert_eq!(get_string(&result), Some("a-b".to_string()));
1616    }
1617
1618    #[test]
1619    fn test_simplify_paren() {
1620        let mut simplifier = Simplifier::new(None);
1621
1622        // (42) -> 42
1623        let expr = Expression::Paren(Box::new(Paren {
1624            this: make_int(42),
1625            trailing_comments: vec![],
1626        }));
1627        let result = simplifier.simplify(expr);
1628        assert_eq!(get_number(&result), Some(42.0));
1629
1630        // (TRUE) -> TRUE
1631        let expr = Expression::Paren(Box::new(Paren {
1632            this: make_bool(true),
1633            trailing_comments: vec![],
1634        }));
1635        let result = simplifier.simplify(expr);
1636        assert!(is_boolean_true(&result));
1637
1638        // (NULL) -> NULL
1639        let expr = Expression::Paren(Box::new(Paren {
1640            this: null(),
1641            trailing_comments: vec![],
1642        }));
1643        let result = simplifier.simplify(expr);
1644        assert!(is_null(&result));
1645
1646        // ((x)) -> x
1647        let inner_paren = Expression::Paren(Box::new(Paren {
1648            this: make_int(10),
1649            trailing_comments: vec![],
1650        }));
1651        let expr = Expression::Paren(Box::new(Paren {
1652            this: inner_paren,
1653            trailing_comments: vec![],
1654        }));
1655        let result = simplifier.simplify(expr);
1656        assert_eq!(get_number(&result), Some(10.0));
1657    }
1658
1659    #[test]
1660    fn test_simplify_equality_solve() {
1661        let mut simplifier = Simplifier::new(None);
1662
1663        // x + 1 = 3 -> x = 2
1664        let x = make_column("x");
1665        let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1666        let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1667        let result = simplifier.simplify(expr);
1668        // Result should be x = 2
1669        if let Expression::Eq(op) = &result {
1670            assert_eq!(gen(&op.left), "x");
1671            assert_eq!(get_number(&op.right), Some(2.0));
1672        } else {
1673            panic!("Expected Eq expression");
1674        }
1675
1676        // x - 1 = 3 -> x = 4
1677        let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1678        let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1679        let result = simplifier.simplify(expr);
1680        if let Expression::Eq(op) = &result {
1681            assert_eq!(gen(&op.left), "x");
1682            assert_eq!(get_number(&op.right), Some(4.0));
1683        } else {
1684            panic!("Expected Eq expression");
1685        }
1686
1687        // x * 2 = 6 -> x = 3
1688        let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1689        let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1690        let result = simplifier.simplify(expr);
1691        if let Expression::Eq(op) = &result {
1692            assert_eq!(gen(&op.left), "x");
1693            assert_eq!(get_number(&op.right), Some(3.0));
1694        } else {
1695            panic!("Expected Eq expression");
1696        }
1697
1698        // 1 + x = 3 -> x = 2 (commutative)
1699        let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1700        let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1701        let result = simplifier.simplify(expr);
1702        if let Expression::Eq(op) = &result {
1703            assert_eq!(gen(&op.left), "x");
1704            assert_eq!(get_number(&op.right), Some(2.0));
1705        } else {
1706            panic!("Expected Eq expression");
1707        }
1708    }
1709
1710    #[test]
1711    fn test_simplify_datetrunc() {
1712        use crate::expressions::DateTimeField;
1713        let mut simplifier = Simplifier::new(None);
1714
1715        // DATE_TRUNC('day', x) with a column just passes through with simplified children
1716        let x = make_column("x");
1717        let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1718            this: x.clone(),
1719            unit: DateTimeField::Day,
1720        }));
1721        let result = simplifier.simplify(expr);
1722        if let Expression::DateTrunc(dt) = &result {
1723            assert_eq!(gen(&dt.this), "x");
1724            assert_eq!(dt.unit, DateTimeField::Day);
1725        } else {
1726            panic!("Expected DateTrunc expression");
1727        }
1728    }
1729}