Skip to main content

polyglot_sql/optimizer/
simplify.rs

1//! Expression Simplification
2//!
3//! This module provides boolean and expression simplification for SQL AST nodes.
4//! It applies various algebraic transformations to simplify expressions:
5//! - De Morgan's laws (NOT (A AND B) -> NOT A OR NOT B)
6//! - Constant folding (1 + 2 -> 3)
7//! - Boolean absorption (A AND (A OR B) -> A)
8//! - Complement removal (A AND NOT A -> FALSE)
9//! - Connector flattening (A AND (B AND C) -> A AND B AND C)
10//!
11//! Based on SQLGlot's optimizer/simplify.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{
15    BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16    UnaryOp,
17};
18
19/// Main entry point for expression simplification
20pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21    let mut simplifier = Simplifier::new(dialect);
22    simplifier.simplify(expression)
23}
24
25/// Check if expression is always true
26pub fn always_true(expr: &Expression) -> bool {
27    match expr {
28        Expression::Boolean(b) => b.value,
29        Expression::Literal(Literal::Number(n)) => {
30            // Non-zero numbers are truthy
31            if let Ok(num) = n.parse::<f64>() {
32                num != 0.0
33            } else {
34                false
35            }
36        }
37        _ => false,
38    }
39}
40
41/// Check if expression is a boolean TRUE literal (not just truthy)
42pub fn is_boolean_true(expr: &Expression) -> bool {
43    matches!(expr, Expression::Boolean(b) if b.value)
44}
45
46/// Check if expression is a boolean FALSE literal (not just falsy)
47pub fn is_boolean_false(expr: &Expression) -> bool {
48    matches!(expr, Expression::Boolean(b) if !b.value)
49}
50
51/// Check if expression is always false
52pub fn always_false(expr: &Expression) -> bool {
53    is_false(expr) || is_null(expr) || is_zero(expr)
54}
55
56/// Check if expression is boolean FALSE
57pub fn is_false(expr: &Expression) -> bool {
58    matches!(expr, Expression::Boolean(b) if !b.value)
59}
60
61/// Check if expression is NULL
62pub fn is_null(expr: &Expression) -> bool {
63    matches!(expr, Expression::Null(_))
64}
65
66/// Check if expression is zero
67pub fn is_zero(expr: &Expression) -> bool {
68    match expr {
69        Expression::Literal(Literal::Number(n)) => {
70            if let Ok(num) = n.parse::<f64>() {
71                num == 0.0
72            } else {
73                false
74            }
75        }
76        _ => false,
77    }
78}
79
80/// Check if b is the complement of a (i.e., b = NOT a)
81pub fn is_complement(a: &Expression, b: &Expression) -> bool {
82    if let Expression::Not(not_op) = b {
83        &not_op.this == a
84    } else {
85        false
86    }
87}
88
89/// Create a TRUE boolean literal
90pub fn bool_true() -> Expression {
91    Expression::Boolean(BooleanLiteral { value: true })
92}
93
94/// Create a FALSE boolean literal
95pub fn bool_false() -> Expression {
96    Expression::Boolean(BooleanLiteral { value: false })
97}
98
99/// Create a NULL expression
100pub fn null() -> Expression {
101    Expression::Null(Null)
102}
103
104/// Evaluate a boolean comparison between two numbers
105pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
106    let result = match op {
107        "=" | "==" => a == b,
108        "!=" | "<>" => a != b,
109        ">" => a > b,
110        ">=" => a >= b,
111        "<" => a < b,
112        "<=" => a <= b,
113        _ => return None,
114    };
115    Some(if result { bool_true() } else { bool_false() })
116}
117
118/// Evaluate a boolean comparison between two strings
119pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
120    let result = match op {
121        "=" | "==" => a == b,
122        "!=" | "<>" => a != b,
123        ">" => a > b,
124        ">=" => a >= b,
125        "<" => a < b,
126        "<=" => a <= b,
127        _ => return None,
128    };
129    Some(if result { bool_true() } else { bool_false() })
130}
131
132/// Expression simplifier
133pub struct Simplifier {
134    _dialect: Option<DialectType>,
135    max_iterations: usize,
136}
137
138impl Simplifier {
139    /// Create a new simplifier
140    pub fn new(dialect: Option<DialectType>) -> Self {
141        Self {
142            _dialect: dialect,
143            max_iterations: 100,
144        }
145    }
146
147    /// Simplify an expression
148    pub fn simplify(&mut self, expression: Expression) -> Expression {
149        // Apply simplifications until no more changes (or max iterations)
150        let mut current = expression;
151        for _ in 0..self.max_iterations {
152            let simplified = self.simplify_once(current.clone());
153            if expressions_equal(&simplified, &current) {
154                return simplified;
155            }
156            current = simplified;
157        }
158        current
159    }
160
161    /// Apply one round of simplifications
162    fn simplify_once(&mut self, expression: Expression) -> Expression {
163        match expression {
164            // Binary logical operations
165            Expression::And(op) => self.simplify_and(*op),
166            Expression::Or(op) => self.simplify_or(*op),
167
168            // NOT operation - De Morgan's laws
169            Expression::Not(op) => self.simplify_not(*op),
170
171            // Arithmetic operations - constant folding
172            Expression::Add(op) => self.simplify_add(*op),
173            Expression::Sub(op) => self.simplify_sub(*op),
174            Expression::Mul(op) => self.simplify_mul(*op),
175            Expression::Div(op) => self.simplify_div(*op),
176
177            // Comparison operations
178            Expression::Eq(op) => self.simplify_comparison(*op, "="),
179            Expression::Neq(op) => self.simplify_comparison(*op, "!="),
180            Expression::Gt(op) => self.simplify_comparison(*op, ">"),
181            Expression::Gte(op) => self.simplify_comparison(*op, ">="),
182            Expression::Lt(op) => self.simplify_comparison(*op, "<"),
183            Expression::Lte(op) => self.simplify_comparison(*op, "<="),
184
185            // Negation
186            Expression::Neg(op) => self.simplify_neg(*op),
187
188            // CASE expression
189            Expression::Case(case) => self.simplify_case(*case),
190
191            // String concatenation
192            Expression::Concat(op) => self.simplify_concat(*op),
193            Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
194
195            // Parentheses - remove if unnecessary
196            Expression::Paren(paren) => self.simplify_paren(*paren),
197
198            // Date truncation
199            Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
200            Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
201
202            // Recursively simplify children for other expressions
203            other => self.simplify_children(other),
204        }
205    }
206
207    /// Simplify AND operation
208    fn simplify_and(&mut self, op: BinaryOp) -> Expression {
209        let left = self.simplify_once(op.left);
210        let right = self.simplify_once(op.right);
211
212        // FALSE AND x -> FALSE
213        // x AND FALSE -> FALSE
214        if is_boolean_false(&left) || is_boolean_false(&right) {
215            return bool_false();
216        }
217
218        // 0 AND x -> FALSE (in boolean context)
219        // x AND 0 -> FALSE
220        if is_zero(&left) || is_zero(&right) {
221            return bool_false();
222        }
223
224        // NULL AND NULL -> NULL
225        // NULL AND TRUE -> NULL
226        // TRUE AND NULL -> NULL
227        if (is_null(&left) && is_null(&right))
228            || (is_null(&left) && is_boolean_true(&right))
229            || (is_boolean_true(&left) && is_null(&right))
230        {
231            return null();
232        }
233
234        // TRUE AND x -> x (only when left is actually boolean TRUE)
235        if is_boolean_true(&left) {
236            return right;
237        }
238
239        // x AND TRUE -> x (only when right is actually boolean TRUE)
240        if is_boolean_true(&right) {
241            return left;
242        }
243
244        // A AND NOT A -> FALSE (complement elimination)
245        if is_complement(&left, &right) || is_complement(&right, &left) {
246            return bool_false();
247        }
248
249        // A AND A -> A (idempotent)
250        if expressions_equal(&left, &right) {
251            return left;
252        }
253
254        // Apply absorption rules
255        // A AND (A OR B) -> A
256        // A AND (NOT A OR B) -> A AND B
257        absorb_and_eliminate_and(left, right)
258    }
259
260    /// Simplify OR operation
261    fn simplify_or(&mut self, op: BinaryOp) -> Expression {
262        let left = self.simplify_once(op.left);
263        let right = self.simplify_once(op.right);
264
265        // TRUE OR x -> TRUE (only when left is actually boolean TRUE)
266        if is_boolean_true(&left) {
267            return bool_true();
268        }
269
270        // x OR TRUE -> TRUE (only when right is actually boolean TRUE)
271        if is_boolean_true(&right) {
272            return bool_true();
273        }
274
275        // NULL OR NULL -> NULL
276        // NULL OR FALSE -> NULL
277        // FALSE OR NULL -> NULL
278        if (is_null(&left) && is_null(&right))
279            || (is_null(&left) && is_boolean_false(&right))
280            || (is_boolean_false(&left) && is_null(&right))
281        {
282            return null();
283        }
284
285        // FALSE OR x -> x (only when left is actually boolean FALSE)
286        if is_boolean_false(&left) {
287            return right;
288        }
289
290        // x OR FALSE -> x (only when right is actually boolean FALSE)
291        if is_boolean_false(&right) {
292            return left;
293        }
294
295        // A OR A -> A (idempotent)
296        if expressions_equal(&left, &right) {
297            return left;
298        }
299
300        // Apply absorption rules
301        // A OR (A AND B) -> A
302        // A OR (NOT A AND B) -> A OR B
303        absorb_and_eliminate_or(left, right)
304    }
305
306    /// Simplify NOT operation (De Morgan's laws)
307    fn simplify_not(&mut self, op: UnaryOp) -> Expression {
308        // Check for De Morgan's laws BEFORE simplifying inner expression
309        // This prevents constant folding from eliminating the comparison operator
310        match &op.this {
311            // NOT (a = b) -> a != b
312            Expression::Eq(inner_op) => {
313                let left = self.simplify_once(inner_op.left.clone());
314                let right = self.simplify_once(inner_op.right.clone());
315                return Expression::Neq(Box::new(BinaryOp::new(left, right)));
316            }
317            // NOT (a != b) -> a = b
318            Expression::Neq(inner_op) => {
319                let left = self.simplify_once(inner_op.left.clone());
320                let right = self.simplify_once(inner_op.right.clone());
321                return Expression::Eq(Box::new(BinaryOp::new(left, right)));
322            }
323            // NOT (a > b) -> a <= b
324            Expression::Gt(inner_op) => {
325                let left = self.simplify_once(inner_op.left.clone());
326                let right = self.simplify_once(inner_op.right.clone());
327                return Expression::Lte(Box::new(BinaryOp::new(left, right)));
328            }
329            // NOT (a >= b) -> a < b
330            Expression::Gte(inner_op) => {
331                let left = self.simplify_once(inner_op.left.clone());
332                let right = self.simplify_once(inner_op.right.clone());
333                return Expression::Lt(Box::new(BinaryOp::new(left, right)));
334            }
335            // NOT (a < b) -> a >= b
336            Expression::Lt(inner_op) => {
337                let left = self.simplify_once(inner_op.left.clone());
338                let right = self.simplify_once(inner_op.right.clone());
339                return Expression::Gte(Box::new(BinaryOp::new(left, right)));
340            }
341            // NOT (a <= b) -> a > b
342            Expression::Lte(inner_op) => {
343                let left = self.simplify_once(inner_op.left.clone());
344                let right = self.simplify_once(inner_op.right.clone());
345                return Expression::Gt(Box::new(BinaryOp::new(left, right)));
346            }
347            _ => {}
348        }
349
350        // Now simplify the inner expression for other patterns
351        let inner = self.simplify_once(op.this);
352
353        // NOT NULL -> NULL (with TRUE for SQL semantics)
354        if is_null(&inner) {
355            return null();
356        }
357
358        // NOT TRUE -> FALSE (only for boolean TRUE literal)
359        if is_boolean_true(&inner) {
360            return bool_false();
361        }
362
363        // NOT FALSE -> TRUE (only for boolean FALSE literal)
364        if is_boolean_false(&inner) {
365            return bool_true();
366        }
367
368        // NOT NOT x -> x (double negation elimination)
369        if let Expression::Not(inner_not) = &inner {
370            return inner_not.this.clone();
371        }
372
373        Expression::Not(Box::new(UnaryOp {
374            this: inner,
375            inferred_type: None,
376        }))
377    }
378
379    /// Simplify addition (constant folding)
380    fn simplify_add(&mut self, op: BinaryOp) -> Expression {
381        let left = self.simplify_once(op.left);
382        let right = self.simplify_once(op.right);
383
384        // Try constant folding for numbers
385        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
386            return Expression::Literal(Literal::Number((a + b).to_string()));
387        }
388
389        // x + 0 -> x
390        if is_zero(&right) {
391            return left;
392        }
393
394        // 0 + x -> x
395        if is_zero(&left) {
396            return right;
397        }
398
399        Expression::Add(Box::new(BinaryOp::new(left, right)))
400    }
401
402    /// Simplify subtraction (constant folding)
403    fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
404        let left = self.simplify_once(op.left);
405        let right = self.simplify_once(op.right);
406
407        // Try constant folding for numbers
408        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
409            return Expression::Literal(Literal::Number((a - b).to_string()));
410        }
411
412        // x - 0 -> x
413        if is_zero(&right) {
414            return left;
415        }
416
417        // x - x -> 0 (only for literals/constants)
418        if expressions_equal(&left, &right) {
419            if let Expression::Literal(Literal::Number(_)) = &left {
420                return Expression::Literal(Literal::Number("0".to_string()));
421            }
422        }
423
424        Expression::Sub(Box::new(BinaryOp::new(left, right)))
425    }
426
427    /// Simplify multiplication (constant folding)
428    fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
429        let left = self.simplify_once(op.left);
430        let right = self.simplify_once(op.right);
431
432        // Try constant folding for numbers
433        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
434            return Expression::Literal(Literal::Number((a * b).to_string()));
435        }
436
437        // x * 0 -> 0
438        if is_zero(&right) {
439            return Expression::Literal(Literal::Number("0".to_string()));
440        }
441
442        // 0 * x -> 0
443        if is_zero(&left) {
444            return Expression::Literal(Literal::Number("0".to_string()));
445        }
446
447        // x * 1 -> x
448        if is_one(&right) {
449            return left;
450        }
451
452        // 1 * x -> x
453        if is_one(&left) {
454            return right;
455        }
456
457        Expression::Mul(Box::new(BinaryOp::new(left, right)))
458    }
459
460    /// Simplify division (constant folding)
461    fn simplify_div(&mut self, op: BinaryOp) -> Expression {
462        let left = self.simplify_once(op.left);
463        let right = self.simplify_once(op.right);
464
465        // Try constant folding for numbers (but not integer division)
466        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
467            // Only fold if both are floats to avoid integer division issues
468            if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
469                return Expression::Literal(Literal::Number((a / b).to_string()));
470            }
471        }
472
473        // 0 / x -> 0 (when x != 0)
474        if is_zero(&left) && !is_zero(&right) {
475            return Expression::Literal(Literal::Number("0".to_string()));
476        }
477
478        // x / 1 -> x
479        if is_one(&right) {
480            return left;
481        }
482
483        Expression::Div(Box::new(BinaryOp::new(left, right)))
484    }
485
486    /// Simplify negation
487    fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
488        let inner = self.simplify_once(op.this);
489
490        // -(-x) -> x (double negation)
491        if let Expression::Neg(inner_neg) = inner {
492            return inner_neg.this;
493        }
494
495        // -(number) -> -number
496        if let Some(n) = get_number(&inner) {
497            return Expression::Literal(Literal::Number((-n).to_string()));
498        }
499
500        Expression::Neg(Box::new(UnaryOp {
501            this: inner,
502            inferred_type: None,
503        }))
504    }
505
506    /// Simplify comparison operations (constant folding)
507    fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
508        let left = self.simplify_once(op.left);
509        let right = self.simplify_once(op.right);
510
511        // Try constant folding for numbers
512        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
513            if let Some(result) = eval_boolean_nums(operator, a, b) {
514                return result;
515            }
516        }
517
518        // Try constant folding for strings
519        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
520            if let Some(result) = eval_boolean_strings(operator, &a, &b) {
521                return result;
522            }
523        }
524
525        // For equality, try to solve simple equations (x + 1 = 3 -> x = 2)
526        if operator == "=" {
527            if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
528                return simplified;
529            }
530        }
531
532        // Reconstruct the comparison
533        let new_op = BinaryOp::new(left, right);
534
535        match operator {
536            "=" => Expression::Eq(Box::new(new_op)),
537            "!=" | "<>" => Expression::Neq(Box::new(new_op)),
538            ">" => Expression::Gt(Box::new(new_op)),
539            ">=" => Expression::Gte(Box::new(new_op)),
540            "<" => Expression::Lt(Box::new(new_op)),
541            "<=" => Expression::Lte(Box::new(new_op)),
542            _ => Expression::Eq(Box::new(new_op)),
543        }
544    }
545
546    /// Simplify CASE expression
547    fn simplify_case(&mut self, case: Case) -> Expression {
548        let mut new_whens = Vec::new();
549
550        for (cond, then_expr) in case.whens {
551            let simplified_cond = self.simplify_once(cond);
552
553            // If condition is always true, return the THEN expression
554            if always_true(&simplified_cond) {
555                return self.simplify_once(then_expr);
556            }
557
558            // If condition is always false, skip this WHEN clause
559            if always_false(&simplified_cond) {
560                continue;
561            }
562
563            new_whens.push((simplified_cond, self.simplify_once(then_expr)));
564        }
565
566        // If no WHEN clauses remain, return the ELSE expression (or NULL)
567        if new_whens.is_empty() {
568            return case
569                .else_
570                .map(|e| self.simplify_once(e))
571                .unwrap_or_else(null);
572        }
573
574        Expression::Case(Box::new(Case {
575            operand: case.operand.map(|e| self.simplify_once(e)),
576            whens: new_whens,
577            else_: case.else_.map(|e| self.simplify_once(e)),
578            comments: Vec::new(),
579            inferred_type: None,
580        }))
581    }
582
583    /// Simplify string concatenation (Concat is || operator)
584    ///
585    /// Folds adjacent string literals:
586    /// - 'a' || 'b' -> 'ab'
587    /// - 'a' || 'b' || 'c' -> 'abc'
588    /// - '' || x -> x
589    /// - x || '' -> x
590    fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
591        let left = self.simplify_once(op.left);
592        let right = self.simplify_once(op.right);
593
594        // Fold two string literals: 'a' || 'b' -> 'ab'
595        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
596            return Expression::Literal(Literal::String(format!("{}{}", a, b)));
597        }
598
599        // '' || x -> x
600        if let Some(s) = get_string(&left) {
601            if s.is_empty() {
602                return right;
603            }
604        }
605
606        // x || '' -> x
607        if let Some(s) = get_string(&right) {
608            if s.is_empty() {
609                return left;
610            }
611        }
612
613        // NULL || x -> NULL, x || NULL -> NULL (SQL string concat semantics)
614        if is_null(&left) || is_null(&right) {
615            return null();
616        }
617
618        Expression::Concat(Box::new(BinaryOp::new(left, right)))
619    }
620
621    /// Simplify CONCAT_WS function
622    ///
623    /// CONCAT_WS(sep, a, b, c) -> concatenates with separator, skipping NULLs
624    /// - CONCAT_WS(',', 'a', 'b') -> 'a,b' (when all are literals)
625    /// - CONCAT_WS(',', 'a', NULL, 'b') -> 'a,b' (NULLs are skipped)
626    /// - CONCAT_WS(NULL, ...) -> NULL
627    fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
628        let separator = self.simplify_once(concat_ws.separator);
629
630        // If separator is NULL, result is NULL
631        if is_null(&separator) {
632            return null();
633        }
634
635        let expressions: Vec<Expression> = concat_ws
636            .expressions
637            .into_iter()
638            .map(|e| self.simplify_once(e))
639            .filter(|e| !is_null(e)) // Skip NULL values
640            .collect();
641
642        // If no expressions remain, return empty string
643        if expressions.is_empty() {
644            return Expression::Literal(Literal::String(String::new()));
645        }
646
647        // Try to fold if all are string literals
648        if let Some(sep) = get_string(&separator) {
649            let all_strings: Option<Vec<String>> =
650                expressions.iter().map(|e| get_string(e)).collect();
651
652            if let Some(strings) = all_strings {
653                return Expression::Literal(Literal::String(strings.join(&sep)));
654            }
655        }
656
657        // Return simplified CONCAT_WS
658        Expression::ConcatWs(Box::new(ConcatWs {
659            separator,
660            expressions,
661        }))
662    }
663
664    /// Simplify parentheses
665    ///
666    /// Remove unnecessary parentheses:
667    /// - (x) -> x when x is a literal, column, or already parenthesized
668    /// - ((x)) -> (x) -> x (recursive simplification)
669    fn simplify_paren(&mut self, paren: Paren) -> Expression {
670        let inner = self.simplify_once(paren.this);
671
672        // If inner is a literal, column, boolean, null, or already parenthesized,
673        // we can remove the parentheses
674        match &inner {
675            Expression::Literal(_)
676            | Expression::Boolean(_)
677            | Expression::Null(_)
678            | Expression::Column(_)
679            | Expression::Paren(_) => inner,
680            // For other expressions, keep the parentheses
681            _ => Expression::Paren(Box::new(Paren {
682                this: inner,
683                trailing_comments: paren.trailing_comments,
684            })),
685        }
686    }
687
688    /// Simplify DATE_TRUNC and TIMESTAMP_TRUNC
689    ///
690    /// Currently just simplifies children and passes through.
691    /// Future: could fold DATE_TRUNC('day', '2024-01-15') -> '2024-01-15'
692    fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
693        let inner = self.simplify_once(dt.this);
694
695        // For now, just return with simplified inner expression
696        // A more advanced implementation would fold constant date/timestamps
697        Expression::DateTrunc(Box::new(DateTruncFunc {
698            this: inner,
699            unit: dt.unit,
700        }))
701    }
702
703    /// Simplify equality with arithmetic (solve simple equations)
704    ///
705    /// - x + 1 = 3 -> x = 2
706    /// - x - 1 = 3 -> x = 4
707    /// - x * 2 = 6 -> x = 3 (only when divisible)
708    /// - 1 + x = 3 -> x = 2 (commutative)
709    fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
710        // Only works when right side is a constant
711        let right_val = get_number(&right)?;
712
713        // Check if left side is arithmetic with one constant
714        match left {
715            Expression::Add(ref op) => {
716                // x + c = r -> x = r - c
717                if let Some(c) = get_number(&op.right) {
718                    let new_right =
719                        Expression::Literal(Literal::Number((right_val - c).to_string()));
720                    return Some(Expression::Eq(Box::new(BinaryOp::new(
721                        op.left.clone(),
722                        new_right,
723                    ))));
724                }
725                // c + x = r -> x = r - c
726                if let Some(c) = get_number(&op.left) {
727                    let new_right =
728                        Expression::Literal(Literal::Number((right_val - c).to_string()));
729                    return Some(Expression::Eq(Box::new(BinaryOp::new(
730                        op.right.clone(),
731                        new_right,
732                    ))));
733                }
734            }
735            Expression::Sub(ref op) => {
736                // x - c = r -> x = r + c
737                if let Some(c) = get_number(&op.right) {
738                    let new_right =
739                        Expression::Literal(Literal::Number((right_val + c).to_string()));
740                    return Some(Expression::Eq(Box::new(BinaryOp::new(
741                        op.left.clone(),
742                        new_right,
743                    ))));
744                }
745                // c - x = r -> x = c - r
746                if let Some(c) = get_number(&op.left) {
747                    let new_right =
748                        Expression::Literal(Literal::Number((c - right_val).to_string()));
749                    return Some(Expression::Eq(Box::new(BinaryOp::new(
750                        op.right.clone(),
751                        new_right,
752                    ))));
753                }
754            }
755            Expression::Mul(ref op) => {
756                // x * c = r -> x = r / c (only for non-zero c and when divisible)
757                if let Some(c) = get_number(&op.right) {
758                    if c != 0.0 && right_val % c == 0.0 {
759                        let new_right =
760                            Expression::Literal(Literal::Number((right_val / c).to_string()));
761                        return Some(Expression::Eq(Box::new(BinaryOp::new(
762                            op.left.clone(),
763                            new_right,
764                        ))));
765                    }
766                }
767                // c * x = r -> x = r / c
768                if let Some(c) = get_number(&op.left) {
769                    if c != 0.0 && right_val % c == 0.0 {
770                        let new_right =
771                            Expression::Literal(Literal::Number((right_val / c).to_string()));
772                        return Some(Expression::Eq(Box::new(BinaryOp::new(
773                            op.right.clone(),
774                            new_right,
775                        ))));
776                    }
777                }
778            }
779            _ => {}
780        }
781
782        None
783    }
784
785    /// Recursively simplify children of an expression
786    fn simplify_children(&mut self, expr: Expression) -> Expression {
787        // For expressions we don't have specific simplification rules for,
788        // we still want to simplify their children
789        match expr {
790            Expression::Alias(mut alias) => {
791                alias.this = self.simplify_once(alias.this);
792                Expression::Alias(alias)
793            }
794            Expression::Between(mut between) => {
795                between.this = self.simplify_once(between.this);
796                between.low = self.simplify_once(between.low);
797                between.high = self.simplify_once(between.high);
798                Expression::Between(between)
799            }
800            Expression::In(mut in_expr) => {
801                in_expr.this = self.simplify_once(in_expr.this);
802                in_expr.expressions = in_expr
803                    .expressions
804                    .into_iter()
805                    .map(|e| self.simplify_once(e))
806                    .collect();
807                Expression::In(in_expr)
808            }
809            Expression::Function(mut func) => {
810                func.args = func
811                    .args
812                    .into_iter()
813                    .map(|e| self.simplify_once(e))
814                    .collect();
815                Expression::Function(func)
816            }
817            // For other expressions, return as-is for now
818            other => other,
819        }
820    }
821}
822
823/// Check if expression equals 1
824fn is_one(expr: &Expression) -> bool {
825    match expr {
826        Expression::Literal(Literal::Number(n)) => {
827            if let Ok(num) = n.parse::<f64>() {
828                num == 1.0
829            } else {
830                false
831            }
832        }
833        _ => false,
834    }
835}
836
837/// Get numeric value from expression if it's a number literal
838fn get_number(expr: &Expression) -> Option<f64> {
839    match expr {
840        Expression::Literal(Literal::Number(n)) => n.parse().ok(),
841        _ => None,
842    }
843}
844
845/// Get string value from expression if it's a string literal
846fn get_string(expr: &Expression) -> Option<String> {
847    match expr {
848        Expression::Literal(Literal::String(s)) => Some(s.clone()),
849        _ => None,
850    }
851}
852
853/// Check if two expressions are structurally equal
854/// This is a simplified comparison - a full implementation would need deep comparison
855fn expressions_equal(a: &Expression, b: &Expression) -> bool {
856    // For now, use Debug representation for comparison
857    // A proper implementation would do structural comparison
858    format!("{:?}", a) == format!("{:?}", b)
859}
860
861/// Flatten nested AND expressions into a list of operands
862/// e.g., (A AND (B AND C)) -> [A, B, C]
863fn flatten_and(expr: &Expression) -> Vec<Expression> {
864    match expr {
865        Expression::And(op) => {
866            let mut result = flatten_and(&op.left);
867            result.extend(flatten_and(&op.right));
868            result
869        }
870        other => vec![other.clone()],
871    }
872}
873
874/// Flatten nested OR expressions into a list of operands
875/// e.g., (A OR (B OR C)) -> [A, B, C]
876fn flatten_or(expr: &Expression) -> Vec<Expression> {
877    match expr {
878        Expression::Or(op) => {
879            let mut result = flatten_or(&op.left);
880            result.extend(flatten_or(&op.right));
881            result
882        }
883        other => vec![other.clone()],
884    }
885}
886
887/// Rebuild an AND expression from a list of operands
888fn rebuild_and(operands: Vec<Expression>) -> Expression {
889    if operands.is_empty() {
890        return bool_true(); // Empty AND is TRUE
891    }
892    let mut result = operands.into_iter();
893    let first = result.next().unwrap();
894    result.fold(first, |acc, op| {
895        Expression::And(Box::new(BinaryOp::new(acc, op)))
896    })
897}
898
899/// Rebuild an OR expression from a list of operands
900fn rebuild_or(operands: Vec<Expression>) -> Expression {
901    if operands.is_empty() {
902        return bool_false(); // Empty OR is FALSE
903    }
904    let mut result = operands.into_iter();
905    let first = result.next().unwrap();
906    result.fold(first, |acc, op| {
907        Expression::Or(Box::new(BinaryOp::new(acc, op)))
908    })
909}
910
911/// Get the inner expression of a NOT, if it is one
912fn get_not_inner(expr: &Expression) -> Option<&Expression> {
913    match expr {
914        Expression::Not(op) => Some(&op.this),
915        _ => None,
916    }
917}
918
919/// Apply Boolean absorption and elimination rules to an AND expression
920///
921/// Absorption:
922///   A AND (A OR B) -> A
923///   A AND (NOT A OR B) -> A AND B
924///
925/// Elimination:
926///   (A OR B) AND (A OR NOT B) -> A
927pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
928    // Flatten both sides
929    let left_ops = flatten_and(&left);
930    let right_ops = flatten_and(&right);
931    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
932
933    // Build a set of string representations for quick lookup
934    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
935
936    let mut result_ops: Vec<Expression> = Vec::new();
937    let mut absorbed = std::collections::HashSet::new();
938
939    for (i, op) in all_ops.iter().enumerate() {
940        let op_str = gen(op);
941
942        // Skip if already absorbed
943        if absorbed.contains(&op_str) {
944            continue;
945        }
946
947        // Check if this is an OR expression (potential absorption target)
948        if let Expression::Or(_) = op {
949            let or_operands = flatten_or(op);
950
951            // Absorption: A AND (A OR B) -> A
952            // Check if any OR operand is already in our AND operands
953            let absorbed_by_existing = or_operands.iter().any(|or_op| {
954                let or_op_str = gen(or_op);
955                // Check if this OR operand exists in other AND operands (not this OR itself)
956                all_ops
957                    .iter()
958                    .enumerate()
959                    .any(|(j, other)| i != j && gen(other) == or_op_str)
960            });
961
962            if absorbed_by_existing {
963                // This OR is absorbed, skip it
964                absorbed.insert(op_str);
965                continue;
966            }
967
968            // Absorption with complement: A AND (NOT A OR B) -> A AND B
969            // Check if any OR operand's complement is in our AND operands
970            let mut remaining_or_ops: Vec<Expression> = Vec::new();
971            let mut had_complement_absorption = false;
972
973            for or_op in or_operands {
974                let complement_str = if let Some(inner) = get_not_inner(&or_op) {
975                    // or_op is NOT X, complement is X
976                    gen(inner)
977                } else {
978                    // or_op is X, complement is NOT X
979                    format!("NOT {}", gen(&or_op))
980                };
981
982                // Check if complement exists in our AND operands
983                let has_complement = all_ops
984                    .iter()
985                    .enumerate()
986                    .any(|(j, other)| i != j && gen(other) == complement_str)
987                    || op_strings.contains(&complement_str);
988
989                if has_complement {
990                    // This OR operand's complement exists, so this term becomes TRUE in AND context
991                    // NOT A OR B, where A exists, becomes TRUE OR B (when A is true) or B (when A is false)
992                    // Actually: A AND (NOT A OR B) -> A AND B, so we drop NOT A from the OR
993                    had_complement_absorption = true;
994                    // Drop this operand from OR
995                } else {
996                    remaining_or_ops.push(or_op);
997                }
998            }
999
1000            if had_complement_absorption {
1001                if remaining_or_ops.is_empty() {
1002                    // All OR operands were absorbed, the OR becomes TRUE
1003                    // A AND TRUE -> A, so we just skip adding this
1004                    absorbed.insert(op_str);
1005                    continue;
1006                } else if remaining_or_ops.len() == 1 {
1007                    // Single remaining operand
1008                    result_ops.push(remaining_or_ops.into_iter().next().unwrap());
1009                    absorbed.insert(op_str);
1010                    continue;
1011                } else {
1012                    // Rebuild the OR with remaining operands
1013                    result_ops.push(rebuild_or(remaining_or_ops));
1014                    absorbed.insert(op_str);
1015                    continue;
1016                }
1017            }
1018        }
1019
1020        result_ops.push(op.clone());
1021    }
1022
1023    // Deduplicate
1024    let mut seen = std::collections::HashSet::new();
1025    result_ops.retain(|op| seen.insert(gen(op)));
1026
1027    if result_ops.is_empty() {
1028        bool_true()
1029    } else {
1030        rebuild_and(result_ops)
1031    }
1032}
1033
1034/// Apply Boolean absorption and elimination rules to an OR expression
1035///
1036/// Absorption:
1037///   A OR (A AND B) -> A
1038///   A OR (NOT A AND B) -> A OR B
1039///
1040/// Elimination:
1041///   (A AND B) OR (A AND NOT B) -> A
1042pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1043    // Flatten both sides
1044    let left_ops = flatten_or(&left);
1045    let right_ops = flatten_or(&right);
1046    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1047
1048    // Build a set of string representations for quick lookup
1049    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1050
1051    let mut result_ops: Vec<Expression> = Vec::new();
1052    let mut absorbed = std::collections::HashSet::new();
1053
1054    for (i, op) in all_ops.iter().enumerate() {
1055        let op_str = gen(op);
1056
1057        // Skip if already absorbed
1058        if absorbed.contains(&op_str) {
1059            continue;
1060        }
1061
1062        // Check if this is an AND expression (potential absorption target)
1063        if let Expression::And(_) = op {
1064            let and_operands = flatten_and(op);
1065
1066            // Absorption: A OR (A AND B) -> A
1067            // Check if any AND operand is already in our OR operands
1068            let absorbed_by_existing = and_operands.iter().any(|and_op| {
1069                let and_op_str = gen(and_op);
1070                // Check if this AND operand exists in other OR operands (not this AND itself)
1071                all_ops
1072                    .iter()
1073                    .enumerate()
1074                    .any(|(j, other)| i != j && gen(other) == and_op_str)
1075            });
1076
1077            if absorbed_by_existing {
1078                // This AND is absorbed, skip it
1079                absorbed.insert(op_str);
1080                continue;
1081            }
1082
1083            // Absorption with complement: A OR (NOT A AND B) -> A OR B
1084            // Check if any AND operand's complement is in our OR operands
1085            let mut remaining_and_ops: Vec<Expression> = Vec::new();
1086            let mut had_complement_absorption = false;
1087
1088            for and_op in and_operands {
1089                let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1090                    // and_op is NOT X, complement is X
1091                    gen(inner)
1092                } else {
1093                    // and_op is X, complement is NOT X
1094                    format!("NOT {}", gen(&and_op))
1095                };
1096
1097                // Check if complement exists in our OR operands
1098                let has_complement = all_ops
1099                    .iter()
1100                    .enumerate()
1101                    .any(|(j, other)| i != j && gen(other) == complement_str)
1102                    || op_strings.contains(&complement_str);
1103
1104                if has_complement {
1105                    // This AND operand's complement exists, so this term becomes FALSE in OR context
1106                    // A OR (NOT A AND B) -> A OR B, so we drop NOT A from the AND
1107                    had_complement_absorption = true;
1108                    // Drop this operand from AND
1109                } else {
1110                    remaining_and_ops.push(and_op);
1111                }
1112            }
1113
1114            if had_complement_absorption {
1115                if remaining_and_ops.is_empty() {
1116                    // All AND operands were absorbed, the AND becomes FALSE
1117                    // A OR FALSE -> A, so we just skip adding this
1118                    absorbed.insert(op_str);
1119                    continue;
1120                } else if remaining_and_ops.len() == 1 {
1121                    // Single remaining operand
1122                    result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1123                    absorbed.insert(op_str);
1124                    continue;
1125                } else {
1126                    // Rebuild the AND with remaining operands
1127                    result_ops.push(rebuild_and(remaining_and_ops));
1128                    absorbed.insert(op_str);
1129                    continue;
1130                }
1131            }
1132        }
1133
1134        result_ops.push(op.clone());
1135    }
1136
1137    // Deduplicate
1138    let mut seen = std::collections::HashSet::new();
1139    result_ops.retain(|op| seen.insert(gen(op)));
1140
1141    if result_ops.is_empty() {
1142        bool_false()
1143    } else {
1144        rebuild_or(result_ops)
1145    }
1146}
1147
1148/// Generate a simple string representation of an expression for sorting/deduping
1149pub fn gen(expr: &Expression) -> String {
1150    match expr {
1151        Expression::Literal(lit) => match lit {
1152            Literal::String(s) => format!("'{}'", s),
1153            Literal::Number(n) => n.clone(),
1154            _ => format!("{:?}", lit),
1155        },
1156        Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1157        Expression::Null(_) => "NULL".to_string(),
1158        Expression::Column(col) => {
1159            if let Some(ref table) = col.table {
1160                format!("{}.{}", table.name, col.name.name)
1161            } else {
1162                col.name.name.clone()
1163            }
1164        }
1165        Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1166        Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1167        Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1168        Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1169        Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1170        Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1171        Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1172        Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1173        Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1174        Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1175        Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1176        Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1177        Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1178        Expression::Function(f) => {
1179            let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1180            format!("{}({})", f.name.to_uppercase(), args.join(", "))
1181        }
1182        _ => format!("{:?}", expr),
1183    }
1184}
1185
1186#[cfg(test)]
1187mod tests {
1188    use super::*;
1189
1190    fn make_int(val: i64) -> Expression {
1191        Expression::Literal(Literal::Number(val.to_string()))
1192    }
1193
1194    fn make_string(val: &str) -> Expression {
1195        Expression::Literal(Literal::String(val.to_string()))
1196    }
1197
1198    fn make_bool(val: bool) -> Expression {
1199        Expression::Boolean(BooleanLiteral { value: val })
1200    }
1201
1202    fn make_column(name: &str) -> Expression {
1203        use crate::expressions::{Column, Identifier};
1204        Expression::Column(Column {
1205            name: Identifier::new(name),
1206            table: None,
1207            join_mark: false,
1208            trailing_comments: vec![],
1209            span: None,
1210            inferred_type: None,
1211        })
1212    }
1213
1214    #[test]
1215    fn test_always_true_false() {
1216        assert!(always_true(&make_bool(true)));
1217        assert!(!always_true(&make_bool(false)));
1218        assert!(always_true(&make_int(1)));
1219        assert!(!always_true(&make_int(0)));
1220
1221        assert!(always_false(&make_bool(false)));
1222        assert!(!always_false(&make_bool(true)));
1223        assert!(always_false(&null()));
1224        assert!(always_false(&make_int(0)));
1225    }
1226
1227    #[test]
1228    fn test_simplify_and_with_true() {
1229        let mut simplifier = Simplifier::new(None);
1230
1231        // TRUE AND TRUE -> TRUE
1232        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1233        let result = simplifier.simplify(expr);
1234        assert!(always_true(&result));
1235
1236        // TRUE AND FALSE -> FALSE
1237        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1238        let result = simplifier.simplify(expr);
1239        assert!(always_false(&result));
1240
1241        // TRUE AND x -> x
1242        let x = make_int(42);
1243        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1244        let result = simplifier.simplify(expr);
1245        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1246    }
1247
1248    #[test]
1249    fn test_simplify_or_with_false() {
1250        let mut simplifier = Simplifier::new(None);
1251
1252        // FALSE OR FALSE -> FALSE
1253        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1254        let result = simplifier.simplify(expr);
1255        assert!(always_false(&result));
1256
1257        // FALSE OR TRUE -> TRUE
1258        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1259        let result = simplifier.simplify(expr);
1260        assert!(always_true(&result));
1261
1262        // FALSE OR x -> x
1263        let x = make_int(42);
1264        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1265        let result = simplifier.simplify(expr);
1266        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1267    }
1268
1269    #[test]
1270    fn test_simplify_not() {
1271        let mut simplifier = Simplifier::new(None);
1272
1273        // NOT TRUE -> FALSE
1274        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1275        let result = simplifier.simplify(expr);
1276        assert!(is_false(&result));
1277
1278        // NOT FALSE -> TRUE
1279        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1280        let result = simplifier.simplify(expr);
1281        assert!(always_true(&result));
1282
1283        // NOT NOT x -> x
1284        let x = make_int(42);
1285        let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1286        let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1287        let result = simplifier.simplify(expr);
1288        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1289    }
1290
1291    #[test]
1292    fn test_simplify_demorgan_comparison() {
1293        let mut simplifier = Simplifier::new(None);
1294
1295        // NOT (a = b) -> a != b (using columns to avoid constant folding)
1296        let a = make_column("a");
1297        let b = make_column("b");
1298        let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1299        let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1300        let result = simplifier.simplify(expr);
1301        assert!(matches!(result, Expression::Neq(_)));
1302
1303        // NOT (a > b) -> a <= b
1304        let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1305        let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1306        let result = simplifier.simplify(expr);
1307        assert!(matches!(result, Expression::Lte(_)));
1308    }
1309
1310    #[test]
1311    fn test_constant_folding_add() {
1312        let mut simplifier = Simplifier::new(None);
1313
1314        // 1 + 2 -> 3
1315        let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1316        let result = simplifier.simplify(expr);
1317        assert_eq!(get_number(&result), Some(3.0));
1318
1319        // x + 0 -> x
1320        let x = make_int(42);
1321        let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1322        let result = simplifier.simplify(expr);
1323        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1324    }
1325
1326    #[test]
1327    fn test_constant_folding_mul() {
1328        let mut simplifier = Simplifier::new(None);
1329
1330        // 3 * 4 -> 12
1331        let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1332        let result = simplifier.simplify(expr);
1333        assert_eq!(get_number(&result), Some(12.0));
1334
1335        // x * 0 -> 0
1336        let x = make_int(42);
1337        let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1338        let result = simplifier.simplify(expr);
1339        assert_eq!(get_number(&result), Some(0.0));
1340
1341        // x * 1 -> x
1342        let x = make_int(42);
1343        let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1344        let result = simplifier.simplify(expr);
1345        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1346    }
1347
1348    #[test]
1349    fn test_constant_folding_comparison() {
1350        let mut simplifier = Simplifier::new(None);
1351
1352        // 1 = 1 -> TRUE
1353        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1354        let result = simplifier.simplify(expr);
1355        assert!(always_true(&result));
1356
1357        // 1 = 2 -> FALSE
1358        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1359        let result = simplifier.simplify(expr);
1360        assert!(is_false(&result));
1361
1362        // 3 > 2 -> TRUE
1363        let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1364        let result = simplifier.simplify(expr);
1365        assert!(always_true(&result));
1366
1367        // 'a' = 'a' -> TRUE
1368        let expr = Expression::Eq(Box::new(BinaryOp::new(
1369            make_string("abc"),
1370            make_string("abc"),
1371        )));
1372        let result = simplifier.simplify(expr);
1373        assert!(always_true(&result));
1374    }
1375
1376    #[test]
1377    fn test_simplify_negation() {
1378        let mut simplifier = Simplifier::new(None);
1379
1380        // -(-5) -> 5
1381        let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1382        let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1383        let result = simplifier.simplify(expr);
1384        assert_eq!(get_number(&result), Some(5.0));
1385
1386        // -(3) -> -3
1387        let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1388        let result = simplifier.simplify(expr);
1389        assert_eq!(get_number(&result), Some(-3.0));
1390    }
1391
1392    #[test]
1393    fn test_gen_simple() {
1394        assert_eq!(gen(&make_int(42)), "42");
1395        assert_eq!(gen(&make_string("hello")), "'hello'");
1396        assert_eq!(gen(&make_bool(true)), "TRUE");
1397        assert_eq!(gen(&make_bool(false)), "FALSE");
1398        assert_eq!(gen(&null()), "NULL");
1399    }
1400
1401    #[test]
1402    fn test_gen_operations() {
1403        let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1404        assert_eq!(gen(&add), "1 + 2");
1405
1406        let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1407        assert_eq!(gen(&and), "(TRUE AND FALSE)");
1408    }
1409
1410    #[test]
1411    fn test_complement_elimination() {
1412        let mut simplifier = Simplifier::new(None);
1413
1414        // x AND NOT x -> FALSE
1415        let x = make_int(42);
1416        let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1417        let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1418        let result = simplifier.simplify(expr);
1419        assert!(is_false(&result));
1420    }
1421
1422    #[test]
1423    fn test_idempotent() {
1424        let mut simplifier = Simplifier::new(None);
1425
1426        // x AND x -> x
1427        let x = make_int(42);
1428        let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1429        let result = simplifier.simplify(expr);
1430        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1431
1432        // x OR x -> x
1433        let x = make_int(42);
1434        let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1435        let result = simplifier.simplify(expr);
1436        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1437    }
1438
1439    #[test]
1440    fn test_absorption_and() {
1441        let mut simplifier = Simplifier::new(None);
1442
1443        // A AND (A OR B) -> A
1444        let a = make_column("a");
1445        let b = make_column("b");
1446        let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1447        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1448        let result = simplifier.simplify(expr);
1449        // Result should be just A
1450        assert_eq!(gen(&result), gen(&a));
1451    }
1452
1453    #[test]
1454    fn test_absorption_or() {
1455        let mut simplifier = Simplifier::new(None);
1456
1457        // A OR (A AND B) -> A
1458        let a = make_column("a");
1459        let b = make_column("b");
1460        let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1461        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1462        let result = simplifier.simplify(expr);
1463        // Result should be just A
1464        assert_eq!(gen(&result), gen(&a));
1465    }
1466
1467    #[test]
1468    fn test_absorption_with_complement_and() {
1469        let mut simplifier = Simplifier::new(None);
1470
1471        // A AND (NOT A OR B) -> A AND B
1472        let a = make_column("a");
1473        let b = make_column("b");
1474        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1475        let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1476        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1477        let result = simplifier.simplify(expr);
1478        // Result should be A AND B
1479        let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1480        assert_eq!(gen(&result), gen(&expected));
1481    }
1482
1483    #[test]
1484    fn test_absorption_with_complement_or() {
1485        let mut simplifier = Simplifier::new(None);
1486
1487        // A OR (NOT A AND B) -> A OR B
1488        let a = make_column("a");
1489        let b = make_column("b");
1490        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1491        let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1492        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1493        let result = simplifier.simplify(expr);
1494        // Result should be A OR B
1495        let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1496        assert_eq!(gen(&result), gen(&expected));
1497    }
1498
1499    #[test]
1500    fn test_flatten_and() {
1501        // (A AND (B AND C)) should flatten to [A, B, C]
1502        let a = make_column("a");
1503        let b = make_column("b");
1504        let c = make_column("c");
1505        let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1506        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1507        let flattened = flatten_and(&expr);
1508        assert_eq!(flattened.len(), 3);
1509        assert_eq!(gen(&flattened[0]), "a");
1510        assert_eq!(gen(&flattened[1]), "b");
1511        assert_eq!(gen(&flattened[2]), "c");
1512    }
1513
1514    #[test]
1515    fn test_flatten_or() {
1516        // (A OR (B OR C)) should flatten to [A, B, C]
1517        let a = make_column("a");
1518        let b = make_column("b");
1519        let c = make_column("c");
1520        let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1521        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1522        let flattened = flatten_or(&expr);
1523        assert_eq!(flattened.len(), 3);
1524        assert_eq!(gen(&flattened[0]), "a");
1525        assert_eq!(gen(&flattened[1]), "b");
1526        assert_eq!(gen(&flattened[2]), "c");
1527    }
1528
1529    #[test]
1530    fn test_simplify_concat() {
1531        let mut simplifier = Simplifier::new(None);
1532
1533        // 'a' || 'b' -> 'ab'
1534        let expr = Expression::Concat(Box::new(BinaryOp::new(
1535            make_string("hello"),
1536            make_string("world"),
1537        )));
1538        let result = simplifier.simplify(expr);
1539        assert_eq!(get_string(&result), Some("helloworld".to_string()));
1540
1541        // '' || x -> x
1542        let x = make_string("test");
1543        let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1544        let result = simplifier.simplify(expr);
1545        assert_eq!(get_string(&result), Some("test".to_string()));
1546
1547        // x || '' -> x
1548        let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1549        let result = simplifier.simplify(expr);
1550        assert_eq!(get_string(&result), Some("test".to_string()));
1551
1552        // NULL || x -> NULL
1553        let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1554        let result = simplifier.simplify(expr);
1555        assert!(is_null(&result));
1556    }
1557
1558    #[test]
1559    fn test_simplify_concat_ws() {
1560        let mut simplifier = Simplifier::new(None);
1561
1562        // CONCAT_WS(',', 'a', 'b', 'c') -> 'a,b,c'
1563        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1564            separator: make_string(","),
1565            expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1566        }));
1567        let result = simplifier.simplify(expr);
1568        assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1569
1570        // CONCAT_WS with NULL separator -> NULL
1571        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1572            separator: null(),
1573            expressions: vec![make_string("a"), make_string("b")],
1574        }));
1575        let result = simplifier.simplify(expr);
1576        assert!(is_null(&result));
1577
1578        // CONCAT_WS with empty expressions -> ''
1579        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1580            separator: make_string(","),
1581            expressions: vec![],
1582        }));
1583        let result = simplifier.simplify(expr);
1584        assert_eq!(get_string(&result), Some("".to_string()));
1585
1586        // CONCAT_WS skips NULLs
1587        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1588            separator: make_string("-"),
1589            expressions: vec![make_string("a"), null(), make_string("b")],
1590        }));
1591        let result = simplifier.simplify(expr);
1592        assert_eq!(get_string(&result), Some("a-b".to_string()));
1593    }
1594
1595    #[test]
1596    fn test_simplify_paren() {
1597        let mut simplifier = Simplifier::new(None);
1598
1599        // (42) -> 42
1600        let expr = Expression::Paren(Box::new(Paren {
1601            this: make_int(42),
1602            trailing_comments: vec![],
1603        }));
1604        let result = simplifier.simplify(expr);
1605        assert_eq!(get_number(&result), Some(42.0));
1606
1607        // (TRUE) -> TRUE
1608        let expr = Expression::Paren(Box::new(Paren {
1609            this: make_bool(true),
1610            trailing_comments: vec![],
1611        }));
1612        let result = simplifier.simplify(expr);
1613        assert!(is_boolean_true(&result));
1614
1615        // (NULL) -> NULL
1616        let expr = Expression::Paren(Box::new(Paren {
1617            this: null(),
1618            trailing_comments: vec![],
1619        }));
1620        let result = simplifier.simplify(expr);
1621        assert!(is_null(&result));
1622
1623        // ((x)) -> x
1624        let inner_paren = Expression::Paren(Box::new(Paren {
1625            this: make_int(10),
1626            trailing_comments: vec![],
1627        }));
1628        let expr = Expression::Paren(Box::new(Paren {
1629            this: inner_paren,
1630            trailing_comments: vec![],
1631        }));
1632        let result = simplifier.simplify(expr);
1633        assert_eq!(get_number(&result), Some(10.0));
1634    }
1635
1636    #[test]
1637    fn test_simplify_equality_solve() {
1638        let mut simplifier = Simplifier::new(None);
1639
1640        // x + 1 = 3 -> x = 2
1641        let x = make_column("x");
1642        let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1643        let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1644        let result = simplifier.simplify(expr);
1645        // Result should be x = 2
1646        if let Expression::Eq(op) = &result {
1647            assert_eq!(gen(&op.left), "x");
1648            assert_eq!(get_number(&op.right), Some(2.0));
1649        } else {
1650            panic!("Expected Eq expression");
1651        }
1652
1653        // x - 1 = 3 -> x = 4
1654        let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1655        let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1656        let result = simplifier.simplify(expr);
1657        if let Expression::Eq(op) = &result {
1658            assert_eq!(gen(&op.left), "x");
1659            assert_eq!(get_number(&op.right), Some(4.0));
1660        } else {
1661            panic!("Expected Eq expression");
1662        }
1663
1664        // x * 2 = 6 -> x = 3
1665        let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1666        let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1667        let result = simplifier.simplify(expr);
1668        if let Expression::Eq(op) = &result {
1669            assert_eq!(gen(&op.left), "x");
1670            assert_eq!(get_number(&op.right), Some(3.0));
1671        } else {
1672            panic!("Expected Eq expression");
1673        }
1674
1675        // 1 + x = 3 -> x = 2 (commutative)
1676        let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1677        let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1678        let result = simplifier.simplify(expr);
1679        if let Expression::Eq(op) = &result {
1680            assert_eq!(gen(&op.left), "x");
1681            assert_eq!(get_number(&op.right), Some(2.0));
1682        } else {
1683            panic!("Expected Eq expression");
1684        }
1685    }
1686
1687    #[test]
1688    fn test_simplify_datetrunc() {
1689        use crate::expressions::DateTimeField;
1690        let mut simplifier = Simplifier::new(None);
1691
1692        // DATE_TRUNC('day', x) with a column just passes through with simplified children
1693        let x = make_column("x");
1694        let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1695            this: x.clone(),
1696            unit: DateTimeField::Day,
1697        }));
1698        let result = simplifier.simplify(expr);
1699        if let Expression::DateTrunc(dt) = &result {
1700            assert_eq!(gen(&dt.this), "x");
1701            assert_eq!(dt.unit, DateTimeField::Day);
1702        } else {
1703            panic!("Expected DateTrunc expression");
1704        }
1705    }
1706}