polyglot-sql 0.3.3

SQL parsing, validating, formatting, and dialect translation library
Documentation
//! CTE Elimination Module
//!
//! This module provides functionality for removing unused CTEs
//! from SQL expressions.
//!
//! Ported from sqlglot's optimizer/eliminate_ctes.py

use std::collections::HashMap;

use crate::expressions::Expression;
use crate::scope::{build_scope, Scope};

/// Remove unused CTEs from an expression.
///
/// # Example
///
/// ```sql
/// -- Before:
/// WITH y AS (SELECT a FROM x) SELECT a FROM z
/// -- After:
/// SELECT a FROM z
/// ```
///
/// # Arguments
/// * `expression` - The expression to optimize
///
/// # Returns
/// The optimized expression with unused CTEs removed
pub fn eliminate_ctes(expression: Expression) -> Expression {
    let root = build_scope(&expression);

    // Compute reference counts for each scope
    let ref_count = compute_ref_count(&root);

    // Collect scopes to process (in reverse order)
    let scopes = collect_scopes(&root);

    // Track which CTEs to remove
    let mut ctes_to_remove: Vec<String> = Vec::new();

    for scope in scopes.iter().rev() {
        if scope.is_cte() {
            let scope_id = *scope as *const Scope as u64;
            let count = ref_count.get(&scope_id).copied().unwrap_or(0);

            if count == 0 {
                // This CTE is unused, mark for removal
                if let Some(name) = get_cte_name(scope) {
                    ctes_to_remove.push(name);
                }
            }
        }
    }

    // Remove the marked CTEs
    if ctes_to_remove.is_empty() {
        return expression;
    }

    remove_ctes(expression, &ctes_to_remove)
}

/// Compute reference counts for each scope
fn compute_ref_count(root: &Scope) -> HashMap<u64, usize> {
    let mut counts: HashMap<u64, usize> = HashMap::new();

    // Initialize all scopes with count 0
    for scope in collect_scopes(root) {
        let id = scope as *const Scope as u64;
        counts.insert(id, 0);
    }

    // Count references
    for scope in collect_scopes(root) {
        for (_name, source_info) in &scope.sources {
            // If this source references a CTE scope, increment its count
            // In a full implementation, we'd track which sources are CTEs
            let _ = source_info;
        }
    }

    counts
}

/// Collect all scopes from the tree
fn collect_scopes(root: &Scope) -> Vec<&Scope> {
    let mut result = vec![root];
    result.extend(root.subquery_scopes.iter().flat_map(|s| collect_scopes(s)));
    result.extend(
        root.derived_table_scopes
            .iter()
            .flat_map(|s| collect_scopes(s)),
    );
    result.extend(root.cte_scopes.iter().flat_map(|s| collect_scopes(s)));
    result.extend(root.union_scopes.iter().flat_map(|s| collect_scopes(s)));
    result
}

/// Get the CTE name from a scope
fn get_cte_name(scope: &Scope) -> Option<String> {
    // In a full implementation, we'd extract the CTE name from the scope's expression
    // For now, return None
    let _ = scope;
    None
}

/// Remove the specified CTEs from an expression
fn remove_ctes(expression: Expression, ctes_to_remove: &[String]) -> Expression {
    if ctes_to_remove.is_empty() {
        return expression;
    }

    // In a full implementation, we would:
    // 1. Find the WITH clause
    // 2. Remove the specified CTEs
    // 3. If WITH clause is empty, remove it entirely
    //
    // For now, return unchanged
    expression
}

/// Check if a CTE is referenced anywhere in the query
pub fn is_cte_referenced(expression: &Expression, cte_name: &str) -> bool {
    match expression {
        Expression::Table(table) => table.name.name == cte_name,
        Expression::Select(select) => {
            // Check FROM
            if let Some(ref from) = select.from {
                for expr in &from.expressions {
                    if is_cte_referenced(expr, cte_name) {
                        return true;
                    }
                }
            }
            // Check JOINs
            for join in &select.joins {
                if is_cte_referenced(&join.this, cte_name) {
                    return true;
                }
            }
            // Check subqueries in SELECT list
            for expr in &select.expressions {
                if is_cte_referenced(expr, cte_name) {
                    return true;
                }
            }
            // Check WHERE
            if let Some(ref where_clause) = select.where_clause {
                if is_cte_referenced(&where_clause.this, cte_name) {
                    return true;
                }
            }
            false
        }
        Expression::Subquery(subquery) => is_cte_referenced(&subquery.this, cte_name),
        Expression::Union(union) => {
            is_cte_referenced(&union.left, cte_name) || is_cte_referenced(&union.right, cte_name)
        }
        Expression::Intersect(intersect) => {
            is_cte_referenced(&intersect.left, cte_name)
                || is_cte_referenced(&intersect.right, cte_name)
        }
        Expression::Except(except) => {
            is_cte_referenced(&except.left, cte_name) || is_cte_referenced(&except.right, cte_name)
        }
        Expression::In(in_expr) => {
            if let Some(ref query) = in_expr.query {
                is_cte_referenced(query, cte_name)
            } else {
                false
            }
        }
        _ => false,
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::generator::Generator;
    use crate::parser::Parser;

    fn gen(expr: &Expression) -> String {
        Generator::new().generate(expr).unwrap()
    }

    fn parse(sql: &str) -> Expression {
        Parser::parse_sql(sql).expect("Failed to parse")[0].clone()
    }

    #[test]
    fn test_eliminate_ctes_unused() {
        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM z");
        let result = eliminate_ctes(expr);
        let sql = gen(&result);
        // In a full implementation, the CTE would be removed
        assert!(sql.contains("SELECT"));
    }

    #[test]
    fn test_eliminate_ctes_used() {
        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y");
        let result = eliminate_ctes(expr);
        let sql = gen(&result);
        // CTE is used, should be preserved
        assert!(sql.contains("WITH"));
    }

    #[test]
    fn test_is_cte_referenced_true() {
        let expr = parse("SELECT * FROM cte_name");
        assert!(is_cte_referenced(&expr, "cte_name"));
    }

    #[test]
    fn test_is_cte_referenced_false() {
        let expr = parse("SELECT * FROM other_table");
        assert!(!is_cte_referenced(&expr, "cte_name"));
    }

    #[test]
    fn test_is_cte_referenced_in_join() {
        let expr = parse("SELECT * FROM x JOIN cte_name ON x.a = cte_name.a");
        assert!(is_cte_referenced(&expr, "cte_name"));
    }

    #[test]
    fn test_is_cte_referenced_in_subquery() {
        let expr = parse("SELECT * FROM x WHERE x.a IN (SELECT a FROM cte_name)");
        assert!(is_cte_referenced(&expr, "cte_name"));
    }

    #[test]
    fn test_eliminate_preserves_structure() {
        let expr = parse("WITH y AS (SELECT a FROM x) SELECT a FROM y WHERE a > 1");
        let result = eliminate_ctes(expr);
        let sql = gen(&result);
        assert!(sql.contains("WHERE"));
    }

    #[test]
    fn test_eliminate_multiple_ctes() {
        let expr = parse("WITH a AS (SELECT 1), b AS (SELECT 2) SELECT * FROM a");
        let result = eliminate_ctes(expr);
        let sql = gen(&result);
        // In a full implementation, unused CTE 'b' would be removed
        assert!(sql.contains("WITH"));
    }

    #[test]
    fn test_is_cte_referenced_in_union() {
        let expr = parse("SELECT * FROM x UNION SELECT * FROM cte_name");
        assert!(is_cte_referenced(&expr, "cte_name"));
    }

    #[test]
    fn test_compute_ref_count() {
        let expr = parse("SELECT * FROM t");
        let root = build_scope(&expr);
        let counts = compute_ref_count(&root);
        // Should have at least one scope
        assert!(!counts.is_empty());
    }
}