Skip to main content

polyglot_sql/optimizer/
eliminate_ctes.rs

1//! CTE Elimination Module
2//!
3//! This module provides functionality for removing unused CTEs
4//! from SQL expressions.
5//!
6//! Ported from sqlglot's optimizer/eliminate_ctes.py
7
8use std::collections::HashMap;
9
10use crate::expressions::Expression;
11use crate::scope::{build_scope, Scope};
12
13/// Remove unused CTEs from an expression.
14///
15/// # Example
16///
17/// ```sql
18/// -- Before:
19/// WITH y AS (SELECT a FROM x) SELECT a FROM z
20/// -- After:
21/// SELECT a FROM z
22/// ```
23///
24/// # Arguments
25/// * `expression` - The expression to optimize
26///
27/// # Returns
28/// The optimized expression with unused CTEs removed
29pub fn eliminate_ctes(expression: Expression) -> Expression {
30    let root = build_scope(&expression);
31
32    // Compute reference counts for each scope
33    let ref_count = compute_ref_count(&root);
34
35    // Collect scopes to process (in reverse order)
36    let scopes = collect_scopes(&root);
37
38    // Track which CTEs to remove
39    let mut ctes_to_remove: Vec<String> = Vec::new();
40
41    for scope in scopes.iter().rev() {
42        if scope.is_cte() {
43            let scope_id = *scope as *const Scope as u64;
44            let count = ref_count.get(&scope_id).copied().unwrap_or(0);
45
46            if count == 0 {
47                // This CTE is unused, mark for removal
48                if let Some(name) = get_cte_name(scope) {
49                    ctes_to_remove.push(name);
50                }
51            }
52        }
53    }
54
55    // Remove the marked CTEs
56    if ctes_to_remove.is_empty() {
57        return expression;
58    }
59
60    remove_ctes(expression, &ctes_to_remove)
61}
62
63/// Compute reference counts for each scope
64fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
65    let mut counts: HashMap<u64, usize> = HashMap::new();
66
67    // Initialize all scopes with count 0
68    for scope in collect_scopes(root) {
69        let id = scope as *const Scope as u64;
70        counts.insert(id, 0);
71    }
72
73    // Count references
74    for scope in collect_scopes(root) {
75        for (_name, source_info) in &scope.sources {
76            // If this source references a CTE scope, increment its count
77            // In a full implementation, we'd track which sources are CTEs
78            let _ = source_info;
79        }
80    }
81
82    counts
83}
84
85/// Collect all scopes from the tree
86fn collect_scopes(root: &Scope) -> Vec<&Scope> {
87    let mut result = vec![root];
88    result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
89    result.extend(root.derived_table_scopes.iter().flat_map(|s| collect_scopes(s)));
90    result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
91    result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
92    result
93}
94
95/// Get the CTE name from a scope
96fn get_cte_name(scope: &Scope) -> Option<String> {
97    // In a full implementation, we'd extract the CTE name from the scope's expression
98    // For now, return None
99    let _ = scope;
100    None
101}
102
103/// Remove the specified CTEs from an expression
104fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
105    if ctes_to_remove.is_empty() {
106        return expression;
107    }
108
109    // In a full implementation, we would:
110    // 1. Find the WITH clause
111    // 2. Remove the specified CTEs
112    // 3. If WITH clause is empty, remove it entirely
113    //
114    // For now, return unchanged
115    expression
116}
117
118/// Check if a CTE is referenced anywhere in the query
119pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
120    match expression {
121        Expression::Table(table) => {
122            table.name.name == cte_name
123        }
124        Expression::Select(select) => {
125            // Check FROM
126            if let Some(ref from) = select.from {
127                for expr in &from.expressions {
128                    if is_cte_referenced(expr, cte_name) {
129                        return true;
130                    }
131                }
132            }
133            // Check JOINs
134            for join in &select.joins {
135                if is_cte_referenced(&join.this, cte_name) {
136                    return true;
137                }
138            }
139            // Check subqueries in SELECT list
140            for expr in &select.expressions {
141                if is_cte_referenced(expr, cte_name) {
142                    return true;
143                }
144            }
145            // Check WHERE
146            if let Some(ref where_clause) = select.where_clause {
147                if is_cte_referenced(&where_clause.this, cte_name) {
148                    return true;
149                }
150            }
151            false
152        }
153        Expression::Subquery(subquery) => {
154            is_cte_referenced(&subquery.this, cte_name)
155        }
156        Expression::Union(union) => {
157            is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
158        }
159        Expression::Intersect(intersect) => {
160            is_cte_referenced(&intersect.left, cte_name) || is_cte_referenced(&intersect.right, cte_name)
161        }
162        Expression::Except(except) => {
163            is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
164        }
165        Expression::In(in_expr) => {
166            if let Some(ref query) = in_expr.query {
167                is_cte_referenced(query, cte_name)
168            } else {
169                false
170            }
171        }
172        _ => false,
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::generator::Generator;
180    use crate::parser::Parser;
181
182    fn gen(expr: &Expression) -> String {
183        Generator::new().generate(expr).unwrap()
184    }
185
186    fn parse(sql: &str) -> Expression {
187        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
188    }
189
190    #[test]
191    fn test_eliminate_ctes_unused() {
192        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
193        let result = eliminate_ctes(expr);
194        let sql = gen(&result);
195        // In a full implementation, the CTE would be removed
196        assert!(sql.contains("SELECT"));
197    }
198
199    #[test]
200    fn test_eliminate_ctes_used() {
201        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
202        let result = eliminate_ctes(expr);
203        let sql = gen(&result);
204        // CTE is used, should be preserved
205        assert!(sql.contains("WITH"));
206    }
207
208    #[test]
209    fn test_is_cte_referenced_true() {
210        let expr = parse("SELECT * FROM cte_name");
211        assert!(is_cte_referenced(&expr, "cte_name"));
212    }
213
214    #[test]
215    fn test_is_cte_referenced_false() {
216        let expr = parse("SELECT * FROM other_table");
217        assert!(!is_cte_referenced(&expr, "cte_name"));
218    }
219
220    #[test]
221    fn test_is_cte_referenced_in_join() {
222        let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
223        assert!(is_cte_referenced(&expr, "cte_name"));
224    }
225
226    #[test]
227    fn test_is_cte_referenced_in_subquery() {
228        let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
229        assert!(is_cte_referenced(&expr, "cte_name"));
230    }
231
232    #[test]
233    fn test_eliminate_preserves_structure() {
234        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
235        let result = eliminate_ctes(expr);
236        let sql = gen(&result);
237        assert!(sql.contains("WHERE"));
238    }
239
240    #[test]
241    fn test_eliminate_multiple_ctes() {
242        let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
243        let result = eliminate_ctes(expr);
244        let sql = gen(&result);
245        // In a full implementation, unused CTE 'b' would be removed
246        assert!(sql.contains("WITH"));
247    }
248
249    #[test]
250    fn test_is_cte_referenced_in_union() {
251        let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
252        assert!(is_cte_referenced(&expr, "cte_name"));
253    }
254
255    #[test]
256    fn test_compute_ref_count() {
257        let expr = parse("SELECT * FROM t");
258        let root = build_scope(&expr);
259        let counts = compute_ref_count(&root);
260        // Should have at least one scope
261        assert!(!counts.is_empty());
262    }
263}