Skip to main content

polyglot_sql/optimizer/
simplify.rs

1//! Expression Simplification
2//!
3//! This module provides boolean and expression simplification for SQL AST nodes.
4//! It applies various algebraic transformations to simplify expressions:
5//! - De Morgan's laws (NOT (A AND B) -> NOT A OR NOT B)
6//! - Constant folding (1 + 2 -> 3)
7//! - Boolean absorption (A AND (A OR B) -> A)
8//! - Complement removal (A AND NOT A -> FALSE)
9//! - Connector flattening (A AND (B AND C) -> A AND B AND C)
10//!
11//! Based on SQLGlot's optimizer/simplify.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{
15    BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16    UnaryOp,
17};
18
19/// Main entry point for expression simplification
20pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21    let mut simplifier = Simplifier::new(dialect);
22    simplifier.simplify(expression)
23}
24
25/// Check if expression is always true
26pub fn always_true(expr: &Expression) -> bool {
27    match expr {
28        Expression::Boolean(b) => b.value,
29        Expression::Literal(Literal::Number(n)) => {
30            // Non-zero numbers are truthy
31            if let Ok(num) = n.parse::<f64>() {
32                num != 0.0
33            } else {
34                false
35            }
36        }
37        _ => false,
38    }
39}
40
41/// Check if expression is a boolean TRUE literal (not just truthy)
42pub fn is_boolean_true(expr: &Expression) -> bool {
43    matches!(expr, Expression::Boolean(b) if b.value)
44}
45
46/// Check if expression is a boolean FALSE literal (not just falsy)
47pub fn is_boolean_false(expr: &Expression) -> bool {
48    matches!(expr, Expression::Boolean(b) if !b.value)
49}
50
51/// Check if expression is always false
52pub fn always_false(expr: &Expression) -> bool {
53    is_false(expr) || is_null(expr) || is_zero(expr)
54}
55
56/// Check if expression is boolean FALSE
57pub fn is_false(expr: &Expression) -> bool {
58    matches!(expr, Expression::Boolean(b) if !b.value)
59}
60
61/// Check if expression is NULL
62pub fn is_null(expr: &Expression) -> bool {
63    matches!(expr, Expression::Null(_))
64}
65
66/// Check if expression is zero
67pub fn is_zero(expr: &Expression) -> bool {
68    match expr {
69        Expression::Literal(Literal::Number(n)) => {
70            if let Ok(num) = n.parse::<f64>() {
71                num == 0.0
72            } else {
73                false
74            }
75        }
76        _ => false,
77    }
78}
79
80/// Check if b is the complement of a (i.e., b = NOT a)
81pub fn is_complement(a: &Expression, b: &Expression) -> bool {
82    if let Expression::Not(not_op) = b {
83        &not_op.this == a
84    } else {
85        false
86    }
87}
88
89/// Create a TRUE boolean literal
90pub fn bool_true() -> Expression {
91    Expression::Boolean(BooleanLiteral { value: true })
92}
93
94/// Create a FALSE boolean literal
95pub fn bool_false() -> Expression {
96    Expression::Boolean(BooleanLiteral { value: false })
97}
98
99/// Create a NULL expression
100pub fn null() -> Expression {
101    Expression::Null(Null)
102}
103
104/// Evaluate a boolean comparison between two numbers
105pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
106    let result = match op {
107        "=" | "==" => a == b,
108        "!=" | "<>" => a != b,
109        ">" => a > b,
110        ">=" => a >= b,
111        "<" => a < b,
112        "<=" => a <= b,
113        _ => return None,
114    };
115    Some(if result { bool_true() } else { bool_false() })
116}
117
118/// Evaluate a boolean comparison between two strings
119pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
120    let result = match op {
121        "=" | "==" => a == b,
122        "!=" | "<>" => a != b,
123        ">" => a > b,
124        ">=" => a >= b,
125        "<" => a < b,
126        "<=" => a <= b,
127        _ => return None,
128    };
129    Some(if result { bool_true() } else { bool_false() })
130}
131
132/// Expression simplifier
133pub struct Simplifier {
134    _dialect: Option<DialectType>,
135    max_iterations: usize,
136}
137
138impl Simplifier {
139    /// Create a new simplifier
140    pub fn new(dialect: Option<DialectType>) -> Self {
141        Self {
142            _dialect: dialect,
143            max_iterations: 100,
144        }
145    }
146
147    /// Simplify an expression
148    pub fn simplify(&mut self, expression: Expression) -> Expression {
149        // Apply simplifications until no more changes (or max iterations)
150        let mut current = expression;
151        for _ in 0..self.max_iterations {
152            let simplified = self.simplify_once(current.clone());
153            if expressions_equal(&simplified, &current) {
154                return simplified;
155            }
156            current = simplified;
157        }
158        current
159    }
160
161    /// Apply one round of simplifications
162    fn simplify_once(&mut self, expression: Expression) -> Expression {
163        match expression {
164            // Binary logical operations
165            Expression::And(op) => self.simplify_and(*op),
166            Expression::Or(op) => self.simplify_or(*op),
167
168            // NOT operation - De Morgan's laws
169            Expression::Not(op) => self.simplify_not(*op),
170
171            // Arithmetic operations - constant folding
172            Expression::Add(op) => self.simplify_add(*op),
173            Expression::Sub(op) => self.simplify_sub(*op),
174            Expression::Mul(op) => self.simplify_mul(*op),
175            Expression::Div(op) => self.simplify_div(*op),
176
177            // Comparison operations
178            Expression::Eq(op) => self.simplify_comparison(*op, "="),
179            Expression::Neq(op) => self.simplify_comparison(*op, "!="),
180            Expression::Gt(op) => self.simplify_comparison(*op, ">"),
181            Expression::Gte(op) => self.simplify_comparison(*op, ">="),
182            Expression::Lt(op) => self.simplify_comparison(*op, "<"),
183            Expression::Lte(op) => self.simplify_comparison(*op, "<="),
184
185            // Negation
186            Expression::Neg(op) => self.simplify_neg(*op),
187
188            // CASE expression
189            Expression::Case(case) => self.simplify_case(*case),
190
191            // String concatenation
192            Expression::Concat(op) => self.simplify_concat(*op),
193            Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
194
195            // Parentheses - remove if unnecessary
196            Expression::Paren(paren) => self.simplify_paren(*paren),
197
198            // Date truncation
199            Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
200            Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
201
202            // Recursively simplify children for other expressions
203            other => self.simplify_children(other),
204        }
205    }
206
207    /// Simplify AND operation
208    fn simplify_and(&mut self, op: BinaryOp) -> Expression {
209        let left = self.simplify_once(op.left);
210        let right = self.simplify_once(op.right);
211
212        // FALSE AND x -> FALSE
213        // x AND FALSE -> FALSE
214        if is_boolean_false(&left) || is_boolean_false(&right) {
215            return bool_false();
216        }
217
218        // 0 AND x -> FALSE (in boolean context)
219        // x AND 0 -> FALSE
220        if is_zero(&left) || is_zero(&right) {
221            return bool_false();
222        }
223
224        // NULL AND NULL -> NULL
225        // NULL AND TRUE -> NULL
226        // TRUE AND NULL -> NULL
227        if (is_null(&left) && is_null(&right))
228            || (is_null(&left) && is_boolean_true(&right))
229            || (is_boolean_true(&left) && is_null(&right))
230        {
231            return null();
232        }
233
234        // TRUE AND x -> x (only when left is actually boolean TRUE)
235        if is_boolean_true(&left) {
236            return right;
237        }
238
239        // x AND TRUE -> x (only when right is actually boolean TRUE)
240        if is_boolean_true(&right) {
241            return left;
242        }
243
244        // A AND NOT A -> FALSE (complement elimination)
245        if is_complement(&left, &right) || is_complement(&right, &left) {
246            return bool_false();
247        }
248
249        // A AND A -> A (idempotent)
250        if expressions_equal(&left, &right) {
251            return left;
252        }
253
254        // Apply absorption rules
255        // A AND (A OR B) -> A
256        // A AND (NOT A OR B) -> A AND B
257        absorb_and_eliminate_and(left, right)
258    }
259
260    /// Simplify OR operation
261    fn simplify_or(&mut self, op: BinaryOp) -> Expression {
262        let left = self.simplify_once(op.left);
263        let right = self.simplify_once(op.right);
264
265        // TRUE OR x -> TRUE (only when left is actually boolean TRUE)
266        if is_boolean_true(&left) {
267            return bool_true();
268        }
269
270        // x OR TRUE -> TRUE (only when right is actually boolean TRUE)
271        if is_boolean_true(&right) {
272            return bool_true();
273        }
274
275        // NULL OR NULL -> NULL
276        // NULL OR FALSE -> NULL
277        // FALSE OR NULL -> NULL
278        if (is_null(&left) && is_null(&right))
279            || (is_null(&left) && is_boolean_false(&right))
280            || (is_boolean_false(&left) && is_null(&right))
281        {
282            return null();
283        }
284
285        // FALSE OR x -> x (only when left is actually boolean FALSE)
286        if is_boolean_false(&left) {
287            return right;
288        }
289
290        // x OR FALSE -> x (only when right is actually boolean FALSE)
291        if is_boolean_false(&right) {
292            return left;
293        }
294
295        // A OR A -> A (idempotent)
296        if expressions_equal(&left, &right) {
297            return left;
298        }
299
300        // Apply absorption rules
301        // A OR (A AND B) -> A
302        // A OR (NOT A AND B) -> A OR B
303        absorb_and_eliminate_or(left, right)
304    }
305
306    /// Simplify NOT operation (De Morgan's laws)
307    fn simplify_not(&mut self, op: UnaryOp) -> Expression {
308        // Check for De Morgan's laws BEFORE simplifying inner expression
309        // This prevents constant folding from eliminating the comparison operator
310        match &op.this {
311            // NOT (a = b) -> a != b
312            Expression::Eq(inner_op) => {
313                let left = self.simplify_once(inner_op.left.clone());
314                let right = self.simplify_once(inner_op.right.clone());
315                return Expression::Neq(Box::new(BinaryOp::new(left, right)));
316            }
317            // NOT (a != b) -> a = b
318            Expression::Neq(inner_op) => {
319                let left = self.simplify_once(inner_op.left.clone());
320                let right = self.simplify_once(inner_op.right.clone());
321                return Expression::Eq(Box::new(BinaryOp::new(left, right)));
322            }
323            // NOT (a > b) -> a <= b
324            Expression::Gt(inner_op) => {
325                let left = self.simplify_once(inner_op.left.clone());
326                let right = self.simplify_once(inner_op.right.clone());
327                return Expression::Lte(Box::new(BinaryOp::new(left, right)));
328            }
329            // NOT (a >= b) -> a < b
330            Expression::Gte(inner_op) => {
331                let left = self.simplify_once(inner_op.left.clone());
332                let right = self.simplify_once(inner_op.right.clone());
333                return Expression::Lt(Box::new(BinaryOp::new(left, right)));
334            }
335            // NOT (a < b) -> a >= b
336            Expression::Lt(inner_op) => {
337                let left = self.simplify_once(inner_op.left.clone());
338                let right = self.simplify_once(inner_op.right.clone());
339                return Expression::Gte(Box::new(BinaryOp::new(left, right)));
340            }
341            // NOT (a <= b) -> a > b
342            Expression::Lte(inner_op) => {
343                let left = self.simplify_once(inner_op.left.clone());
344                let right = self.simplify_once(inner_op.right.clone());
345                return Expression::Gt(Box::new(BinaryOp::new(left, right)));
346            }
347            _ => {}
348        }
349
350        // Now simplify the inner expression for other patterns
351        let inner = self.simplify_once(op.this);
352
353        // NOT NULL -> NULL (with TRUE for SQL semantics)
354        if is_null(&inner) {
355            return null();
356        }
357
358        // NOT TRUE -> FALSE (only for boolean TRUE literal)
359        if is_boolean_true(&inner) {
360            return bool_false();
361        }
362
363        // NOT FALSE -> TRUE (only for boolean FALSE literal)
364        if is_boolean_false(&inner) {
365            return bool_true();
366        }
367
368        // NOT NOT x -> x (double negation elimination)
369        if let Expression::Not(inner_not) = &inner {
370            return inner_not.this.clone();
371        }
372
373        Expression::Not(Box::new(UnaryOp { this: inner }))
374    }
375
376    /// Simplify addition (constant folding)
377    fn simplify_add(&mut self, op: BinaryOp) -> Expression {
378        let left = self.simplify_once(op.left);
379        let right = self.simplify_once(op.right);
380
381        // Try constant folding for numbers
382        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
383            return Expression::Literal(Literal::Number((a + b).to_string()));
384        }
385
386        // x + 0 -> x
387        if is_zero(&right) {
388            return left;
389        }
390
391        // 0 + x -> x
392        if is_zero(&left) {
393            return right;
394        }
395
396        Expression::Add(Box::new(BinaryOp::new(left, right)))
397    }
398
399    /// Simplify subtraction (constant folding)
400    fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
401        let left = self.simplify_once(op.left);
402        let right = self.simplify_once(op.right);
403
404        // Try constant folding for numbers
405        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
406            return Expression::Literal(Literal::Number((a - b).to_string()));
407        }
408
409        // x - 0 -> x
410        if is_zero(&right) {
411            return left;
412        }
413
414        // x - x -> 0 (only for literals/constants)
415        if expressions_equal(&left, &right) {
416            if let Expression::Literal(Literal::Number(_)) = &left {
417                return Expression::Literal(Literal::Number("0".to_string()));
418            }
419        }
420
421        Expression::Sub(Box::new(BinaryOp::new(left, right)))
422    }
423
424    /// Simplify multiplication (constant folding)
425    fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
426        let left = self.simplify_once(op.left);
427        let right = self.simplify_once(op.right);
428
429        // Try constant folding for numbers
430        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
431            return Expression::Literal(Literal::Number((a * b).to_string()));
432        }
433
434        // x * 0 -> 0
435        if is_zero(&right) {
436            return Expression::Literal(Literal::Number("0".to_string()));
437        }
438
439        // 0 * x -> 0
440        if is_zero(&left) {
441            return Expression::Literal(Literal::Number("0".to_string()));
442        }
443
444        // x * 1 -> x
445        if is_one(&right) {
446            return left;
447        }
448
449        // 1 * x -> x
450        if is_one(&left) {
451            return right;
452        }
453
454        Expression::Mul(Box::new(BinaryOp::new(left, right)))
455    }
456
457    /// Simplify division (constant folding)
458    fn simplify_div(&mut self, op: BinaryOp) -> Expression {
459        let left = self.simplify_once(op.left);
460        let right = self.simplify_once(op.right);
461
462        // Try constant folding for numbers (but not integer division)
463        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
464            // Only fold if both are floats to avoid integer division issues
465            if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
466                return Expression::Literal(Literal::Number((a / b).to_string()));
467            }
468        }
469
470        // 0 / x -> 0 (when x != 0)
471        if is_zero(&left) && !is_zero(&right) {
472            return Expression::Literal(Literal::Number("0".to_string()));
473        }
474
475        // x / 1 -> x
476        if is_one(&right) {
477            return left;
478        }
479
480        Expression::Div(Box::new(BinaryOp::new(left, right)))
481    }
482
483    /// Simplify negation
484    fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
485        let inner = self.simplify_once(op.this);
486
487        // -(-x) -> x (double negation)
488        if let Expression::Neg(inner_neg) = inner {
489            return inner_neg.this;
490        }
491
492        // -(number) -> -number
493        if let Some(n) = get_number(&inner) {
494            return Expression::Literal(Literal::Number((-n).to_string()));
495        }
496
497        Expression::Neg(Box::new(UnaryOp { this: inner }))
498    }
499
500    /// Simplify comparison operations (constant folding)
501    fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
502        let left = self.simplify_once(op.left);
503        let right = self.simplify_once(op.right);
504
505        // Try constant folding for numbers
506        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
507            if let Some(result) = eval_boolean_nums(operator, a, b) {
508                return result;
509            }
510        }
511
512        // Try constant folding for strings
513        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
514            if let Some(result) = eval_boolean_strings(operator, &a, &b) {
515                return result;
516            }
517        }
518
519        // For equality, try to solve simple equations (x + 1 = 3 -> x = 2)
520        if operator == "=" {
521            if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
522                return simplified;
523            }
524        }
525
526        // Reconstruct the comparison
527        let new_op = BinaryOp::new(left, right);
528
529        match operator {
530            "=" => Expression::Eq(Box::new(new_op)),
531            "!=" | "<>" => Expression::Neq(Box::new(new_op)),
532            ">" => Expression::Gt(Box::new(new_op)),
533            ">=" => Expression::Gte(Box::new(new_op)),
534            "<" => Expression::Lt(Box::new(new_op)),
535            "<=" => Expression::Lte(Box::new(new_op)),
536            _ => Expression::Eq(Box::new(new_op)),
537        }
538    }
539
540    /// Simplify CASE expression
541    fn simplify_case(&mut self, case: Case) -> Expression {
542        let mut new_whens = Vec::new();
543
544        for (cond, then_expr) in case.whens {
545            let simplified_cond = self.simplify_once(cond);
546
547            // If condition is always true, return the THEN expression
548            if always_true(&simplified_cond) {
549                return self.simplify_once(then_expr);
550            }
551
552            // If condition is always false, skip this WHEN clause
553            if always_false(&simplified_cond) {
554                continue;
555            }
556
557            new_whens.push((simplified_cond, self.simplify_once(then_expr)));
558        }
559
560        // If no WHEN clauses remain, return the ELSE expression (or NULL)
561        if new_whens.is_empty() {
562            return case.else_.map(|e| self.simplify_once(e)).unwrap_or_else(null);
563        }
564
565        Expression::Case(Box::new(Case {
566            operand: case.operand.map(|e| self.simplify_once(e)),
567            whens: new_whens,
568            else_: case.else_.map(|e| self.simplify_once(e)),
569        }))
570    }
571
572    /// Simplify string concatenation (Concat is || operator)
573    ///
574    /// Folds adjacent string literals:
575    /// - 'a' || 'b' -> 'ab'
576    /// - 'a' || 'b' || 'c' -> 'abc'
577    /// - '' || x -> x
578    /// - x || '' -> x
579    fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
580        let left = self.simplify_once(op.left);
581        let right = self.simplify_once(op.right);
582
583        // Fold two string literals: 'a' || 'b' -> 'ab'
584        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
585            return Expression::Literal(Literal::String(format!("{}{}", a, b)));
586        }
587
588        // '' || x -> x
589        if let Some(s) = get_string(&left) {
590            if s.is_empty() {
591                return right;
592            }
593        }
594
595        // x || '' -> x
596        if let Some(s) = get_string(&right) {
597            if s.is_empty() {
598                return left;
599            }
600        }
601
602        // NULL || x -> NULL, x || NULL -> NULL (SQL string concat semantics)
603        if is_null(&left) || is_null(&right) {
604            return null();
605        }
606
607        Expression::Concat(Box::new(BinaryOp::new(left, right)))
608    }
609
610    /// Simplify CONCAT_WS function
611    ///
612    /// CONCAT_WS(sep, a, b, c) -> concatenates with separator, skipping NULLs
613    /// - CONCAT_WS(',', 'a', 'b') -> 'a,b' (when all are literals)
614    /// - CONCAT_WS(',', 'a', NULL, 'b') -> 'a,b' (NULLs are skipped)
615    /// - CONCAT_WS(NULL, ...) -> NULL
616    fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
617        let separator = self.simplify_once(concat_ws.separator);
618
619        // If separator is NULL, result is NULL
620        if is_null(&separator) {
621            return null();
622        }
623
624        let expressions: Vec<Expression> = concat_ws
625            .expressions
626            .into_iter()
627            .map(|e| self.simplify_once(e))
628            .filter(|e| !is_null(e)) // Skip NULL values
629            .collect();
630
631        // If no expressions remain, return empty string
632        if expressions.is_empty() {
633            return Expression::Literal(Literal::String(String::new()));
634        }
635
636        // Try to fold if all are string literals
637        if let Some(sep) = get_string(&separator) {
638            let all_strings: Option<Vec<String>> = expressions
639                .iter()
640                .map(|e| get_string(e))
641                .collect();
642
643            if let Some(strings) = all_strings {
644                return Expression::Literal(Literal::String(strings.join(&sep)));
645            }
646        }
647
648        // Return simplified CONCAT_WS
649        Expression::ConcatWs(Box::new(ConcatWs {
650            separator,
651            expressions,
652        }))
653    }
654
655    /// Simplify parentheses
656    ///
657    /// Remove unnecessary parentheses:
658    /// - (x) -> x when x is a literal, column, or already parenthesized
659    /// - ((x)) -> (x) -> x (recursive simplification)
660    fn simplify_paren(&mut self, paren: Paren) -> Expression {
661        let inner = self.simplify_once(paren.this);
662
663        // If inner is a literal, column, boolean, null, or already parenthesized,
664        // we can remove the parentheses
665        match &inner {
666            Expression::Literal(_)
667            | Expression::Boolean(_)
668            | Expression::Null(_)
669            | Expression::Column(_)
670            | Expression::Paren(_) => inner,
671            // For other expressions, keep the parentheses
672            _ => Expression::Paren(Box::new(Paren {
673                this: inner,
674                trailing_comments: paren.trailing_comments,
675            })),
676        }
677    }
678
679    /// Simplify DATE_TRUNC and TIMESTAMP_TRUNC
680    ///
681    /// Currently just simplifies children and passes through.
682    /// Future: could fold DATE_TRUNC('day', '2024-01-15') -> '2024-01-15'
683    fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
684        let inner = self.simplify_once(dt.this);
685
686        // For now, just return with simplified inner expression
687        // A more advanced implementation would fold constant date/timestamps
688        Expression::DateTrunc(Box::new(DateTruncFunc {
689            this: inner,
690            unit: dt.unit,
691        }))
692    }
693
694    /// Simplify equality with arithmetic (solve simple equations)
695    ///
696    /// - x + 1 = 3 -> x = 2
697    /// - x - 1 = 3 -> x = 4
698    /// - x * 2 = 6 -> x = 3 (only when divisible)
699    /// - 1 + x = 3 -> x = 2 (commutative)
700    fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
701        // Only works when right side is a constant
702        let right_val = get_number(&right)?;
703
704        // Check if left side is arithmetic with one constant
705        match left {
706            Expression::Add(ref op) => {
707                // x + c = r -> x = r - c
708                if let Some(c) = get_number(&op.right) {
709                    let new_right = Expression::Literal(Literal::Number((right_val - c).to_string()));
710                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
711                }
712                // c + x = r -> x = r - c
713                if let Some(c) = get_number(&op.left) {
714                    let new_right = Expression::Literal(Literal::Number((right_val - c).to_string()));
715                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
716                }
717            }
718            Expression::Sub(ref op) => {
719                // x - c = r -> x = r + c
720                if let Some(c) = get_number(&op.right) {
721                    let new_right = Expression::Literal(Literal::Number((right_val + c).to_string()));
722                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
723                }
724                // c - x = r -> x = c - r
725                if let Some(c) = get_number(&op.left) {
726                    let new_right = Expression::Literal(Literal::Number((c - right_val).to_string()));
727                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
728                }
729            }
730            Expression::Mul(ref op) => {
731                // x * c = r -> x = r / c (only for non-zero c and when divisible)
732                if let Some(c) = get_number(&op.right) {
733                    if c != 0.0 && right_val % c == 0.0 {
734                        let new_right = Expression::Literal(Literal::Number((right_val / c).to_string()));
735                        return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
736                    }
737                }
738                // c * x = r -> x = r / c
739                if let Some(c) = get_number(&op.left) {
740                    if c != 0.0 && right_val % c == 0.0 {
741                        let new_right = Expression::Literal(Literal::Number((right_val / c).to_string()));
742                        return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
743                    }
744                }
745            }
746            _ => {}
747        }
748
749        None
750    }
751
752    /// Recursively simplify children of an expression
753    fn simplify_children(&mut self, expr: Expression) -> Expression {
754        // For expressions we don't have specific simplification rules for,
755        // we still want to simplify their children
756        match expr {
757            Expression::Alias(mut alias) => {
758                alias.this = self.simplify_once(alias.this);
759                Expression::Alias(alias)
760            }
761            Expression::Between(mut between) => {
762                between.this = self.simplify_once(between.this);
763                between.low = self.simplify_once(between.low);
764                between.high = self.simplify_once(between.high);
765                Expression::Between(between)
766            }
767            Expression::In(mut in_expr) => {
768                in_expr.this = self.simplify_once(in_expr.this);
769                in_expr.expressions = in_expr
770                    .expressions
771                    .into_iter()
772                    .map(|e| self.simplify_once(e))
773                    .collect();
774                Expression::In(in_expr)
775            }
776            Expression::Function(mut func) => {
777                func.args = func.args.into_iter().map(|e| self.simplify_once(e)).collect();
778                Expression::Function(func)
779            }
780            // For other expressions, return as-is for now
781            other => other,
782        }
783    }
784}
785
786/// Check if expression equals 1
787fn is_one(expr: &Expression) -> bool {
788    match expr {
789        Expression::Literal(Literal::Number(n)) => {
790            if let Ok(num) = n.parse::<f64>() {
791                num == 1.0
792            } else {
793                false
794            }
795        }
796        _ => false,
797    }
798}
799
800/// Get numeric value from expression if it's a number literal
801fn get_number(expr: &Expression) -> Option<f64> {
802    match expr {
803        Expression::Literal(Literal::Number(n)) => n.parse().ok(),
804        _ => None,
805    }
806}
807
808/// Get string value from expression if it's a string literal
809fn get_string(expr: &Expression) -> Option<String> {
810    match expr {
811        Expression::Literal(Literal::String(s)) => Some(s.clone()),
812        _ => None,
813    }
814}
815
816/// Check if two expressions are structurally equal
817/// This is a simplified comparison - a full implementation would need deep comparison
818fn expressions_equal(a: &Expression, b: &Expression) -> bool {
819    // For now, use Debug representation for comparison
820    // A proper implementation would do structural comparison
821    format!("{:?}", a) == format!("{:?}", b)
822}
823
824/// Flatten nested AND expressions into a list of operands
825/// e.g., (A AND (B AND C)) -> [A, B, C]
826fn flatten_and(expr: &Expression) -> Vec<Expression> {
827    match expr {
828        Expression::And(op) => {
829            let mut result = flatten_and(&op.left);
830            result.extend(flatten_and(&op.right));
831            result
832        }
833        other => vec![other.clone()],
834    }
835}
836
837/// Flatten nested OR expressions into a list of operands
838/// e.g., (A OR (B OR C)) -> [A, B, C]
839fn flatten_or(expr: &Expression) -> Vec<Expression> {
840    match expr {
841        Expression::Or(op) => {
842            let mut result = flatten_or(&op.left);
843            result.extend(flatten_or(&op.right));
844            result
845        }
846        other => vec![other.clone()],
847    }
848}
849
850/// Rebuild an AND expression from a list of operands
851fn rebuild_and(operands: Vec<Expression>) -> Expression {
852    if operands.is_empty() {
853        return bool_true(); // Empty AND is TRUE
854    }
855    let mut result = operands.into_iter();
856    let first = result.next().unwrap();
857    result.fold(first, |acc, op| {
858        Expression::And(Box::new(BinaryOp::new(acc, op)))
859    })
860}
861
862/// Rebuild an OR expression from a list of operands
863fn rebuild_or(operands: Vec<Expression>) -> Expression {
864    if operands.is_empty() {
865        return bool_false(); // Empty OR is FALSE
866    }
867    let mut result = operands.into_iter();
868    let first = result.next().unwrap();
869    result.fold(first, |acc, op| {
870        Expression::Or(Box::new(BinaryOp::new(acc, op)))
871    })
872}
873
874/// Get the inner expression of a NOT, if it is one
875fn get_not_inner(expr: &Expression) -> Option<&Expression> {
876    match expr {
877        Expression::Not(op) => Some(&op.this),
878        _ => None,
879    }
880}
881
882/// Apply Boolean absorption and elimination rules to an AND expression
883///
884/// Absorption:
885///   A AND (A OR B) -> A
886///   A AND (NOT A OR B) -> A AND B
887///
888/// Elimination:
889///   (A OR B) AND (A OR NOT B) -> A
890pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
891    // Flatten both sides
892    let left_ops = flatten_and(&left);
893    let right_ops = flatten_and(&right);
894    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
895
896    // Build a set of string representations for quick lookup
897    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
898
899    let mut result_ops: Vec<Expression> = Vec::new();
900    let mut absorbed = std::collections::HashSet::new();
901
902    for (i, op) in all_ops.iter().enumerate() {
903        let op_str = gen(op);
904
905        // Skip if already absorbed
906        if absorbed.contains(&op_str) {
907            continue;
908        }
909
910        // Check if this is an OR expression (potential absorption target)
911        if let Expression::Or(_) = op {
912            let or_operands = flatten_or(op);
913
914            // Absorption: A AND (A OR B) -> A
915            // Check if any OR operand is already in our AND operands
916            let absorbed_by_existing = or_operands.iter().any(|or_op| {
917                let or_op_str = gen(or_op);
918                // Check if this OR operand exists in other AND operands (not this OR itself)
919                all_ops.iter().enumerate().any(|(j, other)| {
920                    i != j && gen(other) == or_op_str
921                })
922            });
923
924            if absorbed_by_existing {
925                // This OR is absorbed, skip it
926                absorbed.insert(op_str);
927                continue;
928            }
929
930            // Absorption with complement: A AND (NOT A OR B) -> A AND B
931            // Check if any OR operand's complement is in our AND operands
932            let mut remaining_or_ops: Vec<Expression> = Vec::new();
933            let mut had_complement_absorption = false;
934
935            for or_op in or_operands {
936                let complement_str = if let Some(inner) = get_not_inner(&or_op) {
937                    // or_op is NOT X, complement is X
938                    gen(inner)
939                } else {
940                    // or_op is X, complement is NOT X
941                    format!("NOT {}", gen(&or_op))
942                };
943
944                // Check if complement exists in our AND operands
945                let has_complement = all_ops.iter().enumerate().any(|(j, other)| {
946                    i != j && gen(other) == complement_str
947                }) || op_strings.contains(&complement_str);
948
949                if has_complement {
950                    // This OR operand's complement exists, so this term becomes TRUE in AND context
951                    // NOT A OR B, where A exists, becomes TRUE OR B (when A is true) or B (when A is false)
952                    // Actually: A AND (NOT A OR B) -> A AND B, so we drop NOT A from the OR
953                    had_complement_absorption = true;
954                    // Drop this operand from OR
955                } else {
956                    remaining_or_ops.push(or_op);
957                }
958            }
959
960            if had_complement_absorption {
961                if remaining_or_ops.is_empty() {
962                    // All OR operands were absorbed, the OR becomes TRUE
963                    // A AND TRUE -> A, so we just skip adding this
964                    absorbed.insert(op_str);
965                    continue;
966                } else if remaining_or_ops.len() == 1 {
967                    // Single remaining operand
968                    result_ops.push(remaining_or_ops.into_iter().next().unwrap());
969                    absorbed.insert(op_str);
970                    continue;
971                } else {
972                    // Rebuild the OR with remaining operands
973                    result_ops.push(rebuild_or(remaining_or_ops));
974                    absorbed.insert(op_str);
975                    continue;
976                }
977            }
978        }
979
980        result_ops.push(op.clone());
981    }
982
983    // Deduplicate
984    let mut seen = std::collections::HashSet::new();
985    result_ops.retain(|op| seen.insert(gen(op)));
986
987    if result_ops.is_empty() {
988        bool_true()
989    } else {
990        rebuild_and(result_ops)
991    }
992}
993
994/// Apply Boolean absorption and elimination rules to an OR expression
995///
996/// Absorption:
997///   A OR (A AND B) -> A
998///   A OR (NOT A AND B) -> A OR B
999///
1000/// Elimination:
1001///   (A AND B) OR (A AND NOT B) -> A
1002pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1003    // Flatten both sides
1004    let left_ops = flatten_or(&left);
1005    let right_ops = flatten_or(&right);
1006    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1007
1008    // Build a set of string representations for quick lookup
1009    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1010
1011    let mut result_ops: Vec<Expression> = Vec::new();
1012    let mut absorbed = std::collections::HashSet::new();
1013
1014    for (i, op) in all_ops.iter().enumerate() {
1015        let op_str = gen(op);
1016
1017        // Skip if already absorbed
1018        if absorbed.contains(&op_str) {
1019            continue;
1020        }
1021
1022        // Check if this is an AND expression (potential absorption target)
1023        if let Expression::And(_) = op {
1024            let and_operands = flatten_and(op);
1025
1026            // Absorption: A OR (A AND B) -> A
1027            // Check if any AND operand is already in our OR operands
1028            let absorbed_by_existing = and_operands.iter().any(|and_op| {
1029                let and_op_str = gen(and_op);
1030                // Check if this AND operand exists in other OR operands (not this AND itself)
1031                all_ops.iter().enumerate().any(|(j, other)| {
1032                    i != j && gen(other) == and_op_str
1033                })
1034            });
1035
1036            if absorbed_by_existing {
1037                // This AND is absorbed, skip it
1038                absorbed.insert(op_str);
1039                continue;
1040            }
1041
1042            // Absorption with complement: A OR (NOT A AND B) -> A OR B
1043            // Check if any AND operand's complement is in our OR operands
1044            let mut remaining_and_ops: Vec<Expression> = Vec::new();
1045            let mut had_complement_absorption = false;
1046
1047            for and_op in and_operands {
1048                let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1049                    // and_op is NOT X, complement is X
1050                    gen(inner)
1051                } else {
1052                    // and_op is X, complement is NOT X
1053                    format!("NOT {}", gen(&and_op))
1054                };
1055
1056                // Check if complement exists in our OR operands
1057                let has_complement = all_ops.iter().enumerate().any(|(j, other)| {
1058                    i != j && gen(other) == complement_str
1059                }) || op_strings.contains(&complement_str);
1060
1061                if has_complement {
1062                    // This AND operand's complement exists, so this term becomes FALSE in OR context
1063                    // A OR (NOT A AND B) -> A OR B, so we drop NOT A from the AND
1064                    had_complement_absorption = true;
1065                    // Drop this operand from AND
1066                } else {
1067                    remaining_and_ops.push(and_op);
1068                }
1069            }
1070
1071            if had_complement_absorption {
1072                if remaining_and_ops.is_empty() {
1073                    // All AND operands were absorbed, the AND becomes FALSE
1074                    // A OR FALSE -> A, so we just skip adding this
1075                    absorbed.insert(op_str);
1076                    continue;
1077                } else if remaining_and_ops.len() == 1 {
1078                    // Single remaining operand
1079                    result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1080                    absorbed.insert(op_str);
1081                    continue;
1082                } else {
1083                    // Rebuild the AND with remaining operands
1084                    result_ops.push(rebuild_and(remaining_and_ops));
1085                    absorbed.insert(op_str);
1086                    continue;
1087                }
1088            }
1089        }
1090
1091        result_ops.push(op.clone());
1092    }
1093
1094    // Deduplicate
1095    let mut seen = std::collections::HashSet::new();
1096    result_ops.retain(|op| seen.insert(gen(op)));
1097
1098    if result_ops.is_empty() {
1099        bool_false()
1100    } else {
1101        rebuild_or(result_ops)
1102    }
1103}
1104
1105/// Generate a simple string representation of an expression for sorting/deduping
1106pub fn gen(expr: &Expression) -> String {
1107    match expr {
1108        Expression::Literal(lit) => match lit {
1109            Literal::String(s) => format!("'{}'", s),
1110            Literal::Number(n) => n.clone(),
1111            _ => format!("{:?}", lit),
1112        },
1113        Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1114        Expression::Null(_) => "NULL".to_string(),
1115        Expression::Column(col) => {
1116            if let Some(ref table) = col.table {
1117                format!("{}.{}", table.name, col.name.name)
1118            } else {
1119                col.name.name.clone()
1120            }
1121        }
1122        Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1123        Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1124        Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1125        Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1126        Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1127        Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1128        Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1129        Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1130        Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1131        Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1132        Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1133        Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1134        Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1135        Expression::Function(f) => {
1136            let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1137            format!("{}({})", f.name.to_uppercase(), args.join(", "))
1138        }
1139        _ => format!("{:?}", expr),
1140    }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146
1147    fn make_int(val: i64) -> Expression {
1148        Expression::Literal(Literal::Number(val.to_string()))
1149    }
1150
1151    fn make_string(val: &str) -> Expression {
1152        Expression::Literal(Literal::String(val.to_string()))
1153    }
1154
1155    fn make_bool(val: bool) -> Expression {
1156        Expression::Boolean(BooleanLiteral { value: val })
1157    }
1158
1159    fn make_column(name: &str) -> Expression {
1160        use crate::expressions::{Column, Identifier};
1161        Expression::Column(Column {
1162            name: Identifier::new(name),
1163            table: None,
1164            join_mark: false,
1165            trailing_comments: vec![],
1166        })
1167    }
1168
1169    #[test]
1170    fn test_always_true_false() {
1171        assert!(always_true(&make_bool(true)));
1172        assert!(!always_true(&make_bool(false)));
1173        assert!(always_true(&make_int(1)));
1174        assert!(!always_true(&make_int(0)));
1175
1176        assert!(always_false(&make_bool(false)));
1177        assert!(!always_false(&make_bool(true)));
1178        assert!(always_false(&null()));
1179        assert!(always_false(&make_int(0)));
1180    }
1181
1182    #[test]
1183    fn test_simplify_and_with_true() {
1184        let mut simplifier = Simplifier::new(None);
1185
1186        // TRUE AND TRUE -> TRUE
1187        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1188        let result = simplifier.simplify(expr);
1189        assert!(always_true(&result));
1190
1191        // TRUE AND FALSE -> FALSE
1192        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1193        let result = simplifier.simplify(expr);
1194        assert!(always_false(&result));
1195
1196        // TRUE AND x -> x
1197        let x = make_int(42);
1198        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1199        let result = simplifier.simplify(expr);
1200        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1201    }
1202
1203    #[test]
1204    fn test_simplify_or_with_false() {
1205        let mut simplifier = Simplifier::new(None);
1206
1207        // FALSE OR FALSE -> FALSE
1208        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1209        let result = simplifier.simplify(expr);
1210        assert!(always_false(&result));
1211
1212        // FALSE OR TRUE -> TRUE
1213        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1214        let result = simplifier.simplify(expr);
1215        assert!(always_true(&result));
1216
1217        // FALSE OR x -> x
1218        let x = make_int(42);
1219        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1220        let result = simplifier.simplify(expr);
1221        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1222    }
1223
1224    #[test]
1225    fn test_simplify_not() {
1226        let mut simplifier = Simplifier::new(None);
1227
1228        // NOT TRUE -> FALSE
1229        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1230        let result = simplifier.simplify(expr);
1231        assert!(is_false(&result));
1232
1233        // NOT FALSE -> TRUE
1234        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1235        let result = simplifier.simplify(expr);
1236        assert!(always_true(&result));
1237
1238        // NOT NOT x -> x
1239        let x = make_int(42);
1240        let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1241        let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1242        let result = simplifier.simplify(expr);
1243        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1244    }
1245
1246    #[test]
1247    fn test_simplify_demorgan_comparison() {
1248        let mut simplifier = Simplifier::new(None);
1249
1250        // NOT (a = b) -> a != b (using columns to avoid constant folding)
1251        let a = make_column("a");
1252        let b = make_column("b");
1253        let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1254        let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1255        let result = simplifier.simplify(expr);
1256        assert!(matches!(result, Expression::Neq(_)));
1257
1258        // NOT (a > b) -> a <= b
1259        let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1260        let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1261        let result = simplifier.simplify(expr);
1262        assert!(matches!(result, Expression::Lte(_)));
1263    }
1264
1265    #[test]
1266    fn test_constant_folding_add() {
1267        let mut simplifier = Simplifier::new(None);
1268
1269        // 1 + 2 -> 3
1270        let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1271        let result = simplifier.simplify(expr);
1272        assert_eq!(get_number(&result), Some(3.0));
1273
1274        // x + 0 -> x
1275        let x = make_int(42);
1276        let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1277        let result = simplifier.simplify(expr);
1278        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1279    }
1280
1281    #[test]
1282    fn test_constant_folding_mul() {
1283        let mut simplifier = Simplifier::new(None);
1284
1285        // 3 * 4 -> 12
1286        let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1287        let result = simplifier.simplify(expr);
1288        assert_eq!(get_number(&result), Some(12.0));
1289
1290        // x * 0 -> 0
1291        let x = make_int(42);
1292        let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1293        let result = simplifier.simplify(expr);
1294        assert_eq!(get_number(&result), Some(0.0));
1295
1296        // x * 1 -> x
1297        let x = make_int(42);
1298        let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1299        let result = simplifier.simplify(expr);
1300        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1301    }
1302
1303    #[test]
1304    fn test_constant_folding_comparison() {
1305        let mut simplifier = Simplifier::new(None);
1306
1307        // 1 = 1 -> TRUE
1308        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1309        let result = simplifier.simplify(expr);
1310        assert!(always_true(&result));
1311
1312        // 1 = 2 -> FALSE
1313        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1314        let result = simplifier.simplify(expr);
1315        assert!(is_false(&result));
1316
1317        // 3 > 2 -> TRUE
1318        let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1319        let result = simplifier.simplify(expr);
1320        assert!(always_true(&result));
1321
1322        // 'a' = 'a' -> TRUE
1323        let expr = Expression::Eq(Box::new(BinaryOp::new(
1324            make_string("abc"),
1325            make_string("abc"),
1326        )));
1327        let result = simplifier.simplify(expr);
1328        assert!(always_true(&result));
1329    }
1330
1331    #[test]
1332    fn test_simplify_negation() {
1333        let mut simplifier = Simplifier::new(None);
1334
1335        // -(-5) -> 5
1336        let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1337        let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1338        let result = simplifier.simplify(expr);
1339        assert_eq!(get_number(&result), Some(5.0));
1340
1341        // -(3) -> -3
1342        let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1343        let result = simplifier.simplify(expr);
1344        assert_eq!(get_number(&result), Some(-3.0));
1345    }
1346
1347    #[test]
1348    fn test_gen_simple() {
1349        assert_eq!(gen(&make_int(42)), "42");
1350        assert_eq!(gen(&make_string("hello")), "'hello'");
1351        assert_eq!(gen(&make_bool(true)), "TRUE");
1352        assert_eq!(gen(&make_bool(false)), "FALSE");
1353        assert_eq!(gen(&null()), "NULL");
1354    }
1355
1356    #[test]
1357    fn test_gen_operations() {
1358        let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1359        assert_eq!(gen(&add), "1 + 2");
1360
1361        let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1362        assert_eq!(gen(&and), "(TRUE AND FALSE)");
1363    }
1364
1365    #[test]
1366    fn test_complement_elimination() {
1367        let mut simplifier = Simplifier::new(None);
1368
1369        // x AND NOT x -> FALSE
1370        let x = make_int(42);
1371        let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1372        let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1373        let result = simplifier.simplify(expr);
1374        assert!(is_false(&result));
1375    }
1376
1377    #[test]
1378    fn test_idempotent() {
1379        let mut simplifier = Simplifier::new(None);
1380
1381        // x AND x -> x
1382        let x = make_int(42);
1383        let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1384        let result = simplifier.simplify(expr);
1385        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1386
1387        // x OR x -> x
1388        let x = make_int(42);
1389        let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1390        let result = simplifier.simplify(expr);
1391        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1392    }
1393
1394    #[test]
1395    fn test_absorption_and() {
1396        let mut simplifier = Simplifier::new(None);
1397
1398        // A AND (A OR B) -> A
1399        let a = make_column("a");
1400        let b = make_column("b");
1401        let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1402        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1403        let result = simplifier.simplify(expr);
1404        // Result should be just A
1405        assert_eq!(gen(&result), gen(&a));
1406    }
1407
1408    #[test]
1409    fn test_absorption_or() {
1410        let mut simplifier = Simplifier::new(None);
1411
1412        // A OR (A AND B) -> A
1413        let a = make_column("a");
1414        let b = make_column("b");
1415        let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1416        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1417        let result = simplifier.simplify(expr);
1418        // Result should be just A
1419        assert_eq!(gen(&result), gen(&a));
1420    }
1421
1422    #[test]
1423    fn test_absorption_with_complement_and() {
1424        let mut simplifier = Simplifier::new(None);
1425
1426        // A AND (NOT A OR B) -> A AND B
1427        let a = make_column("a");
1428        let b = make_column("b");
1429        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1430        let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1431        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1432        let result = simplifier.simplify(expr);
1433        // Result should be A AND B
1434        let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1435        assert_eq!(gen(&result), gen(&expected));
1436    }
1437
1438    #[test]
1439    fn test_absorption_with_complement_or() {
1440        let mut simplifier = Simplifier::new(None);
1441
1442        // A OR (NOT A AND B) -> A OR B
1443        let a = make_column("a");
1444        let b = make_column("b");
1445        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1446        let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1447        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1448        let result = simplifier.simplify(expr);
1449        // Result should be A OR B
1450        let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1451        assert_eq!(gen(&result), gen(&expected));
1452    }
1453
1454    #[test]
1455    fn test_flatten_and() {
1456        // (A AND (B AND C)) should flatten to [A, B, C]
1457        let a = make_column("a");
1458        let b = make_column("b");
1459        let c = make_column("c");
1460        let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1461        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1462        let flattened = flatten_and(&expr);
1463        assert_eq!(flattened.len(), 3);
1464        assert_eq!(gen(&flattened[0]), "a");
1465        assert_eq!(gen(&flattened[1]), "b");
1466        assert_eq!(gen(&flattened[2]), "c");
1467    }
1468
1469    #[test]
1470    fn test_flatten_or() {
1471        // (A OR (B OR C)) should flatten to [A, B, C]
1472        let a = make_column("a");
1473        let b = make_column("b");
1474        let c = make_column("c");
1475        let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1476        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1477        let flattened = flatten_or(&expr);
1478        assert_eq!(flattened.len(), 3);
1479        assert_eq!(gen(&flattened[0]), "a");
1480        assert_eq!(gen(&flattened[1]), "b");
1481        assert_eq!(gen(&flattened[2]), "c");
1482    }
1483
1484    #[test]
1485    fn test_simplify_concat() {
1486        let mut simplifier = Simplifier::new(None);
1487
1488        // 'a' || 'b' -> 'ab'
1489        let expr = Expression::Concat(Box::new(BinaryOp::new(
1490            make_string("hello"),
1491            make_string("world"),
1492        )));
1493        let result = simplifier.simplify(expr);
1494        assert_eq!(get_string(&result), Some("helloworld".to_string()));
1495
1496        // '' || x -> x
1497        let x = make_string("test");
1498        let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1499        let result = simplifier.simplify(expr);
1500        assert_eq!(get_string(&result), Some("test".to_string()));
1501
1502        // x || '' -> x
1503        let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1504        let result = simplifier.simplify(expr);
1505        assert_eq!(get_string(&result), Some("test".to_string()));
1506
1507        // NULL || x -> NULL
1508        let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1509        let result = simplifier.simplify(expr);
1510        assert!(is_null(&result));
1511    }
1512
1513    #[test]
1514    fn test_simplify_concat_ws() {
1515        let mut simplifier = Simplifier::new(None);
1516
1517        // CONCAT_WS(',', 'a', 'b', 'c') -> 'a,b,c'
1518        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1519            separator: make_string(","),
1520            expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1521        }));
1522        let result = simplifier.simplify(expr);
1523        assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1524
1525        // CONCAT_WS with NULL separator -> NULL
1526        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1527            separator: null(),
1528            expressions: vec![make_string("a"), make_string("b")],
1529        }));
1530        let result = simplifier.simplify(expr);
1531        assert!(is_null(&result));
1532
1533        // CONCAT_WS with empty expressions -> ''
1534        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1535            separator: make_string(","),
1536            expressions: vec![],
1537        }));
1538        let result = simplifier.simplify(expr);
1539        assert_eq!(get_string(&result), Some("".to_string()));
1540
1541        // CONCAT_WS skips NULLs
1542        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1543            separator: make_string("-"),
1544            expressions: vec![make_string("a"), null(), make_string("b")],
1545        }));
1546        let result = simplifier.simplify(expr);
1547        assert_eq!(get_string(&result), Some("a-b".to_string()));
1548    }
1549
1550    #[test]
1551    fn test_simplify_paren() {
1552        let mut simplifier = Simplifier::new(None);
1553
1554        // (42) -> 42
1555        let expr = Expression::Paren(Box::new(Paren {
1556            this: make_int(42),
1557            trailing_comments: vec![],
1558        }));
1559        let result = simplifier.simplify(expr);
1560        assert_eq!(get_number(&result), Some(42.0));
1561
1562        // (TRUE) -> TRUE
1563        let expr = Expression::Paren(Box::new(Paren {
1564            this: make_bool(true),
1565            trailing_comments: vec![],
1566        }));
1567        let result = simplifier.simplify(expr);
1568        assert!(is_boolean_true(&result));
1569
1570        // (NULL) -> NULL
1571        let expr = Expression::Paren(Box::new(Paren {
1572            this: null(),
1573            trailing_comments: vec![],
1574        }));
1575        let result = simplifier.simplify(expr);
1576        assert!(is_null(&result));
1577
1578        // ((x)) -> x
1579        let inner_paren = Expression::Paren(Box::new(Paren {
1580            this: make_int(10),
1581            trailing_comments: vec![],
1582        }));
1583        let expr = Expression::Paren(Box::new(Paren {
1584            this: inner_paren,
1585            trailing_comments: vec![],
1586        }));
1587        let result = simplifier.simplify(expr);
1588        assert_eq!(get_number(&result), Some(10.0));
1589    }
1590
1591    #[test]
1592    fn test_simplify_equality_solve() {
1593        let mut simplifier = Simplifier::new(None);
1594
1595        // x + 1 = 3 -> x = 2
1596        let x = make_column("x");
1597        let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1598        let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1599        let result = simplifier.simplify(expr);
1600        // Result should be x = 2
1601        if let Expression::Eq(op) = &result {
1602            assert_eq!(gen(&op.left), "x");
1603            assert_eq!(get_number(&op.right), Some(2.0));
1604        } else {
1605            panic!("Expected Eq expression");
1606        }
1607
1608        // x - 1 = 3 -> x = 4
1609        let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1610        let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1611        let result = simplifier.simplify(expr);
1612        if let Expression::Eq(op) = &result {
1613            assert_eq!(gen(&op.left), "x");
1614            assert_eq!(get_number(&op.right), Some(4.0));
1615        } else {
1616            panic!("Expected Eq expression");
1617        }
1618
1619        // x * 2 = 6 -> x = 3
1620        let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1621        let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1622        let result = simplifier.simplify(expr);
1623        if let Expression::Eq(op) = &result {
1624            assert_eq!(gen(&op.left), "x");
1625            assert_eq!(get_number(&op.right), Some(3.0));
1626        } else {
1627            panic!("Expected Eq expression");
1628        }
1629
1630        // 1 + x = 3 -> x = 2 (commutative)
1631        let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1632        let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1633        let result = simplifier.simplify(expr);
1634        if let Expression::Eq(op) = &result {
1635            assert_eq!(gen(&op.left), "x");
1636            assert_eq!(get_number(&op.right), Some(2.0));
1637        } else {
1638            panic!("Expected Eq expression");
1639        }
1640    }
1641
1642    #[test]
1643    fn test_simplify_datetrunc() {
1644        use crate::expressions::DateTimeField;
1645        let mut simplifier = Simplifier::new(None);
1646
1647        // DATE_TRUNC('day', x) with a column just passes through with simplified children
1648        let x = make_column("x");
1649        let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1650            this: x.clone(),
1651            unit: DateTimeField::Day,
1652        }));
1653        let result = simplifier.simplify(expr);
1654        if let Expression::DateTrunc(dt) = &result {
1655            assert_eq!(gen(&dt.this), "x");
1656            assert_eq!(dt.unit, DateTimeField::Day);
1657        } else {
1658            panic!("Expected DateTrunc expression");
1659        }
1660    }
1661}