Skip to main content

polyglot_sql/optimizer/
simplify.rs

1//! Expression Simplification
2//!
3//! This module provides boolean and expression simplification for SQL AST nodes.
4//! It applies various algebraic transformations to simplify expressions:
5//! - De Morgan's laws (NOT (A AND B) -> NOT A OR NOT B)
6//! - Constant folding (1 + 2 -> 3)
7//! - Boolean absorption (A AND (A OR B) -> A)
8//! - Complement removal (A AND NOT A -> FALSE)
9//! - Connector flattening (A AND (B AND C) -> A AND B AND C)
10//!
11//! Based on SQLGlot's optimizer/simplify.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{
15    BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16    UnaryOp,
17};
18
19/// Main entry point for expression simplification
20pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21    let mut simplifier = Simplifier::new(dialect);
22    simplifier.simplify(expression)
23}
24
25/// Check if expression is always true
26pub fn always_true(expr: &Expression) -> bool {
27    match expr {
28        Expression::Boolean(b) => b.value,
29        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
30            let Literal::Number(n) = lit.as_ref() else { unreachable!() };
31            // Non-zero numbers are truthy
32            if let Ok(num) = n.parse::<f64>() {
33                num != 0.0
34            } else {
35                false
36            }
37        }
38        _ => false,
39    }
40}
41
42/// Check if expression is a boolean TRUE literal (not just truthy)
43pub fn is_boolean_true(expr: &Expression) -> bool {
44    matches!(expr, Expression::Boolean(b) if b.value)
45}
46
47/// Check if expression is a boolean FALSE literal (not just falsy)
48pub fn is_boolean_false(expr: &Expression) -> bool {
49    matches!(expr, Expression::Boolean(b) if !b.value)
50}
51
52/// Check if expression is always false
53pub fn always_false(expr: &Expression) -> bool {
54    is_false(expr) || is_null(expr) || is_zero(expr)
55}
56
57/// Check if expression is boolean FALSE
58pub fn is_false(expr: &Expression) -> bool {
59    matches!(expr, Expression::Boolean(b) if !b.value)
60}
61
62/// Check if expression is NULL
63pub fn is_null(expr: &Expression) -> bool {
64    matches!(expr, Expression::Null(_))
65}
66
67/// Check if expression is zero
68pub fn is_zero(expr: &Expression) -> bool {
69    match expr {
70        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
71            let Literal::Number(n) = lit.as_ref() else { unreachable!() };
72            if let Ok(num) = n.parse::<f64>() {
73                num == 0.0
74            } else {
75                false
76            }
77        }
78        _ => false,
79    }
80}
81
82/// Check if b is the complement of a (i.e., b = NOT a)
83pub fn is_complement(a: &Expression, b: &Expression) -> bool {
84    if let Expression::Not(not_op) = b {
85        &not_op.this == a
86    } else {
87        false
88    }
89}
90
91/// Create a TRUE boolean literal
92pub fn bool_true() -> Expression {
93    Expression::Boolean(BooleanLiteral { value: true })
94}
95
96/// Create a FALSE boolean literal
97pub fn bool_false() -> Expression {
98    Expression::Boolean(BooleanLiteral { value: false })
99}
100
101/// Create a NULL expression
102pub fn null() -> Expression {
103    Expression::Null(Null)
104}
105
106/// Evaluate a boolean comparison between two numbers
107pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
108    let result = match op {
109        "=" | "==" => a == b,
110        "!=" | "<>" => a != b,
111        ">" => a > b,
112        ">=" => a >= b,
113        "<" => a < b,
114        "<=" => a <= b,
115        _ => return None,
116    };
117    Some(if result { bool_true() } else { bool_false() })
118}
119
120/// Evaluate a boolean comparison between two strings
121pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
122    let result = match op {
123        "=" | "==" => a == b,
124        "!=" | "<>" => a != b,
125        ">" => a > b,
126        ">=" => a >= b,
127        "<" => a < b,
128        "<=" => a <= b,
129        _ => return None,
130    };
131    Some(if result { bool_true() } else { bool_false() })
132}
133
134/// Expression simplifier
135pub struct Simplifier {
136    _dialect: Option<DialectType>,
137    max_iterations: usize,
138}
139
140impl Simplifier {
141    /// Create a new simplifier
142    pub fn new(dialect: Option<DialectType>) -> Self {
143        Self {
144            _dialect: dialect,
145            max_iterations: 100,
146        }
147    }
148
149    /// Simplify an expression
150    pub fn simplify(&mut self, expression: Expression) -> Expression {
151        // Apply simplifications until no more changes (or max iterations)
152        let mut current = expression;
153        for _ in 0..self.max_iterations {
154            let simplified = self.simplify_once(current.clone());
155            if expressions_equal(&simplified, &current) {
156                return simplified;
157            }
158            current = simplified;
159        }
160        current
161    }
162
163    /// Apply one round of simplifications
164    fn simplify_once(&mut self, expression: Expression) -> Expression {
165        match expression {
166            // Binary logical operations
167            Expression::And(op) => self.simplify_and(*op),
168            Expression::Or(op) => self.simplify_or(*op),
169
170            // NOT operation - De Morgan's laws
171            Expression::Not(op) => self.simplify_not(*op),
172
173            // Arithmetic operations - constant folding
174            Expression::Add(op) => self.simplify_add(*op),
175            Expression::Sub(op) => self.simplify_sub(*op),
176            Expression::Mul(op) => self.simplify_mul(*op),
177            Expression::Div(op) => self.simplify_div(*op),
178
179            // Comparison operations
180            Expression::Eq(op) => self.simplify_comparison(*op, "="),
181            Expression::Neq(op) => self.simplify_comparison(*op, "!="),
182            Expression::Gt(op) => self.simplify_comparison(*op, ">"),
183            Expression::Gte(op) => self.simplify_comparison(*op, ">="),
184            Expression::Lt(op) => self.simplify_comparison(*op, "<"),
185            Expression::Lte(op) => self.simplify_comparison(*op, "<="),
186
187            // Negation
188            Expression::Neg(op) => self.simplify_neg(*op),
189
190            // CASE expression
191            Expression::Case(case) => self.simplify_case(*case),
192
193            // String concatenation
194            Expression::Concat(op) => self.simplify_concat(*op),
195            Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
196
197            // Parentheses - remove if unnecessary
198            Expression::Paren(paren) => self.simplify_paren(*paren),
199
200            // Date truncation
201            Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
202            Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
203
204            // Recursively simplify children for other expressions
205            other => self.simplify_children(other),
206        }
207    }
208
209    /// Simplify AND operation
210    fn simplify_and(&mut self, op: BinaryOp) -> Expression {
211        let left = self.simplify_once(op.left);
212        let right = self.simplify_once(op.right);
213
214        // FALSE AND x -> FALSE
215        // x AND FALSE -> FALSE
216        if is_boolean_false(&left) || is_boolean_false(&right) {
217            return bool_false();
218        }
219
220        // 0 AND x -> FALSE (in boolean context)
221        // x AND 0 -> FALSE
222        if is_zero(&left) || is_zero(&right) {
223            return bool_false();
224        }
225
226        // NULL AND NULL -> NULL
227        // NULL AND TRUE -> NULL
228        // TRUE AND NULL -> NULL
229        if (is_null(&left) && is_null(&right))
230            || (is_null(&left) && is_boolean_true(&right))
231            || (is_boolean_true(&left) && is_null(&right))
232        {
233            return null();
234        }
235
236        // TRUE AND x -> x (only when left is actually boolean TRUE)
237        if is_boolean_true(&left) {
238            return right;
239        }
240
241        // x AND TRUE -> x (only when right is actually boolean TRUE)
242        if is_boolean_true(&right) {
243            return left;
244        }
245
246        // A AND NOT A -> FALSE (complement elimination)
247        if is_complement(&left, &right) || is_complement(&right, &left) {
248            return bool_false();
249        }
250
251        // A AND A -> A (idempotent)
252        if expressions_equal(&left, &right) {
253            return left;
254        }
255
256        // Apply absorption rules
257        // A AND (A OR B) -> A
258        // A AND (NOT A OR B) -> A AND B
259        absorb_and_eliminate_and(left, right)
260    }
261
262    /// Simplify OR operation
263    fn simplify_or(&mut self, op: BinaryOp) -> Expression {
264        let left = self.simplify_once(op.left);
265        let right = self.simplify_once(op.right);
266
267        // TRUE OR x -> TRUE (only when left is actually boolean TRUE)
268        if is_boolean_true(&left) {
269            return bool_true();
270        }
271
272        // x OR TRUE -> TRUE (only when right is actually boolean TRUE)
273        if is_boolean_true(&right) {
274            return bool_true();
275        }
276
277        // NULL OR NULL -> NULL
278        // NULL OR FALSE -> NULL
279        // FALSE OR NULL -> NULL
280        if (is_null(&left) && is_null(&right))
281            || (is_null(&left) && is_boolean_false(&right))
282            || (is_boolean_false(&left) && is_null(&right))
283        {
284            return null();
285        }
286
287        // FALSE OR x -> x (only when left is actually boolean FALSE)
288        if is_boolean_false(&left) {
289            return right;
290        }
291
292        // x OR FALSE -> x (only when right is actually boolean FALSE)
293        if is_boolean_false(&right) {
294            return left;
295        }
296
297        // A OR A -> A (idempotent)
298        if expressions_equal(&left, &right) {
299            return left;
300        }
301
302        // Apply absorption rules
303        // A OR (A AND B) -> A
304        // A OR (NOT A AND B) -> A OR B
305        absorb_and_eliminate_or(left, right)
306    }
307
308    /// Simplify NOT operation (De Morgan's laws)
309    fn simplify_not(&mut self, op: UnaryOp) -> Expression {
310        // Check for De Morgan's laws BEFORE simplifying inner expression
311        // This prevents constant folding from eliminating the comparison operator
312        match &op.this {
313            // NOT (a = b) -> a != b
314            Expression::Eq(inner_op) => {
315                let left = self.simplify_once(inner_op.left.clone());
316                let right = self.simplify_once(inner_op.right.clone());
317                return Expression::Neq(Box::new(BinaryOp::new(left, right)));
318            }
319            // NOT (a != b) -> a = b
320            Expression::Neq(inner_op) => {
321                let left = self.simplify_once(inner_op.left.clone());
322                let right = self.simplify_once(inner_op.right.clone());
323                return Expression::Eq(Box::new(BinaryOp::new(left, right)));
324            }
325            // NOT (a > b) -> a <= b
326            Expression::Gt(inner_op) => {
327                let left = self.simplify_once(inner_op.left.clone());
328                let right = self.simplify_once(inner_op.right.clone());
329                return Expression::Lte(Box::new(BinaryOp::new(left, right)));
330            }
331            // NOT (a >= b) -> a < b
332            Expression::Gte(inner_op) => {
333                let left = self.simplify_once(inner_op.left.clone());
334                let right = self.simplify_once(inner_op.right.clone());
335                return Expression::Lt(Box::new(BinaryOp::new(left, right)));
336            }
337            // NOT (a < b) -> a >= b
338            Expression::Lt(inner_op) => {
339                let left = self.simplify_once(inner_op.left.clone());
340                let right = self.simplify_once(inner_op.right.clone());
341                return Expression::Gte(Box::new(BinaryOp::new(left, right)));
342            }
343            // NOT (a <= b) -> a > b
344            Expression::Lte(inner_op) => {
345                let left = self.simplify_once(inner_op.left.clone());
346                let right = self.simplify_once(inner_op.right.clone());
347                return Expression::Gt(Box::new(BinaryOp::new(left, right)));
348            }
349            _ => {}
350        }
351
352        // Now simplify the inner expression for other patterns
353        let inner = self.simplify_once(op.this);
354
355        // NOT NULL -> NULL (with TRUE for SQL semantics)
356        if is_null(&inner) {
357            return null();
358        }
359
360        // NOT TRUE -> FALSE (only for boolean TRUE literal)
361        if is_boolean_true(&inner) {
362            return bool_false();
363        }
364
365        // NOT FALSE -> TRUE (only for boolean FALSE literal)
366        if is_boolean_false(&inner) {
367            return bool_true();
368        }
369
370        // NOT NOT x -> x (double negation elimination)
371        if let Expression::Not(inner_not) = &inner {
372            return inner_not.this.clone();
373        }
374
375        Expression::Not(Box::new(UnaryOp {
376            this: inner,
377            inferred_type: None,
378        }))
379    }
380
381    /// Simplify addition (constant folding)
382    fn simplify_add(&mut self, op: BinaryOp) -> Expression {
383        let left = self.simplify_once(op.left);
384        let right = self.simplify_once(op.right);
385
386        // Try constant folding for numbers
387        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
388            return Expression::Literal(Box::new(Literal::Number((a + b).to_string())));
389        }
390
391        // x + 0 -> x
392        if is_zero(&right) {
393            return left;
394        }
395
396        // 0 + x -> x
397        if is_zero(&left) {
398            return right;
399        }
400
401        Expression::Add(Box::new(BinaryOp::new(left, right)))
402    }
403
404    /// Simplify subtraction (constant folding)
405    fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
406        let left = self.simplify_once(op.left);
407        let right = self.simplify_once(op.right);
408
409        // Try constant folding for numbers
410        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
411            return Expression::Literal(Box::new(Literal::Number((a - b).to_string())));
412        }
413
414        // x - 0 -> x
415        if is_zero(&right) {
416            return left;
417        }
418
419        // x - x -> 0 (only for literals/constants)
420        if expressions_equal(&left, &right) {
421            if let Expression::Literal(lit) = &left {
422                if let Literal::Number(_) = lit.as_ref() {
423                return Expression::Literal(Box::new(Literal::Number("0".to_string())));
424            }
425            }
426        }
427
428        Expression::Sub(Box::new(BinaryOp::new(left, right)))
429    }
430
431    /// Simplify multiplication (constant folding)
432    fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
433        let left = self.simplify_once(op.left);
434        let right = self.simplify_once(op.right);
435
436        // Try constant folding for numbers
437        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
438            return Expression::Literal(Box::new(Literal::Number((a * b).to_string())));
439        }
440
441        // x * 0 -> 0
442        if is_zero(&right) {
443            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
444        }
445
446        // 0 * x -> 0
447        if is_zero(&left) {
448            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
449        }
450
451        // x * 1 -> x
452        if is_one(&right) {
453            return left;
454        }
455
456        // 1 * x -> x
457        if is_one(&left) {
458            return right;
459        }
460
461        Expression::Mul(Box::new(BinaryOp::new(left, right)))
462    }
463
464    /// Simplify division (constant folding)
465    fn simplify_div(&mut self, op: BinaryOp) -> Expression {
466        let left = self.simplify_once(op.left);
467        let right = self.simplify_once(op.right);
468
469        // Try constant folding for numbers (but not integer division)
470        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
471            // Only fold if both are floats to avoid integer division issues
472            if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
473                return Expression::Literal(Box::new(Literal::Number((a / b).to_string())));
474            }
475        }
476
477        // 0 / x -> 0 (when x != 0)
478        if is_zero(&left) && !is_zero(&right) {
479            return Expression::Literal(Box::new(Literal::Number("0".to_string())));
480        }
481
482        // x / 1 -> x
483        if is_one(&right) {
484            return left;
485        }
486
487        Expression::Div(Box::new(BinaryOp::new(left, right)))
488    }
489
490    /// Simplify negation
491    fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
492        let inner = self.simplify_once(op.this);
493
494        // -(-x) -> x (double negation)
495        if let Expression::Neg(inner_neg) = inner {
496            return inner_neg.this;
497        }
498
499        // -(number) -> -number
500        if let Some(n) = get_number(&inner) {
501            return Expression::Literal(Box::new(Literal::Number((-n).to_string())));
502        }
503
504        Expression::Neg(Box::new(UnaryOp {
505            this: inner,
506            inferred_type: None,
507        }))
508    }
509
510    /// Simplify comparison operations (constant folding)
511    fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
512        let left = self.simplify_once(op.left);
513        let right = self.simplify_once(op.right);
514
515        // Try constant folding for numbers
516        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
517            if let Some(result) = eval_boolean_nums(operator, a, b) {
518                return result;
519            }
520        }
521
522        // Try constant folding for strings
523        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
524            if let Some(result) = eval_boolean_strings(operator, &a, &b) {
525                return result;
526            }
527        }
528
529        // For equality, try to solve simple equations (x + 1 = 3 -> x = 2)
530        if operator == "=" {
531            if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
532                return simplified;
533            }
534        }
535
536        // Reconstruct the comparison
537        let new_op = BinaryOp::new(left, right);
538
539        match operator {
540            "=" => Expression::Eq(Box::new(new_op)),
541            "!=" | "<>" => Expression::Neq(Box::new(new_op)),
542            ">" => Expression::Gt(Box::new(new_op)),
543            ">=" => Expression::Gte(Box::new(new_op)),
544            "<" => Expression::Lt(Box::new(new_op)),
545            "<=" => Expression::Lte(Box::new(new_op)),
546            _ => Expression::Eq(Box::new(new_op)),
547        }
548    }
549
550    /// Simplify CASE expression
551    fn simplify_case(&mut self, case: Case) -> Expression {
552        let mut new_whens = Vec::new();
553
554        for (cond, then_expr) in case.whens {
555            let simplified_cond = self.simplify_once(cond);
556
557            // If condition is always true, return the THEN expression
558            if always_true(&simplified_cond) {
559                return self.simplify_once(then_expr);
560            }
561
562            // If condition is always false, skip this WHEN clause
563            if always_false(&simplified_cond) {
564                continue;
565            }
566
567            new_whens.push((simplified_cond, self.simplify_once(then_expr)));
568        }
569
570        // If no WHEN clauses remain, return the ELSE expression (or NULL)
571        if new_whens.is_empty() {
572            return case
573                .else_
574                .map(|e| self.simplify_once(e))
575                .unwrap_or_else(null);
576        }
577
578        Expression::Case(Box::new(Case {
579            operand: case.operand.map(|e| self.simplify_once(e)),
580            whens: new_whens,
581            else_: case.else_.map(|e| self.simplify_once(e)),
582            comments: Vec::new(),
583            inferred_type: None,
584        }))
585    }
586
587    /// Simplify string concatenation (Concat is || operator)
588    ///
589    /// Folds adjacent string literals:
590    /// - 'a' || 'b' -> 'ab'
591    /// - 'a' || 'b' || 'c' -> 'abc'
592    /// - '' || x -> x
593    /// - x || '' -> x
594    fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
595        let left = self.simplify_once(op.left);
596        let right = self.simplify_once(op.right);
597
598        // Fold two string literals: 'a' || 'b' -> 'ab'
599        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
600            return Expression::Literal(Box::new(Literal::String(format!("{}{}", a, b))));
601        }
602
603        // '' || x -> x
604        if let Some(s) = get_string(&left) {
605            if s.is_empty() {
606                return right;
607            }
608        }
609
610        // x || '' -> x
611        if let Some(s) = get_string(&right) {
612            if s.is_empty() {
613                return left;
614            }
615        }
616
617        // NULL || x -> NULL, x || NULL -> NULL (SQL string concat semantics)
618        if is_null(&left) || is_null(&right) {
619            return null();
620        }
621
622        Expression::Concat(Box::new(BinaryOp::new(left, right)))
623    }
624
625    /// Simplify CONCAT_WS function
626    ///
627    /// CONCAT_WS(sep, a, b, c) -> concatenates with separator, skipping NULLs
628    /// - CONCAT_WS(',', 'a', 'b') -> 'a,b' (when all are literals)
629    /// - CONCAT_WS(',', 'a', NULL, 'b') -> 'a,b' (NULLs are skipped)
630    /// - CONCAT_WS(NULL, ...) -> NULL
631    fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
632        let separator = self.simplify_once(concat_ws.separator);
633
634        // If separator is NULL, result is NULL
635        if is_null(&separator) {
636            return null();
637        }
638
639        let expressions: Vec<Expression> = concat_ws
640            .expressions
641            .into_iter()
642            .map(|e| self.simplify_once(e))
643            .filter(|e| !is_null(e)) // Skip NULL values
644            .collect();
645
646        // If no expressions remain, return empty string
647        if expressions.is_empty() {
648            return Expression::Literal(Box::new(Literal::String(String::new())));
649        }
650
651        // Try to fold if all are string literals
652        if let Some(sep) = get_string(&separator) {
653            let all_strings: Option<Vec<String>> =
654                expressions.iter().map(|e| get_string(e)).collect();
655
656            if let Some(strings) = all_strings {
657                return Expression::Literal(Box::new(Literal::String(strings.join(&sep))));
658            }
659        }
660
661        // Return simplified CONCAT_WS
662        Expression::ConcatWs(Box::new(ConcatWs {
663            separator,
664            expressions,
665        }))
666    }
667
668    /// Simplify parentheses
669    ///
670    /// Remove unnecessary parentheses:
671    /// - (x) -> x when x is a literal, column, or already parenthesized
672    /// - ((x)) -> (x) -> x (recursive simplification)
673    fn simplify_paren(&mut self, paren: Paren) -> Expression {
674        let inner = self.simplify_once(paren.this);
675
676        // If inner is a literal, column, boolean, null, or already parenthesized,
677        // we can remove the parentheses
678        match &inner {
679            Expression::Literal(_)
680            | Expression::Boolean(_)
681            | Expression::Null(_)
682            | Expression::Column(_)
683            | Expression::Paren(_) => inner,
684            // For other expressions, keep the parentheses
685            _ => Expression::Paren(Box::new(Paren {
686                this: inner,
687                trailing_comments: paren.trailing_comments,
688            })),
689        }
690    }
691
692    /// Simplify DATE_TRUNC and TIMESTAMP_TRUNC
693    ///
694    /// Currently just simplifies children and passes through.
695    /// Future: could fold DATE_TRUNC('day', '2024-01-15') -> '2024-01-15'
696    fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
697        let inner = self.simplify_once(dt.this);
698
699        // For now, just return with simplified inner expression
700        // A more advanced implementation would fold constant date/timestamps
701        Expression::DateTrunc(Box::new(DateTruncFunc {
702            this: inner,
703            unit: dt.unit,
704        }))
705    }
706
707    /// Simplify equality with arithmetic (solve simple equations)
708    ///
709    /// - x + 1 = 3 -> x = 2
710    /// - x - 1 = 3 -> x = 4
711    /// - x * 2 = 6 -> x = 3 (only when divisible)
712    /// - 1 + x = 3 -> x = 2 (commutative)
713    fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
714        // Only works when right side is a constant
715        let right_val = get_number(&right)?;
716
717        // Check if left side is arithmetic with one constant
718        match left {
719            Expression::Add(ref op) => {
720                // x + c = r -> x = r - c
721                if let Some(c) = get_number(&op.right) {
722                    let new_right =
723                        Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
724                    return Some(Expression::Eq(Box::new(BinaryOp::new(
725                        op.left.clone(),
726                        new_right,
727                    ))));
728                }
729                // c + x = r -> x = r - c
730                if let Some(c) = get_number(&op.left) {
731                    let new_right =
732                        Expression::Literal(Box::new(Literal::Number((right_val - c).to_string())));
733                    return Some(Expression::Eq(Box::new(BinaryOp::new(
734                        op.right.clone(),
735                        new_right,
736                    ))));
737                }
738            }
739            Expression::Sub(ref op) => {
740                // x - c = r -> x = r + c
741                if let Some(c) = get_number(&op.right) {
742                    let new_right =
743                        Expression::Literal(Box::new(Literal::Number((right_val + c).to_string())));
744                    return Some(Expression::Eq(Box::new(BinaryOp::new(
745                        op.left.clone(),
746                        new_right,
747                    ))));
748                }
749                // c - x = r -> x = c - r
750                if let Some(c) = get_number(&op.left) {
751                    let new_right =
752                        Expression::Literal(Box::new(Literal::Number((c - right_val).to_string())));
753                    return Some(Expression::Eq(Box::new(BinaryOp::new(
754                        op.right.clone(),
755                        new_right,
756                    ))));
757                }
758            }
759            Expression::Mul(ref op) => {
760                // x * c = r -> x = r / c (only for non-zero c and when divisible)
761                if let Some(c) = get_number(&op.right) {
762                    if c != 0.0 && right_val % c == 0.0 {
763                        let new_right =
764                            Expression::Literal(Box::new(Literal::Number((right_val / c).to_string())));
765                        return Some(Expression::Eq(Box::new(BinaryOp::new(
766                            op.left.clone(),
767                            new_right,
768                        ))));
769                    }
770                }
771                // c * x = r -> x = r / c
772                if let Some(c) = get_number(&op.left) {
773                    if c != 0.0 && right_val % c == 0.0 {
774                        let new_right =
775                            Expression::Literal(Box::new(Literal::Number((right_val / c).to_string())));
776                        return Some(Expression::Eq(Box::new(BinaryOp::new(
777                            op.right.clone(),
778                            new_right,
779                        ))));
780                    }
781                }
782            }
783            _ => {}
784        }
785
786        None
787    }
788
789    /// Recursively simplify children of an expression
790    fn simplify_children(&mut self, expr: Expression) -> Expression {
791        // For expressions we don't have specific simplification rules for,
792        // we still want to simplify their children
793        match expr {
794            Expression::Alias(mut alias) => {
795                alias.this = self.simplify_once(alias.this);
796                Expression::Alias(alias)
797            }
798            Expression::Between(mut between) => {
799                between.this = self.simplify_once(between.this);
800                between.low = self.simplify_once(between.low);
801                between.high = self.simplify_once(between.high);
802                Expression::Between(between)
803            }
804            Expression::In(mut in_expr) => {
805                in_expr.this = self.simplify_once(in_expr.this);
806                in_expr.expressions = in_expr
807                    .expressions
808                    .into_iter()
809                    .map(|e| self.simplify_once(e))
810                    .collect();
811                Expression::In(in_expr)
812            }
813            Expression::Function(mut func) => {
814                func.args = func
815                    .args
816                    .into_iter()
817                    .map(|e| self.simplify_once(e))
818                    .collect();
819                Expression::Function(func)
820            }
821            // For other expressions, return as-is for now
822            other => other,
823        }
824    }
825}
826
827/// Check if expression equals 1
828fn is_one(expr: &Expression) -> bool {
829    match expr {
830        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => {
831            let Literal::Number(n) = lit.as_ref() else { unreachable!() };
832            if let Ok(num) = n.parse::<f64>() {
833                num == 1.0
834            } else {
835                false
836            }
837        }
838        _ => false,
839    }
840}
841
842/// Get numeric value from expression if it's a number literal
843fn get_number(expr: &Expression) -> Option<f64> {
844    match expr {
845        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Number(_)) => { let Literal::Number(n) = lit.as_ref() else { unreachable!() }; n.parse().ok() },
846        _ => None,
847    }
848}
849
850/// Get string value from expression if it's a string literal
851fn get_string(expr: &Expression) -> Option<String> {
852    match expr {
853        Expression::Literal(lit) if matches!(lit.as_ref(), Literal::String(_)) => { let Literal::String(s) = lit.as_ref() else { unreachable!() }; Some(s.clone()) },
854        _ => None,
855    }
856}
857
858/// Check if two expressions are structurally equal
859/// This is a simplified comparison - a full implementation would need deep comparison
860fn expressions_equal(a: &Expression, b: &Expression) -> bool {
861    // For now, use Debug representation for comparison
862    // A proper implementation would do structural comparison
863    format!("{:?}", a) == format!("{:?}", b)
864}
865
866/// Flatten nested AND expressions into a list of operands
867/// e.g., (A AND (B AND C)) -> [A, B, C]
868fn flatten_and(expr: &Expression) -> Vec<Expression> {
869    match expr {
870        Expression::And(op) => {
871            let mut result = flatten_and(&op.left);
872            result.extend(flatten_and(&op.right));
873            result
874        }
875        other => vec![other.clone()],
876    }
877}
878
879/// Flatten nested OR expressions into a list of operands
880/// e.g., (A OR (B OR C)) -> [A, B, C]
881fn flatten_or(expr: &Expression) -> Vec<Expression> {
882    match expr {
883        Expression::Or(op) => {
884            let mut result = flatten_or(&op.left);
885            result.extend(flatten_or(&op.right));
886            result
887        }
888        other => vec![other.clone()],
889    }
890}
891
892/// Rebuild an AND expression from a list of operands
893fn rebuild_and(operands: Vec<Expression>) -> Expression {
894    if operands.is_empty() {
895        return bool_true(); // Empty AND is TRUE
896    }
897    let mut result = operands.into_iter();
898    let first = result.next().unwrap();
899    result.fold(first, |acc, op| {
900        Expression::And(Box::new(BinaryOp::new(acc, op)))
901    })
902}
903
904/// Rebuild an OR expression from a list of operands
905fn rebuild_or(operands: Vec<Expression>) -> Expression {
906    if operands.is_empty() {
907        return bool_false(); // Empty OR is FALSE
908    }
909    let mut result = operands.into_iter();
910    let first = result.next().unwrap();
911    result.fold(first, |acc, op| {
912        Expression::Or(Box::new(BinaryOp::new(acc, op)))
913    })
914}
915
916/// Get the inner expression of a NOT, if it is one
917fn get_not_inner(expr: &Expression) -> Option<&Expression> {
918    match expr {
919        Expression::Not(op) => Some(&op.this),
920        _ => None,
921    }
922}
923
924/// Apply Boolean absorption and elimination rules to an AND expression
925///
926/// Absorption:
927///   A AND (A OR B) -> A
928///   A AND (NOT A OR B) -> A AND B
929///
930/// Elimination:
931///   (A OR B) AND (A OR NOT B) -> A
932pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
933    // Flatten both sides
934    let left_ops = flatten_and(&left);
935    let right_ops = flatten_and(&right);
936    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
937
938    // Build a set of string representations for quick lookup
939    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
940
941    let mut result_ops: Vec<Expression> = Vec::new();
942    let mut absorbed = std::collections::HashSet::new();
943
944    for (i, op) in all_ops.iter().enumerate() {
945        let op_str = gen(op);
946
947        // Skip if already absorbed
948        if absorbed.contains(&op_str) {
949            continue;
950        }
951
952        // Check if this is an OR expression (potential absorption target)
953        if let Expression::Or(_) = op {
954            let or_operands = flatten_or(op);
955
956            // Absorption: A AND (A OR B) -> A
957            // Check if any OR operand is already in our AND operands
958            let absorbed_by_existing = or_operands.iter().any(|or_op| {
959                let or_op_str = gen(or_op);
960                // Check if this OR operand exists in other AND operands (not this OR itself)
961                all_ops
962                    .iter()
963                    .enumerate()
964                    .any(|(j, other)| i != j && gen(other) == or_op_str)
965            });
966
967            if absorbed_by_existing {
968                // This OR is absorbed, skip it
969                absorbed.insert(op_str);
970                continue;
971            }
972
973            // Absorption with complement: A AND (NOT A OR B) -> A AND B
974            // Check if any OR operand's complement is in our AND operands
975            let mut remaining_or_ops: Vec<Expression> = Vec::new();
976            let mut had_complement_absorption = false;
977
978            for or_op in or_operands {
979                let complement_str = if let Some(inner) = get_not_inner(&or_op) {
980                    // or_op is NOT X, complement is X
981                    gen(inner)
982                } else {
983                    // or_op is X, complement is NOT X
984                    format!("NOT {}", gen(&or_op))
985                };
986
987                // Check if complement exists in our AND operands
988                let has_complement = all_ops
989                    .iter()
990                    .enumerate()
991                    .any(|(j, other)| i != j && gen(other) == complement_str)
992                    || op_strings.contains(&complement_str);
993
994                if has_complement {
995                    // This OR operand's complement exists, so this term becomes TRUE in AND context
996                    // NOT A OR B, where A exists, becomes TRUE OR B (when A is true) or B (when A is false)
997                    // Actually: A AND (NOT A OR B) -> A AND B, so we drop NOT A from the OR
998                    had_complement_absorption = true;
999                    // Drop this operand from OR
1000                } else {
1001                    remaining_or_ops.push(or_op);
1002                }
1003            }
1004
1005            if had_complement_absorption {
1006                if remaining_or_ops.is_empty() {
1007                    // All OR operands were absorbed, the OR becomes TRUE
1008                    // A AND TRUE -> A, so we just skip adding this
1009                    absorbed.insert(op_str);
1010                    continue;
1011                } else if remaining_or_ops.len() == 1 {
1012                    // Single remaining operand
1013                    result_ops.push(remaining_or_ops.into_iter().next().unwrap());
1014                    absorbed.insert(op_str);
1015                    continue;
1016                } else {
1017                    // Rebuild the OR with remaining operands
1018                    result_ops.push(rebuild_or(remaining_or_ops));
1019                    absorbed.insert(op_str);
1020                    continue;
1021                }
1022            }
1023        }
1024
1025        result_ops.push(op.clone());
1026    }
1027
1028    // Deduplicate
1029    let mut seen = std::collections::HashSet::new();
1030    result_ops.retain(|op| seen.insert(gen(op)));
1031
1032    if result_ops.is_empty() {
1033        bool_true()
1034    } else {
1035        rebuild_and(result_ops)
1036    }
1037}
1038
1039/// Apply Boolean absorption and elimination rules to an OR expression
1040///
1041/// Absorption:
1042///   A OR (A AND B) -> A
1043///   A OR (NOT A AND B) -> A OR B
1044///
1045/// Elimination:
1046///   (A AND B) OR (A AND NOT B) -> A
1047pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1048    // Flatten both sides
1049    let left_ops = flatten_or(&left);
1050    let right_ops = flatten_or(&right);
1051    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1052
1053    // Build a set of string representations for quick lookup
1054    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1055
1056    let mut result_ops: Vec<Expression> = Vec::new();
1057    let mut absorbed = std::collections::HashSet::new();
1058
1059    for (i, op) in all_ops.iter().enumerate() {
1060        let op_str = gen(op);
1061
1062        // Skip if already absorbed
1063        if absorbed.contains(&op_str) {
1064            continue;
1065        }
1066
1067        // Check if this is an AND expression (potential absorption target)
1068        if let Expression::And(_) = op {
1069            let and_operands = flatten_and(op);
1070
1071            // Absorption: A OR (A AND B) -> A
1072            // Check if any AND operand is already in our OR operands
1073            let absorbed_by_existing = and_operands.iter().any(|and_op| {
1074                let and_op_str = gen(and_op);
1075                // Check if this AND operand exists in other OR operands (not this AND itself)
1076                all_ops
1077                    .iter()
1078                    .enumerate()
1079                    .any(|(j, other)| i != j && gen(other) == and_op_str)
1080            });
1081
1082            if absorbed_by_existing {
1083                // This AND is absorbed, skip it
1084                absorbed.insert(op_str);
1085                continue;
1086            }
1087
1088            // Absorption with complement: A OR (NOT A AND B) -> A OR B
1089            // Check if any AND operand's complement is in our OR operands
1090            let mut remaining_and_ops: Vec<Expression> = Vec::new();
1091            let mut had_complement_absorption = false;
1092
1093            for and_op in and_operands {
1094                let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1095                    // and_op is NOT X, complement is X
1096                    gen(inner)
1097                } else {
1098                    // and_op is X, complement is NOT X
1099                    format!("NOT {}", gen(&and_op))
1100                };
1101
1102                // Check if complement exists in our OR operands
1103                let has_complement = all_ops
1104                    .iter()
1105                    .enumerate()
1106                    .any(|(j, other)| i != j && gen(other) == complement_str)
1107                    || op_strings.contains(&complement_str);
1108
1109                if has_complement {
1110                    // This AND operand's complement exists, so this term becomes FALSE in OR context
1111                    // A OR (NOT A AND B) -> A OR B, so we drop NOT A from the AND
1112                    had_complement_absorption = true;
1113                    // Drop this operand from AND
1114                } else {
1115                    remaining_and_ops.push(and_op);
1116                }
1117            }
1118
1119            if had_complement_absorption {
1120                if remaining_and_ops.is_empty() {
1121                    // All AND operands were absorbed, the AND becomes FALSE
1122                    // A OR FALSE -> A, so we just skip adding this
1123                    absorbed.insert(op_str);
1124                    continue;
1125                } else if remaining_and_ops.len() == 1 {
1126                    // Single remaining operand
1127                    result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1128                    absorbed.insert(op_str);
1129                    continue;
1130                } else {
1131                    // Rebuild the AND with remaining operands
1132                    result_ops.push(rebuild_and(remaining_and_ops));
1133                    absorbed.insert(op_str);
1134                    continue;
1135                }
1136            }
1137        }
1138
1139        result_ops.push(op.clone());
1140    }
1141
1142    // Deduplicate
1143    let mut seen = std::collections::HashSet::new();
1144    result_ops.retain(|op| seen.insert(gen(op)));
1145
1146    if result_ops.is_empty() {
1147        bool_false()
1148    } else {
1149        rebuild_or(result_ops)
1150    }
1151}
1152
1153/// Generate a simple string representation of an expression for sorting/deduping
1154pub fn gen(expr: &Expression) -> String {
1155    match expr {
1156        Expression::Literal(lit) => match lit.as_ref() {
1157            Literal::String(s) => format!("'{}'", s),
1158            Literal::Number(n) => n.clone(),
1159            _ => format!("{:?}", lit),
1160        },
1161        Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1162        Expression::Null(_) => "NULL".to_string(),
1163        Expression::Column(col) => {
1164            if let Some(ref table) = col.table {
1165                format!("{}.{}", table.name, col.name.name)
1166            } else {
1167                col.name.name.clone()
1168            }
1169        }
1170        Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1171        Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1172        Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1173        Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1174        Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1175        Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1176        Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1177        Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1178        Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1179        Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1180        Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1181        Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1182        Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1183        Expression::Function(f) => {
1184            let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1185            format!("{}({})", f.name.to_uppercase(), args.join(", "))
1186        }
1187        _ => format!("{:?}", expr),
1188    }
1189}
1190
1191#[cfg(test)]
1192mod tests {
1193    use super::*;
1194
1195    fn make_int(val: i64) -> Expression {
1196        Expression::Literal(Box::new(Literal::Number(val.to_string())))
1197    }
1198
1199    fn make_string(val: &str) -> Expression {
1200        Expression::Literal(Box::new(Literal::String(val.to_string())))
1201    }
1202
1203    fn make_bool(val: bool) -> Expression {
1204        Expression::Boolean(BooleanLiteral { value: val })
1205    }
1206
1207    fn make_column(name: &str) -> Expression {
1208        use crate::expressions::{Column, Identifier};
1209        Expression::boxed_column(Column {
1210            name: Identifier::new(name),
1211            table: None,
1212            join_mark: false,
1213            trailing_comments: vec![],
1214            span: None,
1215            inferred_type: None,
1216        })
1217    }
1218
1219    #[test]
1220    fn test_always_true_false() {
1221        assert!(always_true(&make_bool(true)));
1222        assert!(!always_true(&make_bool(false)));
1223        assert!(always_true(&make_int(1)));
1224        assert!(!always_true(&make_int(0)));
1225
1226        assert!(always_false(&make_bool(false)));
1227        assert!(!always_false(&make_bool(true)));
1228        assert!(always_false(&null()));
1229        assert!(always_false(&make_int(0)));
1230    }
1231
1232    #[test]
1233    fn test_simplify_and_with_true() {
1234        let mut simplifier = Simplifier::new(None);
1235
1236        // TRUE AND TRUE -> TRUE
1237        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1238        let result = simplifier.simplify(expr);
1239        assert!(always_true(&result));
1240
1241        // TRUE AND FALSE -> FALSE
1242        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1243        let result = simplifier.simplify(expr);
1244        assert!(always_false(&result));
1245
1246        // TRUE AND x -> x
1247        let x = make_int(42);
1248        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1249        let result = simplifier.simplify(expr);
1250        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1251    }
1252
1253    #[test]
1254    fn test_simplify_or_with_false() {
1255        let mut simplifier = Simplifier::new(None);
1256
1257        // FALSE OR FALSE -> FALSE
1258        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1259        let result = simplifier.simplify(expr);
1260        assert!(always_false(&result));
1261
1262        // FALSE OR TRUE -> TRUE
1263        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1264        let result = simplifier.simplify(expr);
1265        assert!(always_true(&result));
1266
1267        // FALSE OR x -> x
1268        let x = make_int(42);
1269        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1270        let result = simplifier.simplify(expr);
1271        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1272    }
1273
1274    #[test]
1275    fn test_simplify_not() {
1276        let mut simplifier = Simplifier::new(None);
1277
1278        // NOT TRUE -> FALSE
1279        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1280        let result = simplifier.simplify(expr);
1281        assert!(is_false(&result));
1282
1283        // NOT FALSE -> TRUE
1284        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1285        let result = simplifier.simplify(expr);
1286        assert!(always_true(&result));
1287
1288        // NOT NOT x -> x
1289        let x = make_int(42);
1290        let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1291        let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1292        let result = simplifier.simplify(expr);
1293        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1294    }
1295
1296    #[test]
1297    fn test_simplify_demorgan_comparison() {
1298        let mut simplifier = Simplifier::new(None);
1299
1300        // NOT (a = b) -> a != b (using columns to avoid constant folding)
1301        let a = make_column("a");
1302        let b = make_column("b");
1303        let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1304        let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1305        let result = simplifier.simplify(expr);
1306        assert!(matches!(result, Expression::Neq(_)));
1307
1308        // NOT (a > b) -> a <= b
1309        let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1310        let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1311        let result = simplifier.simplify(expr);
1312        assert!(matches!(result, Expression::Lte(_)));
1313    }
1314
1315    #[test]
1316    fn test_constant_folding_add() {
1317        let mut simplifier = Simplifier::new(None);
1318
1319        // 1 + 2 -> 3
1320        let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1321        let result = simplifier.simplify(expr);
1322        assert_eq!(get_number(&result), Some(3.0));
1323
1324        // x + 0 -> x
1325        let x = make_int(42);
1326        let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1327        let result = simplifier.simplify(expr);
1328        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1329    }
1330
1331    #[test]
1332    fn test_constant_folding_mul() {
1333        let mut simplifier = Simplifier::new(None);
1334
1335        // 3 * 4 -> 12
1336        let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1337        let result = simplifier.simplify(expr);
1338        assert_eq!(get_number(&result), Some(12.0));
1339
1340        // x * 0 -> 0
1341        let x = make_int(42);
1342        let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1343        let result = simplifier.simplify(expr);
1344        assert_eq!(get_number(&result), Some(0.0));
1345
1346        // x * 1 -> x
1347        let x = make_int(42);
1348        let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1349        let result = simplifier.simplify(expr);
1350        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1351    }
1352
1353    #[test]
1354    fn test_constant_folding_comparison() {
1355        let mut simplifier = Simplifier::new(None);
1356
1357        // 1 = 1 -> TRUE
1358        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1359        let result = simplifier.simplify(expr);
1360        assert!(always_true(&result));
1361
1362        // 1 = 2 -> FALSE
1363        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1364        let result = simplifier.simplify(expr);
1365        assert!(is_false(&result));
1366
1367        // 3 > 2 -> TRUE
1368        let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1369        let result = simplifier.simplify(expr);
1370        assert!(always_true(&result));
1371
1372        // 'a' = 'a' -> TRUE
1373        let expr = Expression::Eq(Box::new(BinaryOp::new(
1374            make_string("abc"),
1375            make_string("abc"),
1376        )));
1377        let result = simplifier.simplify(expr);
1378        assert!(always_true(&result));
1379    }
1380
1381    #[test]
1382    fn test_simplify_negation() {
1383        let mut simplifier = Simplifier::new(None);
1384
1385        // -(-5) -> 5
1386        let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1387        let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1388        let result = simplifier.simplify(expr);
1389        assert_eq!(get_number(&result), Some(5.0));
1390
1391        // -(3) -> -3
1392        let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1393        let result = simplifier.simplify(expr);
1394        assert_eq!(get_number(&result), Some(-3.0));
1395    }
1396
1397    #[test]
1398    fn test_gen_simple() {
1399        assert_eq!(gen(&make_int(42)), "42");
1400        assert_eq!(gen(&make_string("hello")), "'hello'");
1401        assert_eq!(gen(&make_bool(true)), "TRUE");
1402        assert_eq!(gen(&make_bool(false)), "FALSE");
1403        assert_eq!(gen(&null()), "NULL");
1404    }
1405
1406    #[test]
1407    fn test_gen_operations() {
1408        let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1409        assert_eq!(gen(&add), "1 + 2");
1410
1411        let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1412        assert_eq!(gen(&and), "(TRUE AND FALSE)");
1413    }
1414
1415    #[test]
1416    fn test_complement_elimination() {
1417        let mut simplifier = Simplifier::new(None);
1418
1419        // x AND NOT x -> FALSE
1420        let x = make_int(42);
1421        let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1422        let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1423        let result = simplifier.simplify(expr);
1424        assert!(is_false(&result));
1425    }
1426
1427    #[test]
1428    fn test_idempotent() {
1429        let mut simplifier = Simplifier::new(None);
1430
1431        // x AND x -> x
1432        let x = make_int(42);
1433        let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1434        let result = simplifier.simplify(expr);
1435        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1436
1437        // x OR x -> x
1438        let x = make_int(42);
1439        let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1440        let result = simplifier.simplify(expr);
1441        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1442    }
1443
1444    #[test]
1445    fn test_absorption_and() {
1446        let mut simplifier = Simplifier::new(None);
1447
1448        // A AND (A OR B) -> A
1449        let a = make_column("a");
1450        let b = make_column("b");
1451        let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1452        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1453        let result = simplifier.simplify(expr);
1454        // Result should be just A
1455        assert_eq!(gen(&result), gen(&a));
1456    }
1457
1458    #[test]
1459    fn test_absorption_or() {
1460        let mut simplifier = Simplifier::new(None);
1461
1462        // A OR (A AND B) -> A
1463        let a = make_column("a");
1464        let b = make_column("b");
1465        let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1466        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1467        let result = simplifier.simplify(expr);
1468        // Result should be just A
1469        assert_eq!(gen(&result), gen(&a));
1470    }
1471
1472    #[test]
1473    fn test_absorption_with_complement_and() {
1474        let mut simplifier = Simplifier::new(None);
1475
1476        // A AND (NOT A OR B) -> A AND B
1477        let a = make_column("a");
1478        let b = make_column("b");
1479        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1480        let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1481        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1482        let result = simplifier.simplify(expr);
1483        // Result should be A AND B
1484        let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1485        assert_eq!(gen(&result), gen(&expected));
1486    }
1487
1488    #[test]
1489    fn test_absorption_with_complement_or() {
1490        let mut simplifier = Simplifier::new(None);
1491
1492        // A OR (NOT A AND B) -> A OR B
1493        let a = make_column("a");
1494        let b = make_column("b");
1495        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1496        let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1497        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1498        let result = simplifier.simplify(expr);
1499        // Result should be A OR B
1500        let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1501        assert_eq!(gen(&result), gen(&expected));
1502    }
1503
1504    #[test]
1505    fn test_flatten_and() {
1506        // (A AND (B AND C)) should flatten to [A, B, C]
1507        let a = make_column("a");
1508        let b = make_column("b");
1509        let c = make_column("c");
1510        let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1511        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1512        let flattened = flatten_and(&expr);
1513        assert_eq!(flattened.len(), 3);
1514        assert_eq!(gen(&flattened[0]), "a");
1515        assert_eq!(gen(&flattened[1]), "b");
1516        assert_eq!(gen(&flattened[2]), "c");
1517    }
1518
1519    #[test]
1520    fn test_flatten_or() {
1521        // (A OR (B OR C)) should flatten to [A, B, C]
1522        let a = make_column("a");
1523        let b = make_column("b");
1524        let c = make_column("c");
1525        let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1526        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1527        let flattened = flatten_or(&expr);
1528        assert_eq!(flattened.len(), 3);
1529        assert_eq!(gen(&flattened[0]), "a");
1530        assert_eq!(gen(&flattened[1]), "b");
1531        assert_eq!(gen(&flattened[2]), "c");
1532    }
1533
1534    #[test]
1535    fn test_simplify_concat() {
1536        let mut simplifier = Simplifier::new(None);
1537
1538        // 'a' || 'b' -> 'ab'
1539        let expr = Expression::Concat(Box::new(BinaryOp::new(
1540            make_string("hello"),
1541            make_string("world"),
1542        )));
1543        let result = simplifier.simplify(expr);
1544        assert_eq!(get_string(&result), Some("helloworld".to_string()));
1545
1546        // '' || x -> x
1547        let x = make_string("test");
1548        let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1549        let result = simplifier.simplify(expr);
1550        assert_eq!(get_string(&result), Some("test".to_string()));
1551
1552        // x || '' -> x
1553        let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1554        let result = simplifier.simplify(expr);
1555        assert_eq!(get_string(&result), Some("test".to_string()));
1556
1557        // NULL || x -> NULL
1558        let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1559        let result = simplifier.simplify(expr);
1560        assert!(is_null(&result));
1561    }
1562
1563    #[test]
1564    fn test_simplify_concat_ws() {
1565        let mut simplifier = Simplifier::new(None);
1566
1567        // CONCAT_WS(',', 'a', 'b', 'c') -> 'a,b,c'
1568        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1569            separator: make_string(","),
1570            expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1571        }));
1572        let result = simplifier.simplify(expr);
1573        assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1574
1575        // CONCAT_WS with NULL separator -> NULL
1576        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1577            separator: null(),
1578            expressions: vec![make_string("a"), make_string("b")],
1579        }));
1580        let result = simplifier.simplify(expr);
1581        assert!(is_null(&result));
1582
1583        // CONCAT_WS with empty expressions -> ''
1584        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1585            separator: make_string(","),
1586            expressions: vec![],
1587        }));
1588        let result = simplifier.simplify(expr);
1589        assert_eq!(get_string(&result), Some("".to_string()));
1590
1591        // CONCAT_WS skips NULLs
1592        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1593            separator: make_string("-"),
1594            expressions: vec![make_string("a"), null(), make_string("b")],
1595        }));
1596        let result = simplifier.simplify(expr);
1597        assert_eq!(get_string(&result), Some("a-b".to_string()));
1598    }
1599
1600    #[test]
1601    fn test_simplify_paren() {
1602        let mut simplifier = Simplifier::new(None);
1603
1604        // (42) -> 42
1605        let expr = Expression::Paren(Box::new(Paren {
1606            this: make_int(42),
1607            trailing_comments: vec![],
1608        }));
1609        let result = simplifier.simplify(expr);
1610        assert_eq!(get_number(&result), Some(42.0));
1611
1612        // (TRUE) -> TRUE
1613        let expr = Expression::Paren(Box::new(Paren {
1614            this: make_bool(true),
1615            trailing_comments: vec![],
1616        }));
1617        let result = simplifier.simplify(expr);
1618        assert!(is_boolean_true(&result));
1619
1620        // (NULL) -> NULL
1621        let expr = Expression::Paren(Box::new(Paren {
1622            this: null(),
1623            trailing_comments: vec![],
1624        }));
1625        let result = simplifier.simplify(expr);
1626        assert!(is_null(&result));
1627
1628        // ((x)) -> x
1629        let inner_paren = Expression::Paren(Box::new(Paren {
1630            this: make_int(10),
1631            trailing_comments: vec![],
1632        }));
1633        let expr = Expression::Paren(Box::new(Paren {
1634            this: inner_paren,
1635            trailing_comments: vec![],
1636        }));
1637        let result = simplifier.simplify(expr);
1638        assert_eq!(get_number(&result), Some(10.0));
1639    }
1640
1641    #[test]
1642    fn test_simplify_equality_solve() {
1643        let mut simplifier = Simplifier::new(None);
1644
1645        // x + 1 = 3 -> x = 2
1646        let x = make_column("x");
1647        let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1648        let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1649        let result = simplifier.simplify(expr);
1650        // Result should be x = 2
1651        if let Expression::Eq(op) = &result {
1652            assert_eq!(gen(&op.left), "x");
1653            assert_eq!(get_number(&op.right), Some(2.0));
1654        } else {
1655            panic!("Expected Eq expression");
1656        }
1657
1658        // x - 1 = 3 -> x = 4
1659        let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1660        let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1661        let result = simplifier.simplify(expr);
1662        if let Expression::Eq(op) = &result {
1663            assert_eq!(gen(&op.left), "x");
1664            assert_eq!(get_number(&op.right), Some(4.0));
1665        } else {
1666            panic!("Expected Eq expression");
1667        }
1668
1669        // x * 2 = 6 -> x = 3
1670        let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1671        let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1672        let result = simplifier.simplify(expr);
1673        if let Expression::Eq(op) = &result {
1674            assert_eq!(gen(&op.left), "x");
1675            assert_eq!(get_number(&op.right), Some(3.0));
1676        } else {
1677            panic!("Expected Eq expression");
1678        }
1679
1680        // 1 + x = 3 -> x = 2 (commutative)
1681        let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1682        let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1683        let result = simplifier.simplify(expr);
1684        if let Expression::Eq(op) = &result {
1685            assert_eq!(gen(&op.left), "x");
1686            assert_eq!(get_number(&op.right), Some(2.0));
1687        } else {
1688            panic!("Expected Eq expression");
1689        }
1690    }
1691
1692    #[test]
1693    fn test_simplify_datetrunc() {
1694        use crate::expressions::DateTimeField;
1695        let mut simplifier = Simplifier::new(None);
1696
1697        // DATE_TRUNC('day', x) with a column just passes through with simplified children
1698        let x = make_column("x");
1699        let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1700            this: x.clone(),
1701            unit: DateTimeField::Day,
1702        }));
1703        let result = simplifier.simplify(expr);
1704        if let Expression::DateTrunc(dt) = &result {
1705            assert_eq!(gen(&dt.this), "x");
1706            assert_eq!(dt.unit, DateTimeField::Day);
1707        } else {
1708            panic!("Expected DateTrunc expression");
1709        }
1710    }
1711}