1use crate::error::{QueryError, Result};
18use crate::parser::ast::*;
19use oxigdal_core::error::OxiGdalError;
20use std::collections::HashMap;
21
22use super::OptimizationRule;
23
24const MAX_CSE_CANDIDATES: usize = 1000;
26
27pub struct CommonSubexpressionElimination;
29
30impl OptimizationRule for CommonSubexpressionElimination {
31 fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
32 let mut proj_registry: HashMap<String, (usize, Option<String>)> = HashMap::new();
35
36 for (idx, item) in stmt.projection.iter().enumerate() {
37 if let SelectItem::Expr { expr, alias } = item {
38 if is_cse_candidate(expr) {
39 let key = format!("{}", expr);
40 proj_registry.insert(key, (idx, alias.clone()));
41 }
42 }
43 }
44
45 if proj_registry.len() > MAX_CSE_CANDIDATES {
47 return Err(QueryError::optimization(
48 OxiGdalError::invalid_operation_builder("Too many CSE candidates in query")
49 .with_operation("common_subexpression_elimination")
50 .with_parameter("candidate_count", proj_registry.len().to_string())
51 .with_parameter("max_allowed", MAX_CSE_CANDIDATES.to_string())
52 .with_suggestion(
53 "Simplify the query or reduce the number of complex expressions in SELECT",
54 )
55 .build()
56 .to_string(),
57 ));
58 }
59
60 if proj_registry.is_empty() {
61 return Ok(stmt);
62 }
63
64 let mut replacement_map: HashMap<String, String> = HashMap::new();
66 let mut proj_alias_assignments: HashMap<usize, String> = HashMap::new();
67 let mut next_cse_id = 0usize;
68
69 if let Some(ref selection) = stmt.selection {
70 detect_cse_matches(
71 selection,
72 &proj_registry,
73 &mut replacement_map,
74 &mut proj_alias_assignments,
75 &mut next_cse_id,
76 );
77 }
78 for expr in &stmt.group_by {
79 detect_cse_matches(
80 expr,
81 &proj_registry,
82 &mut replacement_map,
83 &mut proj_alias_assignments,
84 &mut next_cse_id,
85 );
86 }
87 if let Some(ref having) = stmt.having {
88 detect_cse_matches(
89 having,
90 &proj_registry,
91 &mut replacement_map,
92 &mut proj_alias_assignments,
93 &mut next_cse_id,
94 );
95 }
96 for order in &stmt.order_by {
97 detect_cse_matches(
98 &order.expr,
99 &proj_registry,
100 &mut replacement_map,
101 &mut proj_alias_assignments,
102 &mut next_cse_id,
103 );
104 }
105
106 if replacement_map.is_empty() {
107 return Ok(stmt);
108 }
109
110 for (idx, alias_name) in &proj_alias_assignments {
112 if let Some(SelectItem::Expr { alias, .. }) = stmt.projection.get_mut(*idx) {
113 if alias.is_none() {
114 *alias = Some(alias_name.clone());
115 }
116 }
117 }
118
119 if let Some(selection) = stmt.selection.take() {
121 stmt.selection = Some(replace_cse(selection, &replacement_map));
122 }
123 stmt.group_by = stmt
124 .group_by
125 .into_iter()
126 .map(|expr| replace_cse(expr, &replacement_map))
127 .collect();
128 if let Some(having) = stmt.having.take() {
129 stmt.having = Some(replace_cse(having, &replacement_map));
130 }
131 stmt.order_by = stmt
132 .order_by
133 .into_iter()
134 .map(|order| OrderByExpr {
135 expr: replace_cse(order.expr, &replacement_map),
136 asc: order.asc,
137 nulls_first: order.nulls_first,
138 })
139 .collect();
140
141 Ok(stmt)
142 }
143}
144
145pub(crate) fn is_cse_candidate(expr: &Expr) -> bool {
148 !matches!(
149 expr,
150 Expr::Column { .. } | Expr::Literal(_) | Expr::Wildcard
151 )
152}
153
154fn detect_cse_matches(
159 expr: &Expr,
160 proj_registry: &HashMap<String, (usize, Option<String>)>,
161 replacement_map: &mut HashMap<String, String>,
162 proj_alias_assignments: &mut HashMap<usize, String>,
163 next_cse_id: &mut usize,
164) {
165 let key = format!("{}", expr);
166
167 if let Some((idx, existing_alias)) = proj_registry.get(&key) {
169 let alias = if let Some(a) = existing_alias {
170 a.clone()
171 } else if let Some(a) = proj_alias_assignments.get(idx) {
172 a.clone()
173 } else {
174 let a = format!("__cse_{}", *next_cse_id);
175 *next_cse_id += 1;
176 proj_alias_assignments.insert(*idx, a.clone());
177 a
178 };
179 replacement_map.insert(key, alias);
180 return; }
182
183 match expr {
185 Expr::BinaryOp { left, right, .. } => {
186 detect_cse_matches(
187 left,
188 proj_registry,
189 replacement_map,
190 proj_alias_assignments,
191 next_cse_id,
192 );
193 detect_cse_matches(
194 right,
195 proj_registry,
196 replacement_map,
197 proj_alias_assignments,
198 next_cse_id,
199 );
200 }
201 Expr::UnaryOp { expr: inner, .. } => {
202 detect_cse_matches(
203 inner,
204 proj_registry,
205 replacement_map,
206 proj_alias_assignments,
207 next_cse_id,
208 );
209 }
210 Expr::Function { args, .. } => {
211 for arg in args {
212 detect_cse_matches(
213 arg,
214 proj_registry,
215 replacement_map,
216 proj_alias_assignments,
217 next_cse_id,
218 );
219 }
220 }
221 Expr::Case {
222 operand,
223 when_then,
224 else_result,
225 } => {
226 if let Some(op) = operand {
227 detect_cse_matches(
228 op,
229 proj_registry,
230 replacement_map,
231 proj_alias_assignments,
232 next_cse_id,
233 );
234 }
235 for (when, then) in when_then {
236 detect_cse_matches(
237 when,
238 proj_registry,
239 replacement_map,
240 proj_alias_assignments,
241 next_cse_id,
242 );
243 detect_cse_matches(
244 then,
245 proj_registry,
246 replacement_map,
247 proj_alias_assignments,
248 next_cse_id,
249 );
250 }
251 if let Some(else_expr) = else_result {
252 detect_cse_matches(
253 else_expr,
254 proj_registry,
255 replacement_map,
256 proj_alias_assignments,
257 next_cse_id,
258 );
259 }
260 }
261 Expr::Cast { expr: inner, .. } => {
262 detect_cse_matches(
263 inner,
264 proj_registry,
265 replacement_map,
266 proj_alias_assignments,
267 next_cse_id,
268 );
269 }
270 Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
271 detect_cse_matches(
272 inner,
273 proj_registry,
274 replacement_map,
275 proj_alias_assignments,
276 next_cse_id,
277 );
278 }
279 Expr::InList {
280 expr: inner, list, ..
281 } => {
282 detect_cse_matches(
283 inner,
284 proj_registry,
285 replacement_map,
286 proj_alias_assignments,
287 next_cse_id,
288 );
289 for item in list {
290 detect_cse_matches(
291 item,
292 proj_registry,
293 replacement_map,
294 proj_alias_assignments,
295 next_cse_id,
296 );
297 }
298 }
299 Expr::Between {
300 expr: inner,
301 low,
302 high,
303 ..
304 } => {
305 detect_cse_matches(
306 inner,
307 proj_registry,
308 replacement_map,
309 proj_alias_assignments,
310 next_cse_id,
311 );
312 detect_cse_matches(
313 low,
314 proj_registry,
315 replacement_map,
316 proj_alias_assignments,
317 next_cse_id,
318 );
319 detect_cse_matches(
320 high,
321 proj_registry,
322 replacement_map,
323 proj_alias_assignments,
324 next_cse_id,
325 );
326 }
327 Expr::Column { .. } | Expr::Literal(_) | Expr::Wildcard | Expr::Subquery(_) => {}
329 }
330}
331
332fn replace_cse(expr: Expr, replacements: &HashMap<String, String>) -> Expr {
336 let key = format!("{}", expr);
337 if let Some(alias) = replacements.get(&key) {
338 return Expr::Column {
339 table: None,
340 name: alias.clone(),
341 };
342 }
343
344 match expr {
345 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
346 left: Box::new(replace_cse(*left, replacements)),
347 op,
348 right: Box::new(replace_cse(*right, replacements)),
349 },
350 Expr::UnaryOp { op, expr: inner } => Expr::UnaryOp {
351 op,
352 expr: Box::new(replace_cse(*inner, replacements)),
353 },
354 Expr::Function { name, args } => Expr::Function {
355 name,
356 args: args
357 .into_iter()
358 .map(|a| replace_cse(a, replacements))
359 .collect(),
360 },
361 Expr::Case {
362 operand,
363 when_then,
364 else_result,
365 } => Expr::Case {
366 operand: operand.map(|e| Box::new(replace_cse(*e, replacements))),
367 when_then: when_then
368 .into_iter()
369 .map(|(w, t)| (replace_cse(w, replacements), replace_cse(t, replacements)))
370 .collect(),
371 else_result: else_result.map(|e| Box::new(replace_cse(*e, replacements))),
372 },
373 Expr::Cast {
374 expr: inner,
375 data_type,
376 } => Expr::Cast {
377 expr: Box::new(replace_cse(*inner, replacements)),
378 data_type,
379 },
380 Expr::IsNull(inner) => Expr::IsNull(Box::new(replace_cse(*inner, replacements))),
381 Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(replace_cse(*inner, replacements))),
382 Expr::InList {
383 expr: inner,
384 list,
385 negated,
386 } => Expr::InList {
387 expr: Box::new(replace_cse(*inner, replacements)),
388 list: list
389 .into_iter()
390 .map(|i| replace_cse(i, replacements))
391 .collect(),
392 negated,
393 },
394 Expr::Between {
395 expr: inner,
396 low,
397 high,
398 negated,
399 } => Expr::Between {
400 expr: Box::new(replace_cse(*inner, replacements)),
401 low: Box::new(replace_cse(*low, replacements)),
402 high: Box::new(replace_cse(*high, replacements)),
403 negated,
404 },
405 other => other,
407 }
408}
409
410#[cfg(test)]
411#[allow(clippy::unwrap_used)]
412#[allow(clippy::panic)]
413mod tests {
414 use super::*;
415
416 #[test]
417 fn test_cse_projection_to_where() {
418 let a_plus_b = Expr::BinaryOp {
421 left: Box::new(Expr::Column {
422 table: None,
423 name: "a".to_string(),
424 }),
425 op: BinaryOperator::Plus,
426 right: Box::new(Expr::Column {
427 table: None,
428 name: "b".to_string(),
429 }),
430 };
431
432 let stmt = SelectStatement {
433 projection: vec![
434 SelectItem::Expr {
435 expr: a_plus_b.clone(),
436 alias: None,
437 },
438 SelectItem::Expr {
439 expr: Expr::Column {
440 table: None,
441 name: "x".to_string(),
442 },
443 alias: None,
444 },
445 ],
446 from: Some(TableReference::Table {
447 name: "t".to_string(),
448 alias: None,
449 }),
450 selection: Some(Expr::BinaryOp {
451 left: Box::new(a_plus_b),
452 op: BinaryOperator::Gt,
453 right: Box::new(Expr::Literal(Literal::Integer(10))),
454 }),
455 group_by: Vec::new(),
456 having: None,
457 order_by: Vec::new(),
458 limit: None,
459 offset: None,
460 };
461
462 let cse = CommonSubexpressionElimination;
463 let result = cse.apply(stmt);
464 assert!(result.is_ok(), "CSE should succeed");
465 let result = result.expect("CSE should succeed");
466
467 if let SelectItem::Expr { alias, .. } = &result.projection[0] {
469 assert!(
470 alias.is_some(),
471 "CSE should assign alias to common expression"
472 );
473 }
474
475 if let Some(Expr::BinaryOp { left, .. }) = &result.selection {
477 assert!(
478 matches!(**left, Expr::Column { .. }),
479 "CSE should replace expression in WHERE with column ref"
480 );
481 }
482 }
483
484 #[test]
485 fn test_cse_with_existing_alias() {
486 let a_plus_b = Expr::BinaryOp {
489 left: Box::new(Expr::Column {
490 table: None,
491 name: "a".to_string(),
492 }),
493 op: BinaryOperator::Plus,
494 right: Box::new(Expr::Column {
495 table: None,
496 name: "b".to_string(),
497 }),
498 };
499
500 let stmt = SelectStatement {
501 projection: vec![SelectItem::Expr {
502 expr: a_plus_b.clone(),
503 alias: Some("total".to_string()),
504 }],
505 from: Some(TableReference::Table {
506 name: "t".to_string(),
507 alias: None,
508 }),
509 selection: None,
510 group_by: Vec::new(),
511 having: None,
512 order_by: vec![OrderByExpr {
513 expr: a_plus_b,
514 asc: true,
515 nulls_first: false,
516 }],
517 limit: None,
518 offset: None,
519 };
520
521 let cse = CommonSubexpressionElimination;
522 let result = cse.apply(stmt);
523 assert!(result.is_ok(), "CSE should succeed");
524 let result = result.expect("CSE should succeed");
525
526 let Expr::Column { name, .. } = &result.order_by[0].expr else {
528 panic!("ORDER BY should be a column reference after CSE");
529 };
530 assert_eq!(name, "total");
531 }
532
533 #[test]
534 fn test_cse_no_common_expressions() {
535 let stmt = SelectStatement {
538 projection: vec![SelectItem::Expr {
539 expr: Expr::Column {
540 table: None,
541 name: "a".to_string(),
542 },
543 alias: None,
544 }],
545 from: Some(TableReference::Table {
546 name: "t".to_string(),
547 alias: None,
548 }),
549 selection: Some(Expr::BinaryOp {
550 left: Box::new(Expr::Column {
551 table: None,
552 name: "b".to_string(),
553 }),
554 op: BinaryOperator::Gt,
555 right: Box::new(Expr::Literal(Literal::Integer(5))),
556 }),
557 group_by: Vec::new(),
558 having: None,
559 order_by: Vec::new(),
560 limit: None,
561 offset: None,
562 };
563
564 let cse = CommonSubexpressionElimination;
565 let result = cse.apply(stmt);
566 assert!(result.is_ok(), "CSE should succeed");
567 let result = result.expect("CSE should succeed");
568
569 if let SelectItem::Expr { alias, .. } = &result.projection[0] {
571 assert!(alias.is_none());
572 }
573 }
574
575 #[test]
576 fn test_cse_subexpression_in_where() {
577 let a_plus_b = Expr::BinaryOp {
580 left: Box::new(Expr::Column {
581 table: None,
582 name: "a".to_string(),
583 }),
584 op: BinaryOperator::Plus,
585 right: Box::new(Expr::Column {
586 table: None,
587 name: "b".to_string(),
588 }),
589 };
590
591 let stmt = SelectStatement {
592 projection: vec![SelectItem::Expr {
593 expr: a_plus_b.clone(),
594 alias: None,
595 }],
596 from: Some(TableReference::Table {
597 name: "t".to_string(),
598 alias: None,
599 }),
600 selection: Some(Expr::BinaryOp {
601 left: Box::new(Expr::BinaryOp {
602 left: Box::new(a_plus_b),
603 op: BinaryOperator::Multiply,
604 right: Box::new(Expr::Literal(Literal::Integer(2))),
605 }),
606 op: BinaryOperator::Gt,
607 right: Box::new(Expr::Literal(Literal::Integer(10))),
608 }),
609 group_by: Vec::new(),
610 having: None,
611 order_by: Vec::new(),
612 limit: None,
613 offset: None,
614 };
615
616 let cse = CommonSubexpressionElimination;
617 let result = cse.apply(stmt);
618 assert!(result.is_ok(), "CSE should succeed");
619 let result = result.expect("CSE should succeed");
620
621 if let SelectItem::Expr { alias, .. } = &result.projection[0] {
623 assert!(alias.is_some());
624 }
625
626 if let Some(Expr::BinaryOp {
628 left: outer_left, ..
629 }) = &result.selection
630 {
631 if let Expr::BinaryOp {
632 left: inner_left, ..
633 } = outer_left.as_ref()
634 {
635 assert!(
636 matches!(inner_left.as_ref(), Expr::Column { .. }),
637 "a+b should be replaced with column ref inside larger expression"
638 );
639 }
640 }
641 }
642
643 #[test]
644 fn test_is_cse_candidate() {
645 assert!(!is_cse_candidate(&Expr::Column {
647 table: None,
648 name: "a".to_string()
649 }));
650
651 assert!(!is_cse_candidate(&Expr::Literal(Literal::Integer(42))));
653
654 assert!(!is_cse_candidate(&Expr::Wildcard));
656
657 assert!(is_cse_candidate(&Expr::BinaryOp {
659 left: Box::new(Expr::Column {
660 table: None,
661 name: "a".to_string()
662 }),
663 op: BinaryOperator::Plus,
664 right: Box::new(Expr::Column {
665 table: None,
666 name: "b".to_string()
667 }),
668 }));
669
670 assert!(is_cse_candidate(&Expr::Function {
672 name: "SUM".to_string(),
673 args: vec![Expr::Column {
674 table: None,
675 name: "x".to_string()
676 }],
677 }));
678 }
679}