chryso_optimizer/
expr_rewrite.rs

1use chryso_core::ast::{BinaryOperator, Expr, Literal, OrderByExpr, UnaryOperator};
2use chryso_planner::LogicalPlan;
3
4pub fn rewrite_plan(plan: &LogicalPlan) -> LogicalPlan {
5    match plan {
6        LogicalPlan::Scan { table } => LogicalPlan::Scan { table: table.clone() },
7        LogicalPlan::IndexScan {
8            table,
9            index,
10            predicate,
11        } => LogicalPlan::IndexScan {
12            table: table.clone(),
13            index: index.clone(),
14            predicate: rewrite_expr(predicate),
15        },
16        LogicalPlan::Dml { sql } => LogicalPlan::Dml { sql: sql.clone() },
17        LogicalPlan::Derived {
18            input,
19            alias,
20            column_aliases,
21        } => LogicalPlan::Derived {
22            input: Box::new(rewrite_plan(input.as_ref())),
23            alias: alias.clone(),
24            column_aliases: column_aliases.clone(),
25        },
26        LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
27            predicate: rewrite_expr(predicate),
28            input: Box::new(rewrite_plan(input.as_ref())),
29        },
30        LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
31            exprs: exprs.iter().map(rewrite_expr).collect(),
32            input: Box::new(rewrite_plan(input.as_ref())),
33        },
34        LogicalPlan::Join {
35            join_type,
36            left,
37            right,
38            on,
39        } => LogicalPlan::Join {
40            join_type: *join_type,
41            left: Box::new(rewrite_plan(left.as_ref())),
42            right: Box::new(rewrite_plan(right.as_ref())),
43            on: rewrite_expr(on),
44        },
45        LogicalPlan::Aggregate {
46            group_exprs,
47            aggr_exprs,
48            input,
49        } => LogicalPlan::Aggregate {
50            group_exprs: group_exprs.iter().map(rewrite_expr).collect(),
51            aggr_exprs: aggr_exprs.iter().map(rewrite_expr).collect(),
52            input: Box::new(rewrite_plan(input.as_ref())),
53        },
54        LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
55            input: Box::new(rewrite_plan(input.as_ref())),
56        },
57        LogicalPlan::TopN {
58            order_by,
59            limit,
60            input,
61        } => LogicalPlan::TopN {
62            order_by: rewrite_order_by(order_by),
63            limit: *limit,
64            input: Box::new(rewrite_plan(input.as_ref())),
65        },
66        LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
67            order_by: rewrite_order_by(order_by),
68            input: Box::new(rewrite_plan(input.as_ref())),
69        },
70        LogicalPlan::Limit {
71            limit,
72            offset,
73            input,
74        } => LogicalPlan::Limit {
75            limit: *limit,
76            offset: *offset,
77            input: Box::new(rewrite_plan(input.as_ref())),
78        },
79    }
80}
81
82pub fn rewrite_expr(expr: &Expr) -> Expr {
83    match expr {
84        Expr::Identifier(name) => Expr::Identifier(name.clone()),
85        Expr::Literal(Literal::String(value)) => Expr::Literal(Literal::String(value.clone())),
86        Expr::Literal(Literal::Number(value)) => Expr::Literal(Literal::Number(*value)),
87        Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(*value)),
88        Expr::UnaryOp { op, expr } => {
89            let inner = rewrite_expr(expr);
90            match (op, inner) {
91                (UnaryOperator::Neg, Expr::Literal(Literal::Number(value))) => {
92                    Expr::Literal(Literal::Number(-value))
93                }
94                (UnaryOperator::Not, Expr::Literal(Literal::Bool(value))) => {
95                    Expr::Literal(Literal::Bool(!value))
96                }
97                (UnaryOperator::Not, Expr::UnaryOp { op: UnaryOperator::Not, expr }) => *expr,
98                (UnaryOperator::Not, Expr::IsNull { expr, negated }) => Expr::IsNull {
99                    expr,
100                    negated: !negated,
101                },
102                (UnaryOperator::Not, Expr::BinaryOp { left, op, right }) => match op {
103                    BinaryOperator::And => Expr::BinaryOp {
104                        left: Box::new(negate_expr(*left)),
105                        op: BinaryOperator::Or,
106                        right: Box::new(negate_expr(*right)),
107                    },
108                    BinaryOperator::Or => Expr::BinaryOp {
109                        left: Box::new(negate_expr(*left)),
110                        op: BinaryOperator::And,
111                        right: Box::new(negate_expr(*right)),
112                    },
113                    _ => Expr::UnaryOp {
114                        op: UnaryOperator::Not,
115                        expr: Box::new(Expr::BinaryOp { left, op, right }),
116                    },
117                },
118                (op, inner) => Expr::UnaryOp {
119                    op: *op,
120                    expr: Box::new(inner),
121                },
122            }
123        }
124        Expr::BinaryOp { left, op, right } => {
125            let left = rewrite_expr(left);
126            let right = rewrite_expr(right);
127            rewrite_binary(left, *op, right)
128        }
129        Expr::IsNull { expr, negated } => Expr::IsNull {
130            expr: Box::new(rewrite_expr(expr)),
131            negated: *negated,
132        },
133        Expr::FunctionCall { name, args } => Expr::FunctionCall {
134            name: name.clone(),
135            args: args.iter().map(rewrite_expr).collect(),
136        },
137        Expr::WindowFunction { function, spec } => Expr::WindowFunction {
138            function: Box::new(rewrite_expr(function)),
139            spec: chryso_core::ast::WindowSpec {
140                partition_by: spec.partition_by.iter().map(rewrite_expr).collect(),
141                order_by: rewrite_order_by(&spec.order_by),
142            },
143        },
144        Expr::Subquery(select) => Expr::Subquery(select.clone()),
145        Expr::Exists(select) => Expr::Exists(select.clone()),
146        Expr::InSubquery { expr, subquery } => Expr::InSubquery {
147            expr: Box::new(rewrite_expr(expr)),
148            subquery: subquery.clone(),
149        },
150        Expr::Case {
151            operand,
152            when_then,
153            else_expr,
154        } => Expr::Case {
155            operand: operand.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
156            when_then: when_then
157                .iter()
158                .map(|(when_expr, then_expr)| (rewrite_expr(when_expr), rewrite_expr(then_expr)))
159                .collect(),
160            else_expr: else_expr.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
161        },
162        Expr::Wildcard => Expr::Wildcard,
163    }
164}
165
166fn rewrite_binary(left: Expr, op: BinaryOperator, right: Expr) -> Expr {
167    if let Some(expr) = fold_bool_binary(&left, op, &right) {
168        return expr;
169    }
170    if let Some(expr) = fold_comparison(&left, op, &right) {
171        return expr;
172    }
173    if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
174        return left;
175    }
176    match (op, &left, &right) {
177        (BinaryOperator::Add, Expr::Literal(Literal::Number(0.0)), _) => right,
178        (BinaryOperator::Add, _, Expr::Literal(Literal::Number(0.0))) => left,
179        (BinaryOperator::Sub, _, Expr::Literal(Literal::Number(0.0))) => left,
180        (BinaryOperator::Mul, Expr::Literal(Literal::Number(1.0)), _) => right,
181        (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(1.0))) => left,
182        (BinaryOperator::Mul, Expr::Literal(Literal::Number(0.0)), _) => {
183            Expr::Literal(Literal::Number(0.0))
184        }
185        (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(0.0))) => {
186            Expr::Literal(Literal::Number(0.0))
187        }
188        (BinaryOperator::Div, _, Expr::Literal(Literal::Number(1.0))) => left,
189        (BinaryOperator::Add, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
190            Expr::Literal(Literal::Number(a + b))
191        }
192        (BinaryOperator::Sub, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
193            Expr::Literal(Literal::Number(a - b))
194        }
195        (BinaryOperator::Mul, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
196            Expr::Literal(Literal::Number(a * b))
197        }
198        (BinaryOperator::Div, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
199            if *b == 0.0 {
200                Expr::BinaryOp {
201                    left: Box::new(left),
202                    op,
203                    right: Box::new(right),
204                }
205            } else {
206                Expr::Literal(Literal::Number(a / b))
207            }
208        }
209        _ => Expr::BinaryOp {
210            left: Box::new(left),
211            op,
212            right: Box::new(right),
213        },
214    }
215}
216
217fn negate_expr(expr: Expr) -> Expr {
218    rewrite_expr(&Expr::UnaryOp {
219        op: UnaryOperator::Not,
220        expr: Box::new(expr),
221    })
222}
223
224fn fold_bool_binary(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
225    let left_bool = match left {
226        Expr::Literal(Literal::Bool(value)) => Some(*value),
227        _ => None,
228    };
229    let right_bool = match right {
230        Expr::Literal(Literal::Bool(value)) => Some(*value),
231        _ => None,
232    };
233    match op {
234        BinaryOperator::And => match (left_bool, right_bool) {
235            (Some(true), _) => Some(right.clone()),
236            (Some(false), _) => Some(Expr::Literal(Literal::Bool(false))),
237            (_, Some(true)) => Some(left.clone()),
238            (_, Some(false)) => Some(Expr::Literal(Literal::Bool(false))),
239            _ => None,
240        },
241        BinaryOperator::Or => match (left_bool, right_bool) {
242            (Some(true), _) => Some(Expr::Literal(Literal::Bool(true))),
243            (Some(false), _) => Some(right.clone()),
244            (_, Some(true)) => Some(Expr::Literal(Literal::Bool(true))),
245            (_, Some(false)) => Some(left.clone()),
246            _ => None,
247        },
248        _ => None,
249    }
250}
251
252fn fold_comparison(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
253    match (left, right) {
254        (Expr::Literal(Literal::Number(left)), Expr::Literal(Literal::Number(right))) => {
255            let result = match op {
256                BinaryOperator::Eq => Some(left == right),
257                BinaryOperator::NotEq => Some(left != right),
258                BinaryOperator::Lt => Some(left < right),
259                BinaryOperator::LtEq => Some(left <= right),
260                BinaryOperator::Gt => Some(left > right),
261                BinaryOperator::GtEq => Some(left >= right),
262                _ => None,
263            };
264            result.map(|value| Expr::Literal(Literal::Bool(value)))
265        }
266        (Expr::Literal(Literal::Bool(left)), Expr::Literal(Literal::Bool(right))) => {
267            let result = match op {
268                BinaryOperator::Eq => Some(left == right),
269                BinaryOperator::NotEq => Some(left != right),
270                _ => None,
271            };
272            result.map(|value| Expr::Literal(Literal::Bool(value)))
273        }
274        _ => None,
275    }
276}
277
278fn rewrite_order_by(order_by: &[OrderByExpr]) -> Vec<OrderByExpr> {
279    order_by
280        .iter()
281        .map(|item| OrderByExpr {
282            expr: rewrite_expr(&item.expr),
283            asc: item.asc,
284            nulls_first: item.nulls_first,
285        })
286        .collect()
287}
288
289#[cfg(test)]
290mod tests {
291    use super::{rewrite_expr, rewrite_plan};
292    use chryso_core::ast::{BinaryOperator, Expr, Literal};
293    use chryso_planner::LogicalPlan;
294
295    #[test]
296    fn folds_numeric_arithmetic() {
297        let expr = Expr::BinaryOp {
298            left: Box::new(Expr::Literal(Literal::Number(1.0))),
299            op: BinaryOperator::Add,
300            right: Box::new(Expr::BinaryOp {
301                left: Box::new(Expr::Literal(Literal::Number(2.0))),
302                op: BinaryOperator::Mul,
303                right: Box::new(Expr::Literal(Literal::Number(3.0))),
304            }),
305        };
306        let rewritten = rewrite_expr(&expr);
307        match rewritten {
308            Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
309            other => panic!("expected folded literal, got {other:?}"),
310        }
311    }
312
313    #[test]
314    fn folds_boolean_logic() {
315        let expr = Expr::BinaryOp {
316            left: Box::new(Expr::Identifier("a".to_string())),
317            op: BinaryOperator::And,
318            right: Box::new(Expr::Literal(Literal::Bool(true))),
319        };
320        let rewritten = rewrite_expr(&expr);
321        match rewritten {
322            Expr::Identifier(name) => assert_eq!(name, "a"),
323            other => panic!("expected identifier, got {other:?}"),
324        }
325
326        let expr = Expr::BinaryOp {
327            left: Box::new(Expr::Identifier("a".to_string())),
328            op: BinaryOperator::Or,
329            right: Box::new(Expr::Literal(Literal::Bool(true))),
330        };
331        let rewritten = rewrite_expr(&expr);
332        match rewritten {
333            Expr::Literal(Literal::Bool(value)) => assert!(value),
334            other => panic!("expected literal true, got {other:?}"),
335        }
336    }
337
338    #[test]
339    fn folds_boolean_comparisons() {
340        let expr = Expr::BinaryOp {
341            left: Box::new(Expr::Literal(Literal::Bool(true))),
342            op: BinaryOperator::NotEq,
343            right: Box::new(Expr::Literal(Literal::Bool(false))),
344        };
345        let rewritten = rewrite_expr(&expr);
346        match rewritten {
347            Expr::Literal(Literal::Bool(value)) => assert!(value),
348            other => panic!("expected literal true, got {other:?}"),
349        }
350    }
351
352    #[test]
353    fn folds_numeric_comparisons() {
354        let expr = Expr::BinaryOp {
355            left: Box::new(Expr::Literal(Literal::Number(1.0))),
356            op: BinaryOperator::Lt,
357            right: Box::new(Expr::Literal(Literal::Number(2.0))),
358        };
359        let rewritten = rewrite_expr(&expr);
360        match rewritten {
361            Expr::Literal(Literal::Bool(value)) => assert!(value),
362            other => panic!("expected literal true, got {other:?}"),
363        }
364    }
365
366    #[test]
367    fn normalizes_not() {
368        let expr = Expr::UnaryOp {
369            op: chryso_core::ast::UnaryOperator::Not,
370            expr: Box::new(Expr::UnaryOp {
371                op: chryso_core::ast::UnaryOperator::Not,
372                expr: Box::new(Expr::Identifier("a".to_string())),
373            }),
374        };
375        let rewritten = rewrite_expr(&expr);
376        match rewritten {
377            Expr::Identifier(name) => assert_eq!(name, "a"),
378            other => panic!("expected identifier, got {other:?}"),
379        }
380    }
381
382    #[test]
383    fn applies_de_morgan() {
384        let expr = Expr::UnaryOp {
385            op: chryso_core::ast::UnaryOperator::Not,
386            expr: Box::new(Expr::BinaryOp {
387                left: Box::new(Expr::Identifier("a".to_string())),
388                op: BinaryOperator::And,
389                right: Box::new(Expr::Identifier("b".to_string())),
390            }),
391        };
392        let rewritten = rewrite_expr(&expr);
393        match rewritten {
394            Expr::BinaryOp { op: BinaryOperator::Or, left, right } => {
395                match (*left, *right) {
396                    (Expr::UnaryOp { op: chryso_core::ast::UnaryOperator::Not, .. },
397                     Expr::UnaryOp { op: chryso_core::ast::UnaryOperator::Not, .. }) => {}
398                    other => panic!("expected negated operands, got {other:?}"),
399                }
400            }
401            other => panic!("expected OR, got {other:?}"),
402        }
403    }
404
405    #[test]
406    fn dedups_boolean_idempotence() {
407        let expr = Expr::BinaryOp {
408            left: Box::new(Expr::Identifier("a".to_string())),
409            op: BinaryOperator::And,
410            right: Box::new(Expr::Identifier("a".to_string())),
411        };
412        let rewritten = rewrite_expr(&expr);
413        match rewritten {
414            Expr::Identifier(name) => assert_eq!(name, "a"),
415            other => panic!("expected identifier, got {other:?}"),
416        }
417    }
418
419    #[test]
420    fn rewrites_filter_predicate() {
421        let plan = LogicalPlan::Filter {
422            predicate: Expr::BinaryOp {
423                left: Box::new(Expr::Literal(Literal::Number(10.0))),
424                op: BinaryOperator::Sub,
425                right: Box::new(Expr::Literal(Literal::Number(3.0))),
426            },
427            input: Box::new(LogicalPlan::Scan {
428                table: "t".to_string(),
429            }),
430        };
431        let rewritten = rewrite_plan(&plan);
432        match rewritten {
433            LogicalPlan::Filter { predicate, .. } => match predicate {
434                Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
435                other => panic!("expected folded literal, got {other:?}"),
436            },
437            other => panic!("unexpected plan: {other:?}"),
438        }
439    }
440}