Skip to main content

polyglot_sql/
scope.rs

1//! Scope Analysis Module
2//!
3//! This module provides scope analysis for SQL queries, enabling detection of
4//! correlated subqueries, column references, and scope relationships.
5//!
6//! Ported from sqlglot's optimizer/scope.py
7
8use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14/// Type of scope in a SQL query
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19    /// Root scope of the query
20    Root,
21    /// Subquery scope (e.g., WHERE x IN (SELECT ...))
22    Subquery,
23    /// Derived table scope (e.g., FROM (SELECT ...) AS t)
24    DerivedTable,
25    /// Common Table Expression scope
26    Cte,
27    /// Union/Intersect/Except scope
28    SetOperation,
29    /// User-Defined Table Function scope
30    Udtf,
31}
32
33/// Information about a source (table or subquery) in a scope
34#[derive(Debug, Clone)]
35pub struct SourceInfo {
36    /// The source expression (Table or subquery)
37    pub expression: Expression,
38    /// Whether this source is a scope (vs. a plain table)
39    pub is_scope: bool,
40}
41
42/// A column reference found in a scope
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45    /// The table/alias qualifier (if any)
46    pub table: Option<String>,
47    /// The column name
48    pub name: String,
49}
50
51/// Represents a scope in a SQL query
52///
53/// A scope is the context of a SELECT statement and its sources.
54/// Scopes can be nested (subqueries, CTEs, derived tables) and form a tree.
55#[derive(Debug, Clone)]
56pub struct Scope {
57    /// The expression at the root of this scope
58    pub expression: Expression,
59
60    /// Type of this scope relative to its parent
61    pub scope_type: ScopeType,
62
63    /// Mapping of source names to their info
64    pub sources: HashMap<String, SourceInfo>,
65
66    /// Sources from LATERAL views (have access to preceding sources)
67    pub lateral_sources: HashMap<String, SourceInfo>,
68
69    /// CTE sources available to this scope
70    pub cte_sources: HashMap<String, SourceInfo>,
71
72    /// If this is a derived table or CTE with alias columns, this is that list
73    /// e.g., `SELECT * FROM (SELECT ...) AS y(col1, col2)` => ["col1", "col2"]
74    pub outer_columns: Vec<String>,
75
76    /// Whether this scope can potentially be correlated
77    /// (true for subqueries and UDTFs)
78    pub can_be_correlated: bool,
79
80    /// Child subquery scopes
81    pub subquery_scopes: Vec<Scope>,
82
83    /// Child derived table scopes
84    pub derived_table_scopes: Vec<Scope>,
85
86    /// Child CTE scopes
87    pub cte_scopes: Vec<Scope>,
88
89    /// Child UDTF (User Defined Table Function) scopes
90    pub udtf_scopes: Vec<Scope>,
91
92    /// Combined derived_table_scopes + udtf_scopes in definition order
93    pub table_scopes: Vec<Scope>,
94
95    /// Union/set operation scopes (left and right)
96    pub union_scopes: Vec<Scope>,
97
98    /// Cached columns
99    columns_cache: Option<Vec<ColumnRef>>,
100
101    /// Cached external columns
102    external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106    /// Create a new root scope
107    pub fn new(expression: Expression) -> Self {
108        Self {
109            expression,
110            scope_type: ScopeType::Root,
111            sources: HashMap::new(),
112            lateral_sources: HashMap::new(),
113            cte_sources: HashMap::new(),
114            outer_columns: Vec::new(),
115            can_be_correlated: false,
116            subquery_scopes: Vec::new(),
117            derived_table_scopes: Vec::new(),
118            cte_scopes: Vec::new(),
119            udtf_scopes: Vec::new(),
120            table_scopes: Vec::new(),
121            union_scopes: Vec::new(),
122            columns_cache: None,
123            external_columns_cache: None,
124        }
125    }
126
127    /// Create a child scope branching from this one
128    pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129        self.branch_with_options(expression, scope_type, None, None, None)
130    }
131
132    /// Create a child scope with additional options
133    pub fn branch_with_options(
134        &self,
135        expression: Expression,
136        scope_type: ScopeType,
137        sources: Option<HashMap<String, SourceInfo>>,
138        lateral_sources: Option<HashMap<String, SourceInfo>>,
139        outer_columns: Option<Vec<String>>,
140    ) -> Self {
141        let can_be_correlated = self.can_be_correlated
142            || scope_type == ScopeType::Subquery
143            || scope_type == ScopeType::Udtf;
144
145        Self {
146            expression,
147            scope_type,
148            sources: sources.unwrap_or_default(),
149            lateral_sources: lateral_sources.unwrap_or_default(),
150            cte_sources: self.cte_sources.clone(),
151            outer_columns: outer_columns.unwrap_or_default(),
152            can_be_correlated,
153            subquery_scopes: Vec::new(),
154            derived_table_scopes: Vec::new(),
155            cte_scopes: Vec::new(),
156            udtf_scopes: Vec::new(),
157            table_scopes: Vec::new(),
158            union_scopes: Vec::new(),
159            columns_cache: None,
160            external_columns_cache: None,
161        }
162    }
163
164    /// Clear all cached properties
165    pub fn clear_cache(&mut self) {
166        self.columns_cache = None;
167        self.external_columns_cache = None;
168    }
169
170    /// Add a source to this scope
171    pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172        self.sources.insert(
173            name,
174            SourceInfo {
175                expression,
176                is_scope,
177            },
178        );
179        self.clear_cache();
180    }
181
182    /// Add a lateral source to this scope
183    pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184        self.lateral_sources.insert(
185            name.clone(),
186            SourceInfo {
187                expression: expression.clone(),
188                is_scope,
189            },
190        );
191        self.sources.insert(
192            name,
193            SourceInfo {
194                expression,
195                is_scope,
196            },
197        );
198        self.clear_cache();
199    }
200
201    /// Add a CTE source to this scope
202    pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203        self.cte_sources.insert(
204            name.clone(),
205            SourceInfo {
206                expression: expression.clone(),
207                is_scope: true,
208            },
209        );
210        self.sources.insert(
211            name,
212            SourceInfo {
213                expression,
214                is_scope: true,
215            },
216        );
217        self.clear_cache();
218    }
219
220    /// Rename a source
221    pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222        if let Some(source) = self.sources.remove(old_name) {
223            self.sources.insert(new_name, source);
224        }
225        self.clear_cache();
226    }
227
228    /// Remove a source
229    pub fn remove_source(&mut self, name: &str) {
230        self.sources.remove(name);
231        self.clear_cache();
232    }
233
234    /// Collect all column references in this scope
235    pub fn columns(&mut self) -> &[ColumnRef] {
236        if self.columns_cache.is_none() {
237            let mut columns = Vec::new();
238            collect_columns(&self.expression, &mut columns);
239            self.columns_cache = Some(columns);
240        }
241        self.columns_cache.as_ref().unwrap()
242    }
243
244    /// Get all source names in this scope
245    pub fn source_names(&self) -> HashSet<String> {
246        let mut names: HashSet<String> = self.sources.keys().cloned().collect();
247        names.extend(self.cte_sources.keys().cloned());
248        names
249    }
250
251    /// Get columns that reference sources outside this scope
252    pub fn external_columns(&mut self) -> Vec<ColumnRef> {
253        if self.external_columns_cache.is_some() {
254            return self.external_columns_cache.clone().unwrap();
255        }
256
257        let source_names = self.source_names();
258        let columns = self.columns().to_vec();
259
260        let external: Vec<ColumnRef> = columns
261            .into_iter()
262            .filter(|col| {
263                // A column is external if it has a table qualifier that's not in our sources
264                match &col.table {
265                    Some(table) => !source_names.contains(table),
266                    None => false, // Unqualified columns might be local
267                }
268            })
269            .collect();
270
271        self.external_columns_cache = Some(external.clone());
272        external
273    }
274
275    /// Get columns that reference sources in this scope (not external)
276    pub fn local_columns(&mut self) -> Vec<ColumnRef> {
277        let external_set: HashSet<_> = self.external_columns().into_iter().collect();
278        let columns = self.columns().to_vec();
279
280        columns
281            .into_iter()
282            .filter(|col| !external_set.contains(col))
283            .collect()
284    }
285
286    /// Get unqualified columns (columns without table qualifier)
287    pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
288        self.columns()
289            .iter()
290            .filter(|c| c.table.is_none())
291            .cloned()
292            .collect()
293    }
294
295    /// Get columns for a specific source
296    pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
297        self.columns()
298            .iter()
299            .filter(|col| col.table.as_deref() == Some(source_name))
300            .cloned()
301            .collect()
302    }
303
304    /// Determine if this scope is a correlated subquery
305    ///
306    /// A subquery is correlated if:
307    /// 1. It can be correlated (is a subquery or UDTF), AND
308    /// 2. It references columns from outer scopes
309    pub fn is_correlated_subquery(&mut self) -> bool {
310        self.can_be_correlated && !self.external_columns().is_empty()
311    }
312
313    /// Check if this is a subquery scope
314    pub fn is_subquery(&self) -> bool {
315        self.scope_type == ScopeType::Subquery
316    }
317
318    /// Check if this is a derived table scope
319    pub fn is_derived_table(&self) -> bool {
320        self.scope_type == ScopeType::DerivedTable
321    }
322
323    /// Check if this is a CTE scope
324    pub fn is_cte(&self) -> bool {
325        self.scope_type == ScopeType::Cte
326    }
327
328    /// Check if this is the root scope
329    pub fn is_root(&self) -> bool {
330        self.scope_type == ScopeType::Root
331    }
332
333    /// Check if this is a UDTF scope
334    pub fn is_udtf(&self) -> bool {
335        self.scope_type == ScopeType::Udtf
336    }
337
338    /// Check if this is a union/set operation scope
339    pub fn is_union(&self) -> bool {
340        self.scope_type == ScopeType::SetOperation
341    }
342
343    /// Traverse all scopes in this tree (depth-first post-order)
344    pub fn traverse(&self) -> Vec<&Scope> {
345        let mut result = Vec::new();
346        self.traverse_impl(&mut result);
347        result
348    }
349
350    fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
351        // First traverse children
352        for scope in &self.cte_scopes {
353            scope.traverse_impl(result);
354        }
355        for scope in &self.union_scopes {
356            scope.traverse_impl(result);
357        }
358        for scope in &self.table_scopes {
359            scope.traverse_impl(result);
360        }
361        for scope in &self.subquery_scopes {
362            scope.traverse_impl(result);
363        }
364        // Then add self
365        result.push(self);
366    }
367
368    /// Count references to each scope in this tree
369    pub fn ref_count(&self) -> HashMap<usize, usize> {
370        let mut counts: HashMap<usize, usize> = HashMap::new();
371
372        for scope in self.traverse() {
373            for (_, source_info) in scope.sources.iter() {
374                if source_info.is_scope {
375                    let id = &source_info.expression as *const _ as usize;
376                    *counts.entry(id).or_insert(0) += 1;
377                }
378            }
379        }
380
381        counts
382    }
383}
384
385/// Collect all column references from an expression tree
386fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
387    match expr {
388        Expression::Column(col) => {
389            columns.push(ColumnRef {
390                table: col.table.as_ref().map(|t| t.name.clone()),
391                name: col.name.name.clone(),
392            });
393        }
394        Expression::Select(select) => {
395            // Collect from SELECT expressions
396            for e in &select.expressions {
397                collect_columns(e, columns);
398            }
399            // Collect from JOIN ON / MATCH_CONDITION clauses
400            for join in &select.joins {
401                if let Some(on) = &join.on {
402                    collect_columns(on, columns);
403                }
404                if let Some(match_condition) = &join.match_condition {
405                    collect_columns(match_condition, columns);
406                }
407            }
408            // Collect from WHERE
409            if let Some(where_clause) = &select.where_clause {
410                collect_columns(&where_clause.this, columns);
411            }
412            // Collect from HAVING
413            if let Some(having) = &select.having {
414                collect_columns(&having.this, columns);
415            }
416            // Collect from ORDER BY
417            if let Some(order_by) = &select.order_by {
418                for ord in &order_by.expressions {
419                    collect_columns(&ord.this, columns);
420                }
421            }
422            // Collect from GROUP BY
423            if let Some(group_by) = &select.group_by {
424                for e in &group_by.expressions {
425                    collect_columns(e, columns);
426                }
427            }
428            // Note: We don't recurse into FROM/JOIN source subqueries here
429            // as those create their own scopes.
430        }
431        // Binary operations
432        Expression::And(bin)
433        | Expression::Or(bin)
434        | Expression::Add(bin)
435        | Expression::Sub(bin)
436        | Expression::Mul(bin)
437        | Expression::Div(bin)
438        | Expression::Mod(bin)
439        | Expression::Eq(bin)
440        | Expression::Neq(bin)
441        | Expression::Lt(bin)
442        | Expression::Lte(bin)
443        | Expression::Gt(bin)
444        | Expression::Gte(bin)
445        | Expression::BitwiseAnd(bin)
446        | Expression::BitwiseOr(bin)
447        | Expression::BitwiseXor(bin)
448        | Expression::Concat(bin) => {
449            collect_columns(&bin.left, columns);
450            collect_columns(&bin.right, columns);
451        }
452        // LIKE/ILIKE operations
453        Expression::Like(like) | Expression::ILike(like) => {
454            collect_columns(&like.left, columns);
455            collect_columns(&like.right, columns);
456            if let Some(escape) = &like.escape {
457                collect_columns(escape, columns);
458            }
459        }
460        // Unary operations
461        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
462            collect_columns(&un.this, columns);
463        }
464        Expression::Function(func) => {
465            for arg in &func.args {
466                collect_columns(arg, columns);
467            }
468        }
469        Expression::AggregateFunction(agg) => {
470            for arg in &agg.args {
471                collect_columns(arg, columns);
472            }
473        }
474        Expression::WindowFunction(wf) => {
475            collect_columns(&wf.this, columns);
476            for e in &wf.over.partition_by {
477                collect_columns(e, columns);
478            }
479            for e in &wf.over.order_by {
480                collect_columns(&e.this, columns);
481            }
482        }
483        Expression::Alias(alias) => {
484            collect_columns(&alias.this, columns);
485        }
486        Expression::Case(case) => {
487            if let Some(operand) = &case.operand {
488                collect_columns(operand, columns);
489            }
490            for (when_expr, then_expr) in &case.whens {
491                collect_columns(when_expr, columns);
492                collect_columns(then_expr, columns);
493            }
494            if let Some(else_clause) = &case.else_ {
495                collect_columns(else_clause, columns);
496            }
497        }
498        Expression::Paren(paren) => {
499            collect_columns(&paren.this, columns);
500        }
501        Expression::Ordered(ord) => {
502            collect_columns(&ord.this, columns);
503        }
504        Expression::In(in_expr) => {
505            collect_columns(&in_expr.this, columns);
506            for e in &in_expr.expressions {
507                collect_columns(e, columns);
508            }
509            // Note: in_expr.query is a subquery - creates its own scope
510        }
511        Expression::Between(between) => {
512            collect_columns(&between.this, columns);
513            collect_columns(&between.low, columns);
514            collect_columns(&between.high, columns);
515        }
516        Expression::IsNull(is_null) => {
517            collect_columns(&is_null.this, columns);
518        }
519        Expression::Cast(cast) => {
520            collect_columns(&cast.this, columns);
521        }
522        Expression::Extract(extract) => {
523            collect_columns(&extract.this, columns);
524        }
525        Expression::Exists(_) | Expression::Subquery(_) => {
526            // These create their own scopes - don't collect from here
527        }
528        _ => {
529            // For other expressions, we might need to add more cases
530        }
531    }
532}
533
534/// Build scope tree from an expression
535///
536/// This traverses the expression tree and builds a hierarchy of Scope objects
537/// that track sources and column references at each level.
538pub fn build_scope(expression: &Expression) -> Scope {
539    let mut root = Scope::new(expression.clone());
540    build_scope_impl(expression, &mut root);
541    root
542}
543
544fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
545    match expression {
546        Expression::Select(select) => {
547            // Process CTEs first
548            if let Some(with) = &select.with {
549                for cte in &with.ctes {
550                    let cte_name = cte.alias.name.clone();
551                    let mut cte_scope = current_scope
552                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
553                    build_scope_impl(&cte.this, &mut cte_scope);
554                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
555                    current_scope.cte_scopes.push(cte_scope);
556                }
557            }
558
559            // Process FROM clause
560            if let Some(from) = &select.from {
561                for table in &from.expressions {
562                    add_table_to_scope(table, current_scope);
563                }
564            }
565
566            // Process JOINs
567            for join in &select.joins {
568                add_table_to_scope(&join.this, current_scope);
569            }
570
571            // Process subqueries in WHERE, SELECT expressions, etc.
572            collect_subqueries(expression, current_scope);
573        }
574        Expression::Union(union) => {
575            let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
576            build_scope_impl(&union.left, &mut left_scope);
577
578            let mut right_scope =
579                current_scope.branch(union.right.clone(), ScopeType::SetOperation);
580            build_scope_impl(&union.right, &mut right_scope);
581
582            current_scope.union_scopes.push(left_scope);
583            current_scope.union_scopes.push(right_scope);
584        }
585        Expression::Intersect(intersect) => {
586            let mut left_scope =
587                current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
588            build_scope_impl(&intersect.left, &mut left_scope);
589
590            let mut right_scope =
591                current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
592            build_scope_impl(&intersect.right, &mut right_scope);
593
594            current_scope.union_scopes.push(left_scope);
595            current_scope.union_scopes.push(right_scope);
596        }
597        Expression::Except(except) => {
598            let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
599            build_scope_impl(&except.left, &mut left_scope);
600
601            let mut right_scope =
602                current_scope.branch(except.right.clone(), ScopeType::SetOperation);
603            build_scope_impl(&except.right, &mut right_scope);
604
605            current_scope.union_scopes.push(left_scope);
606            current_scope.union_scopes.push(right_scope);
607        }
608        _ => {}
609    }
610}
611
612fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
613    match expr {
614        Expression::Table(table) => {
615            let name = table
616                .alias
617                .as_ref()
618                .map(|a| a.name.clone())
619                .unwrap_or_else(|| table.name.name.clone());
620            let cte_source = if table.schema.is_none() && table.catalog.is_none() {
621                scope
622                    .cte_sources
623                    .get(&table.name.name)
624                    .or_else(|| {
625                        scope
626                            .cte_sources
627                            .iter()
628                            .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
629                            .map(|(_, source)| source)
630                    })
631            } else {
632                None
633            };
634
635            if let Some(source) = cte_source {
636                scope.add_source(name, source.expression.clone(), true);
637            } else {
638                scope.add_source(name, expr.clone(), false);
639            }
640        }
641        Expression::Subquery(subquery) => {
642            let name = subquery
643                .alias
644                .as_ref()
645                .map(|a| a.name.clone())
646                .unwrap_or_default();
647
648            let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
649            build_scope_impl(&subquery.this, &mut derived_scope);
650
651            scope.add_source(name.clone(), expr.clone(), true);
652            scope.derived_table_scopes.push(derived_scope);
653        }
654        Expression::Paren(paren) => {
655            add_table_to_scope(&paren.this, scope);
656        }
657        _ => {}
658    }
659}
660
661fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
662    match expr {
663        Expression::Select(select) => {
664            // Check WHERE for subqueries
665            if let Some(where_clause) = &select.where_clause {
666                collect_subqueries_in_expr(&where_clause.this, parent_scope);
667            }
668            // Check SELECT expressions for subqueries
669            for e in &select.expressions {
670                collect_subqueries_in_expr(e, parent_scope);
671            }
672            // Check HAVING for subqueries
673            if let Some(having) = &select.having {
674                collect_subqueries_in_expr(&having.this, parent_scope);
675            }
676        }
677        _ => {}
678    }
679}
680
681fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
682    match expr {
683        Expression::Subquery(subquery) if subquery.alias.is_none() => {
684            // This is a scalar subquery or IN subquery (not a derived table)
685            let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
686            build_scope_impl(&subquery.this, &mut sub_scope);
687            parent_scope.subquery_scopes.push(sub_scope);
688        }
689        Expression::In(in_expr) => {
690            collect_subqueries_in_expr(&in_expr.this, parent_scope);
691            if let Some(query) = &in_expr.query {
692                let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
693                build_scope_impl(query, &mut sub_scope);
694                parent_scope.subquery_scopes.push(sub_scope);
695            }
696        }
697        Expression::Exists(exists) => {
698            let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
699            build_scope_impl(&exists.this, &mut sub_scope);
700            parent_scope.subquery_scopes.push(sub_scope);
701        }
702        // Binary operations
703        Expression::And(bin)
704        | Expression::Or(bin)
705        | Expression::Add(bin)
706        | Expression::Sub(bin)
707        | Expression::Mul(bin)
708        | Expression::Div(bin)
709        | Expression::Mod(bin)
710        | Expression::Eq(bin)
711        | Expression::Neq(bin)
712        | Expression::Lt(bin)
713        | Expression::Lte(bin)
714        | Expression::Gt(bin)
715        | Expression::Gte(bin)
716        | Expression::BitwiseAnd(bin)
717        | Expression::BitwiseOr(bin)
718        | Expression::BitwiseXor(bin)
719        | Expression::Concat(bin) => {
720            collect_subqueries_in_expr(&bin.left, parent_scope);
721            collect_subqueries_in_expr(&bin.right, parent_scope);
722        }
723        // LIKE/ILIKE operations (have different structure with escape)
724        Expression::Like(like) | Expression::ILike(like) => {
725            collect_subqueries_in_expr(&like.left, parent_scope);
726            collect_subqueries_in_expr(&like.right, parent_scope);
727            if let Some(escape) = &like.escape {
728                collect_subqueries_in_expr(escape, parent_scope);
729            }
730        }
731        // Unary operations
732        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
733            collect_subqueries_in_expr(&un.this, parent_scope);
734        }
735        Expression::Function(func) => {
736            for arg in &func.args {
737                collect_subqueries_in_expr(arg, parent_scope);
738            }
739        }
740        Expression::Case(case) => {
741            if let Some(operand) = &case.operand {
742                collect_subqueries_in_expr(operand, parent_scope);
743            }
744            for (when_expr, then_expr) in &case.whens {
745                collect_subqueries_in_expr(when_expr, parent_scope);
746                collect_subqueries_in_expr(then_expr, parent_scope);
747            }
748            if let Some(else_clause) = &case.else_ {
749                collect_subqueries_in_expr(else_clause, parent_scope);
750            }
751        }
752        Expression::Paren(paren) => {
753            collect_subqueries_in_expr(&paren.this, parent_scope);
754        }
755        Expression::Alias(alias) => {
756            collect_subqueries_in_expr(&alias.this, parent_scope);
757        }
758        _ => {}
759    }
760}
761
762/// Walk within a scope, yielding expressions without crossing scope boundaries.
763///
764/// This iterator visits all nodes in the syntax tree, stopping at nodes that
765/// start child scopes (CTEs, derived tables, subqueries in FROM/JOIN).
766///
767/// # Arguments
768/// * `expression` - The expression to walk
769/// * `bfs` - If true, uses breadth-first search; otherwise uses depth-first search
770///
771/// # Returns
772/// An iterator over expressions within the scope
773pub fn walk_in_scope<'a>(
774    expression: &'a Expression,
775    bfs: bool,
776) -> impl Iterator<Item = &'a Expression> {
777    WalkInScopeIter::new(expression, bfs)
778}
779
780/// Iterator for walking within a scope
781struct WalkInScopeIter<'a> {
782    queue: VecDeque<&'a Expression>,
783    bfs: bool,
784}
785
786impl<'a> WalkInScopeIter<'a> {
787    fn new(expression: &'a Expression, bfs: bool) -> Self {
788        let mut queue = VecDeque::new();
789        queue.push_back(expression);
790        Self { queue, bfs }
791    }
792
793    fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
794        if is_root {
795            return false;
796        }
797
798        // Stop at CTE definitions
799        if matches!(expr, Expression::Cte(_)) {
800            return true;
801        }
802
803        // Stop at subqueries that are derived tables (in FROM/JOIN)
804        if let Expression::Subquery(subquery) = expr {
805            if subquery.alias.is_some() {
806                return true;
807            }
808        }
809
810        // Stop at standalone SELECT/UNION/etc that would be subqueries
811        if matches!(
812            expr,
813            Expression::Select(_)
814                | Expression::Union(_)
815                | Expression::Intersect(_)
816                | Expression::Except(_)
817        ) {
818            return true;
819        }
820
821        false
822    }
823
824    fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
825        let mut children = Vec::new();
826
827        match expr {
828            Expression::Select(select) => {
829                // Walk SELECT expressions
830                for e in &select.expressions {
831                    children.push(e);
832                }
833                // Walk FROM (but tables/subqueries create new scopes)
834                if let Some(from) = &select.from {
835                    for table in &from.expressions {
836                        if !self.should_stop_at(table, false) {
837                            children.push(table);
838                        }
839                    }
840                }
841                // Walk JOINs (but their sources create new scopes)
842                for join in &select.joins {
843                    if let Some(on) = &join.on {
844                        children.push(on);
845                    }
846                    // Don't traverse join.this as it's a source (table or subquery)
847                }
848                // Walk WHERE
849                if let Some(where_clause) = &select.where_clause {
850                    children.push(&where_clause.this);
851                }
852                // Walk GROUP BY
853                if let Some(group_by) = &select.group_by {
854                    for e in &group_by.expressions {
855                        children.push(e);
856                    }
857                }
858                // Walk HAVING
859                if let Some(having) = &select.having {
860                    children.push(&having.this);
861                }
862                // Walk ORDER BY
863                if let Some(order_by) = &select.order_by {
864                    for ord in &order_by.expressions {
865                        children.push(&ord.this);
866                    }
867                }
868                // Walk LIMIT
869                if let Some(limit) = &select.limit {
870                    children.push(&limit.this);
871                }
872                // Walk OFFSET
873                if let Some(offset) = &select.offset {
874                    children.push(&offset.this);
875                }
876            }
877            Expression::And(bin)
878            | Expression::Or(bin)
879            | Expression::Add(bin)
880            | Expression::Sub(bin)
881            | Expression::Mul(bin)
882            | Expression::Div(bin)
883            | Expression::Mod(bin)
884            | Expression::Eq(bin)
885            | Expression::Neq(bin)
886            | Expression::Lt(bin)
887            | Expression::Lte(bin)
888            | Expression::Gt(bin)
889            | Expression::Gte(bin)
890            | Expression::BitwiseAnd(bin)
891            | Expression::BitwiseOr(bin)
892            | Expression::BitwiseXor(bin)
893            | Expression::Concat(bin) => {
894                children.push(&bin.left);
895                children.push(&bin.right);
896            }
897            Expression::Like(like) | Expression::ILike(like) => {
898                children.push(&like.left);
899                children.push(&like.right);
900                if let Some(escape) = &like.escape {
901                    children.push(escape);
902                }
903            }
904            Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
905                children.push(&un.this);
906            }
907            Expression::Function(func) => {
908                for arg in &func.args {
909                    children.push(arg);
910                }
911            }
912            Expression::AggregateFunction(agg) => {
913                for arg in &agg.args {
914                    children.push(arg);
915                }
916            }
917            Expression::WindowFunction(wf) => {
918                children.push(&wf.this);
919                for e in &wf.over.partition_by {
920                    children.push(e);
921                }
922                for e in &wf.over.order_by {
923                    children.push(&e.this);
924                }
925            }
926            Expression::Alias(alias) => {
927                children.push(&alias.this);
928            }
929            Expression::Case(case) => {
930                if let Some(operand) = &case.operand {
931                    children.push(operand);
932                }
933                for (when_expr, then_expr) in &case.whens {
934                    children.push(when_expr);
935                    children.push(then_expr);
936                }
937                if let Some(else_clause) = &case.else_ {
938                    children.push(else_clause);
939                }
940            }
941            Expression::Paren(paren) => {
942                children.push(&paren.this);
943            }
944            Expression::Ordered(ord) => {
945                children.push(&ord.this);
946            }
947            Expression::In(in_expr) => {
948                children.push(&in_expr.this);
949                for e in &in_expr.expressions {
950                    children.push(e);
951                }
952                // Note: in_expr.query creates a new scope - don't traverse
953            }
954            Expression::Between(between) => {
955                children.push(&between.this);
956                children.push(&between.low);
957                children.push(&between.high);
958            }
959            Expression::IsNull(is_null) => {
960                children.push(&is_null.this);
961            }
962            Expression::Cast(cast) => {
963                children.push(&cast.this);
964            }
965            Expression::Extract(extract) => {
966                children.push(&extract.this);
967            }
968            Expression::Coalesce(coalesce) => {
969                for e in &coalesce.expressions {
970                    children.push(e);
971                }
972            }
973            Expression::NullIf(nullif) => {
974                children.push(&nullif.this);
975                children.push(&nullif.expression);
976            }
977            Expression::Table(_table) => {
978                // Tables don't have child expressions to traverse within scope
979                // (joins are handled at the Select level)
980            }
981            Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
982                // Leaf nodes - no children
983            }
984            // Subqueries and Exists create new scopes - don't traverse into them
985            Expression::Subquery(_) | Expression::Exists(_) => {}
986            _ => {
987                // For other expressions, we could add more cases as needed
988            }
989        }
990
991        children
992    }
993}
994
995impl<'a> Iterator for WalkInScopeIter<'a> {
996    type Item = &'a Expression;
997
998    fn next(&mut self) -> Option<Self::Item> {
999        let expr = if self.bfs {
1000            self.queue.pop_front()?
1001        } else {
1002            self.queue.pop_back()?
1003        };
1004
1005        // Get children that don't cross scope boundaries
1006        let children = self.get_children(expr);
1007
1008        if self.bfs {
1009            for child in children {
1010                if !self.should_stop_at(child, false) {
1011                    self.queue.push_back(child);
1012                }
1013            }
1014        } else {
1015            for child in children.into_iter().rev() {
1016                if !self.should_stop_at(child, false) {
1017                    self.queue.push_back(child);
1018                }
1019            }
1020        }
1021
1022        Some(expr)
1023    }
1024}
1025
1026/// Find the first expression matching the predicate within this scope.
1027///
1028/// This does NOT traverse into subscopes.
1029///
1030/// # Arguments
1031/// * `expression` - The root expression
1032/// * `predicate` - Function that returns true for matching expressions
1033/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1034///
1035/// # Returns
1036/// The first matching expression, or None
1037pub fn find_in_scope<'a, F>(
1038    expression: &'a Expression,
1039    predicate: F,
1040    bfs: bool,
1041) -> Option<&'a Expression>
1042where
1043    F: Fn(&Expression) -> bool,
1044{
1045    walk_in_scope(expression, bfs).find(|e| predicate(e))
1046}
1047
1048/// Find all expressions matching the predicate within this scope.
1049///
1050/// This does NOT traverse into subscopes.
1051///
1052/// # Arguments
1053/// * `expression` - The root expression
1054/// * `predicate` - Function that returns true for matching expressions
1055/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1056///
1057/// # Returns
1058/// A vector of matching expressions
1059pub fn find_all_in_scope<'a, F>(
1060    expression: &'a Expression,
1061    predicate: F,
1062    bfs: bool,
1063) -> Vec<&'a Expression>
1064where
1065    F: Fn(&Expression) -> bool,
1066{
1067    walk_in_scope(expression, bfs)
1068        .filter(|e| predicate(e))
1069        .collect()
1070}
1071
1072/// Traverse an expression by its "scopes".
1073///
1074/// Returns a list of all scopes in depth-first post-order.
1075///
1076/// # Arguments
1077/// * `expression` - The expression to traverse
1078///
1079/// # Returns
1080/// A vector of all scopes found
1081pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1082    match expression {
1083        Expression::Select(_)
1084        | Expression::Union(_)
1085        | Expression::Intersect(_)
1086        | Expression::Except(_) => {
1087            let root = build_scope(expression);
1088            root.traverse().into_iter().cloned().collect()
1089        }
1090        _ => Vec::new(),
1091    }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096    use super::*;
1097    use crate::parser::Parser;
1098
1099    fn parse_and_build_scope(sql: &str) -> Scope {
1100        let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1101        build_scope(&ast[0])
1102    }
1103
1104    #[test]
1105    fn test_simple_select_scope() {
1106        let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1107
1108        assert!(scope.is_root());
1109        assert!(!scope.can_be_correlated);
1110        assert!(scope.sources.contains_key("t"));
1111
1112        let columns = scope.columns();
1113        assert_eq!(columns.len(), 2);
1114    }
1115
1116    #[test]
1117    fn test_derived_table_scope() {
1118        let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1119
1120        assert!(scope.sources.contains_key("x"));
1121        assert_eq!(scope.derived_table_scopes.len(), 1);
1122
1123        let derived = &mut scope.derived_table_scopes[0];
1124        assert!(derived.is_derived_table());
1125        assert!(derived.sources.contains_key("t"));
1126    }
1127
1128    #[test]
1129    fn test_non_correlated_subquery() {
1130        let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1131
1132        assert_eq!(scope.subquery_scopes.len(), 1);
1133
1134        let subquery = &mut scope.subquery_scopes[0];
1135        assert!(subquery.is_subquery());
1136        assert!(subquery.can_be_correlated);
1137
1138        // The subquery references only 's', which is in its own sources
1139        assert!(subquery.sources.contains_key("s"));
1140        assert!(!subquery.is_correlated_subquery());
1141    }
1142
1143    #[test]
1144    fn test_correlated_subquery() {
1145        let mut scope =
1146            parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1147
1148        assert_eq!(scope.subquery_scopes.len(), 1);
1149
1150        let subquery = &mut scope.subquery_scopes[0];
1151        assert!(subquery.is_subquery());
1152        assert!(subquery.can_be_correlated);
1153
1154        // The subquery references 't.y' which is external
1155        let external = subquery.external_columns();
1156        assert!(!external.is_empty());
1157        assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1158        assert!(subquery.is_correlated_subquery());
1159    }
1160
1161    #[test]
1162    fn test_cte_scope() {
1163        let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1164
1165        assert_eq!(scope.cte_scopes.len(), 1);
1166        assert!(scope.cte_sources.contains_key("cte"));
1167
1168        let cte = &scope.cte_scopes[0];
1169        assert!(cte.is_cte());
1170    }
1171
1172    #[test]
1173    fn test_multiple_sources() {
1174        let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1175
1176        assert!(scope.sources.contains_key("t"));
1177        assert!(scope.sources.contains_key("s"));
1178        assert_eq!(scope.sources.len(), 2);
1179    }
1180
1181    #[test]
1182    fn test_aliased_table() {
1183        let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1184
1185        // Should be indexed by alias, not original name
1186        assert!(scope.sources.contains_key("x"));
1187        assert!(!scope.sources.contains_key("t"));
1188    }
1189
1190    #[test]
1191    fn test_local_columns() {
1192        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1193
1194        let local = scope.local_columns();
1195        // All columns are local since both t and s are in scope.
1196        // Includes JOIN ON references (t.id, s.id).
1197        assert_eq!(local.len(), 5);
1198        assert!(local.iter().all(|c| c.table.is_some()));
1199    }
1200
1201    #[test]
1202    fn test_columns_include_join_on_clause_references() {
1203        let mut scope = parse_and_build_scope(
1204            "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1205        );
1206
1207        let cols: Vec<String> = scope
1208            .columns()
1209            .iter()
1210            .map(|c| match &c.table {
1211                Some(t) => format!("{}.{}", t, c.name),
1212                None => c.name.clone(),
1213            })
1214            .collect();
1215
1216        assert!(cols.contains(&"o.total".to_string()));
1217        assert!(cols.contains(&"c.id".to_string()));
1218        assert!(cols.contains(&"o.customer_id".to_string()));
1219    }
1220
1221    #[test]
1222    fn test_unqualified_columns() {
1223        let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1224
1225        let unqualified = scope.unqualified_columns();
1226        // Only a and b are unqualified
1227        assert_eq!(unqualified.len(), 2);
1228        assert!(unqualified.iter().all(|c| c.table.is_none()));
1229    }
1230
1231    #[test]
1232    fn test_source_columns() {
1233        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1234
1235        let t_cols = scope.source_columns("t");
1236        // t.a, t.b, and t.id from JOIN condition
1237        assert!(t_cols.len() >= 2);
1238        assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1239
1240        let s_cols = scope.source_columns("s");
1241        // s.c and s.id from JOIN condition
1242        assert!(s_cols.len() >= 1);
1243        assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1244    }
1245
1246    #[test]
1247    fn test_rename_source() {
1248        let mut scope = parse_and_build_scope("SELECT a FROM t");
1249
1250        assert!(scope.sources.contains_key("t"));
1251        scope.rename_source("t", "new_name".to_string());
1252        assert!(!scope.sources.contains_key("t"));
1253        assert!(scope.sources.contains_key("new_name"));
1254    }
1255
1256    #[test]
1257    fn test_remove_source() {
1258        let mut scope = parse_and_build_scope("SELECT a FROM t");
1259
1260        assert!(scope.sources.contains_key("t"));
1261        scope.remove_source("t");
1262        assert!(!scope.sources.contains_key("t"));
1263    }
1264
1265    #[test]
1266    fn test_walk_in_scope() {
1267        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1268        let expr = &ast[0];
1269
1270        // Walk should visit all expressions within the scope
1271        let walked: Vec<_> = walk_in_scope(expr, true).collect();
1272        assert!(!walked.is_empty());
1273
1274        // Should include the root SELECT
1275        assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1276        // Should include columns
1277        assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1278    }
1279
1280    #[test]
1281    fn test_find_in_scope() {
1282        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1283        let expr = &ast[0];
1284
1285        // Find the first column
1286        let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1287        assert!(found.is_some());
1288        assert!(matches!(found.unwrap(), Expression::Column(_)));
1289    }
1290
1291    #[test]
1292    fn test_find_all_in_scope() {
1293        let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1294        let expr = &ast[0];
1295
1296        // Find all columns
1297        let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1298        assert_eq!(found.len(), 3);
1299    }
1300
1301    #[test]
1302    fn test_traverse_scope() {
1303        let ast =
1304            Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1305        let expr = &ast[0];
1306
1307        let scopes = traverse_scope(expr);
1308        // traverse_scope returns all scopes via Scope::traverse
1309        // which includes derived table and root scopes
1310        assert!(!scopes.is_empty());
1311        // The root scope is always included
1312        assert!(scopes.iter().any(|s| s.is_root()));
1313    }
1314
1315    #[test]
1316    fn test_branch_with_options() {
1317        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1318        let scope = build_scope(&ast[0]);
1319
1320        let child = scope.branch_with_options(
1321            ast[0].clone(),
1322            ScopeType::Subquery, // Use Subquery to test can_be_correlated
1323            None,
1324            None,
1325            Some(vec!["col1".to_string(), "col2".to_string()]),
1326        );
1327
1328        assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1329        assert!(child.can_be_correlated); // Subqueries are correlated
1330    }
1331
1332    #[test]
1333    fn test_is_udtf() {
1334        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1335        let scope = Scope::new(ast[0].clone());
1336        assert!(!scope.is_udtf());
1337
1338        let root = build_scope(&ast[0]);
1339        let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1340        assert!(udtf_scope.is_udtf());
1341    }
1342
1343    #[test]
1344    fn test_is_union() {
1345        let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1346
1347        assert!(scope.is_root());
1348        assert_eq!(scope.union_scopes.len(), 2);
1349        // The children are set operation scopes
1350        assert!(scope.union_scopes[0].is_union());
1351        assert!(scope.union_scopes[1].is_union());
1352    }
1353
1354    #[test]
1355    fn test_clear_cache() {
1356        let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1357
1358        // First call populates cache
1359        let _ = scope.columns();
1360        assert!(scope.columns_cache.is_some());
1361
1362        // Clear cache
1363        scope.clear_cache();
1364        assert!(scope.columns_cache.is_none());
1365        assert!(scope.external_columns_cache.is_none());
1366    }
1367
1368    #[test]
1369    fn test_scope_traverse() {
1370        let scope = parse_and_build_scope(
1371            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1372        );
1373
1374        let traversed = scope.traverse();
1375        // Should include: CTE scope, subquery scope, root scope
1376        assert!(traversed.len() >= 3);
1377    }
1378}