Skip to main content

polyglot_sql/optimizer/
simplify.rs

1//! Expression Simplification
2//!
3//! This module provides boolean and expression simplification for SQL AST nodes.
4//! It applies various algebraic transformations to simplify expressions:
5//! - De Morgan's laws (NOT (A AND B) -> NOT A OR NOT B)
6//! - Constant folding (1 + 2 -> 3)
7//! - Boolean absorption (A AND (A OR B) -> A)
8//! - Complement removal (A AND NOT A -> FALSE)
9//! - Connector flattening (A AND (B AND C) -> A AND B AND C)
10//!
11//! Based on SQLGlot's optimizer/simplify.py
12
13use crate::dialects::DialectType;
14use crate::expressions::{
15    BinaryOp, BooleanLiteral, Case, ConcatWs, DateTruncFunc, Expression, Literal, Null, Paren,
16    UnaryOp,
17};
18
19/// Main entry point for expression simplification
20pub fn simplify(expression: Expression, dialect: Option<DialectType>) -> Expression {
21    let mut simplifier = Simplifier::new(dialect);
22    simplifier.simplify(expression)
23}
24
25/// Check if expression is always true
26pub fn always_true(expr: &Expression) -> bool {
27    match expr {
28        Expression::Boolean(b) => b.value,
29        Expression::Literal(Literal::Number(n)) => {
30            // Non-zero numbers are truthy
31            if let Ok(num) = n.parse::<f64>() {
32                num != 0.0
33            } else {
34                false
35            }
36        }
37        _ => false,
38    }
39}
40
41/// Check if expression is a boolean TRUE literal (not just truthy)
42pub fn is_boolean_true(expr: &Expression) -> bool {
43    matches!(expr, Expression::Boolean(b) if b.value)
44}
45
46/// Check if expression is a boolean FALSE literal (not just falsy)
47pub fn is_boolean_false(expr: &Expression) -> bool {
48    matches!(expr, Expression::Boolean(b) if !b.value)
49}
50
51/// Check if expression is always false
52pub fn always_false(expr: &Expression) -> bool {
53    is_false(expr) || is_null(expr) || is_zero(expr)
54}
55
56/// Check if expression is boolean FALSE
57pub fn is_false(expr: &Expression) -> bool {
58    matches!(expr, Expression::Boolean(b) if !b.value)
59}
60
61/// Check if expression is NULL
62pub fn is_null(expr: &Expression) -> bool {
63    matches!(expr, Expression::Null(_))
64}
65
66/// Check if expression is zero
67pub fn is_zero(expr: &Expression) -> bool {
68    match expr {
69        Expression::Literal(Literal::Number(n)) => {
70            if let Ok(num) = n.parse::<f64>() {
71                num == 0.0
72            } else {
73                false
74            }
75        }
76        _ => false,
77    }
78}
79
80/// Check if b is the complement of a (i.e., b = NOT a)
81pub fn is_complement(a: &Expression, b: &Expression) -> bool {
82    if let Expression::Not(not_op) = b {
83        &not_op.this == a
84    } else {
85        false
86    }
87}
88
89/// Create a TRUE boolean literal
90pub fn bool_true() -> Expression {
91    Expression::Boolean(BooleanLiteral { value: true })
92}
93
94/// Create a FALSE boolean literal
95pub fn bool_false() -> Expression {
96    Expression::Boolean(BooleanLiteral { value: false })
97}
98
99/// Create a NULL expression
100pub fn null() -> Expression {
101    Expression::Null(Null)
102}
103
104/// Evaluate a boolean comparison between two numbers
105pub fn eval_boolean_nums(op: &str, a: f64, b: f64) -> Option<Expression> {
106    let result = match op {
107        "=" | "==" => a == b,
108        "!=" | "<>" => a != b,
109        ">" => a > b,
110        ">=" => a >= b,
111        "<" => a < b,
112        "<=" => a <= b,
113        _ => return None,
114    };
115    Some(if result { bool_true() } else { bool_false() })
116}
117
118/// Evaluate a boolean comparison between two strings
119pub fn eval_boolean_strings(op: &str, a: &str, b: &str) -> Option<Expression> {
120    let result = match op {
121        "=" | "==" => a == b,
122        "!=" | "<>" => a != b,
123        ">" => a > b,
124        ">=" => a >= b,
125        "<" => a < b,
126        "<=" => a <= b,
127        _ => return None,
128    };
129    Some(if result { bool_true() } else { bool_false() })
130}
131
132/// Expression simplifier
133pub struct Simplifier {
134    _dialect: Option<DialectType>,
135    max_iterations: usize,
136}
137
138impl Simplifier {
139    /// Create a new simplifier
140    pub fn new(dialect: Option<DialectType>) -> Self {
141        Self {
142            _dialect: dialect,
143            max_iterations: 100,
144        }
145    }
146
147    /// Simplify an expression
148    pub fn simplify(&mut self, expression: Expression) -> Expression {
149        // Apply simplifications until no more changes (or max iterations)
150        let mut current = expression;
151        for _ in 0..self.max_iterations {
152            let simplified = self.simplify_once(current.clone());
153            if expressions_equal(&simplified, &current) {
154                return simplified;
155            }
156            current = simplified;
157        }
158        current
159    }
160
161    /// Apply one round of simplifications
162    fn simplify_once(&mut self, expression: Expression) -> Expression {
163        match expression {
164            // Binary logical operations
165            Expression::And(op) => self.simplify_and(*op),
166            Expression::Or(op) => self.simplify_or(*op),
167
168            // NOT operation - De Morgan's laws
169            Expression::Not(op) => self.simplify_not(*op),
170
171            // Arithmetic operations - constant folding
172            Expression::Add(op) => self.simplify_add(*op),
173            Expression::Sub(op) => self.simplify_sub(*op),
174            Expression::Mul(op) => self.simplify_mul(*op),
175            Expression::Div(op) => self.simplify_div(*op),
176
177            // Comparison operations
178            Expression::Eq(op) => self.simplify_comparison(*op, "="),
179            Expression::Neq(op) => self.simplify_comparison(*op, "!="),
180            Expression::Gt(op) => self.simplify_comparison(*op, ">"),
181            Expression::Gte(op) => self.simplify_comparison(*op, ">="),
182            Expression::Lt(op) => self.simplify_comparison(*op, "<"),
183            Expression::Lte(op) => self.simplify_comparison(*op, "<="),
184
185            // Negation
186            Expression::Neg(op) => self.simplify_neg(*op),
187
188            // CASE expression
189            Expression::Case(case) => self.simplify_case(*case),
190
191            // String concatenation
192            Expression::Concat(op) => self.simplify_concat(*op),
193            Expression::ConcatWs(concat_ws) => self.simplify_concat_ws(*concat_ws),
194
195            // Parentheses - remove if unnecessary
196            Expression::Paren(paren) => self.simplify_paren(*paren),
197
198            // Date truncation
199            Expression::DateTrunc(dt) => self.simplify_datetrunc(*dt),
200            Expression::TimestampTrunc(dt) => self.simplify_datetrunc(*dt),
201
202            // Recursively simplify children for other expressions
203            other => self.simplify_children(other),
204        }
205    }
206
207    /// Simplify AND operation
208    fn simplify_and(&mut self, op: BinaryOp) -> Expression {
209        let left = self.simplify_once(op.left);
210        let right = self.simplify_once(op.right);
211
212        // FALSE AND x -> FALSE
213        // x AND FALSE -> FALSE
214        if is_boolean_false(&left) || is_boolean_false(&right) {
215            return bool_false();
216        }
217
218        // 0 AND x -> FALSE (in boolean context)
219        // x AND 0 -> FALSE
220        if is_zero(&left) || is_zero(&right) {
221            return bool_false();
222        }
223
224        // NULL AND NULL -> NULL
225        // NULL AND TRUE -> NULL
226        // TRUE AND NULL -> NULL
227        if (is_null(&left) && is_null(&right))
228            || (is_null(&left) && is_boolean_true(&right))
229            || (is_boolean_true(&left) && is_null(&right))
230        {
231            return null();
232        }
233
234        // TRUE AND x -> x (only when left is actually boolean TRUE)
235        if is_boolean_true(&left) {
236            return right;
237        }
238
239        // x AND TRUE -> x (only when right is actually boolean TRUE)
240        if is_boolean_true(&right) {
241            return left;
242        }
243
244        // A AND NOT A -> FALSE (complement elimination)
245        if is_complement(&left, &right) || is_complement(&right, &left) {
246            return bool_false();
247        }
248
249        // A AND A -> A (idempotent)
250        if expressions_equal(&left, &right) {
251            return left;
252        }
253
254        // Apply absorption rules
255        // A AND (A OR B) -> A
256        // A AND (NOT A OR B) -> A AND B
257        absorb_and_eliminate_and(left, right)
258    }
259
260    /// Simplify OR operation
261    fn simplify_or(&mut self, op: BinaryOp) -> Expression {
262        let left = self.simplify_once(op.left);
263        let right = self.simplify_once(op.right);
264
265        // TRUE OR x -> TRUE (only when left is actually boolean TRUE)
266        if is_boolean_true(&left) {
267            return bool_true();
268        }
269
270        // x OR TRUE -> TRUE (only when right is actually boolean TRUE)
271        if is_boolean_true(&right) {
272            return bool_true();
273        }
274
275        // NULL OR NULL -> NULL
276        // NULL OR FALSE -> NULL
277        // FALSE OR NULL -> NULL
278        if (is_null(&left) && is_null(&right))
279            || (is_null(&left) && is_boolean_false(&right))
280            || (is_boolean_false(&left) && is_null(&right))
281        {
282            return null();
283        }
284
285        // FALSE OR x -> x (only when left is actually boolean FALSE)
286        if is_boolean_false(&left) {
287            return right;
288        }
289
290        // x OR FALSE -> x (only when right is actually boolean FALSE)
291        if is_boolean_false(&right) {
292            return left;
293        }
294
295        // A OR A -> A (idempotent)
296        if expressions_equal(&left, &right) {
297            return left;
298        }
299
300        // Apply absorption rules
301        // A OR (A AND B) -> A
302        // A OR (NOT A AND B) -> A OR B
303        absorb_and_eliminate_or(left, right)
304    }
305
306    /// Simplify NOT operation (De Morgan's laws)
307    fn simplify_not(&mut self, op: UnaryOp) -> Expression {
308        // Check for De Morgan's laws BEFORE simplifying inner expression
309        // This prevents constant folding from eliminating the comparison operator
310        match &op.this {
311            // NOT (a = b) -> a != b
312            Expression::Eq(inner_op) => {
313                let left = self.simplify_once(inner_op.left.clone());
314                let right = self.simplify_once(inner_op.right.clone());
315                return Expression::Neq(Box::new(BinaryOp::new(left, right)));
316            }
317            // NOT (a != b) -> a = b
318            Expression::Neq(inner_op) => {
319                let left = self.simplify_once(inner_op.left.clone());
320                let right = self.simplify_once(inner_op.right.clone());
321                return Expression::Eq(Box::new(BinaryOp::new(left, right)));
322            }
323            // NOT (a > b) -> a <= b
324            Expression::Gt(inner_op) => {
325                let left = self.simplify_once(inner_op.left.clone());
326                let right = self.simplify_once(inner_op.right.clone());
327                return Expression::Lte(Box::new(BinaryOp::new(left, right)));
328            }
329            // NOT (a >= b) -> a < b
330            Expression::Gte(inner_op) => {
331                let left = self.simplify_once(inner_op.left.clone());
332                let right = self.simplify_once(inner_op.right.clone());
333                return Expression::Lt(Box::new(BinaryOp::new(left, right)));
334            }
335            // NOT (a < b) -> a >= b
336            Expression::Lt(inner_op) => {
337                let left = self.simplify_once(inner_op.left.clone());
338                let right = self.simplify_once(inner_op.right.clone());
339                return Expression::Gte(Box::new(BinaryOp::new(left, right)));
340            }
341            // NOT (a <= b) -> a > b
342            Expression::Lte(inner_op) => {
343                let left = self.simplify_once(inner_op.left.clone());
344                let right = self.simplify_once(inner_op.right.clone());
345                return Expression::Gt(Box::new(BinaryOp::new(left, right)));
346            }
347            _ => {}
348        }
349
350        // Now simplify the inner expression for other patterns
351        let inner = self.simplify_once(op.this);
352
353        // NOT NULL -> NULL (with TRUE for SQL semantics)
354        if is_null(&inner) {
355            return null();
356        }
357
358        // NOT TRUE -> FALSE (only for boolean TRUE literal)
359        if is_boolean_true(&inner) {
360            return bool_false();
361        }
362
363        // NOT FALSE -> TRUE (only for boolean FALSE literal)
364        if is_boolean_false(&inner) {
365            return bool_true();
366        }
367
368        // NOT NOT x -> x (double negation elimination)
369        if let Expression::Not(inner_not) = &inner {
370            return inner_not.this.clone();
371        }
372
373        Expression::Not(Box::new(UnaryOp { this: inner }))
374    }
375
376    /// Simplify addition (constant folding)
377    fn simplify_add(&mut self, op: BinaryOp) -> Expression {
378        let left = self.simplify_once(op.left);
379        let right = self.simplify_once(op.right);
380
381        // Try constant folding for numbers
382        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
383            return Expression::Literal(Literal::Number((a + b).to_string()));
384        }
385
386        // x + 0 -> x
387        if is_zero(&right) {
388            return left;
389        }
390
391        // 0 + x -> x
392        if is_zero(&left) {
393            return right;
394        }
395
396        Expression::Add(Box::new(BinaryOp::new(left, right)))
397    }
398
399    /// Simplify subtraction (constant folding)
400    fn simplify_sub(&mut self, op: BinaryOp) -> Expression {
401        let left = self.simplify_once(op.left);
402        let right = self.simplify_once(op.right);
403
404        // Try constant folding for numbers
405        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
406            return Expression::Literal(Literal::Number((a - b).to_string()));
407        }
408
409        // x - 0 -> x
410        if is_zero(&right) {
411            return left;
412        }
413
414        // x - x -> 0 (only for literals/constants)
415        if expressions_equal(&left, &right) {
416            if let Expression::Literal(Literal::Number(_)) = &left {
417                return Expression::Literal(Literal::Number("0".to_string()));
418            }
419        }
420
421        Expression::Sub(Box::new(BinaryOp::new(left, right)))
422    }
423
424    /// Simplify multiplication (constant folding)
425    fn simplify_mul(&mut self, op: BinaryOp) -> Expression {
426        let left = self.simplify_once(op.left);
427        let right = self.simplify_once(op.right);
428
429        // Try constant folding for numbers
430        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
431            return Expression::Literal(Literal::Number((a * b).to_string()));
432        }
433
434        // x * 0 -> 0
435        if is_zero(&right) {
436            return Expression::Literal(Literal::Number("0".to_string()));
437        }
438
439        // 0 * x -> 0
440        if is_zero(&left) {
441            return Expression::Literal(Literal::Number("0".to_string()));
442        }
443
444        // x * 1 -> x
445        if is_one(&right) {
446            return left;
447        }
448
449        // 1 * x -> x
450        if is_one(&left) {
451            return right;
452        }
453
454        Expression::Mul(Box::new(BinaryOp::new(left, right)))
455    }
456
457    /// Simplify division (constant folding)
458    fn simplify_div(&mut self, op: BinaryOp) -> Expression {
459        let left = self.simplify_once(op.left);
460        let right = self.simplify_once(op.right);
461
462        // Try constant folding for numbers (but not integer division)
463        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
464            // Only fold if both are floats to avoid integer division issues
465            if b != 0.0 && (a.fract() != 0.0 || b.fract() != 0.0) {
466                return Expression::Literal(Literal::Number((a / b).to_string()));
467            }
468        }
469
470        // 0 / x -> 0 (when x != 0)
471        if is_zero(&left) && !is_zero(&right) {
472            return Expression::Literal(Literal::Number("0".to_string()));
473        }
474
475        // x / 1 -> x
476        if is_one(&right) {
477            return left;
478        }
479
480        Expression::Div(Box::new(BinaryOp::new(left, right)))
481    }
482
483    /// Simplify negation
484    fn simplify_neg(&mut self, op: UnaryOp) -> Expression {
485        let inner = self.simplify_once(op.this);
486
487        // -(-x) -> x (double negation)
488        if let Expression::Neg(inner_neg) = inner {
489            return inner_neg.this;
490        }
491
492        // -(number) -> -number
493        if let Some(n) = get_number(&inner) {
494            return Expression::Literal(Literal::Number((-n).to_string()));
495        }
496
497        Expression::Neg(Box::new(UnaryOp { this: inner }))
498    }
499
500    /// Simplify comparison operations (constant folding)
501    fn simplify_comparison(&mut self, op: BinaryOp, operator: &str) -> Expression {
502        let left = self.simplify_once(op.left);
503        let right = self.simplify_once(op.right);
504
505        // Try constant folding for numbers
506        if let (Some(a), Some(b)) = (get_number(&left), get_number(&right)) {
507            if let Some(result) = eval_boolean_nums(operator, a, b) {
508                return result;
509            }
510        }
511
512        // Try constant folding for strings
513        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
514            if let Some(result) = eval_boolean_strings(operator, &a, &b) {
515                return result;
516            }
517        }
518
519        // For equality, try to solve simple equations (x + 1 = 3 -> x = 2)
520        if operator == "=" {
521            if let Some(simplified) = self.simplify_equality(left.clone(), right.clone()) {
522                return simplified;
523            }
524        }
525
526        // Reconstruct the comparison
527        let new_op = BinaryOp::new(left, right);
528
529        match operator {
530            "=" => Expression::Eq(Box::new(new_op)),
531            "!=" | "<>" => Expression::Neq(Box::new(new_op)),
532            ">" => Expression::Gt(Box::new(new_op)),
533            ">=" => Expression::Gte(Box::new(new_op)),
534            "<" => Expression::Lt(Box::new(new_op)),
535            "<=" => Expression::Lte(Box::new(new_op)),
536            _ => Expression::Eq(Box::new(new_op)),
537        }
538    }
539
540    /// Simplify CASE expression
541    fn simplify_case(&mut self, case: Case) -> Expression {
542        let mut new_whens = Vec::new();
543
544        for (cond, then_expr) in case.whens {
545            let simplified_cond = self.simplify_once(cond);
546
547            // If condition is always true, return the THEN expression
548            if always_true(&simplified_cond) {
549                return self.simplify_once(then_expr);
550            }
551
552            // If condition is always false, skip this WHEN clause
553            if always_false(&simplified_cond) {
554                continue;
555            }
556
557            new_whens.push((simplified_cond, self.simplify_once(then_expr)));
558        }
559
560        // If no WHEN clauses remain, return the ELSE expression (or NULL)
561        if new_whens.is_empty() {
562            return case.else_.map(|e| self.simplify_once(e)).unwrap_or_else(null);
563        }
564
565        Expression::Case(Box::new(Case {
566            operand: case.operand.map(|e| self.simplify_once(e)),
567            whens: new_whens,
568            else_: case.else_.map(|e| self.simplify_once(e)),
569        }))
570    }
571
572    /// Simplify string concatenation (Concat is || operator)
573    ///
574    /// Folds adjacent string literals:
575    /// - 'a' || 'b' -> 'ab'
576    /// - 'a' || 'b' || 'c' -> 'abc'
577    /// - '' || x -> x
578    /// - x || '' -> x
579    fn simplify_concat(&mut self, op: BinaryOp) -> Expression {
580        let left = self.simplify_once(op.left);
581        let right = self.simplify_once(op.right);
582
583        // Fold two string literals: 'a' || 'b' -> 'ab'
584        if let (Some(a), Some(b)) = (get_string(&left), get_string(&right)) {
585            return Expression::Literal(Literal::String(format!("{}{}", a, b)));
586        }
587
588        // '' || x -> x
589        if let Some(s) = get_string(&left) {
590            if s.is_empty() {
591                return right;
592            }
593        }
594
595        // x || '' -> x
596        if let Some(s) = get_string(&right) {
597            if s.is_empty() {
598                return left;
599            }
600        }
601
602        // NULL || x -> NULL, x || NULL -> NULL (SQL string concat semantics)
603        if is_null(&left) || is_null(&right) {
604            return null();
605        }
606
607        Expression::Concat(Box::new(BinaryOp::new(left, right)))
608    }
609
610    /// Simplify CONCAT_WS function
611    ///
612    /// CONCAT_WS(sep, a, b, c) -> concatenates with separator, skipping NULLs
613    /// - CONCAT_WS(',', 'a', 'b') -> 'a,b' (when all are literals)
614    /// - CONCAT_WS(',', 'a', NULL, 'b') -> 'a,b' (NULLs are skipped)
615    /// - CONCAT_WS(NULL, ...) -> NULL
616    fn simplify_concat_ws(&mut self, concat_ws: ConcatWs) -> Expression {
617        let separator = self.simplify_once(concat_ws.separator);
618
619        // If separator is NULL, result is NULL
620        if is_null(&separator) {
621            return null();
622        }
623
624        let expressions: Vec<Expression> = concat_ws
625            .expressions
626            .into_iter()
627            .map(|e| self.simplify_once(e))
628            .filter(|e| !is_null(e)) // Skip NULL values
629            .collect();
630
631        // If no expressions remain, return empty string
632        if expressions.is_empty() {
633            return Expression::Literal(Literal::String(String::new()));
634        }
635
636        // Try to fold if all are string literals
637        if let Some(sep) = get_string(&separator) {
638            let all_strings: Option<Vec<String>> = expressions
639                .iter()
640                .map(|e| get_string(e))
641                .collect();
642
643            if let Some(strings) = all_strings {
644                return Expression::Literal(Literal::String(strings.join(&sep)));
645            }
646        }
647
648        // Return simplified CONCAT_WS
649        Expression::ConcatWs(Box::new(ConcatWs {
650            separator,
651            expressions,
652        }))
653    }
654
655    /// Simplify parentheses
656    ///
657    /// Remove unnecessary parentheses:
658    /// - (x) -> x when x is a literal, column, or already parenthesized
659    /// - ((x)) -> (x) -> x (recursive simplification)
660    fn simplify_paren(&mut self, paren: Paren) -> Expression {
661        let inner = self.simplify_once(paren.this);
662
663        // If inner is a literal, column, boolean, null, or already parenthesized,
664        // we can remove the parentheses
665        match &inner {
666            Expression::Literal(_)
667            | Expression::Boolean(_)
668            | Expression::Null(_)
669            | Expression::Column(_)
670            | Expression::Paren(_) => inner,
671            // For other expressions, keep the parentheses
672            _ => Expression::Paren(Box::new(Paren {
673                this: inner,
674                trailing_comments: paren.trailing_comments,
675            })),
676        }
677    }
678
679    /// Simplify DATE_TRUNC and TIMESTAMP_TRUNC
680    ///
681    /// Currently just simplifies children and passes through.
682    /// Future: could fold DATE_TRUNC('day', '2024-01-15') -> '2024-01-15'
683    fn simplify_datetrunc(&mut self, dt: DateTruncFunc) -> Expression {
684        let inner = self.simplify_once(dt.this);
685
686        // For now, just return with simplified inner expression
687        // A more advanced implementation would fold constant date/timestamps
688        Expression::DateTrunc(Box::new(DateTruncFunc {
689            this: inner,
690            unit: dt.unit,
691        }))
692    }
693
694    /// Simplify equality with arithmetic (solve simple equations)
695    ///
696    /// - x + 1 = 3 -> x = 2
697    /// - x - 1 = 3 -> x = 4
698    /// - x * 2 = 6 -> x = 3 (only when divisible)
699    /// - 1 + x = 3 -> x = 2 (commutative)
700    fn simplify_equality(&mut self, left: Expression, right: Expression) -> Option<Expression> {
701        // Only works when right side is a constant
702        let right_val = get_number(&right)?;
703
704        // Check if left side is arithmetic with one constant
705        match left {
706            Expression::Add(ref op) => {
707                // x + c = r -> x = r - c
708                if let Some(c) = get_number(&op.right) {
709                    let new_right = Expression::Literal(Literal::Number((right_val - c).to_string()));
710                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
711                }
712                // c + x = r -> x = r - c
713                if let Some(c) = get_number(&op.left) {
714                    let new_right = Expression::Literal(Literal::Number((right_val - c).to_string()));
715                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
716                }
717            }
718            Expression::Sub(ref op) => {
719                // x - c = r -> x = r + c
720                if let Some(c) = get_number(&op.right) {
721                    let new_right = Expression::Literal(Literal::Number((right_val + c).to_string()));
722                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
723                }
724                // c - x = r -> x = c - r
725                if let Some(c) = get_number(&op.left) {
726                    let new_right = Expression::Literal(Literal::Number((c - right_val).to_string()));
727                    return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
728                }
729            }
730            Expression::Mul(ref op) => {
731                // x * c = r -> x = r / c (only for non-zero c and when divisible)
732                if let Some(c) = get_number(&op.right) {
733                    if c != 0.0 && right_val % c == 0.0 {
734                        let new_right = Expression::Literal(Literal::Number((right_val / c).to_string()));
735                        return Some(Expression::Eq(Box::new(BinaryOp::new(op.left.clone(), new_right))));
736                    }
737                }
738                // c * x = r -> x = r / c
739                if let Some(c) = get_number(&op.left) {
740                    if c != 0.0 && right_val % c == 0.0 {
741                        let new_right = Expression::Literal(Literal::Number((right_val / c).to_string()));
742                        return Some(Expression::Eq(Box::new(BinaryOp::new(op.right.clone(), new_right))));
743                    }
744                }
745            }
746            _ => {}
747        }
748
749        None
750    }
751
752    /// Recursively simplify children of an expression
753    fn simplify_children(&mut self, expr: Expression) -> Expression {
754        // For expressions we don't have specific simplification rules for,
755        // we still want to simplify their children
756        match expr {
757            Expression::Alias(mut alias) => {
758                alias.this = self.simplify_once(alias.this);
759                Expression::Alias(alias)
760            }
761            Expression::Between(mut between) => {
762                between.this = self.simplify_once(between.this);
763                between.low = self.simplify_once(between.low);
764                between.high = self.simplify_once(between.high);
765                Expression::Between(between)
766            }
767            Expression::In(mut in_expr) => {
768                in_expr.this = self.simplify_once(in_expr.this);
769                in_expr.expressions = in_expr
770                    .expressions
771                    .into_iter()
772                    .map(|e| self.simplify_once(e))
773                    .collect();
774                Expression::In(in_expr)
775            }
776            Expression::Function(mut func) => {
777                func.args = func.args.into_iter().map(|e| self.simplify_once(e)).collect();
778                Expression::Function(func)
779            }
780            // For other expressions, return as-is for now
781            other => other,
782        }
783    }
784}
785
786/// Check if expression equals 1
787fn is_one(expr: &Expression) -> bool {
788    match expr {
789        Expression::Literal(Literal::Number(n)) => {
790            if let Ok(num) = n.parse::<f64>() {
791                num == 1.0
792            } else {
793                false
794            }
795        }
796        _ => false,
797    }
798}
799
800/// Get numeric value from expression if it's a number literal
801fn get_number(expr: &Expression) -> Option<f64> {
802    match expr {
803        Expression::Literal(Literal::Number(n)) => n.parse().ok(),
804        _ => None,
805    }
806}
807
808/// Get string value from expression if it's a string literal
809fn get_string(expr: &Expression) -> Option<String> {
810    match expr {
811        Expression::Literal(Literal::String(s)) => Some(s.clone()),
812        _ => None,
813    }
814}
815
816/// Check if two expressions are structurally equal
817/// This is a simplified comparison - a full implementation would need deep comparison
818fn expressions_equal(a: &Expression, b: &Expression) -> bool {
819    // For now, use Debug representation for comparison
820    // A proper implementation would do structural comparison
821    format!("{:?}", a) == format!("{:?}", b)
822}
823
824/// Flatten nested AND expressions into a list of operands
825/// e.g., (A AND (B AND C)) -> [A, B, C]
826fn flatten_and(expr: &Expression) -> Vec<Expression> {
827    match expr {
828        Expression::And(op) => {
829            let mut result = flatten_and(&op.left);
830            result.extend(flatten_and(&op.right));
831            result
832        }
833        other => vec![other.clone()],
834    }
835}
836
837/// Flatten nested OR expressions into a list of operands
838/// e.g., (A OR (B OR C)) -> [A, B, C]
839fn flatten_or(expr: &Expression) -> Vec<Expression> {
840    match expr {
841        Expression::Or(op) => {
842            let mut result = flatten_or(&op.left);
843            result.extend(flatten_or(&op.right));
844            result
845        }
846        other => vec![other.clone()],
847    }
848}
849
850/// Rebuild an AND expression from a list of operands
851fn rebuild_and(operands: Vec<Expression>) -> Expression {
852    if operands.is_empty() {
853        return bool_true(); // Empty AND is TRUE
854    }
855    let mut result = operands.into_iter();
856    let first = result.next().unwrap();
857    result.fold(first, |acc, op| {
858        Expression::And(Box::new(BinaryOp::new(acc, op)))
859    })
860}
861
862/// Rebuild an OR expression from a list of operands
863fn rebuild_or(operands: Vec<Expression>) -> Expression {
864    if operands.is_empty() {
865        return bool_false(); // Empty OR is FALSE
866    }
867    let mut result = operands.into_iter();
868    let first = result.next().unwrap();
869    result.fold(first, |acc, op| {
870        Expression::Or(Box::new(BinaryOp::new(acc, op)))
871    })
872}
873
874/// Get the inner expression of a NOT, if it is one
875fn get_not_inner(expr: &Expression) -> Option<&Expression> {
876    match expr {
877        Expression::Not(op) => Some(&op.this),
878        _ => None,
879    }
880}
881
882/// Apply Boolean absorption and elimination rules to an AND expression
883///
884/// Absorption:
885///   A AND (A OR B) -> A
886///   A AND (NOT A OR B) -> A AND B
887///
888/// Elimination:
889///   (A OR B) AND (A OR NOT B) -> A
890pub fn absorb_and_eliminate_and(left: Expression, right: Expression) -> Expression {
891    // Flatten both sides
892    let left_ops = flatten_and(&left);
893    let right_ops = flatten_and(&right);
894    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
895
896    // Build a set of string representations for quick lookup
897    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
898
899    let mut result_ops: Vec<Expression> = Vec::new();
900    let mut absorbed = std::collections::HashSet::new();
901
902    for (i, op) in all_ops.iter().enumerate() {
903        let op_str = gen(op);
904
905        // Skip if already absorbed
906        if absorbed.contains(&op_str) {
907            continue;
908        }
909
910        // Check if this is an OR expression (potential absorption target)
911        if let Expression::Or(_) = op {
912            let or_operands = flatten_or(op);
913
914            // Absorption: A AND (A OR B) -> A
915            // Check if any OR operand is already in our AND operands
916            let absorbed_by_existing = or_operands.iter().any(|or_op| {
917                let or_op_str = gen(or_op);
918                // Check if this OR operand exists in other AND operands (not this OR itself)
919                all_ops.iter().enumerate().any(|(j, other)| {
920                    i != j && gen(other) == or_op_str
921                })
922            });
923
924            if absorbed_by_existing {
925                // This OR is absorbed, skip it
926                absorbed.insert(op_str);
927                continue;
928            }
929
930            // Absorption with complement: A AND (NOT A OR B) -> A AND B
931            // Check if any OR operand's complement is in our AND operands
932            let mut remaining_or_ops: Vec<Expression> = Vec::new();
933            let mut had_complement_absorption = false;
934
935            for or_op in or_operands {
936                let complement_str = if let Some(inner) = get_not_inner(&or_op) {
937                    // or_op is NOT X, complement is X
938                    gen(inner)
939                } else {
940                    // or_op is X, complement is NOT X
941                    format!("NOT {}", gen(&or_op))
942                };
943
944                // Check if complement exists in our AND operands
945                let has_complement = all_ops.iter().enumerate().any(|(j, other)| {
946                    i != j && gen(other) == complement_str
947                }) || op_strings.contains(&complement_str);
948
949                if has_complement {
950                    // This OR operand's complement exists, so this term becomes TRUE in AND context
951                    // NOT A OR B, where A exists, becomes TRUE OR B (when A is true) or B (when A is false)
952                    // Actually: A AND (NOT A OR B) -> A AND B, so we drop NOT A from the OR
953                    had_complement_absorption = true;
954                    // Drop this operand from OR
955                } else {
956                    remaining_or_ops.push(or_op);
957                }
958            }
959
960            if had_complement_absorption {
961                if remaining_or_ops.is_empty() {
962                    // All OR operands were absorbed, the OR becomes TRUE
963                    // A AND TRUE -> A, so we just skip adding this
964                    absorbed.insert(op_str);
965                    continue;
966                } else if remaining_or_ops.len() == 1 {
967                    // Single remaining operand
968                    result_ops.push(remaining_or_ops.into_iter().next().unwrap());
969                    absorbed.insert(op_str);
970                    continue;
971                } else {
972                    // Rebuild the OR with remaining operands
973                    result_ops.push(rebuild_or(remaining_or_ops));
974                    absorbed.insert(op_str);
975                    continue;
976                }
977            }
978        }
979
980        result_ops.push(op.clone());
981    }
982
983    // Deduplicate
984    let mut seen = std::collections::HashSet::new();
985    result_ops.retain(|op| seen.insert(gen(op)));
986
987    if result_ops.is_empty() {
988        bool_true()
989    } else {
990        rebuild_and(result_ops)
991    }
992}
993
994/// Apply Boolean absorption and elimination rules to an OR expression
995///
996/// Absorption:
997///   A OR (A AND B) -> A
998///   A OR (NOT A AND B) -> A OR B
999///
1000/// Elimination:
1001///   (A AND B) OR (A AND NOT B) -> A
1002pub fn absorb_and_eliminate_or(left: Expression, right: Expression) -> Expression {
1003    // Flatten both sides
1004    let left_ops = flatten_or(&left);
1005    let right_ops = flatten_or(&right);
1006    let all_ops: Vec<Expression> = left_ops.iter().chain(right_ops.iter()).cloned().collect();
1007
1008    // Build a set of string representations for quick lookup
1009    let op_strings: std::collections::HashSet<String> = all_ops.iter().map(gen).collect();
1010
1011    let mut result_ops: Vec<Expression> = Vec::new();
1012    let mut absorbed = std::collections::HashSet::new();
1013
1014    for (i, op) in all_ops.iter().enumerate() {
1015        let op_str = gen(op);
1016
1017        // Skip if already absorbed
1018        if absorbed.contains(&op_str) {
1019            continue;
1020        }
1021
1022        // Check if this is an AND expression (potential absorption target)
1023        if let Expression::And(_) = op {
1024            let and_operands = flatten_and(op);
1025
1026            // Absorption: A OR (A AND B) -> A
1027            // Check if any AND operand is already in our OR operands
1028            let absorbed_by_existing = and_operands.iter().any(|and_op| {
1029                let and_op_str = gen(and_op);
1030                // Check if this AND operand exists in other OR operands (not this AND itself)
1031                all_ops.iter().enumerate().any(|(j, other)| {
1032                    i != j && gen(other) == and_op_str
1033                })
1034            });
1035
1036            if absorbed_by_existing {
1037                // This AND is absorbed, skip it
1038                absorbed.insert(op_str);
1039                continue;
1040            }
1041
1042            // Absorption with complement: A OR (NOT A AND B) -> A OR B
1043            // Check if any AND operand's complement is in our OR operands
1044            let mut remaining_and_ops: Vec<Expression> = Vec::new();
1045            let mut had_complement_absorption = false;
1046
1047            for and_op in and_operands {
1048                let complement_str = if let Some(inner) = get_not_inner(&and_op) {
1049                    // and_op is NOT X, complement is X
1050                    gen(inner)
1051                } else {
1052                    // and_op is X, complement is NOT X
1053                    format!("NOT {}", gen(&and_op))
1054                };
1055
1056                // Check if complement exists in our OR operands
1057                let has_complement = all_ops.iter().enumerate().any(|(j, other)| {
1058                    i != j && gen(other) == complement_str
1059                }) || op_strings.contains(&complement_str);
1060
1061                if has_complement {
1062                    // This AND operand's complement exists, so this term becomes FALSE in OR context
1063                    // A OR (NOT A AND B) -> A OR B, so we drop NOT A from the AND
1064                    had_complement_absorption = true;
1065                    // Drop this operand from AND
1066                } else {
1067                    remaining_and_ops.push(and_op);
1068                }
1069            }
1070
1071            if had_complement_absorption {
1072                if remaining_and_ops.is_empty() {
1073                    // All AND operands were absorbed, the AND becomes FALSE
1074                    // A OR FALSE -> A, so we just skip adding this
1075                    absorbed.insert(op_str);
1076                    continue;
1077                } else if remaining_and_ops.len() == 1 {
1078                    // Single remaining operand
1079                    result_ops.push(remaining_and_ops.into_iter().next().unwrap());
1080                    absorbed.insert(op_str);
1081                    continue;
1082                } else {
1083                    // Rebuild the AND with remaining operands
1084                    result_ops.push(rebuild_and(remaining_and_ops));
1085                    absorbed.insert(op_str);
1086                    continue;
1087                }
1088            }
1089        }
1090
1091        result_ops.push(op.clone());
1092    }
1093
1094    // Deduplicate
1095    let mut seen = std::collections::HashSet::new();
1096    result_ops.retain(|op| seen.insert(gen(op)));
1097
1098    if result_ops.is_empty() {
1099        bool_false()
1100    } else {
1101        rebuild_or(result_ops)
1102    }
1103}
1104
1105/// Generate a simple string representation of an expression for sorting/deduping
1106pub fn gen(expr: &Expression) -> String {
1107    match expr {
1108        Expression::Literal(lit) => match lit {
1109            Literal::String(s) => format!("'{}'", s),
1110            Literal::Number(n) => n.clone(),
1111            _ => format!("{:?}", lit),
1112        },
1113        Expression::Boolean(b) => if b.value { "TRUE" } else { "FALSE" }.to_string(),
1114        Expression::Null(_) => "NULL".to_string(),
1115        Expression::Column(col) => {
1116            if let Some(ref table) = col.table {
1117                format!("{}.{}", table.name, col.name.name)
1118            } else {
1119                col.name.name.clone()
1120            }
1121        }
1122        Expression::And(op) => format!("({} AND {})", gen(&op.left), gen(&op.right)),
1123        Expression::Or(op) => format!("({} OR {})", gen(&op.left), gen(&op.right)),
1124        Expression::Not(op) => format!("NOT {}", gen(&op.this)),
1125        Expression::Eq(op) => format!("{} = {}", gen(&op.left), gen(&op.right)),
1126        Expression::Neq(op) => format!("{} <> {}", gen(&op.left), gen(&op.right)),
1127        Expression::Gt(op) => format!("{} > {}", gen(&op.left), gen(&op.right)),
1128        Expression::Gte(op) => format!("{} >= {}", gen(&op.left), gen(&op.right)),
1129        Expression::Lt(op) => format!("{} < {}", gen(&op.left), gen(&op.right)),
1130        Expression::Lte(op) => format!("{} <= {}", gen(&op.left), gen(&op.right)),
1131        Expression::Add(op) => format!("{} + {}", gen(&op.left), gen(&op.right)),
1132        Expression::Sub(op) => format!("{} - {}", gen(&op.left), gen(&op.right)),
1133        Expression::Mul(op) => format!("{} * {}", gen(&op.left), gen(&op.right)),
1134        Expression::Div(op) => format!("{} / {}", gen(&op.left), gen(&op.right)),
1135        Expression::Function(f) => {
1136            let args: Vec<String> = f.args.iter().map(|a| gen(a)).collect();
1137            format!("{}({})", f.name.to_uppercase(), args.join(", "))
1138        }
1139        _ => format!("{:?}", expr),
1140    }
1141}
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146
1147    fn make_int(val: i64) -> Expression {
1148        Expression::Literal(Literal::Number(val.to_string()))
1149    }
1150
1151    fn make_float(val: f64) -> Expression {
1152        Expression::Literal(Literal::Number(val.to_string()))
1153    }
1154
1155    fn make_string(val: &str) -> Expression {
1156        Expression::Literal(Literal::String(val.to_string()))
1157    }
1158
1159    fn make_bool(val: bool) -> Expression {
1160        Expression::Boolean(BooleanLiteral { value: val })
1161    }
1162
1163    fn make_column(name: &str) -> Expression {
1164        use crate::expressions::{Column, Identifier};
1165        Expression::Column(Column {
1166            name: Identifier::new(name),
1167            table: None,
1168            join_mark: false,
1169            trailing_comments: vec![],
1170        })
1171    }
1172
1173    #[test]
1174    fn test_always_true_false() {
1175        assert!(always_true(&make_bool(true)));
1176        assert!(!always_true(&make_bool(false)));
1177        assert!(always_true(&make_int(1)));
1178        assert!(!always_true(&make_int(0)));
1179
1180        assert!(always_false(&make_bool(false)));
1181        assert!(!always_false(&make_bool(true)));
1182        assert!(always_false(&null()));
1183        assert!(always_false(&make_int(0)));
1184    }
1185
1186    #[test]
1187    fn test_simplify_and_with_true() {
1188        let mut simplifier = Simplifier::new(None);
1189
1190        // TRUE AND TRUE -> TRUE
1191        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(true))));
1192        let result = simplifier.simplify(expr);
1193        assert!(always_true(&result));
1194
1195        // TRUE AND FALSE -> FALSE
1196        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1197        let result = simplifier.simplify(expr);
1198        assert!(always_false(&result));
1199
1200        // TRUE AND x -> x
1201        let x = make_int(42);
1202        let expr = Expression::And(Box::new(BinaryOp::new(make_bool(true), x.clone())));
1203        let result = simplifier.simplify(expr);
1204        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1205    }
1206
1207    #[test]
1208    fn test_simplify_or_with_false() {
1209        let mut simplifier = Simplifier::new(None);
1210
1211        // FALSE OR FALSE -> FALSE
1212        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(false))));
1213        let result = simplifier.simplify(expr);
1214        assert!(always_false(&result));
1215
1216        // FALSE OR TRUE -> TRUE
1217        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), make_bool(true))));
1218        let result = simplifier.simplify(expr);
1219        assert!(always_true(&result));
1220
1221        // FALSE OR x -> x
1222        let x = make_int(42);
1223        let expr = Expression::Or(Box::new(BinaryOp::new(make_bool(false), x.clone())));
1224        let result = simplifier.simplify(expr);
1225        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1226    }
1227
1228    #[test]
1229    fn test_simplify_not() {
1230        let mut simplifier = Simplifier::new(None);
1231
1232        // NOT TRUE -> FALSE
1233        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(true))));
1234        let result = simplifier.simplify(expr);
1235        assert!(is_false(&result));
1236
1237        // NOT FALSE -> TRUE
1238        let expr = Expression::Not(Box::new(UnaryOp::new(make_bool(false))));
1239        let result = simplifier.simplify(expr);
1240        assert!(always_true(&result));
1241
1242        // NOT NOT x -> x
1243        let x = make_int(42);
1244        let inner_not = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1245        let expr = Expression::Not(Box::new(UnaryOp::new(inner_not)));
1246        let result = simplifier.simplify(expr);
1247        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1248    }
1249
1250    #[test]
1251    fn test_simplify_demorgan_comparison() {
1252        let mut simplifier = Simplifier::new(None);
1253
1254        // NOT (a = b) -> a != b (using columns to avoid constant folding)
1255        let a = make_column("a");
1256        let b = make_column("b");
1257        let eq = Expression::Eq(Box::new(BinaryOp::new(a.clone(), b.clone())));
1258        let expr = Expression::Not(Box::new(UnaryOp::new(eq)));
1259        let result = simplifier.simplify(expr);
1260        assert!(matches!(result, Expression::Neq(_)));
1261
1262        // NOT (a > b) -> a <= b
1263        let gt = Expression::Gt(Box::new(BinaryOp::new(a, b)));
1264        let expr = Expression::Not(Box::new(UnaryOp::new(gt)));
1265        let result = simplifier.simplify(expr);
1266        assert!(matches!(result, Expression::Lte(_)));
1267    }
1268
1269    #[test]
1270    fn test_constant_folding_add() {
1271        let mut simplifier = Simplifier::new(None);
1272
1273        // 1 + 2 -> 3
1274        let expr = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1275        let result = simplifier.simplify(expr);
1276        assert_eq!(get_number(&result), Some(3.0));
1277
1278        // x + 0 -> x
1279        let x = make_int(42);
1280        let expr = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(0))));
1281        let result = simplifier.simplify(expr);
1282        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1283    }
1284
1285    #[test]
1286    fn test_constant_folding_mul() {
1287        let mut simplifier = Simplifier::new(None);
1288
1289        // 3 * 4 -> 12
1290        let expr = Expression::Mul(Box::new(BinaryOp::new(make_int(3), make_int(4))));
1291        let result = simplifier.simplify(expr);
1292        assert_eq!(get_number(&result), Some(12.0));
1293
1294        // x * 0 -> 0
1295        let x = make_int(42);
1296        let expr = Expression::Mul(Box::new(BinaryOp::new(x, make_int(0))));
1297        let result = simplifier.simplify(expr);
1298        assert_eq!(get_number(&result), Some(0.0));
1299
1300        // x * 1 -> x
1301        let x = make_int(42);
1302        let expr = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1303        let result = simplifier.simplify(expr);
1304        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1305    }
1306
1307    #[test]
1308    fn test_constant_folding_comparison() {
1309        let mut simplifier = Simplifier::new(None);
1310
1311        // 1 = 1 -> TRUE
1312        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(1))));
1313        let result = simplifier.simplify(expr);
1314        assert!(always_true(&result));
1315
1316        // 1 = 2 -> FALSE
1317        let expr = Expression::Eq(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1318        let result = simplifier.simplify(expr);
1319        assert!(is_false(&result));
1320
1321        // 3 > 2 -> TRUE
1322        let expr = Expression::Gt(Box::new(BinaryOp::new(make_int(3), make_int(2))));
1323        let result = simplifier.simplify(expr);
1324        assert!(always_true(&result));
1325
1326        // 'a' = 'a' -> TRUE
1327        let expr = Expression::Eq(Box::new(BinaryOp::new(
1328            make_string("abc"),
1329            make_string("abc"),
1330        )));
1331        let result = simplifier.simplify(expr);
1332        assert!(always_true(&result));
1333    }
1334
1335    #[test]
1336    fn test_simplify_negation() {
1337        let mut simplifier = Simplifier::new(None);
1338
1339        // -(-5) -> 5
1340        let inner = Expression::Neg(Box::new(UnaryOp::new(make_int(5))));
1341        let expr = Expression::Neg(Box::new(UnaryOp::new(inner)));
1342        let result = simplifier.simplify(expr);
1343        assert_eq!(get_number(&result), Some(5.0));
1344
1345        // -(3) -> -3
1346        let expr = Expression::Neg(Box::new(UnaryOp::new(make_int(3))));
1347        let result = simplifier.simplify(expr);
1348        assert_eq!(get_number(&result), Some(-3.0));
1349    }
1350
1351    #[test]
1352    fn test_gen_simple() {
1353        assert_eq!(gen(&make_int(42)), "42");
1354        assert_eq!(gen(&make_string("hello")), "'hello'");
1355        assert_eq!(gen(&make_bool(true)), "TRUE");
1356        assert_eq!(gen(&make_bool(false)), "FALSE");
1357        assert_eq!(gen(&null()), "NULL");
1358    }
1359
1360    #[test]
1361    fn test_gen_operations() {
1362        let add = Expression::Add(Box::new(BinaryOp::new(make_int(1), make_int(2))));
1363        assert_eq!(gen(&add), "1 + 2");
1364
1365        let and = Expression::And(Box::new(BinaryOp::new(make_bool(true), make_bool(false))));
1366        assert_eq!(gen(&and), "(TRUE AND FALSE)");
1367    }
1368
1369    #[test]
1370    fn test_complement_elimination() {
1371        let mut simplifier = Simplifier::new(None);
1372
1373        // x AND NOT x -> FALSE
1374        let x = make_int(42);
1375        let not_x = Expression::Not(Box::new(UnaryOp::new(x.clone())));
1376        let expr = Expression::And(Box::new(BinaryOp::new(x, not_x)));
1377        let result = simplifier.simplify(expr);
1378        assert!(is_false(&result));
1379    }
1380
1381    #[test]
1382    fn test_idempotent() {
1383        let mut simplifier = Simplifier::new(None);
1384
1385        // x AND x -> x
1386        let x = make_int(42);
1387        let expr = Expression::And(Box::new(BinaryOp::new(x.clone(), x.clone())));
1388        let result = simplifier.simplify(expr);
1389        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1390
1391        // x OR x -> x
1392        let x = make_int(42);
1393        let expr = Expression::Or(Box::new(BinaryOp::new(x.clone(), x.clone())));
1394        let result = simplifier.simplify(expr);
1395        assert_eq!(format!("{:?}", result), format!("{:?}", x));
1396    }
1397
1398    #[test]
1399    fn test_absorption_and() {
1400        let mut simplifier = Simplifier::new(None);
1401
1402        // A AND (A OR B) -> A
1403        let a = make_column("a");
1404        let b = make_column("b");
1405        let a_or_b = Expression::Or(Box::new(BinaryOp::new(a.clone(), b.clone())));
1406        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), a_or_b)));
1407        let result = simplifier.simplify(expr);
1408        // Result should be just A
1409        assert_eq!(gen(&result), gen(&a));
1410    }
1411
1412    #[test]
1413    fn test_absorption_or() {
1414        let mut simplifier = Simplifier::new(None);
1415
1416        // A OR (A AND B) -> A
1417        let a = make_column("a");
1418        let b = make_column("b");
1419        let a_and_b = Expression::And(Box::new(BinaryOp::new(a.clone(), b.clone())));
1420        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), a_and_b)));
1421        let result = simplifier.simplify(expr);
1422        // Result should be just A
1423        assert_eq!(gen(&result), gen(&a));
1424    }
1425
1426    #[test]
1427    fn test_absorption_with_complement_and() {
1428        let mut simplifier = Simplifier::new(None);
1429
1430        // A AND (NOT A OR B) -> A AND B
1431        let a = make_column("a");
1432        let b = make_column("b");
1433        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1434        let not_a_or_b = Expression::Or(Box::new(BinaryOp::new(not_a, b.clone())));
1435        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), not_a_or_b)));
1436        let result = simplifier.simplify(expr);
1437        // Result should be A AND B
1438        let expected = Expression::And(Box::new(BinaryOp::new(a, b)));
1439        assert_eq!(gen(&result), gen(&expected));
1440    }
1441
1442    #[test]
1443    fn test_absorption_with_complement_or() {
1444        let mut simplifier = Simplifier::new(None);
1445
1446        // A OR (NOT A AND B) -> A OR B
1447        let a = make_column("a");
1448        let b = make_column("b");
1449        let not_a = Expression::Not(Box::new(UnaryOp::new(a.clone())));
1450        let not_a_and_b = Expression::And(Box::new(BinaryOp::new(not_a, b.clone())));
1451        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), not_a_and_b)));
1452        let result = simplifier.simplify(expr);
1453        // Result should be A OR B
1454        let expected = Expression::Or(Box::new(BinaryOp::new(a, b)));
1455        assert_eq!(gen(&result), gen(&expected));
1456    }
1457
1458    #[test]
1459    fn test_flatten_and() {
1460        // (A AND (B AND C)) should flatten to [A, B, C]
1461        let a = make_column("a");
1462        let b = make_column("b");
1463        let c = make_column("c");
1464        let b_and_c = Expression::And(Box::new(BinaryOp::new(b.clone(), c.clone())));
1465        let expr = Expression::And(Box::new(BinaryOp::new(a.clone(), b_and_c)));
1466        let flattened = flatten_and(&expr);
1467        assert_eq!(flattened.len(), 3);
1468        assert_eq!(gen(&flattened[0]), "a");
1469        assert_eq!(gen(&flattened[1]), "b");
1470        assert_eq!(gen(&flattened[2]), "c");
1471    }
1472
1473    #[test]
1474    fn test_flatten_or() {
1475        // (A OR (B OR C)) should flatten to [A, B, C]
1476        let a = make_column("a");
1477        let b = make_column("b");
1478        let c = make_column("c");
1479        let b_or_c = Expression::Or(Box::new(BinaryOp::new(b.clone(), c.clone())));
1480        let expr = Expression::Or(Box::new(BinaryOp::new(a.clone(), b_or_c)));
1481        let flattened = flatten_or(&expr);
1482        assert_eq!(flattened.len(), 3);
1483        assert_eq!(gen(&flattened[0]), "a");
1484        assert_eq!(gen(&flattened[1]), "b");
1485        assert_eq!(gen(&flattened[2]), "c");
1486    }
1487
1488    #[test]
1489    fn test_simplify_concat() {
1490        let mut simplifier = Simplifier::new(None);
1491
1492        // 'a' || 'b' -> 'ab'
1493        let expr = Expression::Concat(Box::new(BinaryOp::new(
1494            make_string("hello"),
1495            make_string("world"),
1496        )));
1497        let result = simplifier.simplify(expr);
1498        assert_eq!(get_string(&result), Some("helloworld".to_string()));
1499
1500        // '' || x -> x
1501        let x = make_string("test");
1502        let expr = Expression::Concat(Box::new(BinaryOp::new(make_string(""), x.clone())));
1503        let result = simplifier.simplify(expr);
1504        assert_eq!(get_string(&result), Some("test".to_string()));
1505
1506        // x || '' -> x
1507        let expr = Expression::Concat(Box::new(BinaryOp::new(x, make_string(""))));
1508        let result = simplifier.simplify(expr);
1509        assert_eq!(get_string(&result), Some("test".to_string()));
1510
1511        // NULL || x -> NULL
1512        let expr = Expression::Concat(Box::new(BinaryOp::new(null(), make_string("test"))));
1513        let result = simplifier.simplify(expr);
1514        assert!(is_null(&result));
1515    }
1516
1517    #[test]
1518    fn test_simplify_concat_ws() {
1519        let mut simplifier = Simplifier::new(None);
1520
1521        // CONCAT_WS(',', 'a', 'b', 'c') -> 'a,b,c'
1522        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1523            separator: make_string(","),
1524            expressions: vec![make_string("a"), make_string("b"), make_string("c")],
1525        }));
1526        let result = simplifier.simplify(expr);
1527        assert_eq!(get_string(&result), Some("a,b,c".to_string()));
1528
1529        // CONCAT_WS with NULL separator -> NULL
1530        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1531            separator: null(),
1532            expressions: vec![make_string("a"), make_string("b")],
1533        }));
1534        let result = simplifier.simplify(expr);
1535        assert!(is_null(&result));
1536
1537        // CONCAT_WS with empty expressions -> ''
1538        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1539            separator: make_string(","),
1540            expressions: vec![],
1541        }));
1542        let result = simplifier.simplify(expr);
1543        assert_eq!(get_string(&result), Some("".to_string()));
1544
1545        // CONCAT_WS skips NULLs
1546        let expr = Expression::ConcatWs(Box::new(ConcatWs {
1547            separator: make_string("-"),
1548            expressions: vec![make_string("a"), null(), make_string("b")],
1549        }));
1550        let result = simplifier.simplify(expr);
1551        assert_eq!(get_string(&result), Some("a-b".to_string()));
1552    }
1553
1554    #[test]
1555    fn test_simplify_paren() {
1556        let mut simplifier = Simplifier::new(None);
1557
1558        // (42) -> 42
1559        let expr = Expression::Paren(Box::new(Paren {
1560            this: make_int(42),
1561            trailing_comments: vec![],
1562        }));
1563        let result = simplifier.simplify(expr);
1564        assert_eq!(get_number(&result), Some(42.0));
1565
1566        // (TRUE) -> TRUE
1567        let expr = Expression::Paren(Box::new(Paren {
1568            this: make_bool(true),
1569            trailing_comments: vec![],
1570        }));
1571        let result = simplifier.simplify(expr);
1572        assert!(is_boolean_true(&result));
1573
1574        // (NULL) -> NULL
1575        let expr = Expression::Paren(Box::new(Paren {
1576            this: null(),
1577            trailing_comments: vec![],
1578        }));
1579        let result = simplifier.simplify(expr);
1580        assert!(is_null(&result));
1581
1582        // ((x)) -> x
1583        let inner_paren = Expression::Paren(Box::new(Paren {
1584            this: make_int(10),
1585            trailing_comments: vec![],
1586        }));
1587        let expr = Expression::Paren(Box::new(Paren {
1588            this: inner_paren,
1589            trailing_comments: vec![],
1590        }));
1591        let result = simplifier.simplify(expr);
1592        assert_eq!(get_number(&result), Some(10.0));
1593    }
1594
1595    #[test]
1596    fn test_simplify_equality_solve() {
1597        let mut simplifier = Simplifier::new(None);
1598
1599        // x + 1 = 3 -> x = 2
1600        let x = make_column("x");
1601        let x_plus_1 = Expression::Add(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1602        let expr = Expression::Eq(Box::new(BinaryOp::new(x_plus_1, make_int(3))));
1603        let result = simplifier.simplify(expr);
1604        // Result should be x = 2
1605        if let Expression::Eq(op) = &result {
1606            assert_eq!(gen(&op.left), "x");
1607            assert_eq!(get_number(&op.right), Some(2.0));
1608        } else {
1609            panic!("Expected Eq expression");
1610        }
1611
1612        // x - 1 = 3 -> x = 4
1613        let x_minus_1 = Expression::Sub(Box::new(BinaryOp::new(x.clone(), make_int(1))));
1614        let expr = Expression::Eq(Box::new(BinaryOp::new(x_minus_1, make_int(3))));
1615        let result = simplifier.simplify(expr);
1616        if let Expression::Eq(op) = &result {
1617            assert_eq!(gen(&op.left), "x");
1618            assert_eq!(get_number(&op.right), Some(4.0));
1619        } else {
1620            panic!("Expected Eq expression");
1621        }
1622
1623        // x * 2 = 6 -> x = 3
1624        let x_times_2 = Expression::Mul(Box::new(BinaryOp::new(x.clone(), make_int(2))));
1625        let expr = Expression::Eq(Box::new(BinaryOp::new(x_times_2, make_int(6))));
1626        let result = simplifier.simplify(expr);
1627        if let Expression::Eq(op) = &result {
1628            assert_eq!(gen(&op.left), "x");
1629            assert_eq!(get_number(&op.right), Some(3.0));
1630        } else {
1631            panic!("Expected Eq expression");
1632        }
1633
1634        // 1 + x = 3 -> x = 2 (commutative)
1635        let one_plus_x = Expression::Add(Box::new(BinaryOp::new(make_int(1), x.clone())));
1636        let expr = Expression::Eq(Box::new(BinaryOp::new(one_plus_x, make_int(3))));
1637        let result = simplifier.simplify(expr);
1638        if let Expression::Eq(op) = &result {
1639            assert_eq!(gen(&op.left), "x");
1640            assert_eq!(get_number(&op.right), Some(2.0));
1641        } else {
1642            panic!("Expected Eq expression");
1643        }
1644    }
1645
1646    #[test]
1647    fn test_simplify_datetrunc() {
1648        use crate::expressions::DateTimeField;
1649        let mut simplifier = Simplifier::new(None);
1650
1651        // DATE_TRUNC('day', x) with a column just passes through with simplified children
1652        let x = make_column("x");
1653        let expr = Expression::DateTrunc(Box::new(DateTruncFunc {
1654            this: x.clone(),
1655            unit: DateTimeField::Day,
1656        }));
1657        let result = simplifier.simplify(expr);
1658        if let Expression::DateTrunc(dt) = &result {
1659            assert_eq!(gen(&dt.this), "x");
1660            assert_eq!(dt.unit, DateTimeField::Day);
1661        } else {
1662            panic!("Expected DateTrunc expression");
1663        }
1664    }
1665}