polyglot_sql/optimizer/
eliminate_ctes.rs1use std::collections::HashMap;
9
10use crate::expressions::Expression;
11use crate::scope::{build_scope, Scope};
12
13pub fn eliminate_ctes(expression: Expression) -> Expression {
30 let root = build_scope(&expression);
31
32 let ref_count = compute_ref_count(&root);
34
35 let scopes = collect_scopes(&root);
37
38 let mut ctes_to_remove: Vec<String> = Vec::new();
40
41 for scope in scopes.iter().rev() {
42 if scope.is_cte() {
43 let scope_id = *scope as *const Scope as u64;
44 let count = ref_count.get(&scope_id).copied().unwrap_or(0);
45
46 if count == 0 {
47 if let Some(name) = get_cte_name(scope) {
49 ctes_to_remove.push(name);
50 }
51 }
52 }
53 }
54
55 if ctes_to_remove.is_empty() {
57 return expression;
58 }
59
60 remove_ctes(expression, &ctes_to_remove)
61}
62
63fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
65 let mut counts: HashMap<u64, usize> = HashMap::new();
66
67 for scope in collect_scopes(root) {
69 let id = scope as *const Scope as u64;
70 counts.insert(id, 0);
71 }
72
73 for scope in collect_scopes(root) {
75 for (_name, source_info) in &scope.sources {
76 let _ = source_info;
79 }
80 }
81
82 counts
83}
84
85fn collect_scopes(root: &Scope) -> Vec<&Scope> {
87 let mut result = vec![root];
88 result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
89 result.extend(
90 root.derived_table_scopes
91 .iter()
92 .flat_map(|s| collect_scopes(s)),
93 );
94 result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
95 result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
96 result
97}
98
99fn get_cte_name(scope: &Scope) -> Option<String> {
101 let _ = scope;
104 None
105}
106
107fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
109 if ctes_to_remove.is_empty() {
110 return expression;
111 }
112
113 expression
120}
121
122pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
124 match expression {
125 Expression::Table(table) => table.name.name == cte_name,
126 Expression::Select(select) => {
127 if let Some(ref from) = select.from {
129 for expr in &from.expressions {
130 if is_cte_referenced(expr, cte_name) {
131 return true;
132 }
133 }
134 }
135 for join in &select.joins {
137 if is_cte_referenced(&join.this, cte_name) {
138 return true;
139 }
140 }
141 for expr in &select.expressions {
143 if is_cte_referenced(expr, cte_name) {
144 return true;
145 }
146 }
147 if let Some(ref where_clause) = select.where_clause {
149 if is_cte_referenced(&where_clause.this, cte_name) {
150 return true;
151 }
152 }
153 false
154 }
155 Expression::Subquery(subquery) => is_cte_referenced(&subquery.this, cte_name),
156 Expression::Union(union) => {
157 is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
158 }
159 Expression::Intersect(intersect) => {
160 is_cte_referenced(&intersect.left, cte_name)
161 || is_cte_referenced(&intersect.right, cte_name)
162 }
163 Expression::Except(except) => {
164 is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
165 }
166 Expression::In(in_expr) => {
167 if let Some(ref query) = in_expr.query {
168 is_cte_referenced(query, cte_name)
169 } else {
170 false
171 }
172 }
173 _ => false,
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 use crate::generator::Generator;
181 use crate::parser::Parser;
182
183 fn gen(expr: &Expression) -> String {
184 Generator::new().generate(expr).unwrap()
185 }
186
187 fn parse(sql: &str) -> Expression {
188 Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
189 }
190
191 #[test]
192 fn test_eliminate_ctes_unused() {
193 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
194 let result = eliminate_ctes(expr);
195 let sql = gen(&result);
196 assert!(sql.contains("SELECT"));
198 }
199
200 #[test]
201 fn test_eliminate_ctes_used() {
202 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
203 let result = eliminate_ctes(expr);
204 let sql = gen(&result);
205 assert!(sql.contains("WITH"));
207 }
208
209 #[test]
210 fn test_is_cte_referenced_true() {
211 let expr = parse("SELECT * FROM cte_name");
212 assert!(is_cte_referenced(&expr, "cte_name"));
213 }
214
215 #[test]
216 fn test_is_cte_referenced_false() {
217 let expr = parse("SELECT * FROM other_table");
218 assert!(!is_cte_referenced(&expr, "cte_name"));
219 }
220
221 #[test]
222 fn test_is_cte_referenced_in_join() {
223 let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
224 assert!(is_cte_referenced(&expr, "cte_name"));
225 }
226
227 #[test]
228 fn test_is_cte_referenced_in_subquery() {
229 let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
230 assert!(is_cte_referenced(&expr, "cte_name"));
231 }
232
233 #[test]
234 fn test_eliminate_preserves_structure() {
235 let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
236 let result = eliminate_ctes(expr);
237 let sql = gen(&result);
238 assert!(sql.contains("WHERE"));
239 }
240
241 #[test]
242 fn test_eliminate_multiple_ctes() {
243 let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
244 let result = eliminate_ctes(expr);
245 let sql = gen(&result);
246 assert!(sql.contains("WITH"));
248 }
249
250 #[test]
251 fn test_is_cte_referenced_in_union() {
252 let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
253 assert!(is_cte_referenced(&expr, "cte_name"));
254 }
255
256 #[test]
257 fn test_compute_ref_count() {
258 let expr = parse("SELECT * FROM t");
259 let root = build_scope(&expr);
260 let counts = compute_ref_count(&root);
261 assert!(!counts.is_empty());
263 }
264}