Skip to main content

polyglot_sql/
scope.rs

1//! Scope Analysis Module
2//!
3//! This module provides scope analysis for SQL queries, enabling detection of
4//! correlated subqueries, column references, and scope relationships.
5//!
6//! Ported from sqlglot's optimizer/scope.py
7
8use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14/// Type of scope in a SQL query
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19    /// Root scope of the query
20    Root,
21    /// Subquery scope (e.g., WHERE x IN (SELECT ...))
22    Subquery,
23    /// Derived table scope (e.g., FROM (SELECT ...) AS t)
24    DerivedTable,
25    /// Common Table Expression scope
26    Cte,
27    /// Union/Intersect/Except scope
28    SetOperation,
29    /// User-Defined Table Function scope
30    Udtf,
31}
32
33/// Semantic kind of a source registered in a scope.
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum SourceKind {
37    /// Root query or statement context
38    Root,
39    /// Physical table source
40    Table,
41    /// Derived table/subquery source
42    DerivedTable,
43    /// Common Table Expression source
44    Cte,
45    /// Virtual table source such as UNNEST / UDTF
46    Virtual,
47    /// Unresolved or synthetic fallback source
48    Unknown,
49}
50
51impl Default for SourceKind {
52    fn default() -> Self {
53        Self::Unknown
54    }
55}
56
57/// Information about a source (table or subquery) in a scope
58#[derive(Debug, Clone)]
59pub struct SourceInfo {
60    /// The source expression (Table or subquery)
61    pub expression: Expression,
62    /// Whether this source is a scope (vs. a plain table)
63    pub is_scope: bool,
64    /// Semantic source kind for lineage consumers.
65    pub kind: SourceKind,
66    /// User-written alias, when it should be preserved separately from lineage name.
67    pub alias: Option<String>,
68    /// Canonical lineage source name, e.g. synthetic `_0` for virtual sources.
69    pub lineage_name: Option<String>,
70}
71
72impl SourceInfo {
73    pub fn new(expression: Expression, is_scope: bool, kind: SourceKind) -> Self {
74        Self {
75            expression,
76            is_scope,
77            kind,
78            alias: None,
79            lineage_name: None,
80        }
81    }
82
83    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
84        self.alias = Some(alias.into());
85        self
86    }
87
88    pub fn with_lineage_name(mut self, lineage_name: impl Into<String>) -> Self {
89        self.lineage_name = Some(lineage_name.into());
90        self
91    }
92}
93
94/// A column reference found in a scope
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct ColumnRef {
97    /// The table/alias qualifier (if any)
98    pub table: Option<String>,
99    /// The column name
100    pub name: String,
101}
102
103/// Represents a scope in a SQL query
104///
105/// A scope is the context of a SELECT statement and its sources.
106/// Scopes can be nested (subqueries, CTEs, derived tables) and form a tree.
107#[derive(Debug, Clone)]
108pub struct Scope {
109    /// The expression at the root of this scope
110    pub expression: Expression,
111
112    /// Type of this scope relative to its parent
113    pub scope_type: ScopeType,
114
115    /// Mapping of source names to their info
116    pub sources: HashMap<String, SourceInfo>,
117
118    /// Sources from LATERAL views (have access to preceding sources)
119    pub lateral_sources: HashMap<String, SourceInfo>,
120
121    /// CTE sources available to this scope
122    pub cte_sources: HashMap<String, SourceInfo>,
123
124    /// If this is a derived table or CTE with alias columns, this is that list
125    /// e.g., `SELECT * FROM (SELECT ...) AS y(col1, col2)` => ["col1", "col2"]
126    pub outer_columns: Vec<String>,
127
128    /// Whether this scope can potentially be correlated
129    /// (true for subqueries and UDTFs)
130    pub can_be_correlated: bool,
131
132    /// Child subquery scopes
133    pub subquery_scopes: Vec<Scope>,
134
135    /// Child derived table scopes
136    pub derived_table_scopes: Vec<Scope>,
137
138    /// Child CTE scopes
139    pub cte_scopes: Vec<Scope>,
140
141    /// Child UDTF (User Defined Table Function) scopes
142    pub udtf_scopes: Vec<Scope>,
143
144    /// Combined derived_table_scopes + udtf_scopes in definition order
145    pub table_scopes: Vec<Scope>,
146
147    /// Union/set operation scopes (left and right)
148    pub union_scopes: Vec<Scope>,
149
150    /// Cached columns
151    columns_cache: Option<Vec<ColumnRef>>,
152
153    /// Cached external columns
154    external_columns_cache: Option<Vec<ColumnRef>>,
155}
156
157impl Scope {
158    /// Create a new root scope
159    pub fn new(expression: Expression) -> Self {
160        Self {
161            expression,
162            scope_type: ScopeType::Root,
163            sources: HashMap::new(),
164            lateral_sources: HashMap::new(),
165            cte_sources: HashMap::new(),
166            outer_columns: Vec::new(),
167            can_be_correlated: false,
168            subquery_scopes: Vec::new(),
169            derived_table_scopes: Vec::new(),
170            cte_scopes: Vec::new(),
171            udtf_scopes: Vec::new(),
172            table_scopes: Vec::new(),
173            union_scopes: Vec::new(),
174            columns_cache: None,
175            external_columns_cache: None,
176        }
177    }
178
179    /// Create a child scope branching from this one
180    pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
181        self.branch_with_options(expression, scope_type, None, None, None)
182    }
183
184    /// Create a child scope with additional options
185    pub fn branch_with_options(
186        &self,
187        expression: Expression,
188        scope_type: ScopeType,
189        sources: Option<HashMap<String, SourceInfo>>,
190        lateral_sources: Option<HashMap<String, SourceInfo>>,
191        outer_columns: Option<Vec<String>>,
192    ) -> Self {
193        let can_be_correlated = self.can_be_correlated
194            || scope_type == ScopeType::Subquery
195            || scope_type == ScopeType::Udtf;
196
197        Self {
198            expression,
199            scope_type,
200            sources: sources.unwrap_or_default(),
201            lateral_sources: lateral_sources.unwrap_or_default(),
202            cte_sources: self.cte_sources.clone(),
203            outer_columns: outer_columns.unwrap_or_default(),
204            can_be_correlated,
205            subquery_scopes: Vec::new(),
206            derived_table_scopes: Vec::new(),
207            cte_scopes: Vec::new(),
208            udtf_scopes: Vec::new(),
209            table_scopes: Vec::new(),
210            union_scopes: Vec::new(),
211            columns_cache: None,
212            external_columns_cache: None,
213        }
214    }
215
216    /// Clear all cached properties
217    pub fn clear_cache(&mut self) {
218        self.columns_cache = None;
219        self.external_columns_cache = None;
220    }
221
222    /// Add a source to this scope
223    pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
224        let kind = if is_scope {
225            SourceKind::DerivedTable
226        } else {
227            SourceKind::Table
228        };
229        self.add_source_info(name, SourceInfo::new(expression, is_scope, kind));
230    }
231
232    /// Add a preconfigured source to this scope
233    pub fn add_source_info(&mut self, name: String, info: SourceInfo) {
234        self.sources.insert(name, info);
235        self.clear_cache();
236    }
237
238    /// Add a virtual source such as UNNEST / UDTF.
239    pub fn add_virtual_source(&mut self, alias: String, expression: Expression) {
240        let lineage_name = self.next_virtual_source_name();
241        let info = SourceInfo::new(expression, false, SourceKind::Virtual)
242            .with_alias(alias.clone())
243            .with_lineage_name(lineage_name);
244        self.add_source_info(alias, info);
245    }
246
247    fn next_virtual_source_name(&self) -> String {
248        let count = self
249            .sources
250            .values()
251            .filter(|source| source.kind == SourceKind::Virtual)
252            .count();
253        format!("_{}", count)
254    }
255
256    /// Add a lateral source to this scope
257    pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
258        let kind = if is_scope {
259            SourceKind::DerivedTable
260        } else {
261            SourceKind::Table
262        };
263        let info = SourceInfo::new(expression.clone(), is_scope, kind);
264        self.sources.insert(name.clone(), info.clone());
265        self.lateral_sources.insert(name, info);
266        self.clear_cache();
267    }
268
269    /// Add a CTE source to this scope
270    pub fn add_cte_source(&mut self, name: String, expression: Expression) {
271        let info = SourceInfo::new(expression, true, SourceKind::Cte);
272        self.cte_sources.insert(name.clone(), info.clone());
273        self.sources.insert(name, info);
274        self.clear_cache();
275    }
276
277    /// Rename a source
278    pub fn rename_source(&mut self, old_name: &str, new_name: String) {
279        if let Some(source) = self.sources.remove(old_name) {
280            self.sources.insert(new_name, source);
281        }
282        self.clear_cache();
283    }
284
285    /// Remove a source
286    pub fn remove_source(&mut self, name: &str) {
287        self.sources.remove(name);
288        self.clear_cache();
289    }
290
291    /// Collect all column references in this scope
292    pub fn columns(&mut self) -> &[ColumnRef] {
293        if self.columns_cache.is_none() {
294            let mut columns = Vec::new();
295            collect_columns(&self.expression, &mut columns);
296            self.columns_cache = Some(columns);
297        }
298        self.columns_cache.as_ref().unwrap()
299    }
300
301    /// Collect projected output column names for this scope's query expression.
302    ///
303    /// This is intended for result schema style output columns (e.g. UNION
304    /// outputs), unlike [`Self::columns`], which returns raw referenced columns.
305    pub fn output_columns(&self) -> Vec<String> {
306        crate::ast_transforms::get_output_column_names(&self.expression)
307    }
308
309    /// Get all source names in this scope
310    pub fn source_names(&self) -> HashSet<String> {
311        let mut names: HashSet<String> = self.sources.keys().cloned().collect();
312        names.extend(self.cte_sources.keys().cloned());
313        names
314    }
315
316    /// Get columns that reference sources outside this scope
317    pub fn external_columns(&mut self) -> Vec<ColumnRef> {
318        if self.external_columns_cache.is_some() {
319            return self.external_columns_cache.clone().unwrap();
320        }
321
322        let source_names = self.source_names();
323        let columns = self.columns().to_vec();
324
325        let external: Vec<ColumnRef> = columns
326            .into_iter()
327            .filter(|col| {
328                // A column is external if it has a table qualifier that's not in our sources
329                match &col.table {
330                    Some(table) => !source_names.contains(table),
331                    None => false, // Unqualified columns might be local
332                }
333            })
334            .collect();
335
336        self.external_columns_cache = Some(external.clone());
337        external
338    }
339
340    /// Get columns that reference sources in this scope (not external)
341    pub fn local_columns(&mut self) -> Vec<ColumnRef> {
342        let external_set: HashSet<_> = self.external_columns().into_iter().collect();
343        let columns = self.columns().to_vec();
344
345        columns
346            .into_iter()
347            .filter(|col| !external_set.contains(col))
348            .collect()
349    }
350
351    /// Get unqualified columns (columns without table qualifier)
352    pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
353        self.columns()
354            .iter()
355            .filter(|c| c.table.is_none())
356            .cloned()
357            .collect()
358    }
359
360    /// Get columns for a specific source
361    pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
362        self.columns()
363            .iter()
364            .filter(|col| col.table.as_deref() == Some(source_name))
365            .cloned()
366            .collect()
367    }
368
369    /// Determine if this scope is a correlated subquery
370    ///
371    /// A subquery is correlated if:
372    /// 1. It can be correlated (is a subquery or UDTF), AND
373    /// 2. It references columns from outer scopes
374    pub fn is_correlated_subquery(&mut self) -> bool {
375        self.can_be_correlated && !self.external_columns().is_empty()
376    }
377
378    /// Check if this is a subquery scope
379    pub fn is_subquery(&self) -> bool {
380        self.scope_type == ScopeType::Subquery
381    }
382
383    /// Check if this is a derived table scope
384    pub fn is_derived_table(&self) -> bool {
385        self.scope_type == ScopeType::DerivedTable
386    }
387
388    /// Check if this is a CTE scope
389    pub fn is_cte(&self) -> bool {
390        self.scope_type == ScopeType::Cte
391    }
392
393    /// Check if this is the root scope
394    pub fn is_root(&self) -> bool {
395        self.scope_type == ScopeType::Root
396    }
397
398    /// Check if this is a UDTF scope
399    pub fn is_udtf(&self) -> bool {
400        self.scope_type == ScopeType::Udtf
401    }
402
403    /// Check if this is a union/set operation scope
404    pub fn is_union(&self) -> bool {
405        self.scope_type == ScopeType::SetOperation
406    }
407
408    /// Traverse all scopes in this tree (depth-first post-order)
409    pub fn traverse(&self) -> Vec<&Scope> {
410        let mut result = Vec::new();
411        self.traverse_impl(&mut result);
412        result
413    }
414
415    fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
416        // First traverse children
417        for scope in &self.cte_scopes {
418            scope.traverse_impl(result);
419        }
420        for scope in &self.union_scopes {
421            scope.traverse_impl(result);
422        }
423        for scope in &self.table_scopes {
424            scope.traverse_impl(result);
425        }
426        for scope in &self.subquery_scopes {
427            scope.traverse_impl(result);
428        }
429        // Then add self
430        result.push(self);
431    }
432
433    /// Count references to each scope in this tree
434    pub fn ref_count(&self) -> HashMap<usize, usize> {
435        let mut counts: HashMap<usize, usize> = HashMap::new();
436
437        for scope in self.traverse() {
438            for (_, source_info) in scope.sources.iter() {
439                if source_info.is_scope {
440                    let id = &source_info.expression as *const _ as usize;
441                    *counts.entry(id).or_insert(0) += 1;
442                }
443            }
444        }
445
446        counts
447    }
448}
449
450/// Collect all column references from an expression tree
451fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
452    match expr {
453        Expression::Column(col) => {
454            columns.push(ColumnRef {
455                table: col.table.as_ref().map(|t| t.name.clone()),
456                name: col.name.name.clone(),
457            });
458        }
459        Expression::Select(select) => {
460            // Collect from SELECT expressions
461            for e in &select.expressions {
462                collect_columns(e, columns);
463            }
464            // Collect from JOIN ON / MATCH_CONDITION clauses
465            for join in &select.joins {
466                if let Some(on) = &join.on {
467                    collect_columns(on, columns);
468                }
469                if let Some(match_condition) = &join.match_condition {
470                    collect_columns(match_condition, columns);
471                }
472            }
473            // Collect from WHERE
474            if let Some(where_clause) = &select.where_clause {
475                collect_columns(&where_clause.this, columns);
476            }
477            // Collect from HAVING
478            if let Some(having) = &select.having {
479                collect_columns(&having.this, columns);
480            }
481            // Collect from ORDER BY
482            if let Some(order_by) = &select.order_by {
483                for ord in &order_by.expressions {
484                    collect_columns(&ord.this, columns);
485                }
486            }
487            // Collect from GROUP BY
488            if let Some(group_by) = &select.group_by {
489                for e in &group_by.expressions {
490                    collect_columns(e, columns);
491                }
492            }
493            // Note: We don't recurse into FROM/JOIN source subqueries here
494            // as those create their own scopes.
495        }
496        // Binary operations
497        Expression::And(bin)
498        | Expression::Or(bin)
499        | Expression::Add(bin)
500        | Expression::Sub(bin)
501        | Expression::Mul(bin)
502        | Expression::Div(bin)
503        | Expression::Mod(bin)
504        | Expression::Eq(bin)
505        | Expression::Neq(bin)
506        | Expression::Lt(bin)
507        | Expression::Lte(bin)
508        | Expression::Gt(bin)
509        | Expression::Gte(bin)
510        | Expression::BitwiseAnd(bin)
511        | Expression::BitwiseOr(bin)
512        | Expression::BitwiseXor(bin)
513        | Expression::Concat(bin) => {
514            collect_columns(&bin.left, columns);
515            collect_columns(&bin.right, columns);
516        }
517        // LIKE/ILIKE operations
518        Expression::Like(like) | Expression::ILike(like) => {
519            collect_columns(&like.left, columns);
520            collect_columns(&like.right, columns);
521            if let Some(escape) = &like.escape {
522                collect_columns(escape, columns);
523            }
524        }
525        // Unary operations
526        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
527            collect_columns(&un.this, columns);
528        }
529        Expression::Function(func) => {
530            for arg in &func.args {
531                collect_columns(arg, columns);
532            }
533        }
534        Expression::AggregateFunction(agg) => {
535            for arg in &agg.args {
536                collect_columns(arg, columns);
537            }
538        }
539        Expression::WindowFunction(wf) => {
540            collect_columns(&wf.this, columns);
541            for e in &wf.over.partition_by {
542                collect_columns(e, columns);
543            }
544            for e in &wf.over.order_by {
545                collect_columns(&e.this, columns);
546            }
547        }
548        Expression::Alias(alias) => {
549            collect_columns(&alias.this, columns);
550        }
551        Expression::Case(case) => {
552            if let Some(operand) = &case.operand {
553                collect_columns(operand, columns);
554            }
555            for (when_expr, then_expr) in &case.whens {
556                collect_columns(when_expr, columns);
557                collect_columns(then_expr, columns);
558            }
559            if let Some(else_clause) = &case.else_ {
560                collect_columns(else_clause, columns);
561            }
562        }
563        Expression::Paren(paren) => {
564            collect_columns(&paren.this, columns);
565        }
566        Expression::Ordered(ord) => {
567            collect_columns(&ord.this, columns);
568        }
569        Expression::In(in_expr) => {
570            collect_columns(&in_expr.this, columns);
571            for e in &in_expr.expressions {
572                collect_columns(e, columns);
573            }
574            // Note: in_expr.query is a subquery - creates its own scope
575        }
576        Expression::Between(between) => {
577            collect_columns(&between.this, columns);
578            collect_columns(&between.low, columns);
579            collect_columns(&between.high, columns);
580        }
581        Expression::IsNull(is_null) => {
582            collect_columns(&is_null.this, columns);
583        }
584        Expression::Cast(cast) => {
585            collect_columns(&cast.this, columns);
586        }
587        Expression::Extract(extract) => {
588            collect_columns(&extract.this, columns);
589        }
590        Expression::Exists(_) | Expression::Subquery(_) => {
591            // These create their own scopes - don't collect from here
592        }
593        Expression::Prepare(prepare) => {
594            collect_columns(&prepare.statement, columns);
595        }
596        _ => {
597            // For other expressions, we might need to add more cases
598        }
599    }
600}
601
602/// Build scope tree from an expression
603///
604/// This traverses the expression tree and builds a hierarchy of Scope objects
605/// that track sources and column references at each level.
606pub fn build_scope(expression: &Expression) -> Scope {
607    let mut root = Scope::new(expression.clone());
608    build_scope_impl(expression, &mut root);
609    root
610}
611
612fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
613    match expression {
614        Expression::Prepare(prepare) => {
615            build_scope_impl(&prepare.statement, current_scope);
616        }
617        Expression::Select(select) => {
618            // Process CTEs first
619            if let Some(with) = &select.with {
620                for cte in &with.ctes {
621                    let cte_name = cte.alias.name.clone();
622                    let mut cte_scope = current_scope
623                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
624                    build_scope_impl(&cte.this, &mut cte_scope);
625                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
626                    current_scope.cte_scopes.push(cte_scope);
627                }
628            }
629
630            // Process FROM clause
631            if let Some(from) = &select.from {
632                for table in &from.expressions {
633                    add_table_to_scope(table, current_scope);
634                }
635            }
636
637            // Process JOINs
638            for join in &select.joins {
639                add_table_to_scope(&join.this, current_scope);
640            }
641
642            // Process table-generating lateral views (Hive/Spark style UDTFs).
643            for lateral_view in &select.lateral_views {
644                add_lateral_view_to_scope(lateral_view, current_scope);
645            }
646
647            // Process subqueries in WHERE, SELECT expressions, etc.
648            collect_subqueries(expression, current_scope);
649        }
650        Expression::Union(union) => {
651            let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
652            build_scope_impl(&union.left, &mut left_scope);
653
654            let mut right_scope =
655                current_scope.branch(union.right.clone(), ScopeType::SetOperation);
656            build_scope_impl(&union.right, &mut right_scope);
657
658            current_scope.union_scopes.push(left_scope);
659            current_scope.union_scopes.push(right_scope);
660        }
661        Expression::Intersect(intersect) => {
662            let mut left_scope =
663                current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
664            build_scope_impl(&intersect.left, &mut left_scope);
665
666            let mut right_scope =
667                current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
668            build_scope_impl(&intersect.right, &mut right_scope);
669
670            current_scope.union_scopes.push(left_scope);
671            current_scope.union_scopes.push(right_scope);
672        }
673        Expression::Except(except) => {
674            let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
675            build_scope_impl(&except.left, &mut left_scope);
676
677            let mut right_scope =
678                current_scope.branch(except.right.clone(), ScopeType::SetOperation);
679            build_scope_impl(&except.right, &mut right_scope);
680
681            current_scope.union_scopes.push(left_scope);
682            current_scope.union_scopes.push(right_scope);
683        }
684        Expression::CreateTable(create) => {
685            // Handle CREATE TABLE ... AS [WITH ...] SELECT ...
686            // Process CTEs if present
687            if let Some(with) = &create.with_cte {
688                for cte in &with.ctes {
689                    let cte_name = cte.alias.name.clone();
690                    let mut cte_scope = current_scope
691                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
692                    build_scope_impl(&cte.this, &mut cte_scope);
693                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
694                    current_scope.cte_scopes.push(cte_scope);
695                }
696            }
697            // Traverse the AS SELECT body
698            if let Some(as_select) = &create.as_select {
699                build_scope_impl(as_select, current_scope);
700            }
701        }
702        _ => {}
703    }
704}
705
706fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
707    match expr {
708        Expression::Table(table) => {
709            let name = table
710                .alias
711                .as_ref()
712                .map(|a| a.name.clone())
713                .unwrap_or_else(|| table.name.name.clone());
714            let cte_source = if table.schema.is_none() && table.catalog.is_none() {
715                scope.cte_sources.get(&table.name.name).or_else(|| {
716                    scope
717                        .cte_sources
718                        .iter()
719                        .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
720                        .map(|(_, source)| source)
721                })
722            } else {
723                None
724            };
725
726            if let Some(source) = cte_source {
727                scope.add_source_info(name, source.clone());
728            } else {
729                let mut source = SourceInfo::new(expr.clone(), false, SourceKind::Table);
730                if let Some(alias) = &table.alias {
731                    source = source.with_alias(alias.name.clone());
732                }
733                scope.add_source_info(name, source);
734            }
735        }
736        Expression::Subquery(subquery) => {
737            let name = subquery
738                .alias
739                .as_ref()
740                .map(|a| a.name.clone())
741                .unwrap_or_default();
742
743            let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
744            build_scope_impl(&subquery.this, &mut derived_scope);
745
746            scope.add_source(name.clone(), expr.clone(), true);
747            scope.derived_table_scopes.push(derived_scope);
748        }
749        Expression::Unnest(unnest) => {
750            if let Some(alias) = &unnest.alias {
751                scope.add_virtual_source(alias.name.clone(), expr.clone());
752            }
753        }
754        Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
755            scope.add_virtual_source(alias.alias.name.clone(), expr.clone());
756        }
757        Expression::Lateral(lateral) => {
758            if let Some(alias) = &lateral.alias {
759                scope.add_virtual_source(alias.clone(), expr.clone());
760            }
761        }
762        Expression::LateralView(lateral_view) => {
763            add_lateral_view_to_scope(lateral_view, scope);
764        }
765        Expression::Paren(paren) => {
766            add_table_to_scope(&paren.this, scope);
767        }
768        _ => {}
769    }
770}
771
772fn add_lateral_view_to_scope(lateral_view: &crate::expressions::LateralView, scope: &mut Scope) {
773    let alias = lateral_view
774        .table_alias
775        .as_ref()
776        .or_else(|| lateral_view.column_aliases.first())
777        .map(|alias| alias.name.clone());
778
779    if let Some(alias) = alias {
780        scope.add_virtual_source(
781            alias,
782            Expression::LateralView(Box::new(lateral_view.clone())),
783        );
784    }
785}
786
787fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
788    match expr {
789        Expression::Select(select) => {
790            // Check WHERE for subqueries
791            if let Some(where_clause) = &select.where_clause {
792                collect_subqueries_in_expr(&where_clause.this, parent_scope);
793            }
794            // Check SELECT expressions for subqueries
795            for e in &select.expressions {
796                collect_subqueries_in_expr(e, parent_scope);
797            }
798            // Check HAVING for subqueries
799            if let Some(having) = &select.having {
800                collect_subqueries_in_expr(&having.this, parent_scope);
801            }
802        }
803        _ => {}
804    }
805}
806
807fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
808    match expr {
809        Expression::Subquery(subquery) if subquery.alias.is_none() => {
810            // This is a scalar subquery or IN subquery (not a derived table)
811            let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
812            build_scope_impl(&subquery.this, &mut sub_scope);
813            parent_scope.subquery_scopes.push(sub_scope);
814        }
815        Expression::In(in_expr) => {
816            collect_subqueries_in_expr(&in_expr.this, parent_scope);
817            if let Some(query) = &in_expr.query {
818                let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
819                build_scope_impl(query, &mut sub_scope);
820                parent_scope.subquery_scopes.push(sub_scope);
821            }
822        }
823        Expression::Exists(exists) => {
824            let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
825            build_scope_impl(&exists.this, &mut sub_scope);
826            parent_scope.subquery_scopes.push(sub_scope);
827        }
828        // Binary operations
829        Expression::And(bin)
830        | Expression::Or(bin)
831        | Expression::Add(bin)
832        | Expression::Sub(bin)
833        | Expression::Mul(bin)
834        | Expression::Div(bin)
835        | Expression::Mod(bin)
836        | Expression::Eq(bin)
837        | Expression::Neq(bin)
838        | Expression::Lt(bin)
839        | Expression::Lte(bin)
840        | Expression::Gt(bin)
841        | Expression::Gte(bin)
842        | Expression::BitwiseAnd(bin)
843        | Expression::BitwiseOr(bin)
844        | Expression::BitwiseXor(bin)
845        | Expression::Concat(bin) => {
846            collect_subqueries_in_expr(&bin.left, parent_scope);
847            collect_subqueries_in_expr(&bin.right, parent_scope);
848        }
849        // LIKE/ILIKE operations (have different structure with escape)
850        Expression::Like(like) | Expression::ILike(like) => {
851            collect_subqueries_in_expr(&like.left, parent_scope);
852            collect_subqueries_in_expr(&like.right, parent_scope);
853            if let Some(escape) = &like.escape {
854                collect_subqueries_in_expr(escape, parent_scope);
855            }
856        }
857        // Unary operations
858        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
859            collect_subqueries_in_expr(&un.this, parent_scope);
860        }
861        Expression::Function(func) => {
862            for arg in &func.args {
863                collect_subqueries_in_expr(arg, parent_scope);
864            }
865        }
866        Expression::Case(case) => {
867            if let Some(operand) = &case.operand {
868                collect_subqueries_in_expr(operand, parent_scope);
869            }
870            for (when_expr, then_expr) in &case.whens {
871                collect_subqueries_in_expr(when_expr, parent_scope);
872                collect_subqueries_in_expr(then_expr, parent_scope);
873            }
874            if let Some(else_clause) = &case.else_ {
875                collect_subqueries_in_expr(else_clause, parent_scope);
876            }
877        }
878        Expression::Paren(paren) => {
879            collect_subqueries_in_expr(&paren.this, parent_scope);
880        }
881        Expression::Alias(alias) => {
882            collect_subqueries_in_expr(&alias.this, parent_scope);
883        }
884        _ => {}
885    }
886}
887
888/// Walk within a scope, yielding expressions without crossing scope boundaries.
889///
890/// This iterator visits all nodes in the syntax tree, stopping at nodes that
891/// start child scopes (CTEs, derived tables, subqueries in FROM/JOIN).
892///
893/// # Arguments
894/// * `expression` - The expression to walk
895/// * `bfs` - If true, uses breadth-first search; otherwise uses depth-first search
896///
897/// # Returns
898/// An iterator over expressions within the scope
899pub fn walk_in_scope<'a>(
900    expression: &'a Expression,
901    bfs: bool,
902) -> impl Iterator<Item = &'a Expression> {
903    WalkInScopeIter::new(expression, bfs)
904}
905
906/// Iterator for walking within a scope
907struct WalkInScopeIter<'a> {
908    queue: VecDeque<&'a Expression>,
909    bfs: bool,
910}
911
912impl<'a> WalkInScopeIter<'a> {
913    fn new(expression: &'a Expression, bfs: bool) -> Self {
914        let mut queue = VecDeque::new();
915        queue.push_back(expression);
916        Self { queue, bfs }
917    }
918
919    fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
920        if is_root {
921            return false;
922        }
923
924        // Stop at CTE definitions
925        if matches!(expr, Expression::Cte(_)) {
926            return true;
927        }
928
929        // Stop at subqueries that are derived tables (in FROM/JOIN)
930        if let Expression::Subquery(subquery) = expr {
931            if subquery.alias.is_some() {
932                return true;
933            }
934        }
935
936        // Stop at standalone SELECT/UNION/etc that would be subqueries
937        if matches!(
938            expr,
939            Expression::Select(_)
940                | Expression::Union(_)
941                | Expression::Intersect(_)
942                | Expression::Except(_)
943        ) {
944            return true;
945        }
946
947        false
948    }
949
950    fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
951        let mut children = Vec::new();
952
953        match expr {
954            Expression::Prepare(prepare) => {
955                children.push(&prepare.statement);
956            }
957            Expression::Select(select) => {
958                // Walk SELECT expressions
959                for e in &select.expressions {
960                    children.push(e);
961                }
962                // Walk FROM (but tables/subqueries create new scopes)
963                if let Some(from) = &select.from {
964                    for table in &from.expressions {
965                        if !self.should_stop_at(table, false) {
966                            children.push(table);
967                        }
968                    }
969                }
970                // Walk JOINs (but their sources create new scopes)
971                for join in &select.joins {
972                    if let Some(on) = &join.on {
973                        children.push(on);
974                    }
975                    // Don't traverse join.this as it's a source (table or subquery)
976                }
977                // Walk WHERE
978                if let Some(where_clause) = &select.where_clause {
979                    children.push(&where_clause.this);
980                }
981                // Walk GROUP BY
982                if let Some(group_by) = &select.group_by {
983                    for e in &group_by.expressions {
984                        children.push(e);
985                    }
986                }
987                // Walk HAVING
988                if let Some(having) = &select.having {
989                    children.push(&having.this);
990                }
991                // Walk ORDER BY
992                if let Some(order_by) = &select.order_by {
993                    for ord in &order_by.expressions {
994                        children.push(&ord.this);
995                    }
996                }
997                // Walk LIMIT
998                if let Some(limit) = &select.limit {
999                    children.push(&limit.this);
1000                }
1001                // Walk OFFSET
1002                if let Some(offset) = &select.offset {
1003                    children.push(&offset.this);
1004                }
1005            }
1006            Expression::And(bin)
1007            | Expression::Or(bin)
1008            | Expression::Add(bin)
1009            | Expression::Sub(bin)
1010            | Expression::Mul(bin)
1011            | Expression::Div(bin)
1012            | Expression::Mod(bin)
1013            | Expression::Eq(bin)
1014            | Expression::Neq(bin)
1015            | Expression::Lt(bin)
1016            | Expression::Lte(bin)
1017            | Expression::Gt(bin)
1018            | Expression::Gte(bin)
1019            | Expression::BitwiseAnd(bin)
1020            | Expression::BitwiseOr(bin)
1021            | Expression::BitwiseXor(bin)
1022            | Expression::Concat(bin) => {
1023                children.push(&bin.left);
1024                children.push(&bin.right);
1025            }
1026            Expression::Like(like) | Expression::ILike(like) => {
1027                children.push(&like.left);
1028                children.push(&like.right);
1029                if let Some(escape) = &like.escape {
1030                    children.push(escape);
1031                }
1032            }
1033            Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1034                children.push(&un.this);
1035            }
1036            Expression::Function(func) => {
1037                for arg in &func.args {
1038                    children.push(arg);
1039                }
1040            }
1041            Expression::AggregateFunction(agg) => {
1042                for arg in &agg.args {
1043                    children.push(arg);
1044                }
1045            }
1046            Expression::WindowFunction(wf) => {
1047                children.push(&wf.this);
1048                for e in &wf.over.partition_by {
1049                    children.push(e);
1050                }
1051                for e in &wf.over.order_by {
1052                    children.push(&e.this);
1053                }
1054            }
1055            Expression::Alias(alias) => {
1056                children.push(&alias.this);
1057            }
1058            Expression::Case(case) => {
1059                if let Some(operand) = &case.operand {
1060                    children.push(operand);
1061                }
1062                for (when_expr, then_expr) in &case.whens {
1063                    children.push(when_expr);
1064                    children.push(then_expr);
1065                }
1066                if let Some(else_clause) = &case.else_ {
1067                    children.push(else_clause);
1068                }
1069            }
1070            Expression::Paren(paren) => {
1071                children.push(&paren.this);
1072            }
1073            Expression::Ordered(ord) => {
1074                children.push(&ord.this);
1075            }
1076            Expression::In(in_expr) => {
1077                children.push(&in_expr.this);
1078                for e in &in_expr.expressions {
1079                    children.push(e);
1080                }
1081                // Note: in_expr.query creates a new scope - don't traverse
1082            }
1083            Expression::Between(between) => {
1084                children.push(&between.this);
1085                children.push(&between.low);
1086                children.push(&between.high);
1087            }
1088            Expression::IsNull(is_null) => {
1089                children.push(&is_null.this);
1090            }
1091            Expression::Cast(cast) => {
1092                children.push(&cast.this);
1093            }
1094            Expression::Extract(extract) => {
1095                children.push(&extract.this);
1096            }
1097            Expression::Coalesce(coalesce) => {
1098                for e in &coalesce.expressions {
1099                    children.push(e);
1100                }
1101            }
1102            Expression::NullIf(nullif) => {
1103                children.push(&nullif.this);
1104                children.push(&nullif.expression);
1105            }
1106            Expression::Table(_table) => {
1107                // Tables don't have child expressions to traverse within scope
1108                // (joins are handled at the Select level)
1109            }
1110            Expression::TryCatch(try_catch) => {
1111                for stmt in &try_catch.try_body {
1112                    children.push(stmt);
1113                }
1114                if let Some(catch_body) = &try_catch.catch_body {
1115                    for stmt in catch_body {
1116                        children.push(stmt);
1117                    }
1118                }
1119            }
1120            Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1121                // Leaf nodes - no children
1122            }
1123            // Subqueries and Exists create new scopes - don't traverse into them
1124            Expression::Subquery(_) | Expression::Exists(_) => {}
1125            _ => {
1126                // For other expressions, we could add more cases as needed
1127            }
1128        }
1129
1130        children
1131    }
1132}
1133
1134impl<'a> Iterator for WalkInScopeIter<'a> {
1135    type Item = &'a Expression;
1136
1137    fn next(&mut self) -> Option<Self::Item> {
1138        let expr = if self.bfs {
1139            self.queue.pop_front()?
1140        } else {
1141            self.queue.pop_back()?
1142        };
1143
1144        // Get children that don't cross scope boundaries
1145        let children = self.get_children(expr);
1146
1147        if self.bfs {
1148            for child in children {
1149                if !self.should_stop_at(child, false) {
1150                    self.queue.push_back(child);
1151                }
1152            }
1153        } else {
1154            for child in children.into_iter().rev() {
1155                if !self.should_stop_at(child, false) {
1156                    self.queue.push_back(child);
1157                }
1158            }
1159        }
1160
1161        Some(expr)
1162    }
1163}
1164
1165/// Find the first expression matching the predicate within this scope.
1166///
1167/// This does NOT traverse into subscopes.
1168///
1169/// # Arguments
1170/// * `expression` - The root expression
1171/// * `predicate` - Function that returns true for matching expressions
1172/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1173///
1174/// # Returns
1175/// The first matching expression, or None
1176pub fn find_in_scope<'a, F>(
1177    expression: &'a Expression,
1178    predicate: F,
1179    bfs: bool,
1180) -> Option<&'a Expression>
1181where
1182    F: Fn(&Expression) -> bool,
1183{
1184    walk_in_scope(expression, bfs).find(|e| predicate(e))
1185}
1186
1187/// Find all expressions matching the predicate within this scope.
1188///
1189/// This does NOT traverse into subscopes.
1190///
1191/// # Arguments
1192/// * `expression` - The root expression
1193/// * `predicate` - Function that returns true for matching expressions
1194/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1195///
1196/// # Returns
1197/// A vector of matching expressions
1198pub fn find_all_in_scope<'a, F>(
1199    expression: &'a Expression,
1200    predicate: F,
1201    bfs: bool,
1202) -> Vec<&'a Expression>
1203where
1204    F: Fn(&Expression) -> bool,
1205{
1206    walk_in_scope(expression, bfs)
1207        .filter(|e| predicate(e))
1208        .collect()
1209}
1210
1211/// Traverse an expression by its "scopes".
1212///
1213/// Returns a list of all scopes in depth-first post-order.
1214///
1215/// # Arguments
1216/// * `expression` - The expression to traverse
1217///
1218/// # Returns
1219/// A vector of all scopes found
1220pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1221    match expression {
1222        Expression::Select(_)
1223        | Expression::Union(_)
1224        | Expression::Intersect(_)
1225        | Expression::Except(_)
1226        | Expression::Prepare(_)
1227        | Expression::CreateTable(_) => {
1228            let root = build_scope(expression);
1229            root.traverse().into_iter().cloned().collect()
1230        }
1231        _ => Vec::new(),
1232    }
1233}
1234
1235#[cfg(test)]
1236mod tests {
1237    use super::*;
1238    use crate::parser::Parser;
1239
1240    fn parse_and_build_scope(sql: &str) -> Scope {
1241        let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1242        build_scope(&ast[0])
1243    }
1244
1245    #[test]
1246    fn test_simple_select_scope() {
1247        let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1248
1249        assert!(scope.is_root());
1250        assert!(!scope.can_be_correlated);
1251        assert!(scope.sources.contains_key("t"));
1252
1253        let columns = scope.columns();
1254        assert_eq!(columns.len(), 2);
1255    }
1256
1257    #[test]
1258    fn test_derived_table_scope() {
1259        let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1260
1261        assert!(scope.sources.contains_key("x"));
1262        assert_eq!(scope.derived_table_scopes.len(), 1);
1263
1264        let derived = &mut scope.derived_table_scopes[0];
1265        assert!(derived.is_derived_table());
1266        assert!(derived.sources.contains_key("t"));
1267    }
1268
1269    #[test]
1270    fn test_non_correlated_subquery() {
1271        let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1272
1273        assert_eq!(scope.subquery_scopes.len(), 1);
1274
1275        let subquery = &mut scope.subquery_scopes[0];
1276        assert!(subquery.is_subquery());
1277        assert!(subquery.can_be_correlated);
1278
1279        // The subquery references only 's', which is in its own sources
1280        assert!(subquery.sources.contains_key("s"));
1281        assert!(!subquery.is_correlated_subquery());
1282    }
1283
1284    #[test]
1285    fn test_correlated_subquery() {
1286        let mut scope =
1287            parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1288
1289        assert_eq!(scope.subquery_scopes.len(), 1);
1290
1291        let subquery = &mut scope.subquery_scopes[0];
1292        assert!(subquery.is_subquery());
1293        assert!(subquery.can_be_correlated);
1294
1295        // The subquery references 't.y' which is external
1296        let external = subquery.external_columns();
1297        assert!(!external.is_empty());
1298        assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1299        assert!(subquery.is_correlated_subquery());
1300    }
1301
1302    #[test]
1303    fn test_cte_scope() {
1304        let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1305
1306        assert_eq!(scope.cte_scopes.len(), 1);
1307        assert!(scope.cte_sources.contains_key("cte"));
1308
1309        let cte = &scope.cte_scopes[0];
1310        assert!(cte.is_cte());
1311    }
1312
1313    #[test]
1314    fn test_multiple_sources() {
1315        let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1316
1317        assert!(scope.sources.contains_key("t"));
1318        assert!(scope.sources.contains_key("s"));
1319        assert_eq!(scope.sources.len(), 2);
1320    }
1321
1322    #[test]
1323    fn test_aliased_table() {
1324        let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1325
1326        // Should be indexed by alias, not original name
1327        assert!(scope.sources.contains_key("x"));
1328        assert!(!scope.sources.contains_key("t"));
1329    }
1330
1331    #[test]
1332    fn test_local_columns() {
1333        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1334
1335        let local = scope.local_columns();
1336        // All columns are local since both t and s are in scope.
1337        // Includes JOIN ON references (t.id, s.id).
1338        assert_eq!(local.len(), 5);
1339        assert!(local.iter().all(|c| c.table.is_some()));
1340    }
1341
1342    #[test]
1343    fn test_columns_include_join_on_clause_references() {
1344        let mut scope = parse_and_build_scope(
1345            "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1346        );
1347
1348        let cols: Vec<String> = scope
1349            .columns()
1350            .iter()
1351            .map(|c| match &c.table {
1352                Some(t) => format!("{}.{}", t, c.name),
1353                None => c.name.clone(),
1354            })
1355            .collect();
1356
1357        assert!(cols.contains(&"o.total".to_string()));
1358        assert!(cols.contains(&"c.id".to_string()));
1359        assert!(cols.contains(&"o.customer_id".to_string()));
1360    }
1361
1362    #[test]
1363    fn test_unqualified_columns() {
1364        let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1365
1366        let unqualified = scope.unqualified_columns();
1367        // Only a and b are unqualified
1368        assert_eq!(unqualified.len(), 2);
1369        assert!(unqualified.iter().all(|c| c.table.is_none()));
1370    }
1371
1372    #[test]
1373    fn test_source_columns() {
1374        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1375
1376        let t_cols = scope.source_columns("t");
1377        // t.a, t.b, and t.id from JOIN condition
1378        assert!(t_cols.len() >= 2);
1379        assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1380
1381        let s_cols = scope.source_columns("s");
1382        // s.c and s.id from JOIN condition
1383        assert!(s_cols.len() >= 1);
1384        assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1385    }
1386
1387    #[test]
1388    fn test_rename_source() {
1389        let mut scope = parse_and_build_scope("SELECT a FROM t");
1390
1391        assert!(scope.sources.contains_key("t"));
1392        scope.rename_source("t", "new_name".to_string());
1393        assert!(!scope.sources.contains_key("t"));
1394        assert!(scope.sources.contains_key("new_name"));
1395    }
1396
1397    #[test]
1398    fn test_remove_source() {
1399        let mut scope = parse_and_build_scope("SELECT a FROM t");
1400
1401        assert!(scope.sources.contains_key("t"));
1402        scope.remove_source("t");
1403        assert!(!scope.sources.contains_key("t"));
1404    }
1405
1406    #[test]
1407    fn test_walk_in_scope() {
1408        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1409        let expr = &ast[0];
1410
1411        // Walk should visit all expressions within the scope
1412        let walked: Vec<_> = walk_in_scope(expr, true).collect();
1413        assert!(!walked.is_empty());
1414
1415        // Should include the root SELECT
1416        assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1417        // Should include columns
1418        assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1419    }
1420
1421    #[test]
1422    fn test_find_in_scope() {
1423        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1424        let expr = &ast[0];
1425
1426        // Find the first column
1427        let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1428        assert!(found.is_some());
1429        assert!(matches!(found.unwrap(), Expression::Column(_)));
1430    }
1431
1432    #[test]
1433    fn test_find_all_in_scope() {
1434        let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1435        let expr = &ast[0];
1436
1437        // Find all columns
1438        let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1439        assert_eq!(found.len(), 3);
1440    }
1441
1442    #[test]
1443    fn test_traverse_scope() {
1444        let ast =
1445            Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1446        let expr = &ast[0];
1447
1448        let scopes = traverse_scope(expr);
1449        // traverse_scope returns all scopes via Scope::traverse
1450        // which includes derived table and root scopes
1451        assert!(!scopes.is_empty());
1452        // The root scope is always included
1453        assert!(scopes.iter().any(|s| s.is_root()));
1454    }
1455
1456    #[test]
1457    fn test_branch_with_options() {
1458        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1459        let scope = build_scope(&ast[0]);
1460
1461        let child = scope.branch_with_options(
1462            ast[0].clone(),
1463            ScopeType::Subquery, // Use Subquery to test can_be_correlated
1464            None,
1465            None,
1466            Some(vec!["col1".to_string(), "col2".to_string()]),
1467        );
1468
1469        assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1470        assert!(child.can_be_correlated); // Subqueries are correlated
1471    }
1472
1473    #[test]
1474    fn test_is_udtf() {
1475        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1476        let scope = Scope::new(ast[0].clone());
1477        assert!(!scope.is_udtf());
1478
1479        let root = build_scope(&ast[0]);
1480        let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1481        assert!(udtf_scope.is_udtf());
1482    }
1483
1484    #[test]
1485    fn test_is_union() {
1486        let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1487
1488        assert!(scope.is_root());
1489        assert_eq!(scope.union_scopes.len(), 2);
1490        // The children are set operation scopes
1491        assert!(scope.union_scopes[0].is_union());
1492        assert!(scope.union_scopes[1].is_union());
1493    }
1494
1495    #[test]
1496    fn test_union_output_columns() {
1497        let scope = parse_and_build_scope(
1498            "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1499        );
1500        assert_eq!(scope.output_columns(), vec!["id", "name"]);
1501    }
1502
1503    #[test]
1504    fn test_clear_cache() {
1505        let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1506
1507        // First call populates cache
1508        let _ = scope.columns();
1509        assert!(scope.columns_cache.is_some());
1510
1511        // Clear cache
1512        scope.clear_cache();
1513        assert!(scope.columns_cache.is_none());
1514        assert!(scope.external_columns_cache.is_none());
1515    }
1516
1517    #[test]
1518    fn test_scope_traverse() {
1519        let scope = parse_and_build_scope(
1520            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1521        );
1522
1523        let traversed = scope.traverse();
1524        // Should include: CTE scope, subquery scope, root scope
1525        assert!(traversed.len() >= 3);
1526    }
1527
1528    #[test]
1529    fn test_create_table_as_select_scope() {
1530        // Simple CTAS
1531        let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1532        assert!(
1533            scope.sources.contains_key("src"),
1534            "CTAS scope should contain the FROM table"
1535        );
1536        assert!(
1537            !scope.sources.contains_key("out_table"),
1538            "CTAS target table should not be treated as a source"
1539        );
1540
1541        // CTAS with multiple FROM tables
1542        let scope = parse_and_build_scope(
1543            "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1544        );
1545        assert!(scope.sources.contains_key("a"));
1546        assert!(scope.sources.contains_key("b"));
1547        assert!(
1548            !scope.sources.contains_key("out_table"),
1549            "CTAS target table should not be treated as a source"
1550        );
1551
1552        // CTAS with CTEs
1553        let scope = parse_and_build_scope(
1554            "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1555        );
1556        assert!(
1557            scope.sources.contains_key("cte"),
1558            "CTAS with CTE should resolve CTE as source"
1559        );
1560        assert!(
1561            !scope.sources.contains_key("out_table"),
1562            "CTAS target table should not be treated as a source"
1563        );
1564        assert_eq!(scope.cte_scopes.len(), 1);
1565    }
1566
1567    #[test]
1568    fn test_create_table_as_select_traverse() {
1569        let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1570        let scopes = traverse_scope(&ast[0]);
1571        assert!(
1572            !scopes.is_empty(),
1573            "traverse_scope should return scopes for CTAS"
1574        );
1575    }
1576}