Skip to main content

polyglot_sql/optimizer/
eliminate_ctes.rs

1//! CTE Elimination Module
2//!
3//! This module provides functionality for removing unused CTEs
4//! from SQL expressions.
5//!
6//! Ported from sqlglot's optimizer/eliminate_ctes.py
7
8use std::collections::HashMap;
9
10use crate::expressions::Expression;
11use crate::scope::{build_scope, Scope};
12
13/// Remove unused CTEs from an expression.
14///
15/// # Example
16///
17/// ```sql
18/// -- Before:
19/// WITH y AS (SELECT a FROM x) SELECT a FROM z
20/// -- After:
21/// SELECT a FROM z
22/// ```
23///
24/// # Arguments
25/// * `expression` - The expression to optimize
26///
27/// # Returns
28/// The optimized expression with unused CTEs removed
29pub fn eliminate_ctes(expression: Expression) -> Expression {
30    let root = build_scope(&expression);
31
32    // Compute reference counts for each scope
33    let ref_count = compute_ref_count(&root);
34
35    // Collect scopes to process (in reverse order)
36    let scopes = collect_scopes(&root);
37
38    // Track which CTEs to remove
39    let mut ctes_to_remove: Vec<String> = Vec::new();
40
41    for scope in scopes.iter().rev() {
42        if scope.is_cte() {
43            let scope_id = *scope as *const Scope as u64;
44            let count = ref_count.get(&scope_id).copied().unwrap_or(0);
45
46            if count == 0 {
47                // This CTE is unused, mark for removal
48                if let Some(name) = get_cte_name(scope) {
49                    ctes_to_remove.push(name);
50                }
51            }
52        }
53    }
54
55    // Remove the marked CTEs
56    if ctes_to_remove.is_empty() {
57        return expression;
58    }
59
60    remove_ctes(expression, &ctes_to_remove)
61}
62
63/// Compute reference counts for each scope
64fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
65    let mut counts: HashMap<u64, usize> = HashMap::new();
66
67    // Initialize all scopes with count 0
68    for scope in collect_scopes(root) {
69        let id = scope as *const Scope as u64;
70        counts.insert(id, 0);
71    }
72
73    // Count references
74    for scope in collect_scopes(root) {
75        for (_name, source_info) in &scope.sources {
76            // If this source references a CTE scope, increment its count
77            // In a full implementation, we'd track which sources are CTEs
78            let _ = source_info;
79        }
80    }
81
82    counts
83}
84
85/// Collect all scopes from the tree
86fn collect_scopes(root: &Scope) -> Vec<&Scope> {
87    let mut result = vec![root];
88    result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
89    result.extend(
90        root.derived_table_scopes
91            .iter()
92            .flat_map(|s| collect_scopes(s)),
93    );
94    result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
95    result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
96    result
97}
98
99/// Get the CTE name from a scope
100fn get_cte_name(scope: &Scope) -> Option<String> {
101    // In a full implementation, we'd extract the CTE name from the scope's expression
102    // For now, return None
103    let _ = scope;
104    None
105}
106
107/// Remove the specified CTEs from an expression
108fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
109    if ctes_to_remove.is_empty() {
110        return expression;
111    }
112
113    // In a full implementation, we would:
114    // 1. Find the WITH clause
115    // 2. Remove the specified CTEs
116    // 3. If WITH clause is empty, remove it entirely
117    //
118    // For now, return unchanged
119    expression
120}
121
122/// Check if a CTE is referenced anywhere in the query
123pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
124    match expression {
125        Expression::Table(table) => table.name.name == cte_name,
126        Expression::Select(select) => {
127            // Check FROM
128            if let Some(ref from) = select.from {
129                for expr in &from.expressions {
130                    if is_cte_referenced(expr, cte_name) {
131                        return true;
132                    }
133                }
134            }
135            // Check JOINs
136            for join in &select.joins {
137                if is_cte_referenced(&join.this, cte_name) {
138                    return true;
139                }
140            }
141            // Check subqueries in SELECT list
142            for expr in &select.expressions {
143                if is_cte_referenced(expr, cte_name) {
144                    return true;
145                }
146            }
147            // Check WHERE
148            if let Some(ref where_clause) = select.where_clause {
149                if is_cte_referenced(&where_clause.this, cte_name) {
150                    return true;
151                }
152            }
153            false
154        }
155        Expression::Subquery(subquery) => is_cte_referenced(&subquery.this, cte_name),
156        Expression::Union(union) => {
157            is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
158        }
159        Expression::Intersect(intersect) => {
160            is_cte_referenced(&intersect.left, cte_name)
161                || is_cte_referenced(&intersect.right, cte_name)
162        }
163        Expression::Except(except) => {
164            is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
165        }
166        Expression::In(in_expr) => {
167            if let Some(ref query) = in_expr.query {
168                is_cte_referenced(query, cte_name)
169            } else {
170                false
171            }
172        }
173        _ => false,
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    use crate::generator::Generator;
181    use crate::parser::Parser;
182
183    fn gen(expr: &Expression) -> String {
184        Generator::new().generate(expr).unwrap()
185    }
186
187    fn parse(sql: &str) -> Expression {
188        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
189    }
190
191    #[test]
192    fn test_eliminate_ctes_unused() {
193        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
194        let result = eliminate_ctes(expr);
195        let sql = gen(&result);
196        // In a full implementation, the CTE would be removed
197        assert!(sql.contains("SELECT"));
198    }
199
200    #[test]
201    fn test_eliminate_ctes_used() {
202        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
203        let result = eliminate_ctes(expr);
204        let sql = gen(&result);
205        // CTE is used, should be preserved
206        assert!(sql.contains("WITH"));
207    }
208
209    #[test]
210    fn test_is_cte_referenced_true() {
211        let expr = parse("SELECT * FROM cte_name");
212        assert!(is_cte_referenced(&expr, "cte_name"));
213    }
214
215    #[test]
216    fn test_is_cte_referenced_false() {
217        let expr = parse("SELECT * FROM other_table");
218        assert!(!is_cte_referenced(&expr, "cte_name"));
219    }
220
221    #[test]
222    fn test_is_cte_referenced_in_join() {
223        let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
224        assert!(is_cte_referenced(&expr, "cte_name"));
225    }
226
227    #[test]
228    fn test_is_cte_referenced_in_subquery() {
229        let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
230        assert!(is_cte_referenced(&expr, "cte_name"));
231    }
232
233    #[test]
234    fn test_eliminate_preserves_structure() {
235        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
236        let result = eliminate_ctes(expr);
237        let sql = gen(&result);
238        assert!(sql.contains("WHERE"));
239    }
240
241    #[test]
242    fn test_eliminate_multiple_ctes() {
243        let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
244        let result = eliminate_ctes(expr);
245        let sql = gen(&result);
246        // In a full implementation, unused CTE 'b' would be removed
247        assert!(sql.contains("WITH"));
248    }
249
250    #[test]
251    fn test_is_cte_referenced_in_union() {
252        let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
253        assert!(is_cte_referenced(&expr, "cte_name"));
254    }
255
256    #[test]
257    fn test_compute_ref_count() {
258        let expr = parse("SELECT * FROM t");
259        let root = build_scope(&expr);
260        let counts = compute_ref_count(&root);
261        // Should have at least one scope
262        assert!(!counts.is_empty());
263    }
264}