1use chryso_core::ast::{BinaryOperator, Expr, Literal, OrderByExpr, UnaryOperator};
2use chryso_planner::LogicalPlan;
3
4pub fn rewrite_plan(plan: &LogicalPlan) -> LogicalPlan {
5 match plan {
6 LogicalPlan::Scan { table } => LogicalPlan::Scan { table: table.clone() },
7 LogicalPlan::IndexScan {
8 table,
9 index,
10 predicate,
11 } => LogicalPlan::IndexScan {
12 table: table.clone(),
13 index: index.clone(),
14 predicate: rewrite_expr(predicate),
15 },
16 LogicalPlan::Dml { sql } => LogicalPlan::Dml { sql: sql.clone() },
17 LogicalPlan::Derived {
18 input,
19 alias,
20 column_aliases,
21 } => LogicalPlan::Derived {
22 input: Box::new(rewrite_plan(input.as_ref())),
23 alias: alias.clone(),
24 column_aliases: column_aliases.clone(),
25 },
26 LogicalPlan::Filter { predicate, input } => LogicalPlan::Filter {
27 predicate: rewrite_expr(predicate),
28 input: Box::new(rewrite_plan(input.as_ref())),
29 },
30 LogicalPlan::Projection { exprs, input } => LogicalPlan::Projection {
31 exprs: exprs.iter().map(rewrite_expr).collect(),
32 input: Box::new(rewrite_plan(input.as_ref())),
33 },
34 LogicalPlan::Join {
35 join_type,
36 left,
37 right,
38 on,
39 } => LogicalPlan::Join {
40 join_type: *join_type,
41 left: Box::new(rewrite_plan(left.as_ref())),
42 right: Box::new(rewrite_plan(right.as_ref())),
43 on: rewrite_expr(on),
44 },
45 LogicalPlan::Aggregate {
46 group_exprs,
47 aggr_exprs,
48 input,
49 } => LogicalPlan::Aggregate {
50 group_exprs: group_exprs.iter().map(rewrite_expr).collect(),
51 aggr_exprs: aggr_exprs.iter().map(rewrite_expr).collect(),
52 input: Box::new(rewrite_plan(input.as_ref())),
53 },
54 LogicalPlan::Distinct { input } => LogicalPlan::Distinct {
55 input: Box::new(rewrite_plan(input.as_ref())),
56 },
57 LogicalPlan::TopN {
58 order_by,
59 limit,
60 input,
61 } => LogicalPlan::TopN {
62 order_by: rewrite_order_by(order_by),
63 limit: *limit,
64 input: Box::new(rewrite_plan(input.as_ref())),
65 },
66 LogicalPlan::Sort { order_by, input } => LogicalPlan::Sort {
67 order_by: rewrite_order_by(order_by),
68 input: Box::new(rewrite_plan(input.as_ref())),
69 },
70 LogicalPlan::Limit {
71 limit,
72 offset,
73 input,
74 } => LogicalPlan::Limit {
75 limit: *limit,
76 offset: *offset,
77 input: Box::new(rewrite_plan(input.as_ref())),
78 },
79 }
80}
81
82pub fn rewrite_expr(expr: &Expr) -> Expr {
83 match expr {
84 Expr::Identifier(name) => Expr::Identifier(name.clone()),
85 Expr::Literal(Literal::String(value)) => Expr::Literal(Literal::String(value.clone())),
86 Expr::Literal(Literal::Number(value)) => Expr::Literal(Literal::Number(*value)),
87 Expr::Literal(Literal::Bool(value)) => Expr::Literal(Literal::Bool(*value)),
88 Expr::UnaryOp { op, expr } => {
89 let inner = rewrite_expr(expr);
90 match (op, inner) {
91 (UnaryOperator::Neg, Expr::Literal(Literal::Number(value))) => {
92 Expr::Literal(Literal::Number(-value))
93 }
94 (UnaryOperator::Not, Expr::Literal(Literal::Bool(value))) => {
95 Expr::Literal(Literal::Bool(!value))
96 }
97 (UnaryOperator::Not, Expr::UnaryOp { op: UnaryOperator::Not, expr }) => *expr,
98 (UnaryOperator::Not, Expr::IsNull { expr, negated }) => Expr::IsNull {
99 expr,
100 negated: !negated,
101 },
102 (UnaryOperator::Not, Expr::BinaryOp { left, op, right }) => match op {
103 BinaryOperator::And => Expr::BinaryOp {
104 left: Box::new(negate_expr(*left)),
105 op: BinaryOperator::Or,
106 right: Box::new(negate_expr(*right)),
107 },
108 BinaryOperator::Or => Expr::BinaryOp {
109 left: Box::new(negate_expr(*left)),
110 op: BinaryOperator::And,
111 right: Box::new(negate_expr(*right)),
112 },
113 _ => Expr::UnaryOp {
114 op: UnaryOperator::Not,
115 expr: Box::new(Expr::BinaryOp { left, op, right }),
116 },
117 },
118 (op, inner) => Expr::UnaryOp {
119 op: *op,
120 expr: Box::new(inner),
121 },
122 }
123 }
124 Expr::BinaryOp { left, op, right } => {
125 let left = rewrite_expr(left);
126 let right = rewrite_expr(right);
127 rewrite_binary(left, *op, right)
128 }
129 Expr::IsNull { expr, negated } => Expr::IsNull {
130 expr: Box::new(rewrite_expr(expr)),
131 negated: *negated,
132 },
133 Expr::FunctionCall { name, args } => Expr::FunctionCall {
134 name: name.clone(),
135 args: args.iter().map(rewrite_expr).collect(),
136 },
137 Expr::WindowFunction { function, spec } => Expr::WindowFunction {
138 function: Box::new(rewrite_expr(function)),
139 spec: chryso_core::ast::WindowSpec {
140 partition_by: spec.partition_by.iter().map(rewrite_expr).collect(),
141 order_by: rewrite_order_by(&spec.order_by),
142 },
143 },
144 Expr::Subquery(select) => Expr::Subquery(select.clone()),
145 Expr::Exists(select) => Expr::Exists(select.clone()),
146 Expr::InSubquery { expr, subquery } => Expr::InSubquery {
147 expr: Box::new(rewrite_expr(expr)),
148 subquery: subquery.clone(),
149 },
150 Expr::Case {
151 operand,
152 when_then,
153 else_expr,
154 } => Expr::Case {
155 operand: operand.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
156 when_then: when_then
157 .iter()
158 .map(|(when_expr, then_expr)| (rewrite_expr(when_expr), rewrite_expr(then_expr)))
159 .collect(),
160 else_expr: else_expr.as_ref().map(|expr| Box::new(rewrite_expr(expr))),
161 },
162 Expr::Wildcard => Expr::Wildcard,
163 }
164}
165
166fn rewrite_binary(left: Expr, op: BinaryOperator, right: Expr) -> Expr {
167 if let Some(expr) = fold_bool_binary(&left, op, &right) {
168 return expr;
169 }
170 if let Some(expr) = fold_comparison(&left, op, &right) {
171 return expr;
172 }
173 if matches!(op, BinaryOperator::And | BinaryOperator::Or) && left.structural_eq(&right) {
174 return left;
175 }
176 match (op, &left, &right) {
177 (BinaryOperator::Add, Expr::Literal(Literal::Number(0.0)), _) => right,
178 (BinaryOperator::Add, _, Expr::Literal(Literal::Number(0.0))) => left,
179 (BinaryOperator::Sub, _, Expr::Literal(Literal::Number(0.0))) => left,
180 (BinaryOperator::Mul, Expr::Literal(Literal::Number(1.0)), _) => right,
181 (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(1.0))) => left,
182 (BinaryOperator::Mul, Expr::Literal(Literal::Number(0.0)), _) => {
183 Expr::Literal(Literal::Number(0.0))
184 }
185 (BinaryOperator::Mul, _, Expr::Literal(Literal::Number(0.0))) => {
186 Expr::Literal(Literal::Number(0.0))
187 }
188 (BinaryOperator::Div, _, Expr::Literal(Literal::Number(1.0))) => left,
189 (BinaryOperator::Add, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
190 Expr::Literal(Literal::Number(a + b))
191 }
192 (BinaryOperator::Sub, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
193 Expr::Literal(Literal::Number(a - b))
194 }
195 (BinaryOperator::Mul, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
196 Expr::Literal(Literal::Number(a * b))
197 }
198 (BinaryOperator::Div, Expr::Literal(Literal::Number(a)), Expr::Literal(Literal::Number(b))) => {
199 if *b == 0.0 {
200 Expr::BinaryOp {
201 left: Box::new(left),
202 op,
203 right: Box::new(right),
204 }
205 } else {
206 Expr::Literal(Literal::Number(a / b))
207 }
208 }
209 _ => Expr::BinaryOp {
210 left: Box::new(left),
211 op,
212 right: Box::new(right),
213 },
214 }
215}
216
217fn negate_expr(expr: Expr) -> Expr {
218 rewrite_expr(&Expr::UnaryOp {
219 op: UnaryOperator::Not,
220 expr: Box::new(expr),
221 })
222}
223
224fn fold_bool_binary(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
225 let left_bool = match left {
226 Expr::Literal(Literal::Bool(value)) => Some(*value),
227 _ => None,
228 };
229 let right_bool = match right {
230 Expr::Literal(Literal::Bool(value)) => Some(*value),
231 _ => None,
232 };
233 match op {
234 BinaryOperator::And => match (left_bool, right_bool) {
235 (Some(true), _) => Some(right.clone()),
236 (Some(false), _) => Some(Expr::Literal(Literal::Bool(false))),
237 (_, Some(true)) => Some(left.clone()),
238 (_, Some(false)) => Some(Expr::Literal(Literal::Bool(false))),
239 _ => None,
240 },
241 BinaryOperator::Or => match (left_bool, right_bool) {
242 (Some(true), _) => Some(Expr::Literal(Literal::Bool(true))),
243 (Some(false), _) => Some(right.clone()),
244 (_, Some(true)) => Some(Expr::Literal(Literal::Bool(true))),
245 (_, Some(false)) => Some(left.clone()),
246 _ => None,
247 },
248 _ => None,
249 }
250}
251
252fn fold_comparison(left: &Expr, op: BinaryOperator, right: &Expr) -> Option<Expr> {
253 match (left, right) {
254 (Expr::Literal(Literal::Number(left)), Expr::Literal(Literal::Number(right))) => {
255 let result = match op {
256 BinaryOperator::Eq => Some(left == right),
257 BinaryOperator::NotEq => Some(left != right),
258 BinaryOperator::Lt => Some(left < right),
259 BinaryOperator::LtEq => Some(left <= right),
260 BinaryOperator::Gt => Some(left > right),
261 BinaryOperator::GtEq => Some(left >= right),
262 _ => None,
263 };
264 result.map(|value| Expr::Literal(Literal::Bool(value)))
265 }
266 (Expr::Literal(Literal::Bool(left)), Expr::Literal(Literal::Bool(right))) => {
267 let result = match op {
268 BinaryOperator::Eq => Some(left == right),
269 BinaryOperator::NotEq => Some(left != right),
270 _ => None,
271 };
272 result.map(|value| Expr::Literal(Literal::Bool(value)))
273 }
274 _ => None,
275 }
276}
277
278fn rewrite_order_by(order_by: &[OrderByExpr]) -> Vec<OrderByExpr> {
279 order_by
280 .iter()
281 .map(|item| OrderByExpr {
282 expr: rewrite_expr(&item.expr),
283 asc: item.asc,
284 nulls_first: item.nulls_first,
285 })
286 .collect()
287}
288
289#[cfg(test)]
290mod tests {
291 use super::{rewrite_expr, rewrite_plan};
292 use chryso_core::ast::{BinaryOperator, Expr, Literal};
293 use chryso_planner::LogicalPlan;
294
295 #[test]
296 fn folds_numeric_arithmetic() {
297 let expr = Expr::BinaryOp {
298 left: Box::new(Expr::Literal(Literal::Number(1.0))),
299 op: BinaryOperator::Add,
300 right: Box::new(Expr::BinaryOp {
301 left: Box::new(Expr::Literal(Literal::Number(2.0))),
302 op: BinaryOperator::Mul,
303 right: Box::new(Expr::Literal(Literal::Number(3.0))),
304 }),
305 };
306 let rewritten = rewrite_expr(&expr);
307 match rewritten {
308 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
309 other => panic!("expected folded literal, got {other:?}"),
310 }
311 }
312
313 #[test]
314 fn folds_boolean_logic() {
315 let expr = Expr::BinaryOp {
316 left: Box::new(Expr::Identifier("a".to_string())),
317 op: BinaryOperator::And,
318 right: Box::new(Expr::Literal(Literal::Bool(true))),
319 };
320 let rewritten = rewrite_expr(&expr);
321 match rewritten {
322 Expr::Identifier(name) => assert_eq!(name, "a"),
323 other => panic!("expected identifier, got {other:?}"),
324 }
325
326 let expr = Expr::BinaryOp {
327 left: Box::new(Expr::Identifier("a".to_string())),
328 op: BinaryOperator::Or,
329 right: Box::new(Expr::Literal(Literal::Bool(true))),
330 };
331 let rewritten = rewrite_expr(&expr);
332 match rewritten {
333 Expr::Literal(Literal::Bool(value)) => assert!(value),
334 other => panic!("expected literal true, got {other:?}"),
335 }
336 }
337
338 #[test]
339 fn folds_boolean_comparisons() {
340 let expr = Expr::BinaryOp {
341 left: Box::new(Expr::Literal(Literal::Bool(true))),
342 op: BinaryOperator::NotEq,
343 right: Box::new(Expr::Literal(Literal::Bool(false))),
344 };
345 let rewritten = rewrite_expr(&expr);
346 match rewritten {
347 Expr::Literal(Literal::Bool(value)) => assert!(value),
348 other => panic!("expected literal true, got {other:?}"),
349 }
350 }
351
352 #[test]
353 fn folds_numeric_comparisons() {
354 let expr = Expr::BinaryOp {
355 left: Box::new(Expr::Literal(Literal::Number(1.0))),
356 op: BinaryOperator::Lt,
357 right: Box::new(Expr::Literal(Literal::Number(2.0))),
358 };
359 let rewritten = rewrite_expr(&expr);
360 match rewritten {
361 Expr::Literal(Literal::Bool(value)) => assert!(value),
362 other => panic!("expected literal true, got {other:?}"),
363 }
364 }
365
366 #[test]
367 fn normalizes_not() {
368 let expr = Expr::UnaryOp {
369 op: chryso_core::ast::UnaryOperator::Not,
370 expr: Box::new(Expr::UnaryOp {
371 op: chryso_core::ast::UnaryOperator::Not,
372 expr: Box::new(Expr::Identifier("a".to_string())),
373 }),
374 };
375 let rewritten = rewrite_expr(&expr);
376 match rewritten {
377 Expr::Identifier(name) => assert_eq!(name, "a"),
378 other => panic!("expected identifier, got {other:?}"),
379 }
380 }
381
382 #[test]
383 fn applies_de_morgan() {
384 let expr = Expr::UnaryOp {
385 op: chryso_core::ast::UnaryOperator::Not,
386 expr: Box::new(Expr::BinaryOp {
387 left: Box::new(Expr::Identifier("a".to_string())),
388 op: BinaryOperator::And,
389 right: Box::new(Expr::Identifier("b".to_string())),
390 }),
391 };
392 let rewritten = rewrite_expr(&expr);
393 match rewritten {
394 Expr::BinaryOp { op: BinaryOperator::Or, left, right } => {
395 match (*left, *right) {
396 (Expr::UnaryOp { op: chryso_core::ast::UnaryOperator::Not, .. },
397 Expr::UnaryOp { op: chryso_core::ast::UnaryOperator::Not, .. }) => {}
398 other => panic!("expected negated operands, got {other:?}"),
399 }
400 }
401 other => panic!("expected OR, got {other:?}"),
402 }
403 }
404
405 #[test]
406 fn dedups_boolean_idempotence() {
407 let expr = Expr::BinaryOp {
408 left: Box::new(Expr::Identifier("a".to_string())),
409 op: BinaryOperator::And,
410 right: Box::new(Expr::Identifier("a".to_string())),
411 };
412 let rewritten = rewrite_expr(&expr);
413 match rewritten {
414 Expr::Identifier(name) => assert_eq!(name, "a"),
415 other => panic!("expected identifier, got {other:?}"),
416 }
417 }
418
419 #[test]
420 fn rewrites_filter_predicate() {
421 let plan = LogicalPlan::Filter {
422 predicate: Expr::BinaryOp {
423 left: Box::new(Expr::Literal(Literal::Number(10.0))),
424 op: BinaryOperator::Sub,
425 right: Box::new(Expr::Literal(Literal::Number(3.0))),
426 },
427 input: Box::new(LogicalPlan::Scan {
428 table: "t".to_string(),
429 }),
430 };
431 let rewritten = rewrite_plan(&plan);
432 match rewritten {
433 LogicalPlan::Filter { predicate, .. } => match predicate {
434 Expr::Literal(Literal::Number(value)) => assert_eq!(value, 7.0),
435 other => panic!("expected folded literal, got {other:?}"),
436 },
437 other => panic!("unexpected plan: {other:?}"),
438 }
439 }
440}