1use chryso_core::ast::{BinaryOperator, Expr, Literal, OrderByExpr, UnaryOperator};
2use chryso_planner::LogicalPlan;
3
4pub fn rewrite_plan(plan: &LogicalPlan) -> LogicalPlan {
5 match plan {
6 LogicalPlan::Scan { table } => LogicalPlan::Scan {
7 table: table.clone(),
8 },
9 LogicalPlan::IndexScan {
10 table,
11 index,
12 predicate,
13 } => LogicalPlan::IndexScan {
14 table: table.clone(),
15 index: index.clone(),
16 predicate: rewrite_expr(predicate),
17 },
18 LogicalPlan::Dml { sql } => LogicalPlan::Dml { sql: sql.clone() },
19 LogicalPlan::Derived {
20 input,
21 alias,
22 column_aliases,
23 } => LogicalPlan::Derived {
24 input: Box::new(rewrite_plan(input.as_ref())),
25 alias: alias.clone(),
26 column_aliases: column_aliases.clone(),
27 },
28 LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
29 predicate: rewrite_expr(predicate),
30 input: Box::new(rewrite_plan(input.as_ref())),
31 },
32 LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
33 exprs: exprs.iter().map(rewrite_expr).collect(),
34 input: Box::new(rewrite_plan(input.as_ref())),
35 },
36 LogicalPlan::Join {
37 join_type,
38 left,
39 right,
40 on,
41 } => LogicalPlan::Join {
42 join_type: *join_type,
43 left: Box::new(rewrite_plan(left.as_ref())),
44 right: Box::new(rewrite_plan(right.as_ref())),
45 on: rewrite_expr(on),
46 },
47 LogicalPlan::Aggregate {
48 group_exprs,
49 aggr_exprs,
50 input,
51 } => LogicalPlan::Aggregate {
52 group_exprs: group_exprs.iter().map(rewrite_expr).collect(),
53 aggr_exprs: aggr_exprs.iter().map(rewrite_expr).collect(),
54 input: Box::new(rewrite_plan(input.as_ref())),
55 },
56 LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
57 input: Box::new(rewrite_plan(input.as_ref())),
58 },
59 LogicalPlan::TopN {
60 order_by,
61 limit,
62 input,
63 } => LogicalPlan::TopN {
64 order_by: rewrite_order_by(order_by),
65 limit: *limit,
66 input: Box::new(rewrite_plan(input.as_ref())),
67 },
68 LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
69 order_by: rewrite_order_by(order_by),
70 input: Box::new(rewrite_plan(input.as_ref())),
71 },
72 LogicalPlan::Limit {
73 limit,
74 offset,
75 input,
76 } => LogicalPlan::Limit {
77 limit: *limit,
78 offset: *offset,
79 input: Box::new(rewrite_plan(input.as_ref())),
80 },
81 }
82}
83
84pub fn rewrite_expr(expr: &Expr) -> Expr {
85 match expr {
86 Expr::Identifier(name) => Expr::Identifier(name.clone()),
87 Expr::Literal(Literal::String(value)) => Expr::Literal(Literal::String(value.clone())),
88 Expr::Literal(Literal::Number(value)) => Expr::Literal(Literal::Number(*value)),
89 Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(*value)),
90 Expr::UnaryOp { op, expr } => {
91 let inner = rewrite_expr(expr);
92 match (op, inner) {
93 (UnaryOperator::Neg, Expr::Literal(Literal::Number(value))) => {
94 Expr::Literal(Literal::Number(-value))
95 }
96 (UnaryOperator::Not, Expr::Literal(Literal::Bool(value))) => {
97 Expr::Literal(Literal::Bool(!value))
98 }
99 (
100 UnaryOperator::Not,
101 Expr::UnaryOp {
102 op: UnaryOperator::Not,
103 expr,
104 },
105 ) => *expr,
106 (UnaryOperator::Not, Expr::IsNull { expr, negated }) => Expr::IsNull {
107 expr,
108 negated: !negated,
109 },
110 (UnaryOperator::Not, Expr::BinaryOp { left, op, right }) => match op {
111 BinaryOperator::And => Expr::BinaryOp {
112 left: Box::new(negate_expr(*left)),
113 op: BinaryOperator::Or,
114 right: Box::new(negate_expr(*right)),
115 },
116 BinaryOperator::Or => Expr::BinaryOp {
117 left: Box::new(negate_expr(*left)),
118 op: BinaryOperator::And,
119 right: Box::new(negate_expr(*right)),
120 },
121 _ => Expr::UnaryOp {
122 op: UnaryOperator::Not,
123 expr: Box::new(Expr::BinaryOp { left, op, right }),
124 },
125 },
126 (op, inner) => Expr::UnaryOp {
127 op: *op,
128 expr: Box::new(inner),
129 },
130 }
131 }
132 Expr::BinaryOp { left, op, right } => {
133 let left = rewrite_expr(left);
134 let right = rewrite_expr(right);
135 rewrite_binary(left, *op, right)
136 }
137 Expr::IsNull { expr, negated } => Expr::IsNull {
138 expr: Box::new(rewrite_expr(expr)),
139 negated: *negated,
140 },
141 Expr::FunctionCall { name, args } => Expr::FunctionCall {
142 name: name.clone(),
143 args: args.iter().map(rewrite_expr).collect(),
144 },
145 Expr::WindowFunction { function, spec } => Expr::WindowFunction {
146 function: Box::new(rewrite_expr(function)),
147 spec: chryso_core::ast::WindowSpec {
148 partition_by: spec.partition_by.iter().map(rewrite_expr).collect(),
149 order_by: rewrite_order_by(&spec.order_by),
150 frame: spec.frame.clone(),
151 },
152 },
153 Expr::Subquery(select) => Expr::Subquery(select.clone()),
154 Expr::Exists(select) => Expr::Exists(select.clone()),
155 Expr::InSubquery { expr, subquery } => Expr::InSubquery {
156 expr: Box::new(rewrite_expr(expr)),
157 subquery: subquery.clone(),
158 },
159 Expr::Case {
160 operand,
161 when_then,
162 else_expr,
163 } => Expr::Case {
164 operand: operand.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
165 when_then: when_then
166 .iter()
167 .map(|(when_expr, then_expr)| (rewrite_expr(when_expr), rewrite_expr(then_expr)))
168 .collect(),
169 else_expr: else_expr.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
170 },
171 Expr::Wildcard => Expr::Wildcard,
172 }
173}
174
175fn rewrite_binary(left: Expr, op: BinaryOperator, right: Expr) -> Expr {
176 if let Some(expr) = fold_bool_binary(&left, op, &right) {
177 return expr;
178 }
179 if let Some(expr) = fold_comparison(&left, op, &right) {
180 return expr;
181 }
182 if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
183 return left;
184 }
185 match (op, &left, &right) {
186 (BinaryOperator::Add, Expr::Literal(Literal::Number(0.0)), _) => right,
187 (BinaryOperator::Add, _, Expr::Literal(Literal::Number(0.0))) => left,
188 (BinaryOperator::Sub, _, Expr::Literal(Literal::Number(0.0))) => left,
189 (BinaryOperator::Mul, Expr::Literal(Literal::Number(1.0)), _) => right,
190 (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(1.0))) => left,
191 (BinaryOperator::Mul, Expr::Literal(Literal::Number(0.0)), _) => {
192 Expr::Literal(Literal::Number(0.0))
193 }
194 (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(0.0))) => {
195 Expr::Literal(Literal::Number(0.0))
196 }
197 (BinaryOperator::Div, _, Expr::Literal(Literal::Number(1.0))) => left,
198 (
199 BinaryOperator::Add,
200 Expr::Literal(Literal::Number(a)),
201 Expr::Literal(Literal::Number(b)),
202 ) => Expr::Literal(Literal::Number(a + b)),
203 (
204 BinaryOperator::Sub,
205 Expr::Literal(Literal::Number(a)),
206 Expr::Literal(Literal::Number(b)),
207 ) => Expr::Literal(Literal::Number(a - b)),
208 (
209 BinaryOperator::Mul,
210 Expr::Literal(Literal::Number(a)),
211 Expr::Literal(Literal::Number(b)),
212 ) => Expr::Literal(Literal::Number(a * b)),
213 (
214 BinaryOperator::Div,
215 Expr::Literal(Literal::Number(a)),
216 Expr::Literal(Literal::Number(b)),
217 ) => {
218 if *b == 0.0 {
219 Expr::BinaryOp {
220 left: Box::new(left),
221 op,
222 right: Box::new(right),
223 }
224 } else {
225 Expr::Literal(Literal::Number(a / b))
226 }
227 }
228 _ => Expr::BinaryOp {
229 left: Box::new(left),
230 op,
231 right: Box::new(right),
232 },
233 }
234}
235
236fn negate_expr(expr: Expr) -> Expr {
237 rewrite_expr(&Expr::UnaryOp {
238 op: UnaryOperator::Not,
239 expr: Box::new(expr),
240 })
241}
242
243fn fold_bool_binary(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
244 let left_bool = match left {
245 Expr::Literal(Literal::Bool(value)) => Some(*value),
246 _ => None,
247 };
248 let right_bool = match right {
249 Expr::Literal(Literal::Bool(value)) => Some(*value),
250 _ => None,
251 };
252 match op {
253 BinaryOperator::And => match (left_bool, right_bool) {
254 (Some(true), _) => Some(right.clone()),
255 (Some(false), _) => Some(Expr::Literal(Literal::Bool(false))),
256 (_, Some(true)) => Some(left.clone()),
257 (_, Some(false)) => Some(Expr::Literal(Literal::Bool(false))),
258 _ => None,
259 },
260 BinaryOperator::Or => match (left_bool, right_bool) {
261 (Some(true), _) => Some(Expr::Literal(Literal::Bool(true))),
262 (Some(false), _) => Some(right.clone()),
263 (_, Some(true)) => Some(Expr::Literal(Literal::Bool(true))),
264 (_, Some(false)) => Some(left.clone()),
265 _ => None,
266 },
267 _ => None,
268 }
269}
270
271fn fold_comparison(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
272 match (left, right) {
273 (Expr::Literal(Literal::Number(left)), Expr::Literal(Literal::Number(right))) => {
274 let result = match op {
275 BinaryOperator::Eq => Some(left == right),
276 BinaryOperator::NotEq => Some(left != right),
277 BinaryOperator::Lt => Some(left < right),
278 BinaryOperator::LtEq => Some(left <= right),
279 BinaryOperator::Gt => Some(left > right),
280 BinaryOperator::GtEq => Some(left >= right),
281 _ => None,
282 };
283 result.map(|value| Expr::Literal(Literal::Bool(value)))
284 }
285 (Expr::Literal(Literal::Bool(left)), Expr::Literal(Literal::Bool(right))) => {
286 let result = match op {
287 BinaryOperator::Eq => Some(left == right),
288 BinaryOperator::NotEq => Some(left != right),
289 _ => None,
290 };
291 result.map(|value| Expr::Literal(Literal::Bool(value)))
292 }
293 _ => None,
294 }
295}
296
297fn rewrite_order_by(order_by: &[OrderByExpr]) -> Vec<OrderByExpr> {
298 order_by
299 .iter()
300 .map(|item| OrderByExpr {
301 expr: rewrite_expr(&item.expr),
302 asc: item.asc,
303 nulls_first: item.nulls_first,
304 })
305 .collect()
306}
307
308#[cfg(test)]
309mod tests {
310 use super::{rewrite_expr, rewrite_plan};
311 use chryso_core::ast::{BinaryOperator, Expr, Literal};
312 use chryso_planner::LogicalPlan;
313
314 #[test]
315 fn folds_numeric_arithmetic() {
316 let expr = Expr::BinaryOp {
317 left: Box::new(Expr::Literal(Literal::Number(1.0))),
318 op: BinaryOperator::Add,
319 right: Box::new(Expr::BinaryOp {
320 left: Box::new(Expr::Literal(Literal::Number(2.0))),
321 op: BinaryOperator::Mul,
322 right: Box::new(Expr::Literal(Literal::Number(3.0))),
323 }),
324 };
325 let rewritten = rewrite_expr(&expr);
326 match rewritten {
327 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
328 other => panic!("expected folded literal, got {other:?}"),
329 }
330 }
331
332 #[test]
333 fn folds_boolean_logic() {
334 let expr = Expr::BinaryOp {
335 left: Box::new(Expr::Identifier("a".to_string())),
336 op: BinaryOperator::And,
337 right: Box::new(Expr::Literal(Literal::Bool(true))),
338 };
339 let rewritten = rewrite_expr(&expr);
340 match rewritten {
341 Expr::Identifier(name) => assert_eq!(name, "a"),
342 other => panic!("expected identifier, got {other:?}"),
343 }
344
345 let expr = Expr::BinaryOp {
346 left: Box::new(Expr::Identifier("a".to_string())),
347 op: BinaryOperator::Or,
348 right: Box::new(Expr::Literal(Literal::Bool(true))),
349 };
350 let rewritten = rewrite_expr(&expr);
351 match rewritten {
352 Expr::Literal(Literal::Bool(value)) => assert!(value),
353 other => panic!("expected literal true, got {other:?}"),
354 }
355 }
356
357 #[test]
358 fn folds_boolean_comparisons() {
359 let expr = Expr::BinaryOp {
360 left: Box::new(Expr::Literal(Literal::Bool(true))),
361 op: BinaryOperator::NotEq,
362 right: Box::new(Expr::Literal(Literal::Bool(false))),
363 };
364 let rewritten = rewrite_expr(&expr);
365 match rewritten {
366 Expr::Literal(Literal::Bool(value)) => assert!(value),
367 other => panic!("expected literal true, got {other:?}"),
368 }
369 }
370
371 #[test]
372 fn folds_numeric_comparisons() {
373 let expr = Expr::BinaryOp {
374 left: Box::new(Expr::Literal(Literal::Number(1.0))),
375 op: BinaryOperator::Lt,
376 right: Box::new(Expr::Literal(Literal::Number(2.0))),
377 };
378 let rewritten = rewrite_expr(&expr);
379 match rewritten {
380 Expr::Literal(Literal::Bool(value)) => assert!(value),
381 other => panic!("expected literal true, got {other:?}"),
382 }
383 }
384
385 #[test]
386 fn normalizes_not() {
387 let expr = Expr::UnaryOp {
388 op: chryso_core::ast::UnaryOperator::Not,
389 expr: Box::new(Expr::UnaryOp {
390 op: chryso_core::ast::UnaryOperator::Not,
391 expr: Box::new(Expr::Identifier("a".to_string())),
392 }),
393 };
394 let rewritten = rewrite_expr(&expr);
395 match rewritten {
396 Expr::Identifier(name) => assert_eq!(name, "a"),
397 other => panic!("expected identifier, got {other:?}"),
398 }
399 }
400
401 #[test]
402 fn applies_de_morgan() {
403 let expr = Expr::UnaryOp {
404 op: chryso_core::ast::UnaryOperator::Not,
405 expr: Box::new(Expr::BinaryOp {
406 left: Box::new(Expr::Identifier("a".to_string())),
407 op: BinaryOperator::And,
408 right: Box::new(Expr::Identifier("b".to_string())),
409 }),
410 };
411 let rewritten = rewrite_expr(&expr);
412 match rewritten {
413 Expr::BinaryOp {
414 op: BinaryOperator::Or,
415 left,
416 right,
417 } => match (*left, *right) {
418 (
419 Expr::UnaryOp {
420 op: chryso_core::ast::UnaryOperator::Not,
421 ..
422 },
423 Expr::UnaryOp {
424 op: chryso_core::ast::UnaryOperator::Not,
425 ..
426 },
427 ) => {}
428 other => panic!("expected negated operands, got {other:?}"),
429 },
430 other => panic!("expected OR, got {other:?}"),
431 }
432 }
433
434 #[test]
435 fn dedups_boolean_idempotence() {
436 let expr = Expr::BinaryOp {
437 left: Box::new(Expr::Identifier("a".to_string())),
438 op: BinaryOperator::And,
439 right: Box::new(Expr::Identifier("a".to_string())),
440 };
441 let rewritten = rewrite_expr(&expr);
442 match rewritten {
443 Expr::Identifier(name) => assert_eq!(name, "a"),
444 other => panic!("expected identifier, got {other:?}"),
445 }
446 }
447
448 #[test]
449 fn rewrites_filter_predicate() {
450 let plan = LogicalPlan::Filter {
451 predicate: Expr::BinaryOp {
452 left: Box::new(Expr::Literal(Literal::Number(10.0))),
453 op: BinaryOperator::Sub,
454 right: Box::new(Expr::Literal(Literal::Number(3.0))),
455 },
456 input: Box::new(LogicalPlan::Scan {
457 table: "t".to_string(),
458 }),
459 };
460 let rewritten = rewrite_plan(&plan);
461 match rewritten {
462 LogicalPlan::Filter { predicate, .. } => match predicate {
463 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
464 other => panic!("expected folded literal, got {other:?}"),
465 },
466 other => panic!("unexpected plan: {other:?}"),
467 }
468 }
469}