use std::collections::HashMap;
use crate::expressions::Expression;
use crate::scope::{build_scope, Scope};
pub fn eliminate_ctes(expression: Expression) -> Expression {
let root = build_scope(&expression);
let ref_count = compute_ref_count(&root);
let scopes = collect_scopes(&root);
let mut ctes_to_remove: Vec<String> = Vec::new();
for scope in scopes.iter().rev() {
if scope.is_cte() {
let scope_id = *scope as *const Scope as u64;
let count = ref_count.get(&scope_id).copied().unwrap_or(0);
if count == 0 {
if let Some(name) = get_cte_name(scope) {
ctes_to_remove.push(name);
}
}
}
}
if ctes_to_remove.is_empty() {
return expression;
}
remove_ctes(expression, &ctes_to_remove)
}
fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
let mut counts: HashMap<u64, usize> = HashMap::new();
for scope in collect_scopes(root) {
let id = scope as *const Scope as u64;
counts.insert(id, 0);
}
for scope in collect_scopes(root) {
for (_name, source_info) in &scope.sources {
let _ = source_info;
}
}
counts
}
fn collect_scopes(root: &Scope) -> Vec<&Scope> {
let mut result = vec![root];
result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
result.extend(
root.derived_table_scopes
.iter()
.flat_map(|s| collect_scopes(s)),
);
result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
result
}
fn get_cte_name(scope: &Scope) -> Option<String> {
let _ = scope;
None
}
fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
if ctes_to_remove.is_empty() {
return expression;
}
expression
}
pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
match expression {
Expression::Table(table) => table.name.name == cte_name,
Expression::Select(select) => {
if let Some(ref from) = select.from {
for expr in &from.expressions {
if is_cte_referenced(expr, cte_name) {
return true;
}
}
}
for join in &select.joins {
if is_cte_referenced(&join.this, cte_name) {
return true;
}
}
for expr in &select.expressions {
if is_cte_referenced(expr, cte_name) {
return true;
}
}
if let Some(ref where_clause) = select.where_clause {
if is_cte_referenced(&where_clause.this, cte_name) {
return true;
}
}
false
}
Expression::Subquery(subquery) => is_cte_referenced(&subquery.this, cte_name),
Expression::Union(union) => {
is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
}
Expression::Intersect(intersect) => {
is_cte_referenced(&intersect.left, cte_name)
|| is_cte_referenced(&intersect.right, cte_name)
}
Expression::Except(except) => {
is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
}
Expression::In(in_expr) => {
if let Some(ref query) = in_expr.query {
is_cte_referenced(query, cte_name)
} else {
false
}
}
_ => false,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::generator::Generator;
use crate::parser::Parser;
fn gen(expr: &Expression) -> String {
Generator::new().generate(expr).unwrap()
}
fn parse(sql: &str) -> Expression {
Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
}
#[test]
fn test_eliminate_ctes_unused() {
let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
let result = eliminate_ctes(expr);
let sql = gen(&result);
assert!(sql.contains("SELECT"));
}
#[test]
fn test_eliminate_ctes_used() {
let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
let result = eliminate_ctes(expr);
let sql = gen(&result);
assert!(sql.contains("WITH"));
}
#[test]
fn test_is_cte_referenced_true() {
let expr = parse("SELECT * FROM cte_name");
assert!(is_cte_referenced(&expr, "cte_name"));
}
#[test]
fn test_is_cte_referenced_false() {
let expr = parse("SELECT * FROM other_table");
assert!(!is_cte_referenced(&expr, "cte_name"));
}
#[test]
fn test_is_cte_referenced_in_join() {
let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
assert!(is_cte_referenced(&expr, "cte_name"));
}
#[test]
fn test_is_cte_referenced_in_subquery() {
let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
assert!(is_cte_referenced(&expr, "cte_name"));
}
#[test]
fn test_eliminate_preserves_structure() {
let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
let result = eliminate_ctes(expr);
let sql = gen(&result);
assert!(sql.contains("WHERE"));
}
#[test]
fn test_eliminate_multiple_ctes() {
let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
let result = eliminate_ctes(expr);
let sql = gen(&result);
assert!(sql.contains("WITH"));
}
#[test]
fn test_is_cte_referenced_in_union() {
let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
assert!(is_cte_referenced(&expr, "cte_name"));
}
#[test]
fn test_compute_ref_count() {
let expr = parse("SELECT * FROM t");
let root = build_scope(&expr);
let counts = compute_ref_count(&root);
assert!(!counts.is_empty());
}
}