Skip to main content

chryso_optimizer/
expr_rewrite.rs

1use chryso_core::ast::{BinaryOperator, Expr, Literal, OrderByExpr, UnaryOperator};
2use chryso_planner::LogicalPlan;
3
4pub fn rewrite_plan(plan: &LogicalPlan) -> LogicalPlan {
5    match plan {
6        LogicalPlan::Scan { table } => LogicalPlan::Scan {
7            table: table.clone(),
8        },
9        LogicalPlan::IndexScan {
10            table,
11            index,
12            predicate,
13        } => LogicalPlan::IndexScan {
14            table: table.clone(),
15            index: index.clone(),
16            predicate: rewrite_expr(predicate),
17        },
18        LogicalPlan::Dml { sql } => LogicalPlan::Dml { sql: sql.clone() },
19        LogicalPlan::Derived {
20            input,
21            alias,
22            column_aliases,
23        } => LogicalPlan::Derived {
24            input: Box::new(rewrite_plan(input.as_ref())),
25            alias: alias.clone(),
26            column_aliases: column_aliases.clone(),
27        },
28        LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
29            predicate: rewrite_expr(predicate),
30            input: Box::new(rewrite_plan(input.as_ref())),
31        },
32        LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
33            exprs: exprs.iter().map(rewrite_expr).collect(),
34            input: Box::new(rewrite_plan(input.as_ref())),
35        },
36        LogicalPlan::Join {
37            join_type,
38            left,
39            right,
40            on,
41        } => LogicalPlan::Join {
42            join_type: *join_type,
43            left: Box::new(rewrite_plan(left.as_ref())),
44            right: Box::new(rewrite_plan(right.as_ref())),
45            on: rewrite_expr(on),
46        },
47        LogicalPlan::Aggregate {
48            group_exprs,
49            aggr_exprs,
50            input,
51        } => LogicalPlan::Aggregate {
52            group_exprs: group_exprs.iter().map(rewrite_expr).collect(),
53            aggr_exprs: aggr_exprs.iter().map(rewrite_expr).collect(),
54            input: Box::new(rewrite_plan(input.as_ref())),
55        },
56        LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
57            input: Box::new(rewrite_plan(input.as_ref())),
58        },
59        LogicalPlan::TopN {
60            order_by,
61            limit,
62            input,
63        } => LogicalPlan::TopN {
64            order_by: rewrite_order_by(order_by),
65            limit: *limit,
66            input: Box::new(rewrite_plan(input.as_ref())),
67        },
68        LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
69            order_by: rewrite_order_by(order_by),
70            input: Box::new(rewrite_plan(input.as_ref())),
71        },
72        LogicalPlan::Limit {
73            limit,
74            offset,
75            input,
76        } => LogicalPlan::Limit {
77            limit: *limit,
78            offset: *offset,
79            input: Box::new(rewrite_plan(input.as_ref())),
80        },
81    }
82}
83
84pub fn rewrite_expr(expr: &Expr) -> Expr {
85    match expr {
86        Expr::Identifier(name) => Expr::Identifier(name.clone()),
87        Expr::Literal(Literal::String(value)) => Expr::Literal(Literal::String(value.clone())),
88        Expr::Literal(Literal::Number(value)) => Expr::Literal(Literal::Number(*value)),
89        Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(*value)),
90        Expr::UnaryOp { op, expr } => {
91            let inner = rewrite_expr(expr);
92            match (op, inner) {
93                (UnaryOperator::Neg, Expr::Literal(Literal::Number(value))) => {
94                    Expr::Literal(Literal::Number(-value))
95                }
96                (UnaryOperator::Not, Expr::Literal(Literal::Bool(value))) => {
97                    Expr::Literal(Literal::Bool(!value))
98                }
99                (
100                    UnaryOperator::Not,
101                    Expr::UnaryOp {
102                        op: UnaryOperator::Not,
103                        expr,
104                    },
105                ) => *expr,
106                (UnaryOperator::Not, Expr::IsNull { expr, negated }) => Expr::IsNull {
107                    expr,
108                    negated: !negated,
109                },
110                (UnaryOperator::Not, Expr::BinaryOp { left, op, right }) => match op {
111                    BinaryOperator::And => Expr::BinaryOp {
112                        left: Box::new(negate_expr(*left)),
113                        op: BinaryOperator::Or,
114                        right: Box::new(negate_expr(*right)),
115                    },
116                    BinaryOperator::Or => Expr::BinaryOp {
117                        left: Box::new(negate_expr(*left)),
118                        op: BinaryOperator::And,
119                        right: Box::new(negate_expr(*right)),
120                    },
121                    _ => Expr::UnaryOp {
122                        op: UnaryOperator::Not,
123                        expr: Box::new(Expr::BinaryOp { left, op, right }),
124                    },
125                },
126                (op, inner) => Expr::UnaryOp {
127                    op: *op,
128                    expr: Box::new(inner),
129                },
130            }
131        }
132        Expr::BinaryOp { left, op, right } => {
133            let left = rewrite_expr(left);
134            let right = rewrite_expr(right);
135            rewrite_binary(left, *op, right)
136        }
137        Expr::IsNull { expr, negated } => Expr::IsNull {
138            expr: Box::new(rewrite_expr(expr)),
139            negated: *negated,
140        },
141        Expr::FunctionCall { name, args } => Expr::FunctionCall {
142            name: name.clone(),
143            args: args.iter().map(rewrite_expr).collect(),
144        },
145        Expr::WindowFunction { function, spec } => Expr::WindowFunction {
146            function: Box::new(rewrite_expr(function)),
147            spec: chryso_core::ast::WindowSpec {
148                partition_by: spec.partition_by.iter().map(rewrite_expr).collect(),
149                order_by: rewrite_order_by(&spec.order_by),
150                frame: spec.frame.clone(),
151            },
152        },
153        Expr::Subquery(select) => Expr::Subquery(select.clone()),
154        Expr::Exists(select) => Expr::Exists(select.clone()),
155        Expr::InSubquery { expr, subquery } => Expr::InSubquery {
156            expr: Box::new(rewrite_expr(expr)),
157            subquery: subquery.clone(),
158        },
159        Expr::Case {
160            operand,
161            when_then,
162            else_expr,
163        } => Expr::Case {
164            operand: operand.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
165            when_then: when_then
166                .iter()
167                .map(|(when_expr, then_expr)| (rewrite_expr(when_expr), rewrite_expr(then_expr)))
168                .collect(),
169            else_expr: else_expr.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
170        },
171        Expr::Wildcard => Expr::Wildcard,
172    }
173}
174
175fn rewrite_binary(left: Expr, op: BinaryOperator, right: Expr) -> Expr {
176    if let Some(expr) = fold_bool_binary(&left, op, &right) {
177        return expr;
178    }
179    if let Some(expr) = fold_comparison(&left, op, &right) {
180        return expr;
181    }
182    if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
183        return left;
184    }
185    match (op, &left, &right) {
186        (BinaryOperator::Add, Expr::Literal(Literal::Number(0.0)), _) => right,
187        (BinaryOperator::Add, _, Expr::Literal(Literal::Number(0.0))) => left,
188        (BinaryOperator::Sub, _, Expr::Literal(Literal::Number(0.0))) => left,
189        (BinaryOperator::Mul, Expr::Literal(Literal::Number(1.0)), _) => right,
190        (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(1.0))) => left,
191        (BinaryOperator::Mul, Expr::Literal(Literal::Number(0.0)), _) => {
192            Expr::Literal(Literal::Number(0.0))
193        }
194        (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(0.0))) => {
195            Expr::Literal(Literal::Number(0.0))
196        }
197        (BinaryOperator::Div, _, Expr::Literal(Literal::Number(1.0))) => left,
198        (
199            BinaryOperator::Add,
200            Expr::Literal(Literal::Number(a)),
201            Expr::Literal(Literal::Number(b)),
202        ) => Expr::Literal(Literal::Number(a + b)),
203        (
204            BinaryOperator::Sub,
205            Expr::Literal(Literal::Number(a)),
206            Expr::Literal(Literal::Number(b)),
207        ) => Expr::Literal(Literal::Number(a - b)),
208        (
209            BinaryOperator::Mul,
210            Expr::Literal(Literal::Number(a)),
211            Expr::Literal(Literal::Number(b)),
212        ) => Expr::Literal(Literal::Number(a * b)),
213        (
214            BinaryOperator::Div,
215            Expr::Literal(Literal::Number(a)),
216            Expr::Literal(Literal::Number(b)),
217        ) => {
218            if *b == 0.0 {
219                Expr::BinaryOp {
220                    left: Box::new(left),
221                    op,
222                    right: Box::new(right),
223                }
224            } else {
225                Expr::Literal(Literal::Number(a / b))
226            }
227        }
228        _ => Expr::BinaryOp {
229            left: Box::new(left),
230            op,
231            right: Box::new(right),
232        },
233    }
234}
235
236fn negate_expr(expr: Expr) -> Expr {
237    rewrite_expr(&Expr::UnaryOp {
238        op: UnaryOperator::Not,
239        expr: Box::new(expr),
240    })
241}
242
243fn fold_bool_binary(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
244    let left_bool = match left {
245        Expr::Literal(Literal::Bool(value)) => Some(*value),
246        _ => None,
247    };
248    let right_bool = match right {
249        Expr::Literal(Literal::Bool(value)) => Some(*value),
250        _ => None,
251    };
252    match op {
253        BinaryOperator::And => match (left_bool, right_bool) {
254            (Some(true), _) => Some(right.clone()),
255            (Some(false), _) => Some(Expr::Literal(Literal::Bool(false))),
256            (_, Some(true)) => Some(left.clone()),
257            (_, Some(false)) => Some(Expr::Literal(Literal::Bool(false))),
258            _ => None,
259        },
260        BinaryOperator::Or => match (left_bool, right_bool) {
261            (Some(true), _) => Some(Expr::Literal(Literal::Bool(true))),
262            (Some(false), _) => Some(right.clone()),
263            (_, Some(true)) => Some(Expr::Literal(Literal::Bool(true))),
264            (_, Some(false)) => Some(left.clone()),
265            _ => None,
266        },
267        _ => None,
268    }
269}
270
271fn fold_comparison(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
272    match (left, right) {
273        (Expr::Literal(Literal::Number(left)), Expr::Literal(Literal::Number(right))) => {
274            let result = match op {
275                BinaryOperator::Eq => Some(left == right),
276                BinaryOperator::NotEq => Some(left != right),
277                BinaryOperator::Lt => Some(left < right),
278                BinaryOperator::LtEq => Some(left <= right),
279                BinaryOperator::Gt => Some(left > right),
280                BinaryOperator::GtEq => Some(left >= right),
281                _ => None,
282            };
283            result.map(|value| Expr::Literal(Literal::Bool(value)))
284        }
285        (Expr::Literal(Literal::Bool(left)), Expr::Literal(Literal::Bool(right))) => {
286            let result = match op {
287                BinaryOperator::Eq => Some(left == right),
288                BinaryOperator::NotEq => Some(left != right),
289                _ => None,
290            };
291            result.map(|value| Expr::Literal(Literal::Bool(value)))
292        }
293        _ => None,
294    }
295}
296
297fn rewrite_order_by(order_by: &[OrderByExpr]) -> Vec<OrderByExpr> {
298    order_by
299        .iter()
300        .map(|item| OrderByExpr {
301            expr: rewrite_expr(&item.expr),
302            asc: item.asc,
303            nulls_first: item.nulls_first,
304        })
305        .collect()
306}
307
308#[cfg(test)]
309mod tests {
310    use super::{rewrite_expr, rewrite_plan};
311    use chryso_core::ast::{BinaryOperator, Expr, Literal};
312    use chryso_planner::LogicalPlan;
313
314    #[test]
315    fn folds_numeric_arithmetic() {
316        let expr = Expr::BinaryOp {
317            left: Box::new(Expr::Literal(Literal::Number(1.0))),
318            op: BinaryOperator::Add,
319            right: Box::new(Expr::BinaryOp {
320                left: Box::new(Expr::Literal(Literal::Number(2.0))),
321                op: BinaryOperator::Mul,
322                right: Box::new(Expr::Literal(Literal::Number(3.0))),
323            }),
324        };
325        let rewritten = rewrite_expr(&expr);
326        match rewritten {
327            Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
328            other => panic!("expected folded literal, got {other:?}"),
329        }
330    }
331
332    #[test]
333    fn folds_boolean_logic() {
334        let expr = Expr::BinaryOp {
335            left: Box::new(Expr::Identifier("a".to_string())),
336            op: BinaryOperator::And,
337            right: Box::new(Expr::Literal(Literal::Bool(true))),
338        };
339        let rewritten = rewrite_expr(&expr);
340        match rewritten {
341            Expr::Identifier(name) => assert_eq!(name, "a"),
342            other => panic!("expected identifier, got {other:?}"),
343        }
344
345        let expr = Expr::BinaryOp {
346            left: Box::new(Expr::Identifier("a".to_string())),
347            op: BinaryOperator::Or,
348            right: Box::new(Expr::Literal(Literal::Bool(true))),
349        };
350        let rewritten = rewrite_expr(&expr);
351        match rewritten {
352            Expr::Literal(Literal::Bool(value)) => assert!(value),
353            other => panic!("expected literal true, got {other:?}"),
354        }
355    }
356
357    #[test]
358    fn folds_boolean_comparisons() {
359        let expr = Expr::BinaryOp {
360            left: Box::new(Expr::Literal(Literal::Bool(true))),
361            op: BinaryOperator::NotEq,
362            right: Box::new(Expr::Literal(Literal::Bool(false))),
363        };
364        let rewritten = rewrite_expr(&expr);
365        match rewritten {
366            Expr::Literal(Literal::Bool(value)) => assert!(value),
367            other => panic!("expected literal true, got {other:?}"),
368        }
369    }
370
371    #[test]
372    fn folds_numeric_comparisons() {
373        let expr = Expr::BinaryOp {
374            left: Box::new(Expr::Literal(Literal::Number(1.0))),
375            op: BinaryOperator::Lt,
376            right: Box::new(Expr::Literal(Literal::Number(2.0))),
377        };
378        let rewritten = rewrite_expr(&expr);
379        match rewritten {
380            Expr::Literal(Literal::Bool(value)) => assert!(value),
381            other => panic!("expected literal true, got {other:?}"),
382        }
383    }
384
385    #[test]
386    fn normalizes_not() {
387        let expr = Expr::UnaryOp {
388            op: chryso_core::ast::UnaryOperator::Not,
389            expr: Box::new(Expr::UnaryOp {
390                op: chryso_core::ast::UnaryOperator::Not,
391                expr: Box::new(Expr::Identifier("a".to_string())),
392            }),
393        };
394        let rewritten = rewrite_expr(&expr);
395        match rewritten {
396            Expr::Identifier(name) => assert_eq!(name, "a"),
397            other => panic!("expected identifier, got {other:?}"),
398        }
399    }
400
401    #[test]
402    fn applies_de_morgan() {
403        let expr = Expr::UnaryOp {
404            op: chryso_core::ast::UnaryOperator::Not,
405            expr: Box::new(Expr::BinaryOp {
406                left: Box::new(Expr::Identifier("a".to_string())),
407                op: BinaryOperator::And,
408                right: Box::new(Expr::Identifier("b".to_string())),
409            }),
410        };
411        let rewritten = rewrite_expr(&expr);
412        match rewritten {
413            Expr::BinaryOp {
414                op: BinaryOperator::Or,
415                left,
416                right,
417            } => match (*left, *right) {
418                (
419                    Expr::UnaryOp {
420                        op: chryso_core::ast::UnaryOperator::Not,
421                        ..
422                    },
423                    Expr::UnaryOp {
424                        op: chryso_core::ast::UnaryOperator::Not,
425                        ..
426                    },
427                ) => {}
428                other => panic!("expected negated operands, got {other:?}"),
429            },
430            other => panic!("expected OR, got {other:?}"),
431        }
432    }
433
434    #[test]
435    fn dedups_boolean_idempotence() {
436        let expr = Expr::BinaryOp {
437            left: Box::new(Expr::Identifier("a".to_string())),
438            op: BinaryOperator::And,
439            right: Box::new(Expr::Identifier("a".to_string())),
440        };
441        let rewritten = rewrite_expr(&expr);
442        match rewritten {
443            Expr::Identifier(name) => assert_eq!(name, "a"),
444            other => panic!("expected identifier, got {other:?}"),
445        }
446    }
447
448    #[test]
449    fn rewrites_filter_predicate() {
450        let plan = LogicalPlan::Filter {
451            predicate: Expr::BinaryOp {
452                left: Box::new(Expr::Literal(Literal::Number(10.0))),
453                op: BinaryOperator::Sub,
454                right: Box::new(Expr::Literal(Literal::Number(3.0))),
455            },
456            input: Box::new(LogicalPlan::Scan {
457                table: "t".to_string(),
458            }),
459        };
460        let rewritten = rewrite_plan(&plan);
461        match rewritten {
462            LogicalPlan::Filter { predicate, .. } => match predicate {
463                Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
464                other => panic!("expected folded literal, got {other:?}"),
465            },
466            other => panic!("unexpected plan: {other:?}"),
467        }
468    }
469}