Skip to main content

oxigdal_query/optimizer/rules/
cse.rs

1//! Common subexpression elimination rule.
2//!
3//! Identifies repeated expressions across the query and eliminates redundant
4//! computation by replacing duplicates with references to pre-computed results.
5//!
6//! The algorithm works in three phases:
7//! 1. **Registry**: Build a map of projection expressions keyed by their
8//!    canonical string form (using the `Display` trait).
9//! 2. **Detection**: Scan non-projection clauses (WHERE, GROUP BY, HAVING,
10//!    ORDER BY) for subexpressions matching any projection expression.
11//! 3. **Replacement**: Replace detected matches with column references to
12//!    projection aliases, assigning synthetic aliases where needed.
13//!
14//! This is safe because SQL allows referencing SELECT aliases in GROUP BY,
15//! HAVING, and ORDER BY clauses.
16
17use crate::error::{QueryError, Result};
18use crate::parser::ast::*;
19use oxigdal_core::error::OxiGdalError;
20use std::collections::HashMap;
21
22use super::OptimizationRule;
23
24/// Maximum number of CSE candidates to track (prevents excessive memory use)
25const MAX_CSE_CANDIDATES: usize = 1000;
26
27/// Common subexpression elimination rule.
28pub struct CommonSubexpressionElimination;
29
30impl OptimizationRule for CommonSubexpressionElimination {
31    fn apply(&self, mut stmt: SelectStatement) -> Result<SelectStatement> {
32        // Phase 1: Build a registry of projection expressions
33        // Key: canonical string form, Value: (projection index, existing alias if any)
34        let mut proj_registry: HashMap<String, (usize, Option<String>)> = HashMap::new();
35
36        for (idx, item) in stmt.projection.iter().enumerate() {
37            if let SelectItem::Expr { expr, alias } = item {
38                if is_cse_candidate(expr) {
39                    let key = format!("{}", expr);
40                    proj_registry.insert(key, (idx, alias.clone()));
41                }
42            }
43        }
44
45        // Check complexity limit
46        if proj_registry.len() > MAX_CSE_CANDIDATES {
47            return Err(QueryError::optimization(
48                OxiGdalError::invalid_operation_builder("Too many CSE candidates in query")
49                    .with_operation("common_subexpression_elimination")
50                    .with_parameter("candidate_count", proj_registry.len().to_string())
51                    .with_parameter("max_allowed", MAX_CSE_CANDIDATES.to_string())
52                    .with_suggestion(
53                        "Simplify the query or reduce the number of complex expressions in SELECT",
54                    )
55                    .build()
56                    .to_string(),
57            ));
58        }
59
60        if proj_registry.is_empty() {
61            return Ok(stmt);
62        }
63
64        // Phase 2: Detect common subexpressions in non-projection clauses
65        let mut replacement_map: HashMap<String, String> = HashMap::new();
66        let mut proj_alias_assignments: HashMap<usize, String> = HashMap::new();
67        let mut next_cse_id = 0usize;
68
69        if let Some(ref selection) = stmt.selection {
70            detect_cse_matches(
71                selection,
72                &proj_registry,
73                &mut replacement_map,
74                &mut proj_alias_assignments,
75                &mut next_cse_id,
76            );
77        }
78        for expr in &stmt.group_by {
79            detect_cse_matches(
80                expr,
81                &proj_registry,
82                &mut replacement_map,
83                &mut proj_alias_assignments,
84                &mut next_cse_id,
85            );
86        }
87        if let Some(ref having) = stmt.having {
88            detect_cse_matches(
89                having,
90                &proj_registry,
91                &mut replacement_map,
92                &mut proj_alias_assignments,
93                &mut next_cse_id,
94            );
95        }
96        for order in &stmt.order_by {
97            detect_cse_matches(
98                &order.expr,
99                &proj_registry,
100                &mut replacement_map,
101                &mut proj_alias_assignments,
102                &mut next_cse_id,
103            );
104        }
105
106        if replacement_map.is_empty() {
107            return Ok(stmt);
108        }
109
110        // Phase 3: Assign aliases to projection items that need them
111        for (idx, alias_name) in &proj_alias_assignments {
112            if let Some(SelectItem::Expr { alias, .. }) = stmt.projection.get_mut(*idx) {
113                if alias.is_none() {
114                    *alias = Some(alias_name.clone());
115                }
116            }
117        }
118
119        // Replace common subexpressions in non-projection clauses
120        if let Some(selection) = stmt.selection.take() {
121            stmt.selection = Some(replace_cse(selection, &replacement_map));
122        }
123        stmt.group_by = stmt
124            .group_by
125            .into_iter()
126            .map(|expr| replace_cse(expr, &replacement_map))
127            .collect();
128        if let Some(having) = stmt.having.take() {
129            stmt.having = Some(replace_cse(having, &replacement_map));
130        }
131        stmt.order_by = stmt
132            .order_by
133            .into_iter()
134            .map(|order| OrderByExpr {
135                expr: replace_cse(order.expr, &replacement_map),
136                asc: order.asc,
137                nulls_first: order.nulls_first,
138            })
139            .collect();
140
141        Ok(stmt)
142    }
143}
144
145/// Check if an expression is a candidate for CSE (non-trivial computation).
146/// Simple column references and literals are never worth extracting.
147pub(crate) fn is_cse_candidate(expr: &Expr) -> bool {
148    !matches!(
149        expr,
150        Expr::Column { .. } | Expr::Literal(_) | Expr::Wildcard
151    )
152}
153
154/// Walk an expression tree looking for subexpressions that match entries
155/// in `proj_registry`. When a match is found, record the mapping in
156/// `replacement_map` and, if needed, generate a synthetic alias in
157/// `proj_alias_assignments`.
158fn detect_cse_matches(
159    expr: &Expr,
160    proj_registry: &HashMap<String, (usize, Option<String>)>,
161    replacement_map: &mut HashMap<String, String>,
162    proj_alias_assignments: &mut HashMap<usize, String>,
163    next_cse_id: &mut usize,
164) {
165    let key = format!("{}", expr);
166
167    // Check if this (sub)expression matches a projection expression
168    if let Some((idx, existing_alias)) = proj_registry.get(&key) {
169        let alias = if let Some(a) = existing_alias {
170            a.clone()
171        } else if let Some(a) = proj_alias_assignments.get(idx) {
172            a.clone()
173        } else {
174            let a = format!("__cse_{}", *next_cse_id);
175            *next_cse_id += 1;
176            proj_alias_assignments.insert(*idx, a.clone());
177            a
178        };
179        replacement_map.insert(key, alias);
180        return; // Whole expression will be replaced; no need to recurse deeper
181    }
182
183    // Recurse into children to find deeper matches
184    match expr {
185        Expr::BinaryOp { left, right, .. } => {
186            detect_cse_matches(
187                left,
188                proj_registry,
189                replacement_map,
190                proj_alias_assignments,
191                next_cse_id,
192            );
193            detect_cse_matches(
194                right,
195                proj_registry,
196                replacement_map,
197                proj_alias_assignments,
198                next_cse_id,
199            );
200        }
201        Expr::UnaryOp { expr: inner, .. } => {
202            detect_cse_matches(
203                inner,
204                proj_registry,
205                replacement_map,
206                proj_alias_assignments,
207                next_cse_id,
208            );
209        }
210        Expr::Function { args, .. } => {
211            for arg in args {
212                detect_cse_matches(
213                    arg,
214                    proj_registry,
215                    replacement_map,
216                    proj_alias_assignments,
217                    next_cse_id,
218                );
219            }
220        }
221        Expr::Case {
222            operand,
223            when_then,
224            else_result,
225        } => {
226            if let Some(op) = operand {
227                detect_cse_matches(
228                    op,
229                    proj_registry,
230                    replacement_map,
231                    proj_alias_assignments,
232                    next_cse_id,
233                );
234            }
235            for (when, then) in when_then {
236                detect_cse_matches(
237                    when,
238                    proj_registry,
239                    replacement_map,
240                    proj_alias_assignments,
241                    next_cse_id,
242                );
243                detect_cse_matches(
244                    then,
245                    proj_registry,
246                    replacement_map,
247                    proj_alias_assignments,
248                    next_cse_id,
249                );
250            }
251            if let Some(else_expr) = else_result {
252                detect_cse_matches(
253                    else_expr,
254                    proj_registry,
255                    replacement_map,
256                    proj_alias_assignments,
257                    next_cse_id,
258                );
259            }
260        }
261        Expr::Cast { expr: inner, .. } => {
262            detect_cse_matches(
263                inner,
264                proj_registry,
265                replacement_map,
266                proj_alias_assignments,
267                next_cse_id,
268            );
269        }
270        Expr::IsNull(inner) | Expr::IsNotNull(inner) => {
271            detect_cse_matches(
272                inner,
273                proj_registry,
274                replacement_map,
275                proj_alias_assignments,
276                next_cse_id,
277            );
278        }
279        Expr::InList {
280            expr: inner, list, ..
281        } => {
282            detect_cse_matches(
283                inner,
284                proj_registry,
285                replacement_map,
286                proj_alias_assignments,
287                next_cse_id,
288            );
289            for item in list {
290                detect_cse_matches(
291                    item,
292                    proj_registry,
293                    replacement_map,
294                    proj_alias_assignments,
295                    next_cse_id,
296                );
297            }
298        }
299        Expr::Between {
300            expr: inner,
301            low,
302            high,
303            ..
304        } => {
305            detect_cse_matches(
306                inner,
307                proj_registry,
308                replacement_map,
309                proj_alias_assignments,
310                next_cse_id,
311            );
312            detect_cse_matches(
313                low,
314                proj_registry,
315                replacement_map,
316                proj_alias_assignments,
317                next_cse_id,
318            );
319            detect_cse_matches(
320                high,
321                proj_registry,
322                replacement_map,
323                proj_alias_assignments,
324                next_cse_id,
325            );
326        }
327        // Terminals and subqueries (different scope) - no recursion
328        Expr::Column { .. } | Expr::Literal(_) | Expr::Wildcard | Expr::Subquery(_) => {}
329    }
330}
331
332/// Replace common subexpressions with column references (top-down traversal).
333/// Checks the current node first; if it matches, replaces the whole subtree.
334/// Otherwise, recurses into children.
335fn replace_cse(expr: Expr, replacements: &HashMap<String, String>) -> Expr {
336    let key = format!("{}", expr);
337    if let Some(alias) = replacements.get(&key) {
338        return Expr::Column {
339            table: None,
340            name: alias.clone(),
341        };
342    }
343
344    match expr {
345        Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
346            left: Box::new(replace_cse(*left, replacements)),
347            op,
348            right: Box::new(replace_cse(*right, replacements)),
349        },
350        Expr::UnaryOp { op, expr: inner } => Expr::UnaryOp {
351            op,
352            expr: Box::new(replace_cse(*inner, replacements)),
353        },
354        Expr::Function { name, args } => Expr::Function {
355            name,
356            args: args
357                .into_iter()
358                .map(|a| replace_cse(a, replacements))
359                .collect(),
360        },
361        Expr::Case {
362            operand,
363            when_then,
364            else_result,
365        } => Expr::Case {
366            operand: operand.map(|e| Box::new(replace_cse(*e, replacements))),
367            when_then: when_then
368                .into_iter()
369                .map(|(w, t)| (replace_cse(w, replacements), replace_cse(t, replacements)))
370                .collect(),
371            else_result: else_result.map(|e| Box::new(replace_cse(*e, replacements))),
372        },
373        Expr::Cast {
374            expr: inner,
375            data_type,
376        } => Expr::Cast {
377            expr: Box::new(replace_cse(*inner, replacements)),
378            data_type,
379        },
380        Expr::IsNull(inner) => Expr::IsNull(Box::new(replace_cse(*inner, replacements))),
381        Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(replace_cse(*inner, replacements))),
382        Expr::InList {
383            expr: inner,
384            list,
385            negated,
386        } => Expr::InList {
387            expr: Box::new(replace_cse(*inner, replacements)),
388            list: list
389                .into_iter()
390                .map(|i| replace_cse(i, replacements))
391                .collect(),
392            negated,
393        },
394        Expr::Between {
395            expr: inner,
396            low,
397            high,
398            negated,
399        } => Expr::Between {
400            expr: Box::new(replace_cse(*inner, replacements)),
401            low: Box::new(replace_cse(*low, replacements)),
402            high: Box::new(replace_cse(*high, replacements)),
403            negated,
404        },
405        // Column, Literal, Wildcard, Subquery: return as-is
406        other => other,
407    }
408}
409
410#[cfg(test)]
411#[allow(clippy::unwrap_used)]
412#[allow(clippy::panic)]
413mod tests {
414    use super::*;
415
416    #[test]
417    fn test_cse_projection_to_where() {
418        // SELECT (a + b), x FROM t WHERE (a + b) > 10
419        // -> SELECT (a + b) AS __cse_0, x FROM t WHERE __cse_0 > 10
420        let a_plus_b = Expr::BinaryOp {
421            left: Box::new(Expr::Column {
422                table: None,
423                name: "a".to_string(),
424            }),
425            op: BinaryOperator::Plus,
426            right: Box::new(Expr::Column {
427                table: None,
428                name: "b".to_string(),
429            }),
430        };
431
432        let stmt = SelectStatement {
433            projection: vec![
434                SelectItem::Expr {
435                    expr: a_plus_b.clone(),
436                    alias: None,
437                },
438                SelectItem::Expr {
439                    expr: Expr::Column {
440                        table: None,
441                        name: "x".to_string(),
442                    },
443                    alias: None,
444                },
445            ],
446            from: Some(TableReference::Table {
447                name: "t".to_string(),
448                alias: None,
449            }),
450            selection: Some(Expr::BinaryOp {
451                left: Box::new(a_plus_b),
452                op: BinaryOperator::Gt,
453                right: Box::new(Expr::Literal(Literal::Integer(10))),
454            }),
455            group_by: Vec::new(),
456            having: None,
457            order_by: Vec::new(),
458            limit: None,
459            offset: None,
460        };
461
462        let cse = CommonSubexpressionElimination;
463        let result = cse.apply(stmt);
464        assert!(result.is_ok(), "CSE should succeed");
465        let result = result.expect("CSE should succeed");
466
467        // Projection should have an alias assigned
468        if let SelectItem::Expr { alias, .. } = &result.projection[0] {
469            assert!(
470                alias.is_some(),
471                "CSE should assign alias to common expression"
472            );
473        }
474
475        // WHERE should use a column reference instead of the expression
476        if let Some(Expr::BinaryOp { left, .. }) = &result.selection {
477            assert!(
478                matches!(**left, Expr::Column { .. }),
479                "CSE should replace expression in WHERE with column ref"
480            );
481        }
482    }
483
484    #[test]
485    fn test_cse_with_existing_alias() {
486        // SELECT (a + b) AS total FROM t ORDER BY (a + b)
487        // -> SELECT (a + b) AS total FROM t ORDER BY total
488        let a_plus_b = Expr::BinaryOp {
489            left: Box::new(Expr::Column {
490                table: None,
491                name: "a".to_string(),
492            }),
493            op: BinaryOperator::Plus,
494            right: Box::new(Expr::Column {
495                table: None,
496                name: "b".to_string(),
497            }),
498        };
499
500        let stmt = SelectStatement {
501            projection: vec![SelectItem::Expr {
502                expr: a_plus_b.clone(),
503                alias: Some("total".to_string()),
504            }],
505            from: Some(TableReference::Table {
506                name: "t".to_string(),
507                alias: None,
508            }),
509            selection: None,
510            group_by: Vec::new(),
511            having: None,
512            order_by: vec![OrderByExpr {
513                expr: a_plus_b,
514                asc: true,
515                nulls_first: false,
516            }],
517            limit: None,
518            offset: None,
519        };
520
521        let cse = CommonSubexpressionElimination;
522        let result = cse.apply(stmt);
523        assert!(result.is_ok(), "CSE should succeed");
524        let result = result.expect("CSE should succeed");
525
526        // ORDER BY should now reference "total"
527        let Expr::Column { name, .. } = &result.order_by[0].expr else {
528            panic!("ORDER BY should be a column reference after CSE");
529        };
530        assert_eq!(name, "total");
531    }
532
533    #[test]
534    fn test_cse_no_common_expressions() {
535        // SELECT a FROM t WHERE b > 5
536        // No common subexpressions between projection and WHERE
537        let stmt = SelectStatement {
538            projection: vec![SelectItem::Expr {
539                expr: Expr::Column {
540                    table: None,
541                    name: "a".to_string(),
542                },
543                alias: None,
544            }],
545            from: Some(TableReference::Table {
546                name: "t".to_string(),
547                alias: None,
548            }),
549            selection: Some(Expr::BinaryOp {
550                left: Box::new(Expr::Column {
551                    table: None,
552                    name: "b".to_string(),
553                }),
554                op: BinaryOperator::Gt,
555                right: Box::new(Expr::Literal(Literal::Integer(5))),
556            }),
557            group_by: Vec::new(),
558            having: None,
559            order_by: Vec::new(),
560            limit: None,
561            offset: None,
562        };
563
564        let cse = CommonSubexpressionElimination;
565        let result = cse.apply(stmt);
566        assert!(result.is_ok(), "CSE should succeed");
567        let result = result.expect("CSE should succeed");
568
569        // No aliases should be assigned (column ref is trivial, not a CSE candidate)
570        if let SelectItem::Expr { alias, .. } = &result.projection[0] {
571            assert!(alias.is_none());
572        }
573    }
574
575    #[test]
576    fn test_cse_subexpression_in_where() {
577        // SELECT (a + b) FROM t WHERE ((a + b) * 2) > 10
578        // -> SELECT (a + b) AS __cse_0 FROM t WHERE (__cse_0 * 2) > 10
579        let a_plus_b = Expr::BinaryOp {
580            left: Box::new(Expr::Column {
581                table: None,
582                name: "a".to_string(),
583            }),
584            op: BinaryOperator::Plus,
585            right: Box::new(Expr::Column {
586                table: None,
587                name: "b".to_string(),
588            }),
589        };
590
591        let stmt = SelectStatement {
592            projection: vec![SelectItem::Expr {
593                expr: a_plus_b.clone(),
594                alias: None,
595            }],
596            from: Some(TableReference::Table {
597                name: "t".to_string(),
598                alias: None,
599            }),
600            selection: Some(Expr::BinaryOp {
601                left: Box::new(Expr::BinaryOp {
602                    left: Box::new(a_plus_b),
603                    op: BinaryOperator::Multiply,
604                    right: Box::new(Expr::Literal(Literal::Integer(2))),
605                }),
606                op: BinaryOperator::Gt,
607                right: Box::new(Expr::Literal(Literal::Integer(10))),
608            }),
609            group_by: Vec::new(),
610            having: None,
611            order_by: Vec::new(),
612            limit: None,
613            offset: None,
614        };
615
616        let cse = CommonSubexpressionElimination;
617        let result = cse.apply(stmt);
618        assert!(result.is_ok(), "CSE should succeed");
619        let result = result.expect("CSE should succeed");
620
621        // Projection should have alias
622        if let SelectItem::Expr { alias, .. } = &result.projection[0] {
623            assert!(alias.is_some());
624        }
625
626        // WHERE: ((a+b)*2) > 10 should become (__cse_0 * 2) > 10
627        if let Some(Expr::BinaryOp {
628            left: outer_left, ..
629        }) = &result.selection
630        {
631            if let Expr::BinaryOp {
632                left: inner_left, ..
633            } = outer_left.as_ref()
634            {
635                assert!(
636                    matches!(inner_left.as_ref(), Expr::Column { .. }),
637                    "a+b should be replaced with column ref inside larger expression"
638                );
639            }
640        }
641    }
642
643    #[test]
644    fn test_is_cse_candidate() {
645        // Column reference: not a candidate
646        assert!(!is_cse_candidate(&Expr::Column {
647            table: None,
648            name: "a".to_string()
649        }));
650
651        // Literal: not a candidate
652        assert!(!is_cse_candidate(&Expr::Literal(Literal::Integer(42))));
653
654        // Wildcard: not a candidate
655        assert!(!is_cse_candidate(&Expr::Wildcard));
656
657        // Binary op: is a candidate
658        assert!(is_cse_candidate(&Expr::BinaryOp {
659            left: Box::new(Expr::Column {
660                table: None,
661                name: "a".to_string()
662            }),
663            op: BinaryOperator::Plus,
664            right: Box::new(Expr::Column {
665                table: None,
666                name: "b".to_string()
667            }),
668        }));
669
670        // Function call: is a candidate
671        assert!(is_cse_candidate(&Expr::Function {
672            name: "SUM".to_string(),
673            args: vec![Expr::Column {
674                table: None,
675                name: "x".to_string()
676            }],
677        }));
678    }
679}