Skip to main content

polyglot_sql/
scope.rs

1//! Scope Analysis Module
2//!
3//! This module provides scope analysis for SQL queries, enabling detection of
4//! correlated subqueries, column references, and scope relationships.
5//!
6//! Ported from sqlglot's optimizer/scope.py
7
8use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14/// Type of scope in a SQL query
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19    /// Root scope of the query
20    Root,
21    /// Subquery scope (e.g., WHERE x IN (SELECT ...))
22    Subquery,
23    /// Derived table scope (e.g., FROM (SELECT ...) AS t)
24    DerivedTable,
25    /// Common Table Expression scope
26    Cte,
27    /// Union/Intersect/Except scope
28    SetOperation,
29    /// User-Defined Table Function scope
30    Udtf,
31}
32
33/// Information about a source (table or subquery) in a scope
34#[derive(Debug, Clone)]
35pub struct SourceInfo {
36    /// The source expression (Table or subquery)
37    pub expression: Expression,
38    /// Whether this source is a scope (vs. a plain table)
39    pub is_scope: bool,
40}
41
42/// A column reference found in a scope
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45    /// The table/alias qualifier (if any)
46    pub table: Option<String>,
47    /// The column name
48    pub name: String,
49}
50
51/// Represents a scope in a SQL query
52///
53/// A scope is the context of a SELECT statement and its sources.
54/// Scopes can be nested (subqueries, CTEs, derived tables) and form a tree.
55#[derive(Debug, Clone)]
56pub struct Scope {
57    /// The expression at the root of this scope
58    pub expression: Expression,
59
60    /// Type of this scope relative to its parent
61    pub scope_type: ScopeType,
62
63    /// Mapping of source names to their info
64    pub sources: HashMap<String, SourceInfo>,
65
66    /// Sources from LATERAL views (have access to preceding sources)
67    pub lateral_sources: HashMap<String, SourceInfo>,
68
69    /// CTE sources available to this scope
70    pub cte_sources: HashMap<String, SourceInfo>,
71
72    /// If this is a derived table or CTE with alias columns, this is that list
73    /// e.g., `SELECT * FROM (SELECT ...) AS y(col1, col2)` => ["col1", "col2"]
74    pub outer_columns: Vec<String>,
75
76    /// Whether this scope can potentially be correlated
77    /// (true for subqueries and UDTFs)
78    pub can_be_correlated: bool,
79
80    /// Child subquery scopes
81    pub subquery_scopes: Vec<Scope>,
82
83    /// Child derived table scopes
84    pub derived_table_scopes: Vec<Scope>,
85
86    /// Child CTE scopes
87    pub cte_scopes: Vec<Scope>,
88
89    /// Child UDTF (User Defined Table Function) scopes
90    pub udtf_scopes: Vec<Scope>,
91
92    /// Combined derived_table_scopes + udtf_scopes in definition order
93    pub table_scopes: Vec<Scope>,
94
95    /// Union/set operation scopes (left and right)
96    pub union_scopes: Vec<Scope>,
97
98    /// Cached columns
99    columns_cache: Option<Vec<ColumnRef>>,
100
101    /// Cached external columns
102    external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106    /// Create a new root scope
107    pub fn new(expression: Expression) -> Self {
108        Self {
109            expression,
110            scope_type: ScopeType::Root,
111            sources: HashMap::new(),
112            lateral_sources: HashMap::new(),
113            cte_sources: HashMap::new(),
114            outer_columns: Vec::new(),
115            can_be_correlated: false,
116            subquery_scopes: Vec::new(),
117            derived_table_scopes: Vec::new(),
118            cte_scopes: Vec::new(),
119            udtf_scopes: Vec::new(),
120            table_scopes: Vec::new(),
121            union_scopes: Vec::new(),
122            columns_cache: None,
123            external_columns_cache: None,
124        }
125    }
126
127    /// Create a child scope branching from this one
128    pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129        self.branch_with_options(expression, scope_type, None, None, None)
130    }
131
132    /// Create a child scope with additional options
133    pub fn branch_with_options(
134        &self,
135        expression: Expression,
136        scope_type: ScopeType,
137        sources: Option<HashMap<String, SourceInfo>>,
138        lateral_sources: Option<HashMap<String, SourceInfo>>,
139        outer_columns: Option<Vec<String>>,
140    ) -> Self {
141        let can_be_correlated = self.can_be_correlated
142            || scope_type == ScopeType::Subquery
143            || scope_type == ScopeType::Udtf;
144
145        Self {
146            expression,
147            scope_type,
148            sources: sources.unwrap_or_default(),
149            lateral_sources: lateral_sources.unwrap_or_default(),
150            cte_sources: self.cte_sources.clone(),
151            outer_columns: outer_columns.unwrap_or_default(),
152            can_be_correlated,
153            subquery_scopes: Vec::new(),
154            derived_table_scopes: Vec::new(),
155            cte_scopes: Vec::new(),
156            udtf_scopes: Vec::new(),
157            table_scopes: Vec::new(),
158            union_scopes: Vec::new(),
159            columns_cache: None,
160            external_columns_cache: None,
161        }
162    }
163
164    /// Clear all cached properties
165    pub fn clear_cache(&mut self) {
166        self.columns_cache = None;
167        self.external_columns_cache = None;
168    }
169
170    /// Add a source to this scope
171    pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172        self.sources.insert(
173            name,
174            SourceInfo {
175                expression,
176                is_scope,
177            },
178        );
179        self.clear_cache();
180    }
181
182    /// Add a lateral source to this scope
183    pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184        self.lateral_sources.insert(
185            name.clone(),
186            SourceInfo {
187                expression: expression.clone(),
188                is_scope,
189            },
190        );
191        self.sources.insert(
192            name,
193            SourceInfo {
194                expression,
195                is_scope,
196            },
197        );
198        self.clear_cache();
199    }
200
201    /// Add a CTE source to this scope
202    pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203        self.cte_sources.insert(
204            name.clone(),
205            SourceInfo {
206                expression: expression.clone(),
207                is_scope: true,
208            },
209        );
210        self.sources.insert(
211            name,
212            SourceInfo {
213                expression,
214                is_scope: true,
215            },
216        );
217        self.clear_cache();
218    }
219
220    /// Rename a source
221    pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222        if let Some(source) = self.sources.remove(old_name) {
223            self.sources.insert(new_name, source);
224        }
225        self.clear_cache();
226    }
227
228    /// Remove a source
229    pub fn remove_source(&mut self, name: &str) {
230        self.sources.remove(name);
231        self.clear_cache();
232    }
233
234    /// Collect all column references in this scope
235    pub fn columns(&mut self) -> &[ColumnRef] {
236        if self.columns_cache.is_none() {
237            let mut columns = Vec::new();
238            collect_columns(&self.expression, &mut columns);
239            self.columns_cache = Some(columns);
240        }
241        self.columns_cache.as_ref().unwrap()
242    }
243
244    /// Collect projected output column names for this scope's query expression.
245    ///
246    /// This is intended for result schema style output columns (e.g. UNION
247    /// outputs), unlike [`Self::columns`], which returns raw referenced columns.
248    pub fn output_columns(&self) -> Vec<String> {
249        crate::ast_transforms::get_output_column_names(&self.expression)
250    }
251
252    /// Get all source names in this scope
253    pub fn source_names(&self) -> HashSet<String> {
254        let mut names: HashSet<String> = self.sources.keys().cloned().collect();
255        names.extend(self.cte_sources.keys().cloned());
256        names
257    }
258
259    /// Get columns that reference sources outside this scope
260    pub fn external_columns(&mut self) -> Vec<ColumnRef> {
261        if self.external_columns_cache.is_some() {
262            return self.external_columns_cache.clone().unwrap();
263        }
264
265        let source_names = self.source_names();
266        let columns = self.columns().to_vec();
267
268        let external: Vec<ColumnRef> = columns
269            .into_iter()
270            .filter(|col| {
271                // A column is external if it has a table qualifier that's not in our sources
272                match &col.table {
273                    Some(table) => !source_names.contains(table),
274                    None => false, // Unqualified columns might be local
275                }
276            })
277            .collect();
278
279        self.external_columns_cache = Some(external.clone());
280        external
281    }
282
283    /// Get columns that reference sources in this scope (not external)
284    pub fn local_columns(&mut self) -> Vec<ColumnRef> {
285        let external_set: HashSet<_> = self.external_columns().into_iter().collect();
286        let columns = self.columns().to_vec();
287
288        columns
289            .into_iter()
290            .filter(|col| !external_set.contains(col))
291            .collect()
292    }
293
294    /// Get unqualified columns (columns without table qualifier)
295    pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
296        self.columns()
297            .iter()
298            .filter(|c| c.table.is_none())
299            .cloned()
300            .collect()
301    }
302
303    /// Get columns for a specific source
304    pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
305        self.columns()
306            .iter()
307            .filter(|col| col.table.as_deref() == Some(source_name))
308            .cloned()
309            .collect()
310    }
311
312    /// Determine if this scope is a correlated subquery
313    ///
314    /// A subquery is correlated if:
315    /// 1. It can be correlated (is a subquery or UDTF), AND
316    /// 2. It references columns from outer scopes
317    pub fn is_correlated_subquery(&mut self) -> bool {
318        self.can_be_correlated && !self.external_columns().is_empty()
319    }
320
321    /// Check if this is a subquery scope
322    pub fn is_subquery(&self) -> bool {
323        self.scope_type == ScopeType::Subquery
324    }
325
326    /// Check if this is a derived table scope
327    pub fn is_derived_table(&self) -> bool {
328        self.scope_type == ScopeType::DerivedTable
329    }
330
331    /// Check if this is a CTE scope
332    pub fn is_cte(&self) -> bool {
333        self.scope_type == ScopeType::Cte
334    }
335
336    /// Check if this is the root scope
337    pub fn is_root(&self) -> bool {
338        self.scope_type == ScopeType::Root
339    }
340
341    /// Check if this is a UDTF scope
342    pub fn is_udtf(&self) -> bool {
343        self.scope_type == ScopeType::Udtf
344    }
345
346    /// Check if this is a union/set operation scope
347    pub fn is_union(&self) -> bool {
348        self.scope_type == ScopeType::SetOperation
349    }
350
351    /// Traverse all scopes in this tree (depth-first post-order)
352    pub fn traverse(&self) -> Vec<&Scope> {
353        let mut result = Vec::new();
354        self.traverse_impl(&mut result);
355        result
356    }
357
358    fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
359        // First traverse children
360        for scope in &self.cte_scopes {
361            scope.traverse_impl(result);
362        }
363        for scope in &self.union_scopes {
364            scope.traverse_impl(result);
365        }
366        for scope in &self.table_scopes {
367            scope.traverse_impl(result);
368        }
369        for scope in &self.subquery_scopes {
370            scope.traverse_impl(result);
371        }
372        // Then add self
373        result.push(self);
374    }
375
376    /// Count references to each scope in this tree
377    pub fn ref_count(&self) -> HashMap<usize, usize> {
378        let mut counts: HashMap<usize, usize> = HashMap::new();
379
380        for scope in self.traverse() {
381            for (_, source_info) in scope.sources.iter() {
382                if source_info.is_scope {
383                    let id = &source_info.expression as *const _ as usize;
384                    *counts.entry(id).or_insert(0) += 1;
385                }
386            }
387        }
388
389        counts
390    }
391}
392
393/// Collect all column references from an expression tree
394fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
395    match expr {
396        Expression::Column(col) => {
397            columns.push(ColumnRef {
398                table: col.table.as_ref().map(|t| t.name.clone()),
399                name: col.name.name.clone(),
400            });
401        }
402        Expression::Select(select) => {
403            // Collect from SELECT expressions
404            for e in &select.expressions {
405                collect_columns(e, columns);
406            }
407            // Collect from JOIN ON / MATCH_CONDITION clauses
408            for join in &select.joins {
409                if let Some(on) = &join.on {
410                    collect_columns(on, columns);
411                }
412                if let Some(match_condition) = &join.match_condition {
413                    collect_columns(match_condition, columns);
414                }
415            }
416            // Collect from WHERE
417            if let Some(where_clause) = &select.where_clause {
418                collect_columns(&where_clause.this, columns);
419            }
420            // Collect from HAVING
421            if let Some(having) = &select.having {
422                collect_columns(&having.this, columns);
423            }
424            // Collect from ORDER BY
425            if let Some(order_by) = &select.order_by {
426                for ord in &order_by.expressions {
427                    collect_columns(&ord.this, columns);
428                }
429            }
430            // Collect from GROUP BY
431            if let Some(group_by) = &select.group_by {
432                for e in &group_by.expressions {
433                    collect_columns(e, columns);
434                }
435            }
436            // Note: We don't recurse into FROM/JOIN source subqueries here
437            // as those create their own scopes.
438        }
439        // Binary operations
440        Expression::And(bin)
441        | Expression::Or(bin)
442        | Expression::Add(bin)
443        | Expression::Sub(bin)
444        | Expression::Mul(bin)
445        | Expression::Div(bin)
446        | Expression::Mod(bin)
447        | Expression::Eq(bin)
448        | Expression::Neq(bin)
449        | Expression::Lt(bin)
450        | Expression::Lte(bin)
451        | Expression::Gt(bin)
452        | Expression::Gte(bin)
453        | Expression::BitwiseAnd(bin)
454        | Expression::BitwiseOr(bin)
455        | Expression::BitwiseXor(bin)
456        | Expression::Concat(bin) => {
457            collect_columns(&bin.left, columns);
458            collect_columns(&bin.right, columns);
459        }
460        // LIKE/ILIKE operations
461        Expression::Like(like) | Expression::ILike(like) => {
462            collect_columns(&like.left, columns);
463            collect_columns(&like.right, columns);
464            if let Some(escape) = &like.escape {
465                collect_columns(escape, columns);
466            }
467        }
468        // Unary operations
469        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
470            collect_columns(&un.this, columns);
471        }
472        Expression::Function(func) => {
473            for arg in &func.args {
474                collect_columns(arg, columns);
475            }
476        }
477        Expression::AggregateFunction(agg) => {
478            for arg in &agg.args {
479                collect_columns(arg, columns);
480            }
481        }
482        Expression::WindowFunction(wf) => {
483            collect_columns(&wf.this, columns);
484            for e in &wf.over.partition_by {
485                collect_columns(e, columns);
486            }
487            for e in &wf.over.order_by {
488                collect_columns(&e.this, columns);
489            }
490        }
491        Expression::Alias(alias) => {
492            collect_columns(&alias.this, columns);
493        }
494        Expression::Case(case) => {
495            if let Some(operand) = &case.operand {
496                collect_columns(operand, columns);
497            }
498            for (when_expr, then_expr) in &case.whens {
499                collect_columns(when_expr, columns);
500                collect_columns(then_expr, columns);
501            }
502            if let Some(else_clause) = &case.else_ {
503                collect_columns(else_clause, columns);
504            }
505        }
506        Expression::Paren(paren) => {
507            collect_columns(&paren.this, columns);
508        }
509        Expression::Ordered(ord) => {
510            collect_columns(&ord.this, columns);
511        }
512        Expression::In(in_expr) => {
513            collect_columns(&in_expr.this, columns);
514            for e in &in_expr.expressions {
515                collect_columns(e, columns);
516            }
517            // Note: in_expr.query is a subquery - creates its own scope
518        }
519        Expression::Between(between) => {
520            collect_columns(&between.this, columns);
521            collect_columns(&between.low, columns);
522            collect_columns(&between.high, columns);
523        }
524        Expression::IsNull(is_null) => {
525            collect_columns(&is_null.this, columns);
526        }
527        Expression::Cast(cast) => {
528            collect_columns(&cast.this, columns);
529        }
530        Expression::Extract(extract) => {
531            collect_columns(&extract.this, columns);
532        }
533        Expression::Exists(_) | Expression::Subquery(_) => {
534            // These create their own scopes - don't collect from here
535        }
536        _ => {
537            // For other expressions, we might need to add more cases
538        }
539    }
540}
541
542/// Build scope tree from an expression
543///
544/// This traverses the expression tree and builds a hierarchy of Scope objects
545/// that track sources and column references at each level.
546pub fn build_scope(expression: &Expression) -> Scope {
547    let mut root = Scope::new(expression.clone());
548    build_scope_impl(expression, &mut root);
549    root
550}
551
552fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
553    match expression {
554        Expression::Select(select) => {
555            // Process CTEs first
556            if let Some(with) = &select.with {
557                for cte in &with.ctes {
558                    let cte_name = cte.alias.name.clone();
559                    let mut cte_scope = current_scope
560                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
561                    build_scope_impl(&cte.this, &mut cte_scope);
562                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
563                    current_scope.cte_scopes.push(cte_scope);
564                }
565            }
566
567            // Process FROM clause
568            if let Some(from) = &select.from {
569                for table in &from.expressions {
570                    add_table_to_scope(table, current_scope);
571                }
572            }
573
574            // Process JOINs
575            for join in &select.joins {
576                add_table_to_scope(&join.this, current_scope);
577            }
578
579            // Process subqueries in WHERE, SELECT expressions, etc.
580            collect_subqueries(expression, current_scope);
581        }
582        Expression::Union(union) => {
583            let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
584            build_scope_impl(&union.left, &mut left_scope);
585
586            let mut right_scope =
587                current_scope.branch(union.right.clone(), ScopeType::SetOperation);
588            build_scope_impl(&union.right, &mut right_scope);
589
590            current_scope.union_scopes.push(left_scope);
591            current_scope.union_scopes.push(right_scope);
592        }
593        Expression::Intersect(intersect) => {
594            let mut left_scope =
595                current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
596            build_scope_impl(&intersect.left, &mut left_scope);
597
598            let mut right_scope =
599                current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
600            build_scope_impl(&intersect.right, &mut right_scope);
601
602            current_scope.union_scopes.push(left_scope);
603            current_scope.union_scopes.push(right_scope);
604        }
605        Expression::Except(except) => {
606            let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
607            build_scope_impl(&except.left, &mut left_scope);
608
609            let mut right_scope =
610                current_scope.branch(except.right.clone(), ScopeType::SetOperation);
611            build_scope_impl(&except.right, &mut right_scope);
612
613            current_scope.union_scopes.push(left_scope);
614            current_scope.union_scopes.push(right_scope);
615        }
616        Expression::CreateTable(create) => {
617            // Handle CREATE TABLE ... AS [WITH ...] SELECT ...
618            // Process CTEs if present
619            if let Some(with) = &create.with_cte {
620                for cte in &with.ctes {
621                    let cte_name = cte.alias.name.clone();
622                    let mut cte_scope = current_scope
623                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
624                    build_scope_impl(&cte.this, &mut cte_scope);
625                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
626                    current_scope.cte_scopes.push(cte_scope);
627                }
628            }
629            // Traverse the AS SELECT body
630            if let Some(as_select) = &create.as_select {
631                build_scope_impl(as_select, current_scope);
632            }
633        }
634        _ => {}
635    }
636}
637
638fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
639    match expr {
640        Expression::Table(table) => {
641            let name = table
642                .alias
643                .as_ref()
644                .map(|a| a.name.clone())
645                .unwrap_or_else(|| table.name.name.clone());
646            let cte_source = if table.schema.is_none() && table.catalog.is_none() {
647                scope.cte_sources.get(&table.name.name).or_else(|| {
648                    scope
649                        .cte_sources
650                        .iter()
651                        .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
652                        .map(|(_, source)| source)
653                })
654            } else {
655                None
656            };
657
658            if let Some(source) = cte_source {
659                scope.add_source(name, source.expression.clone(), true);
660            } else {
661                scope.add_source(name, expr.clone(), false);
662            }
663        }
664        Expression::Subquery(subquery) => {
665            let name = subquery
666                .alias
667                .as_ref()
668                .map(|a| a.name.clone())
669                .unwrap_or_default();
670
671            let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
672            build_scope_impl(&subquery.this, &mut derived_scope);
673
674            scope.add_source(name.clone(), expr.clone(), true);
675            scope.derived_table_scopes.push(derived_scope);
676        }
677        Expression::Paren(paren) => {
678            add_table_to_scope(&paren.this, scope);
679        }
680        _ => {}
681    }
682}
683
684fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
685    match expr {
686        Expression::Select(select) => {
687            // Check WHERE for subqueries
688            if let Some(where_clause) = &select.where_clause {
689                collect_subqueries_in_expr(&where_clause.this, parent_scope);
690            }
691            // Check SELECT expressions for subqueries
692            for e in &select.expressions {
693                collect_subqueries_in_expr(e, parent_scope);
694            }
695            // Check HAVING for subqueries
696            if let Some(having) = &select.having {
697                collect_subqueries_in_expr(&having.this, parent_scope);
698            }
699        }
700        _ => {}
701    }
702}
703
704fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
705    match expr {
706        Expression::Subquery(subquery) if subquery.alias.is_none() => {
707            // This is a scalar subquery or IN subquery (not a derived table)
708            let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
709            build_scope_impl(&subquery.this, &mut sub_scope);
710            parent_scope.subquery_scopes.push(sub_scope);
711        }
712        Expression::In(in_expr) => {
713            collect_subqueries_in_expr(&in_expr.this, parent_scope);
714            if let Some(query) = &in_expr.query {
715                let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
716                build_scope_impl(query, &mut sub_scope);
717                parent_scope.subquery_scopes.push(sub_scope);
718            }
719        }
720        Expression::Exists(exists) => {
721            let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
722            build_scope_impl(&exists.this, &mut sub_scope);
723            parent_scope.subquery_scopes.push(sub_scope);
724        }
725        // Binary operations
726        Expression::And(bin)
727        | Expression::Or(bin)
728        | Expression::Add(bin)
729        | Expression::Sub(bin)
730        | Expression::Mul(bin)
731        | Expression::Div(bin)
732        | Expression::Mod(bin)
733        | Expression::Eq(bin)
734        | Expression::Neq(bin)
735        | Expression::Lt(bin)
736        | Expression::Lte(bin)
737        | Expression::Gt(bin)
738        | Expression::Gte(bin)
739        | Expression::BitwiseAnd(bin)
740        | Expression::BitwiseOr(bin)
741        | Expression::BitwiseXor(bin)
742        | Expression::Concat(bin) => {
743            collect_subqueries_in_expr(&bin.left, parent_scope);
744            collect_subqueries_in_expr(&bin.right, parent_scope);
745        }
746        // LIKE/ILIKE operations (have different structure with escape)
747        Expression::Like(like) | Expression::ILike(like) => {
748            collect_subqueries_in_expr(&like.left, parent_scope);
749            collect_subqueries_in_expr(&like.right, parent_scope);
750            if let Some(escape) = &like.escape {
751                collect_subqueries_in_expr(escape, parent_scope);
752            }
753        }
754        // Unary operations
755        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
756            collect_subqueries_in_expr(&un.this, parent_scope);
757        }
758        Expression::Function(func) => {
759            for arg in &func.args {
760                collect_subqueries_in_expr(arg, parent_scope);
761            }
762        }
763        Expression::Case(case) => {
764            if let Some(operand) = &case.operand {
765                collect_subqueries_in_expr(operand, parent_scope);
766            }
767            for (when_expr, then_expr) in &case.whens {
768                collect_subqueries_in_expr(when_expr, parent_scope);
769                collect_subqueries_in_expr(then_expr, parent_scope);
770            }
771            if let Some(else_clause) = &case.else_ {
772                collect_subqueries_in_expr(else_clause, parent_scope);
773            }
774        }
775        Expression::Paren(paren) => {
776            collect_subqueries_in_expr(&paren.this, parent_scope);
777        }
778        Expression::Alias(alias) => {
779            collect_subqueries_in_expr(&alias.this, parent_scope);
780        }
781        _ => {}
782    }
783}
784
785/// Walk within a scope, yielding expressions without crossing scope boundaries.
786///
787/// This iterator visits all nodes in the syntax tree, stopping at nodes that
788/// start child scopes (CTEs, derived tables, subqueries in FROM/JOIN).
789///
790/// # Arguments
791/// * `expression` - The expression to walk
792/// * `bfs` - If true, uses breadth-first search; otherwise uses depth-first search
793///
794/// # Returns
795/// An iterator over expressions within the scope
796pub fn walk_in_scope<'a>(
797    expression: &'a Expression,
798    bfs: bool,
799) -> impl Iterator<Item = &'a Expression> {
800    WalkInScopeIter::new(expression, bfs)
801}
802
803/// Iterator for walking within a scope
804struct WalkInScopeIter<'a> {
805    queue: VecDeque<&'a Expression>,
806    bfs: bool,
807}
808
809impl<'a> WalkInScopeIter<'a> {
810    fn new(expression: &'a Expression, bfs: bool) -> Self {
811        let mut queue = VecDeque::new();
812        queue.push_back(expression);
813        Self { queue, bfs }
814    }
815
816    fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
817        if is_root {
818            return false;
819        }
820
821        // Stop at CTE definitions
822        if matches!(expr, Expression::Cte(_)) {
823            return true;
824        }
825
826        // Stop at subqueries that are derived tables (in FROM/JOIN)
827        if let Expression::Subquery(subquery) = expr {
828            if subquery.alias.is_some() {
829                return true;
830            }
831        }
832
833        // Stop at standalone SELECT/UNION/etc that would be subqueries
834        if matches!(
835            expr,
836            Expression::Select(_)
837                | Expression::Union(_)
838                | Expression::Intersect(_)
839                | Expression::Except(_)
840        ) {
841            return true;
842        }
843
844        false
845    }
846
847    fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
848        let mut children = Vec::new();
849
850        match expr {
851            Expression::Select(select) => {
852                // Walk SELECT expressions
853                for e in &select.expressions {
854                    children.push(e);
855                }
856                // Walk FROM (but tables/subqueries create new scopes)
857                if let Some(from) = &select.from {
858                    for table in &from.expressions {
859                        if !self.should_stop_at(table, false) {
860                            children.push(table);
861                        }
862                    }
863                }
864                // Walk JOINs (but their sources create new scopes)
865                for join in &select.joins {
866                    if let Some(on) = &join.on {
867                        children.push(on);
868                    }
869                    // Don't traverse join.this as it's a source (table or subquery)
870                }
871                // Walk WHERE
872                if let Some(where_clause) = &select.where_clause {
873                    children.push(&where_clause.this);
874                }
875                // Walk GROUP BY
876                if let Some(group_by) = &select.group_by {
877                    for e in &group_by.expressions {
878                        children.push(e);
879                    }
880                }
881                // Walk HAVING
882                if let Some(having) = &select.having {
883                    children.push(&having.this);
884                }
885                // Walk ORDER BY
886                if let Some(order_by) = &select.order_by {
887                    for ord in &order_by.expressions {
888                        children.push(&ord.this);
889                    }
890                }
891                // Walk LIMIT
892                if let Some(limit) = &select.limit {
893                    children.push(&limit.this);
894                }
895                // Walk OFFSET
896                if let Some(offset) = &select.offset {
897                    children.push(&offset.this);
898                }
899            }
900            Expression::And(bin)
901            | Expression::Or(bin)
902            | Expression::Add(bin)
903            | Expression::Sub(bin)
904            | Expression::Mul(bin)
905            | Expression::Div(bin)
906            | Expression::Mod(bin)
907            | Expression::Eq(bin)
908            | Expression::Neq(bin)
909            | Expression::Lt(bin)
910            | Expression::Lte(bin)
911            | Expression::Gt(bin)
912            | Expression::Gte(bin)
913            | Expression::BitwiseAnd(bin)
914            | Expression::BitwiseOr(bin)
915            | Expression::BitwiseXor(bin)
916            | Expression::Concat(bin) => {
917                children.push(&bin.left);
918                children.push(&bin.right);
919            }
920            Expression::Like(like) | Expression::ILike(like) => {
921                children.push(&like.left);
922                children.push(&like.right);
923                if let Some(escape) = &like.escape {
924                    children.push(escape);
925                }
926            }
927            Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
928                children.push(&un.this);
929            }
930            Expression::Function(func) => {
931                for arg in &func.args {
932                    children.push(arg);
933                }
934            }
935            Expression::AggregateFunction(agg) => {
936                for arg in &agg.args {
937                    children.push(arg);
938                }
939            }
940            Expression::WindowFunction(wf) => {
941                children.push(&wf.this);
942                for e in &wf.over.partition_by {
943                    children.push(e);
944                }
945                for e in &wf.over.order_by {
946                    children.push(&e.this);
947                }
948            }
949            Expression::Alias(alias) => {
950                children.push(&alias.this);
951            }
952            Expression::Case(case) => {
953                if let Some(operand) = &case.operand {
954                    children.push(operand);
955                }
956                for (when_expr, then_expr) in &case.whens {
957                    children.push(when_expr);
958                    children.push(then_expr);
959                }
960                if let Some(else_clause) = &case.else_ {
961                    children.push(else_clause);
962                }
963            }
964            Expression::Paren(paren) => {
965                children.push(&paren.this);
966            }
967            Expression::Ordered(ord) => {
968                children.push(&ord.this);
969            }
970            Expression::In(in_expr) => {
971                children.push(&in_expr.this);
972                for e in &in_expr.expressions {
973                    children.push(e);
974                }
975                // Note: in_expr.query creates a new scope - don't traverse
976            }
977            Expression::Between(between) => {
978                children.push(&between.this);
979                children.push(&between.low);
980                children.push(&between.high);
981            }
982            Expression::IsNull(is_null) => {
983                children.push(&is_null.this);
984            }
985            Expression::Cast(cast) => {
986                children.push(&cast.this);
987            }
988            Expression::Extract(extract) => {
989                children.push(&extract.this);
990            }
991            Expression::Coalesce(coalesce) => {
992                for e in &coalesce.expressions {
993                    children.push(e);
994                }
995            }
996            Expression::NullIf(nullif) => {
997                children.push(&nullif.this);
998                children.push(&nullif.expression);
999            }
1000            Expression::Table(_table) => {
1001                // Tables don't have child expressions to traverse within scope
1002                // (joins are handled at the Select level)
1003            }
1004            Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1005                // Leaf nodes - no children
1006            }
1007            // Subqueries and Exists create new scopes - don't traverse into them
1008            Expression::Subquery(_) | Expression::Exists(_) => {}
1009            _ => {
1010                // For other expressions, we could add more cases as needed
1011            }
1012        }
1013
1014        children
1015    }
1016}
1017
1018impl<'a> Iterator for WalkInScopeIter<'a> {
1019    type Item = &'a Expression;
1020
1021    fn next(&mut self) -> Option<Self::Item> {
1022        let expr = if self.bfs {
1023            self.queue.pop_front()?
1024        } else {
1025            self.queue.pop_back()?
1026        };
1027
1028        // Get children that don't cross scope boundaries
1029        let children = self.get_children(expr);
1030
1031        if self.bfs {
1032            for child in children {
1033                if !self.should_stop_at(child, false) {
1034                    self.queue.push_back(child);
1035                }
1036            }
1037        } else {
1038            for child in children.into_iter().rev() {
1039                if !self.should_stop_at(child, false) {
1040                    self.queue.push_back(child);
1041                }
1042            }
1043        }
1044
1045        Some(expr)
1046    }
1047}
1048
1049/// Find the first expression matching the predicate within this scope.
1050///
1051/// This does NOT traverse into subscopes.
1052///
1053/// # Arguments
1054/// * `expression` - The root expression
1055/// * `predicate` - Function that returns true for matching expressions
1056/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1057///
1058/// # Returns
1059/// The first matching expression, or None
1060pub fn find_in_scope<'a, F>(
1061    expression: &'a Expression,
1062    predicate: F,
1063    bfs: bool,
1064) -> Option<&'a Expression>
1065where
1066    F: Fn(&Expression) -> bool,
1067{
1068    walk_in_scope(expression, bfs).find(|e| predicate(e))
1069}
1070
1071/// Find all expressions matching the predicate within this scope.
1072///
1073/// This does NOT traverse into subscopes.
1074///
1075/// # Arguments
1076/// * `expression` - The root expression
1077/// * `predicate` - Function that returns true for matching expressions
1078/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1079///
1080/// # Returns
1081/// A vector of matching expressions
1082pub fn find_all_in_scope<'a, F>(
1083    expression: &'a Expression,
1084    predicate: F,
1085    bfs: bool,
1086) -> Vec<&'a Expression>
1087where
1088    F: Fn(&Expression) -> bool,
1089{
1090    walk_in_scope(expression, bfs)
1091        .filter(|e| predicate(e))
1092        .collect()
1093}
1094
1095/// Traverse an expression by its "scopes".
1096///
1097/// Returns a list of all scopes in depth-first post-order.
1098///
1099/// # Arguments
1100/// * `expression` - The expression to traverse
1101///
1102/// # Returns
1103/// A vector of all scopes found
1104pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1105    match expression {
1106        Expression::Select(_)
1107        | Expression::Union(_)
1108        | Expression::Intersect(_)
1109        | Expression::Except(_)
1110        | Expression::CreateTable(_) => {
1111            let root = build_scope(expression);
1112            root.traverse().into_iter().cloned().collect()
1113        }
1114        _ => Vec::new(),
1115    }
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use super::*;
1121    use crate::parser::Parser;
1122
1123    fn parse_and_build_scope(sql: &str) -> Scope {
1124        let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1125        build_scope(&ast[0])
1126    }
1127
1128    #[test]
1129    fn test_simple_select_scope() {
1130        let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1131
1132        assert!(scope.is_root());
1133        assert!(!scope.can_be_correlated);
1134        assert!(scope.sources.contains_key("t"));
1135
1136        let columns = scope.columns();
1137        assert_eq!(columns.len(), 2);
1138    }
1139
1140    #[test]
1141    fn test_derived_table_scope() {
1142        let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1143
1144        assert!(scope.sources.contains_key("x"));
1145        assert_eq!(scope.derived_table_scopes.len(), 1);
1146
1147        let derived = &mut scope.derived_table_scopes[0];
1148        assert!(derived.is_derived_table());
1149        assert!(derived.sources.contains_key("t"));
1150    }
1151
1152    #[test]
1153    fn test_non_correlated_subquery() {
1154        let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1155
1156        assert_eq!(scope.subquery_scopes.len(), 1);
1157
1158        let subquery = &mut scope.subquery_scopes[0];
1159        assert!(subquery.is_subquery());
1160        assert!(subquery.can_be_correlated);
1161
1162        // The subquery references only 's', which is in its own sources
1163        assert!(subquery.sources.contains_key("s"));
1164        assert!(!subquery.is_correlated_subquery());
1165    }
1166
1167    #[test]
1168    fn test_correlated_subquery() {
1169        let mut scope =
1170            parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1171
1172        assert_eq!(scope.subquery_scopes.len(), 1);
1173
1174        let subquery = &mut scope.subquery_scopes[0];
1175        assert!(subquery.is_subquery());
1176        assert!(subquery.can_be_correlated);
1177
1178        // The subquery references 't.y' which is external
1179        let external = subquery.external_columns();
1180        assert!(!external.is_empty());
1181        assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1182        assert!(subquery.is_correlated_subquery());
1183    }
1184
1185    #[test]
1186    fn test_cte_scope() {
1187        let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1188
1189        assert_eq!(scope.cte_scopes.len(), 1);
1190        assert!(scope.cte_sources.contains_key("cte"));
1191
1192        let cte = &scope.cte_scopes[0];
1193        assert!(cte.is_cte());
1194    }
1195
1196    #[test]
1197    fn test_multiple_sources() {
1198        let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1199
1200        assert!(scope.sources.contains_key("t"));
1201        assert!(scope.sources.contains_key("s"));
1202        assert_eq!(scope.sources.len(), 2);
1203    }
1204
1205    #[test]
1206    fn test_aliased_table() {
1207        let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1208
1209        // Should be indexed by alias, not original name
1210        assert!(scope.sources.contains_key("x"));
1211        assert!(!scope.sources.contains_key("t"));
1212    }
1213
1214    #[test]
1215    fn test_local_columns() {
1216        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1217
1218        let local = scope.local_columns();
1219        // All columns are local since both t and s are in scope.
1220        // Includes JOIN ON references (t.id, s.id).
1221        assert_eq!(local.len(), 5);
1222        assert!(local.iter().all(|c| c.table.is_some()));
1223    }
1224
1225    #[test]
1226    fn test_columns_include_join_on_clause_references() {
1227        let mut scope = parse_and_build_scope(
1228            "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1229        );
1230
1231        let cols: Vec<String> = scope
1232            .columns()
1233            .iter()
1234            .map(|c| match &c.table {
1235                Some(t) => format!("{}.{}", t, c.name),
1236                None => c.name.clone(),
1237            })
1238            .collect();
1239
1240        assert!(cols.contains(&"o.total".to_string()));
1241        assert!(cols.contains(&"c.id".to_string()));
1242        assert!(cols.contains(&"o.customer_id".to_string()));
1243    }
1244
1245    #[test]
1246    fn test_unqualified_columns() {
1247        let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1248
1249        let unqualified = scope.unqualified_columns();
1250        // Only a and b are unqualified
1251        assert_eq!(unqualified.len(), 2);
1252        assert!(unqualified.iter().all(|c| c.table.is_none()));
1253    }
1254
1255    #[test]
1256    fn test_source_columns() {
1257        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1258
1259        let t_cols = scope.source_columns("t");
1260        // t.a, t.b, and t.id from JOIN condition
1261        assert!(t_cols.len() >= 2);
1262        assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1263
1264        let s_cols = scope.source_columns("s");
1265        // s.c and s.id from JOIN condition
1266        assert!(s_cols.len() >= 1);
1267        assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1268    }
1269
1270    #[test]
1271    fn test_rename_source() {
1272        let mut scope = parse_and_build_scope("SELECT a FROM t");
1273
1274        assert!(scope.sources.contains_key("t"));
1275        scope.rename_source("t", "new_name".to_string());
1276        assert!(!scope.sources.contains_key("t"));
1277        assert!(scope.sources.contains_key("new_name"));
1278    }
1279
1280    #[test]
1281    fn test_remove_source() {
1282        let mut scope = parse_and_build_scope("SELECT a FROM t");
1283
1284        assert!(scope.sources.contains_key("t"));
1285        scope.remove_source("t");
1286        assert!(!scope.sources.contains_key("t"));
1287    }
1288
1289    #[test]
1290    fn test_walk_in_scope() {
1291        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1292        let expr = &ast[0];
1293
1294        // Walk should visit all expressions within the scope
1295        let walked: Vec<_> = walk_in_scope(expr, true).collect();
1296        assert!(!walked.is_empty());
1297
1298        // Should include the root SELECT
1299        assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1300        // Should include columns
1301        assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1302    }
1303
1304    #[test]
1305    fn test_find_in_scope() {
1306        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1307        let expr = &ast[0];
1308
1309        // Find the first column
1310        let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1311        assert!(found.is_some());
1312        assert!(matches!(found.unwrap(), Expression::Column(_)));
1313    }
1314
1315    #[test]
1316    fn test_find_all_in_scope() {
1317        let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1318        let expr = &ast[0];
1319
1320        // Find all columns
1321        let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1322        assert_eq!(found.len(), 3);
1323    }
1324
1325    #[test]
1326    fn test_traverse_scope() {
1327        let ast =
1328            Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1329        let expr = &ast[0];
1330
1331        let scopes = traverse_scope(expr);
1332        // traverse_scope returns all scopes via Scope::traverse
1333        // which includes derived table and root scopes
1334        assert!(!scopes.is_empty());
1335        // The root scope is always included
1336        assert!(scopes.iter().any(|s| s.is_root()));
1337    }
1338
1339    #[test]
1340    fn test_branch_with_options() {
1341        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1342        let scope = build_scope(&ast[0]);
1343
1344        let child = scope.branch_with_options(
1345            ast[0].clone(),
1346            ScopeType::Subquery, // Use Subquery to test can_be_correlated
1347            None,
1348            None,
1349            Some(vec!["col1".to_string(), "col2".to_string()]),
1350        );
1351
1352        assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1353        assert!(child.can_be_correlated); // Subqueries are correlated
1354    }
1355
1356    #[test]
1357    fn test_is_udtf() {
1358        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1359        let scope = Scope::new(ast[0].clone());
1360        assert!(!scope.is_udtf());
1361
1362        let root = build_scope(&ast[0]);
1363        let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1364        assert!(udtf_scope.is_udtf());
1365    }
1366
1367    #[test]
1368    fn test_is_union() {
1369        let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1370
1371        assert!(scope.is_root());
1372        assert_eq!(scope.union_scopes.len(), 2);
1373        // The children are set operation scopes
1374        assert!(scope.union_scopes[0].is_union());
1375        assert!(scope.union_scopes[1].is_union());
1376    }
1377
1378    #[test]
1379    fn test_union_output_columns() {
1380        let scope = parse_and_build_scope(
1381            "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1382        );
1383        assert_eq!(scope.output_columns(), vec!["id", "name"]);
1384    }
1385
1386    #[test]
1387    fn test_clear_cache() {
1388        let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1389
1390        // First call populates cache
1391        let _ = scope.columns();
1392        assert!(scope.columns_cache.is_some());
1393
1394        // Clear cache
1395        scope.clear_cache();
1396        assert!(scope.columns_cache.is_none());
1397        assert!(scope.external_columns_cache.is_none());
1398    }
1399
1400    #[test]
1401    fn test_scope_traverse() {
1402        let scope = parse_and_build_scope(
1403            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1404        );
1405
1406        let traversed = scope.traverse();
1407        // Should include: CTE scope, subquery scope, root scope
1408        assert!(traversed.len() >= 3);
1409    }
1410
1411    #[test]
1412    fn test_create_table_as_select_scope() {
1413        // Simple CTAS
1414        let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1415        assert!(
1416            scope.sources.contains_key("src"),
1417            "CTAS scope should contain the FROM table"
1418        );
1419        assert!(
1420            !scope.sources.contains_key("out_table"),
1421            "CTAS target table should not be treated as a source"
1422        );
1423
1424        // CTAS with multiple FROM tables
1425        let scope = parse_and_build_scope(
1426            "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1427        );
1428        assert!(scope.sources.contains_key("a"));
1429        assert!(scope.sources.contains_key("b"));
1430        assert!(
1431            !scope.sources.contains_key("out_table"),
1432            "CTAS target table should not be treated as a source"
1433        );
1434
1435        // CTAS with CTEs
1436        let scope = parse_and_build_scope(
1437            "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1438        );
1439        assert!(
1440            scope.sources.contains_key("cte"),
1441            "CTAS with CTE should resolve CTE as source"
1442        );
1443        assert!(
1444            !scope.sources.contains_key("out_table"),
1445            "CTAS target table should not be treated as a source"
1446        );
1447        assert_eq!(scope.cte_scopes.len(), 1);
1448    }
1449
1450    #[test]
1451    fn test_create_table_as_select_traverse() {
1452        let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1453        let scopes = traverse_scope(&ast[0]);
1454        assert!(
1455            !scopes.is_empty(),
1456            "traverse_scope should return scopes for CTAS"
1457        );
1458    }
1459}