polyglot_sql/optimizer/
eliminate_ctes.rs1use std::collections::HashMap;
9
10use crate::expressions::Expression;
11use crate::scope::{build_scope, Scope};
12
13pub fn eliminate_ctes(expression: Expression) -> Expression {
30 let root = build_scope(&expression);
31
32 let ref_count = compute_ref_count(&root);
34
35 let scopes = collect_scopes(&root);
37
38 let mut ctes_to_remove: Vec<String> = Vec::new();
40
41 for scope in scopes.iter().rev() {
42 if scope.is_cte() {
43 let scope_id = *scope as *const Scope as u64;
44 let count = ref_count.get(&scope_id).copied().unwrap_or(0);
45
46 if count == 0 {
47 if let Some(name) = get_cte_name(scope) {
49 ctes_to_remove.push(name);
50 }
51 }
52 }
53 }
54
55 if ctes_to_remove.is_empty() {
57 return expression;
58 }
59
60 remove_ctes(expression, &ctes_to_remove)
61}
62
63fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
65 let mut counts: HashMap<u64, usize> = HashMap::new();
66
67 for scope in collect_scopes(root) {
69 let id = scope as *const Scope as u64;
70 counts.insert(id, 0);
71 }
72
73 for scope in collect_scopes(root) {
75 for (_name, source_info) in &scope.sources {
76 let _ = source_info;
79 }
80 }
81
82 counts
83}
84
85fn collect_scopes(root: &Scope) -> Vec<&Scope> {
87 let mut result = vec![root];
88 result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
89 result.extend(root.derived_table_scopes.iter().flat_map(|s| collect_scopes(s)));
90 result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
91 result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
92 result
93}
94
95fn get_cte_name(scope: &Scope) -> Option<String> {
97 let _ = scope;
100 None
101}
102
103fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
105 if ctes_to_remove.is_empty() {
106 return expression;
107 }
108
109 expression
116}
117
118pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
120 match expression {
121 Expression::Table(table) => {
122 table.name.name == cte_name
123 }
124 Expression::Select(select) => {
125 if let Some(ref from) = select.from {
127 for expr in &from.expressions {
128 if is_cte_referenced(expr, cte_name) {
129 return true;
130 }
131 }
132 }
133 for join in &select.joins {
135 if is_cte_referenced(&join.this, cte_name) {
136 return true;
137 }
138 }
139 for expr in &select.expressions {
141 if is_cte_referenced(expr, cte_name) {
142 return true;
143 }
144 }
145 if let Some(ref where_clause) = select.where_clause {
147 if is_cte_referenced(&where_clause.this, cte_name) {
148 return true;
149 }
150 }
151 false
152 }
153 Expression::Subquery(subquery) => {
154 is_cte_referenced(&subquery.this, cte_name)
155 }
156 Expression::Union(union) => {
157 is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
158 }
159 Expression::Intersect(intersect) => {
160 is_cte_referenced(&intersect.left, cte_name) || is_cte_referenced(&intersect.right, cte_name)
161 }
162 Expression::Except(except) => {
163 is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
164 }
165 Expression::In(in_expr) => {
166 if let Some(ref query) = in_expr.query {
167 is_cte_referenced(query, cte_name)
168 } else {
169 false
170 }
171 }
172 _ => false,
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179 use crate::generator::Generator;
180 use crate::parser::Parser;
181
182 fn gen(expr: &Expression) -> String {
183 Generator::new().generate(expr).unwrap()
184 }
185
186 fn parse(sql: &str) -> Expression {
187 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
188 }
189
190 #[test]
191 fn test_eliminate_ctes_unused() {
192 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
193 let result = eliminate_ctes(expr);
194 let sql = gen(&result);
195 assert!(sql.contains("SELECT"));
197 }
198
199 #[test]
200 fn test_eliminate_ctes_used() {
201 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
202 let result = eliminate_ctes(expr);
203 let sql = gen(&result);
204 assert!(sql.contains("WITH"));
206 }
207
208 #[test]
209 fn test_is_cte_referenced_true() {
210 let expr = parse("SELECT * FROM cte_name");
211 assert!(is_cte_referenced(&expr, "cte_name"));
212 }
213
214 #[test]
215 fn test_is_cte_referenced_false() {
216 let expr = parse("SELECT * FROM other_table");
217 assert!(!is_cte_referenced(&expr, "cte_name"));
218 }
219
220 #[test]
221 fn test_is_cte_referenced_in_join() {
222 let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
223 assert!(is_cte_referenced(&expr, "cte_name"));
224 }
225
226 #[test]
227 fn test_is_cte_referenced_in_subquery() {
228 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
229 assert!(is_cte_referenced(&expr, "cte_name"));
230 }
231
232 #[test]
233 fn test_eliminate_preserves_structure() {
234 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
235 let result = eliminate_ctes(expr);
236 let sql = gen(&result);
237 assert!(sql.contains("WHERE"));
238 }
239
240 #[test]
241 fn test_eliminate_multiple_ctes() {
242 let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
243 let result = eliminate_ctes(expr);
244 let sql = gen(&result);
245 assert!(sql.contains("WITH"));
247 }
248
249 #[test]
250 fn test_is_cte_referenced_in_union() {
251 let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
252 assert!(is_cte_referenced(&expr, "cte_name"));
253 }
254
255 #[test]
256 fn test_compute_ref_count() {
257 let expr = parse("SELECT * FROM t");
258 let root = build_scope(&expr);
259 let counts = compute_ref_count(&root);
260 assert!(!counts.is_empty());
262 }
263}