Skip to main content

polyglot_sql/
scope.rs

1//! Scope Analysis Module
2//!
3//! This module provides scope analysis for SQL queries, enabling detection of
4//! correlated subqueries, column references, and scope relationships.
5//!
6//! Ported from sqlglot's optimizer/scope.py
7
8use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14/// Type of scope in a SQL query
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19    /// Root scope of the query
20    Root,
21    /// Subquery scope (e.g., WHERE x IN (SELECT ...))
22    Subquery,
23    /// Derived table scope (e.g., FROM (SELECT ...) AS t)
24    DerivedTable,
25    /// Common Table Expression scope
26    Cte,
27    /// Union/Intersect/Except scope
28    SetOperation,
29    /// User-Defined Table Function scope
30    Udtf,
31}
32
33/// Information about a source (table or subquery) in a scope
34#[derive(Debug, Clone)]
35pub struct SourceInfo {
36    /// The source expression (Table or subquery)
37    pub expression: Expression,
38    /// Whether this source is a scope (vs. a plain table)
39    pub is_scope: bool,
40}
41
42/// A column reference found in a scope
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45    /// The table/alias qualifier (if any)
46    pub table: Option<String>,
47    /// The column name
48    pub name: String,
49}
50
51/// Represents a scope in a SQL query
52///
53/// A scope is the context of a SELECT statement and its sources.
54/// Scopes can be nested (subqueries, CTEs, derived tables) and form a tree.
55#[derive(Debug, Clone)]
56pub struct Scope {
57    /// The expression at the root of this scope
58    pub expression: Expression,
59
60    /// Type of this scope relative to its parent
61    pub scope_type: ScopeType,
62
63    /// Mapping of source names to their info
64    pub sources: HashMap<String, SourceInfo>,
65
66    /// Sources from LATERAL views (have access to preceding sources)
67    pub lateral_sources: HashMap<String, SourceInfo>,
68
69    /// CTE sources available to this scope
70    pub cte_sources: HashMap<String, SourceInfo>,
71
72    /// If this is a derived table or CTE with alias columns, this is that list
73    /// e.g., `SELECT * FROM (SELECT ...) AS y(col1, col2)` => ["col1", "col2"]
74    pub outer_columns: Vec<String>,
75
76    /// Whether this scope can potentially be correlated
77    /// (true for subqueries and UDTFs)
78    pub can_be_correlated: bool,
79
80    /// Child subquery scopes
81    pub subquery_scopes: Vec<Scope>,
82
83    /// Child derived table scopes
84    pub derived_table_scopes: Vec<Scope>,
85
86    /// Child CTE scopes
87    pub cte_scopes: Vec<Scope>,
88
89    /// Child UDTF (User Defined Table Function) scopes
90    pub udtf_scopes: Vec<Scope>,
91
92    /// Combined derived_table_scopes + udtf_scopes in definition order
93    pub table_scopes: Vec<Scope>,
94
95    /// Union/set operation scopes (left and right)
96    pub union_scopes: Vec<Scope>,
97
98    /// Cached columns
99    columns_cache: Option<Vec<ColumnRef>>,
100
101    /// Cached external columns
102    external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106    /// Create a new root scope
107    pub fn new(expression: Expression) -> Self {
108        Self {
109            expression,
110            scope_type: ScopeType::Root,
111            sources: HashMap::new(),
112            lateral_sources: HashMap::new(),
113            cte_sources: HashMap::new(),
114            outer_columns: Vec::new(),
115            can_be_correlated: false,
116            subquery_scopes: Vec::new(),
117            derived_table_scopes: Vec::new(),
118            cte_scopes: Vec::new(),
119            udtf_scopes: Vec::new(),
120            table_scopes: Vec::new(),
121            union_scopes: Vec::new(),
122            columns_cache: None,
123            external_columns_cache: None,
124        }
125    }
126
127    /// Create a child scope branching from this one
128    pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129        self.branch_with_options(expression, scope_type, None, None, None)
130    }
131
132    /// Create a child scope with additional options
133    pub fn branch_with_options(
134        &self,
135        expression: Expression,
136        scope_type: ScopeType,
137        sources: Option<HashMap<String, SourceInfo>>,
138        lateral_sources: Option<HashMap<String, SourceInfo>>,
139        outer_columns: Option<Vec<String>>,
140    ) -> Self {
141        let can_be_correlated = self.can_be_correlated
142            || scope_type == ScopeType::Subquery
143            || scope_type == ScopeType::Udtf;
144
145        Self {
146            expression,
147            scope_type,
148            sources: sources.unwrap_or_default(),
149            lateral_sources: lateral_sources.unwrap_or_default(),
150            cte_sources: self.cte_sources.clone(),
151            outer_columns: outer_columns.unwrap_or_default(),
152            can_be_correlated,
153            subquery_scopes: Vec::new(),
154            derived_table_scopes: Vec::new(),
155            cte_scopes: Vec::new(),
156            udtf_scopes: Vec::new(),
157            table_scopes: Vec::new(),
158            union_scopes: Vec::new(),
159            columns_cache: None,
160            external_columns_cache: None,
161        }
162    }
163
164    /// Clear all cached properties
165    pub fn clear_cache(&mut self) {
166        self.columns_cache = None;
167        self.external_columns_cache = None;
168    }
169
170    /// Add a source to this scope
171    pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172        self.sources
173            .insert(name, SourceInfo { expression, is_scope });
174        self.clear_cache();
175    }
176
177    /// Add a lateral source to this scope
178    pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
179        self.lateral_sources.insert(
180            name.clone(),
181            SourceInfo {
182                expression: expression.clone(),
183                is_scope,
184            },
185        );
186        self.sources.insert(name, SourceInfo { expression, is_scope });
187        self.clear_cache();
188    }
189
190    /// Add a CTE source to this scope
191    pub fn add_cte_source(&mut self, name: String, expression: Expression) {
192        self.cte_sources.insert(
193            name.clone(),
194            SourceInfo {
195                expression: expression.clone(),
196                is_scope: true,
197            },
198        );
199        self.sources.insert(
200            name,
201            SourceInfo {
202                expression,
203                is_scope: true,
204            },
205        );
206        self.clear_cache();
207    }
208
209    /// Rename a source
210    pub fn rename_source(&mut self, old_name: &str, new_name: String) {
211        if let Some(source) = self.sources.remove(old_name) {
212            self.sources.insert(new_name, source);
213        }
214        self.clear_cache();
215    }
216
217    /// Remove a source
218    pub fn remove_source(&mut self, name: &str) {
219        self.sources.remove(name);
220        self.clear_cache();
221    }
222
223    /// Collect all column references in this scope
224    pub fn columns(&mut self) -> &[ColumnRef] {
225        if self.columns_cache.is_none() {
226            let mut columns = Vec::new();
227            collect_columns(&self.expression, &mut columns);
228            self.columns_cache = Some(columns);
229        }
230        self.columns_cache.as_ref().unwrap()
231    }
232
233    /// Get all source names in this scope
234    pub fn source_names(&self) -> HashSet<String> {
235        let mut names: HashSet<String> = self.sources.keys().cloned().collect();
236        names.extend(self.cte_sources.keys().cloned());
237        names
238    }
239
240    /// Get columns that reference sources outside this scope
241    pub fn external_columns(&mut self) -> Vec<ColumnRef> {
242        if self.external_columns_cache.is_some() {
243            return self.external_columns_cache.clone().unwrap();
244        }
245
246        let source_names = self.source_names();
247        let columns = self.columns().to_vec();
248
249        let external: Vec<ColumnRef> = columns
250            .into_iter()
251            .filter(|col| {
252                // A column is external if it has a table qualifier that's not in our sources
253                match &col.table {
254                    Some(table) => !source_names.contains(table),
255                    None => false, // Unqualified columns might be local
256                }
257            })
258            .collect();
259
260        self.external_columns_cache = Some(external.clone());
261        external
262    }
263
264    /// Get columns that reference sources in this scope (not external)
265    pub fn local_columns(&mut self) -> Vec<ColumnRef> {
266        let external_set: HashSet<_> = self.external_columns().into_iter().collect();
267        let columns = self.columns().to_vec();
268
269        columns
270            .into_iter()
271            .filter(|col| !external_set.contains(col))
272            .collect()
273    }
274
275    /// Get unqualified columns (columns without table qualifier)
276    pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
277        self.columns()
278            .iter()
279            .filter(|c| c.table.is_none())
280            .cloned()
281            .collect()
282    }
283
284    /// Get columns for a specific source
285    pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
286        self.columns()
287            .iter()
288            .filter(|col| col.table.as_deref() == Some(source_name))
289            .cloned()
290            .collect()
291    }
292
293    /// Determine if this scope is a correlated subquery
294    ///
295    /// A subquery is correlated if:
296    /// 1. It can be correlated (is a subquery or UDTF), AND
297    /// 2. It references columns from outer scopes
298    pub fn is_correlated_subquery(&mut self) -> bool {
299        self.can_be_correlated && !self.external_columns().is_empty()
300    }
301
302    /// Check if this is a subquery scope
303    pub fn is_subquery(&self) -> bool {
304        self.scope_type == ScopeType::Subquery
305    }
306
307    /// Check if this is a derived table scope
308    pub fn is_derived_table(&self) -> bool {
309        self.scope_type == ScopeType::DerivedTable
310    }
311
312    /// Check if this is a CTE scope
313    pub fn is_cte(&self) -> bool {
314        self.scope_type == ScopeType::Cte
315    }
316
317    /// Check if this is the root scope
318    pub fn is_root(&self) -> bool {
319        self.scope_type == ScopeType::Root
320    }
321
322    /// Check if this is a UDTF scope
323    pub fn is_udtf(&self) -> bool {
324        self.scope_type == ScopeType::Udtf
325    }
326
327    /// Check if this is a union/set operation scope
328    pub fn is_union(&self) -> bool {
329        self.scope_type == ScopeType::SetOperation
330    }
331
332    /// Traverse all scopes in this tree (depth-first post-order)
333    pub fn traverse(&self) -> Vec<&Scope> {
334        let mut result = Vec::new();
335        self.traverse_impl(&mut result);
336        result
337    }
338
339    fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
340        // First traverse children
341        for scope in &self.cte_scopes {
342            scope.traverse_impl(result);
343        }
344        for scope in &self.union_scopes {
345            scope.traverse_impl(result);
346        }
347        for scope in &self.table_scopes {
348            scope.traverse_impl(result);
349        }
350        for scope in &self.subquery_scopes {
351            scope.traverse_impl(result);
352        }
353        // Then add self
354        result.push(self);
355    }
356
357    /// Count references to each scope in this tree
358    pub fn ref_count(&self) -> HashMap<usize, usize> {
359        let mut counts: HashMap<usize, usize> = HashMap::new();
360
361        for scope in self.traverse() {
362            for (_, source_info) in scope.sources.iter() {
363                if source_info.is_scope {
364                    let id = &source_info.expression as *const _ as usize;
365                    *counts.entry(id).or_insert(0) += 1;
366                }
367            }
368        }
369
370        counts
371    }
372}
373
374/// Collect all column references from an expression tree
375fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
376    match expr {
377        Expression::Column(col) => {
378            columns.push(ColumnRef {
379                table: col.table.as_ref().map(|t| t.name.clone()),
380                name: col.name.name.clone(),
381            });
382        }
383        Expression::Select(select) => {
384            // Collect from SELECT expressions
385            for e in &select.expressions {
386                collect_columns(e, columns);
387            }
388            // Collect from WHERE
389            if let Some(where_clause) = &select.where_clause {
390                collect_columns(&where_clause.this, columns);
391            }
392            // Collect from HAVING
393            if let Some(having) = &select.having {
394                collect_columns(&having.this, columns);
395            }
396            // Collect from ORDER BY
397            if let Some(order_by) = &select.order_by {
398                for ord in &order_by.expressions {
399                    collect_columns(&ord.this, columns);
400                }
401            }
402            // Collect from GROUP BY
403            if let Some(group_by) = &select.group_by {
404                for e in &group_by.expressions {
405                    collect_columns(e, columns);
406                }
407            }
408            // Note: We don't recurse into FROM/JOIN subqueries here
409            // as those create their own scopes
410        }
411        // Binary operations
412        Expression::And(bin) | Expression::Or(bin) |
413        Expression::Add(bin) | Expression::Sub(bin) |
414        Expression::Mul(bin) | Expression::Div(bin) |
415        Expression::Mod(bin) | Expression::Eq(bin) |
416        Expression::Neq(bin) | Expression::Lt(bin) |
417        Expression::Lte(bin) | Expression::Gt(bin) |
418        Expression::Gte(bin) | Expression::BitwiseAnd(bin) |
419        Expression::BitwiseOr(bin) | Expression::BitwiseXor(bin) |
420        Expression::Concat(bin) => {
421            collect_columns(&bin.left, columns);
422            collect_columns(&bin.right, columns);
423        }
424        // LIKE/ILIKE operations
425        Expression::Like(like) | Expression::ILike(like) => {
426            collect_columns(&like.left, columns);
427            collect_columns(&like.right, columns);
428            if let Some(escape) = &like.escape {
429                collect_columns(escape, columns);
430            }
431        }
432        // Unary operations
433        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
434            collect_columns(&un.this, columns);
435        }
436        Expression::Function(func) => {
437            for arg in &func.args {
438                collect_columns(arg, columns);
439            }
440        }
441        Expression::AggregateFunction(agg) => {
442            for arg in &agg.args {
443                collect_columns(arg, columns);
444            }
445        }
446        Expression::WindowFunction(wf) => {
447            collect_columns(&wf.this, columns);
448            for e in &wf.over.partition_by {
449                collect_columns(e, columns);
450            }
451            for e in &wf.over.order_by {
452                collect_columns(&e.this, columns);
453            }
454        }
455        Expression::Alias(alias) => {
456            collect_columns(&alias.this, columns);
457        }
458        Expression::Case(case) => {
459            if let Some(operand) = &case.operand {
460                collect_columns(operand, columns);
461            }
462            for (when_expr, then_expr) in &case.whens {
463                collect_columns(when_expr, columns);
464                collect_columns(then_expr, columns);
465            }
466            if let Some(else_clause) = &case.else_ {
467                collect_columns(else_clause, columns);
468            }
469        }
470        Expression::Paren(paren) => {
471            collect_columns(&paren.this, columns);
472        }
473        Expression::Ordered(ord) => {
474            collect_columns(&ord.this, columns);
475        }
476        Expression::In(in_expr) => {
477            collect_columns(&in_expr.this, columns);
478            for e in &in_expr.expressions {
479                collect_columns(e, columns);
480            }
481            // Note: in_expr.query is a subquery - creates its own scope
482        }
483        Expression::Between(between) => {
484            collect_columns(&between.this, columns);
485            collect_columns(&between.low, columns);
486            collect_columns(&between.high, columns);
487        }
488        Expression::IsNull(is_null) => {
489            collect_columns(&is_null.this, columns);
490        }
491        Expression::Cast(cast) => {
492            collect_columns(&cast.this, columns);
493        }
494        Expression::Extract(extract) => {
495            collect_columns(&extract.this, columns);
496        }
497        Expression::Exists(_) | Expression::Subquery(_) => {
498            // These create their own scopes - don't collect from here
499        }
500        _ => {
501            // For other expressions, we might need to add more cases
502        }
503    }
504}
505
506/// Build scope tree from an expression
507///
508/// This traverses the expression tree and builds a hierarchy of Scope objects
509/// that track sources and column references at each level.
510pub fn build_scope(expression: &Expression) -> Scope {
511    let mut root = Scope::new(expression.clone());
512    build_scope_impl(expression, &mut root);
513    root
514}
515
516fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
517    match expression {
518        Expression::Select(select) => {
519            // Process CTEs first
520            if let Some(with) = &select.with {
521                for cte in &with.ctes {
522                    let cte_name = cte.alias.name.clone();
523                    let mut cte_scope = current_scope.branch(
524                        Expression::Cte(Box::new(cte.clone())),
525                        ScopeType::Cte,
526                    );
527                    build_scope_impl(&cte.this, &mut cte_scope);
528                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
529                    current_scope.cte_scopes.push(cte_scope);
530                }
531            }
532
533            // Process FROM clause
534            if let Some(from) = &select.from {
535                for table in &from.expressions {
536                    add_table_to_scope(table, current_scope);
537                }
538            }
539
540            // Process JOINs
541            for join in &select.joins {
542                add_table_to_scope(&join.this, current_scope);
543            }
544
545            // Process subqueries in WHERE, SELECT expressions, etc.
546            collect_subqueries(expression, current_scope);
547        }
548        Expression::Union(union) => {
549            let mut left_scope = current_scope.branch(
550                union.left.clone(),
551                ScopeType::SetOperation,
552            );
553            build_scope_impl(&union.left, &mut left_scope);
554
555            let mut right_scope = current_scope.branch(
556                union.right.clone(),
557                ScopeType::SetOperation,
558            );
559            build_scope_impl(&union.right, &mut right_scope);
560
561            current_scope.union_scopes.push(left_scope);
562            current_scope.union_scopes.push(right_scope);
563        }
564        Expression::Intersect(intersect) => {
565            let mut left_scope = current_scope.branch(
566                intersect.left.clone(),
567                ScopeType::SetOperation,
568            );
569            build_scope_impl(&intersect.left, &mut left_scope);
570
571            let mut right_scope = current_scope.branch(
572                intersect.right.clone(),
573                ScopeType::SetOperation,
574            );
575            build_scope_impl(&intersect.right, &mut right_scope);
576
577            current_scope.union_scopes.push(left_scope);
578            current_scope.union_scopes.push(right_scope);
579        }
580        Expression::Except(except) => {
581            let mut left_scope = current_scope.branch(
582                except.left.clone(),
583                ScopeType::SetOperation,
584            );
585            build_scope_impl(&except.left, &mut left_scope);
586
587            let mut right_scope = current_scope.branch(
588                except.right.clone(),
589                ScopeType::SetOperation,
590            );
591            build_scope_impl(&except.right, &mut right_scope);
592
593            current_scope.union_scopes.push(left_scope);
594            current_scope.union_scopes.push(right_scope);
595        }
596        _ => {}
597    }
598}
599
600fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
601    match expr {
602        Expression::Table(table) => {
603            let name = table.alias.as_ref()
604                .map(|a| a.name.clone())
605                .unwrap_or_else(|| table.name.name.clone());
606            scope.add_source(name, expr.clone(), false);
607        }
608        Expression::Subquery(subquery) => {
609            let name = subquery.alias.as_ref()
610                .map(|a| a.name.clone())
611                .unwrap_or_default();
612
613            let mut derived_scope = scope.branch(
614                subquery.this.clone(),
615                ScopeType::DerivedTable,
616            );
617            build_scope_impl(&subquery.this, &mut derived_scope);
618
619            scope.add_source(name.clone(), expr.clone(), true);
620            scope.derived_table_scopes.push(derived_scope);
621        }
622        Expression::Paren(paren) => {
623            add_table_to_scope(&paren.this, scope);
624        }
625        _ => {}
626    }
627}
628
629fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
630    match expr {
631        Expression::Select(select) => {
632            // Check WHERE for subqueries
633            if let Some(where_clause) = &select.where_clause {
634                collect_subqueries_in_expr(&where_clause.this, parent_scope);
635            }
636            // Check SELECT expressions for subqueries
637            for e in &select.expressions {
638                collect_subqueries_in_expr(e, parent_scope);
639            }
640            // Check HAVING for subqueries
641            if let Some(having) = &select.having {
642                collect_subqueries_in_expr(&having.this, parent_scope);
643            }
644        }
645        _ => {}
646    }
647}
648
649fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
650    match expr {
651        Expression::Subquery(subquery) if subquery.alias.is_none() => {
652            // This is a scalar subquery or IN subquery (not a derived table)
653            let mut sub_scope = parent_scope.branch(
654                subquery.this.clone(),
655                ScopeType::Subquery,
656            );
657            build_scope_impl(&subquery.this, &mut sub_scope);
658            parent_scope.subquery_scopes.push(sub_scope);
659        }
660        Expression::In(in_expr) => {
661            collect_subqueries_in_expr(&in_expr.this, parent_scope);
662            if let Some(query) = &in_expr.query {
663                let mut sub_scope = parent_scope.branch(
664                    query.clone(),
665                    ScopeType::Subquery,
666                );
667                build_scope_impl(query, &mut sub_scope);
668                parent_scope.subquery_scopes.push(sub_scope);
669            }
670        }
671        Expression::Exists(exists) => {
672            let mut sub_scope = parent_scope.branch(
673                exists.this.clone(),
674                ScopeType::Subquery,
675            );
676            build_scope_impl(&exists.this, &mut sub_scope);
677            parent_scope.subquery_scopes.push(sub_scope);
678        }
679        // Binary operations
680        Expression::And(bin) | Expression::Or(bin) |
681        Expression::Add(bin) | Expression::Sub(bin) |
682        Expression::Mul(bin) | Expression::Div(bin) |
683        Expression::Mod(bin) | Expression::Eq(bin) |
684        Expression::Neq(bin) | Expression::Lt(bin) |
685        Expression::Lte(bin) | Expression::Gt(bin) |
686        Expression::Gte(bin) | Expression::BitwiseAnd(bin) |
687        Expression::BitwiseOr(bin) | Expression::BitwiseXor(bin) |
688        Expression::Concat(bin) => {
689            collect_subqueries_in_expr(&bin.left, parent_scope);
690            collect_subqueries_in_expr(&bin.right, parent_scope);
691        }
692        // LIKE/ILIKE operations (have different structure with escape)
693        Expression::Like(like) | Expression::ILike(like) => {
694            collect_subqueries_in_expr(&like.left, parent_scope);
695            collect_subqueries_in_expr(&like.right, parent_scope);
696            if let Some(escape) = &like.escape {
697                collect_subqueries_in_expr(escape, parent_scope);
698            }
699        }
700        // Unary operations
701        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
702            collect_subqueries_in_expr(&un.this, parent_scope);
703        }
704        Expression::Function(func) => {
705            for arg in &func.args {
706                collect_subqueries_in_expr(arg, parent_scope);
707            }
708        }
709        Expression::Case(case) => {
710            if let Some(operand) = &case.operand {
711                collect_subqueries_in_expr(operand, parent_scope);
712            }
713            for (when_expr, then_expr) in &case.whens {
714                collect_subqueries_in_expr(when_expr, parent_scope);
715                collect_subqueries_in_expr(then_expr, parent_scope);
716            }
717            if let Some(else_clause) = &case.else_ {
718                collect_subqueries_in_expr(else_clause, parent_scope);
719            }
720        }
721        Expression::Paren(paren) => {
722            collect_subqueries_in_expr(&paren.this, parent_scope);
723        }
724        Expression::Alias(alias) => {
725            collect_subqueries_in_expr(&alias.this, parent_scope);
726        }
727        _ => {}
728    }
729}
730
731/// Walk within a scope, yielding expressions without crossing scope boundaries.
732///
733/// This iterator visits all nodes in the syntax tree, stopping at nodes that
734/// start child scopes (CTEs, derived tables, subqueries in FROM/JOIN).
735///
736/// # Arguments
737/// * `expression` - The expression to walk
738/// * `bfs` - If true, uses breadth-first search; otherwise uses depth-first search
739///
740/// # Returns
741/// An iterator over expressions within the scope
742pub fn walk_in_scope<'a>(
743    expression: &'a Expression,
744    bfs: bool,
745) -> impl Iterator<Item = &'a Expression> {
746    WalkInScopeIter::new(expression, bfs)
747}
748
749/// Iterator for walking within a scope
750struct WalkInScopeIter<'a> {
751    queue: VecDeque<&'a Expression>,
752    bfs: bool,
753}
754
755impl<'a> WalkInScopeIter<'a> {
756    fn new(expression: &'a Expression, bfs: bool) -> Self {
757        let mut queue = VecDeque::new();
758        queue.push_back(expression);
759        Self { queue, bfs }
760    }
761
762    fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
763        if is_root {
764            return false;
765        }
766
767        // Stop at CTE definitions
768        if matches!(expr, Expression::Cte(_)) {
769            return true;
770        }
771
772        // Stop at subqueries that are derived tables (in FROM/JOIN)
773        if let Expression::Subquery(subquery) = expr {
774            if subquery.alias.is_some() {
775                return true;
776            }
777        }
778
779        // Stop at standalone SELECT/UNION/etc that would be subqueries
780        if matches!(
781            expr,
782            Expression::Select(_)
783                | Expression::Union(_)
784                | Expression::Intersect(_)
785                | Expression::Except(_)
786        ) {
787            return true;
788        }
789
790        false
791    }
792
793    fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
794        let mut children = Vec::new();
795
796        match expr {
797            Expression::Select(select) => {
798                // Walk SELECT expressions
799                for e in &select.expressions {
800                    children.push(e);
801                }
802                // Walk FROM (but tables/subqueries create new scopes)
803                if let Some(from) = &select.from {
804                    for table in &from.expressions {
805                        if !self.should_stop_at(table, false) {
806                            children.push(table);
807                        }
808                    }
809                }
810                // Walk JOINs (but their sources create new scopes)
811                for join in &select.joins {
812                    if let Some(on) = &join.on {
813                        children.push(on);
814                    }
815                    // Don't traverse join.this as it's a source (table or subquery)
816                }
817                // Walk WHERE
818                if let Some(where_clause) = &select.where_clause {
819                    children.push(&where_clause.this);
820                }
821                // Walk GROUP BY
822                if let Some(group_by) = &select.group_by {
823                    for e in &group_by.expressions {
824                        children.push(e);
825                    }
826                }
827                // Walk HAVING
828                if let Some(having) = &select.having {
829                    children.push(&having.this);
830                }
831                // Walk ORDER BY
832                if let Some(order_by) = &select.order_by {
833                    for ord in &order_by.expressions {
834                        children.push(&ord.this);
835                    }
836                }
837                // Walk LIMIT
838                if let Some(limit) = &select.limit {
839                    children.push(&limit.this);
840                }
841                // Walk OFFSET
842                if let Some(offset) = &select.offset {
843                    children.push(&offset.this);
844                }
845            }
846            Expression::And(bin)
847            | Expression::Or(bin)
848            | Expression::Add(bin)
849            | Expression::Sub(bin)
850            | Expression::Mul(bin)
851            | Expression::Div(bin)
852            | Expression::Mod(bin)
853            | Expression::Eq(bin)
854            | Expression::Neq(bin)
855            | Expression::Lt(bin)
856            | Expression::Lte(bin)
857            | Expression::Gt(bin)
858            | Expression::Gte(bin)
859            | Expression::BitwiseAnd(bin)
860            | Expression::BitwiseOr(bin)
861            | Expression::BitwiseXor(bin)
862            | Expression::Concat(bin) => {
863                children.push(&bin.left);
864                children.push(&bin.right);
865            }
866            Expression::Like(like) | Expression::ILike(like) => {
867                children.push(&like.left);
868                children.push(&like.right);
869                if let Some(escape) = &like.escape {
870                    children.push(escape);
871                }
872            }
873            Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
874                children.push(&un.this);
875            }
876            Expression::Function(func) => {
877                for arg in &func.args {
878                    children.push(arg);
879                }
880            }
881            Expression::AggregateFunction(agg) => {
882                for arg in &agg.args {
883                    children.push(arg);
884                }
885            }
886            Expression::WindowFunction(wf) => {
887                children.push(&wf.this);
888                for e in &wf.over.partition_by {
889                    children.push(e);
890                }
891                for e in &wf.over.order_by {
892                    children.push(&e.this);
893                }
894            }
895            Expression::Alias(alias) => {
896                children.push(&alias.this);
897            }
898            Expression::Case(case) => {
899                if let Some(operand) = &case.operand {
900                    children.push(operand);
901                }
902                for (when_expr, then_expr) in &case.whens {
903                    children.push(when_expr);
904                    children.push(then_expr);
905                }
906                if let Some(else_clause) = &case.else_ {
907                    children.push(else_clause);
908                }
909            }
910            Expression::Paren(paren) => {
911                children.push(&paren.this);
912            }
913            Expression::Ordered(ord) => {
914                children.push(&ord.this);
915            }
916            Expression::In(in_expr) => {
917                children.push(&in_expr.this);
918                for e in &in_expr.expressions {
919                    children.push(e);
920                }
921                // Note: in_expr.query creates a new scope - don't traverse
922            }
923            Expression::Between(between) => {
924                children.push(&between.this);
925                children.push(&between.low);
926                children.push(&between.high);
927            }
928            Expression::IsNull(is_null) => {
929                children.push(&is_null.this);
930            }
931            Expression::Cast(cast) => {
932                children.push(&cast.this);
933            }
934            Expression::Extract(extract) => {
935                children.push(&extract.this);
936            }
937            Expression::Coalesce(coalesce) => {
938                for e in &coalesce.expressions {
939                    children.push(e);
940                }
941            }
942            Expression::NullIf(nullif) => {
943                children.push(&nullif.this);
944                children.push(&nullif.expression);
945            }
946            Expression::Table(_table) => {
947                // Tables don't have child expressions to traverse within scope
948                // (joins are handled at the Select level)
949            }
950            Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
951                // Leaf nodes - no children
952            }
953            // Subqueries and Exists create new scopes - don't traverse into them
954            Expression::Subquery(_) | Expression::Exists(_) => {}
955            _ => {
956                // For other expressions, we could add more cases as needed
957            }
958        }
959
960        children
961    }
962}
963
964impl<'a> Iterator for WalkInScopeIter<'a> {
965    type Item = &'a Expression;
966
967    fn next(&mut self) -> Option<Self::Item> {
968        let expr = if self.bfs {
969            self.queue.pop_front()?
970        } else {
971            self.queue.pop_back()?
972        };
973
974        // Get children that don't cross scope boundaries
975        let children = self.get_children(expr);
976
977        if self.bfs {
978            for child in children {
979                if !self.should_stop_at(child, false) {
980                    self.queue.push_back(child);
981                }
982            }
983        } else {
984            for child in children.into_iter().rev() {
985                if !self.should_stop_at(child, false) {
986                    self.queue.push_back(child);
987                }
988            }
989        }
990
991        Some(expr)
992    }
993}
994
995/// Find the first expression matching the predicate within this scope.
996///
997/// This does NOT traverse into subscopes.
998///
999/// # Arguments
1000/// * `expression` - The root expression
1001/// * `predicate` - Function that returns true for matching expressions
1002/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1003///
1004/// # Returns
1005/// The first matching expression, or None
1006pub fn find_in_scope<'a, F>(expression: &'a Expression, predicate: F, bfs: bool) -> Option<&'a Expression>
1007where
1008    F: Fn(&Expression) -> bool,
1009{
1010    walk_in_scope(expression, bfs).find(|e| predicate(e))
1011}
1012
1013/// Find all expressions matching the predicate within this scope.
1014///
1015/// This does NOT traverse into subscopes.
1016///
1017/// # Arguments
1018/// * `expression` - The root expression
1019/// * `predicate` - Function that returns true for matching expressions
1020/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1021///
1022/// # Returns
1023/// A vector of matching expressions
1024pub fn find_all_in_scope<'a, F>(expression: &'a Expression, predicate: F, bfs: bool) -> Vec<&'a Expression>
1025where
1026    F: Fn(&Expression) -> bool,
1027{
1028    walk_in_scope(expression, bfs).filter(|e| predicate(e)).collect()
1029}
1030
1031/// Traverse an expression by its "scopes".
1032///
1033/// Returns a list of all scopes in depth-first post-order.
1034///
1035/// # Arguments
1036/// * `expression` - The expression to traverse
1037///
1038/// # Returns
1039/// A vector of all scopes found
1040pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1041    match expression {
1042        Expression::Select(_)
1043        | Expression::Union(_)
1044        | Expression::Intersect(_)
1045        | Expression::Except(_) => {
1046            let root = build_scope(expression);
1047            root.traverse().into_iter().cloned().collect()
1048        }
1049        _ => Vec::new(),
1050    }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055    use super::*;
1056    use crate::parser::Parser;
1057
1058    fn parse_and_build_scope(sql: &str) -> Scope {
1059        let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1060        build_scope(&ast[0])
1061    }
1062
1063    #[test]
1064    fn test_simple_select_scope() {
1065        let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1066
1067        assert!(scope.is_root());
1068        assert!(!scope.can_be_correlated);
1069        assert!(scope.sources.contains_key("t"));
1070
1071        let columns = scope.columns();
1072        assert_eq!(columns.len(), 2);
1073    }
1074
1075    #[test]
1076    fn test_derived_table_scope() {
1077        let mut scope = parse_and_build_scope(
1078            "SELECT x.a FROM (SELECT a FROM t) AS x"
1079        );
1080
1081        assert!(scope.sources.contains_key("x"));
1082        assert_eq!(scope.derived_table_scopes.len(), 1);
1083
1084        let derived = &mut scope.derived_table_scopes[0];
1085        assert!(derived.is_derived_table());
1086        assert!(derived.sources.contains_key("t"));
1087    }
1088
1089    #[test]
1090    fn test_non_correlated_subquery() {
1091        let mut scope = parse_and_build_scope(
1092            "SELECT * FROM t WHERE EXISTS (SELECT b FROM s)"
1093        );
1094
1095        assert_eq!(scope.subquery_scopes.len(), 1);
1096
1097        let subquery = &mut scope.subquery_scopes[0];
1098        assert!(subquery.is_subquery());
1099        assert!(subquery.can_be_correlated);
1100
1101        // The subquery references only 's', which is in its own sources
1102        assert!(subquery.sources.contains_key("s"));
1103        assert!(!subquery.is_correlated_subquery());
1104    }
1105
1106    #[test]
1107    fn test_correlated_subquery() {
1108        let mut scope = parse_and_build_scope(
1109            "SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)"
1110        );
1111
1112        assert_eq!(scope.subquery_scopes.len(), 1);
1113
1114        let subquery = &mut scope.subquery_scopes[0];
1115        assert!(subquery.is_subquery());
1116        assert!(subquery.can_be_correlated);
1117
1118        // The subquery references 't.y' which is external
1119        let external = subquery.external_columns();
1120        assert!(!external.is_empty());
1121        assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1122        assert!(subquery.is_correlated_subquery());
1123    }
1124
1125    #[test]
1126    fn test_cte_scope() {
1127        let scope = parse_and_build_scope(
1128            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte"
1129        );
1130
1131        assert_eq!(scope.cte_scopes.len(), 1);
1132        assert!(scope.cte_sources.contains_key("cte"));
1133
1134        let cte = &scope.cte_scopes[0];
1135        assert!(cte.is_cte());
1136    }
1137
1138    #[test]
1139    fn test_multiple_sources() {
1140        let scope = parse_and_build_scope(
1141            "SELECT t.a, s.b FROM t JOIN s ON t.id = s.id"
1142        );
1143
1144        assert!(scope.sources.contains_key("t"));
1145        assert!(scope.sources.contains_key("s"));
1146        assert_eq!(scope.sources.len(), 2);
1147    }
1148
1149    #[test]
1150    fn test_aliased_table() {
1151        let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1152
1153        // Should be indexed by alias, not original name
1154        assert!(scope.sources.contains_key("x"));
1155        assert!(!scope.sources.contains_key("t"));
1156    }
1157
1158    #[test]
1159    fn test_local_columns() {
1160        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1161
1162        let local = scope.local_columns();
1163        // All columns are local since both t and s are in scope
1164        assert_eq!(local.len(), 3);
1165        assert!(local.iter().all(|c| c.table.is_some()));
1166    }
1167
1168    #[test]
1169    fn test_unqualified_columns() {
1170        let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1171
1172        let unqualified = scope.unqualified_columns();
1173        // Only a and b are unqualified
1174        assert_eq!(unqualified.len(), 2);
1175        assert!(unqualified.iter().all(|c| c.table.is_none()));
1176    }
1177
1178    #[test]
1179    fn test_source_columns() {
1180        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1181
1182        let t_cols = scope.source_columns("t");
1183        // t.a, t.b, and t.id from JOIN condition
1184        assert!(t_cols.len() >= 2);
1185        assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1186
1187        let s_cols = scope.source_columns("s");
1188        // s.c and s.id from JOIN condition
1189        assert!(s_cols.len() >= 1);
1190        assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1191    }
1192
1193    #[test]
1194    fn test_rename_source() {
1195        let mut scope = parse_and_build_scope("SELECT a FROM t");
1196
1197        assert!(scope.sources.contains_key("t"));
1198        scope.rename_source("t", "new_name".to_string());
1199        assert!(!scope.sources.contains_key("t"));
1200        assert!(scope.sources.contains_key("new_name"));
1201    }
1202
1203    #[test]
1204    fn test_remove_source() {
1205        let mut scope = parse_and_build_scope("SELECT a FROM t");
1206
1207        assert!(scope.sources.contains_key("t"));
1208        scope.remove_source("t");
1209        assert!(!scope.sources.contains_key("t"));
1210    }
1211
1212    #[test]
1213    fn test_walk_in_scope() {
1214        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1215        let expr = &ast[0];
1216
1217        // Walk should visit all expressions within the scope
1218        let walked: Vec<_> = walk_in_scope(expr, true).collect();
1219        assert!(!walked.is_empty());
1220
1221        // Should include the root SELECT
1222        assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1223        // Should include columns
1224        assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1225    }
1226
1227    #[test]
1228    fn test_find_in_scope() {
1229        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1230        let expr = &ast[0];
1231
1232        // Find the first column
1233        let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1234        assert!(found.is_some());
1235        assert!(matches!(found.unwrap(), Expression::Column(_)));
1236    }
1237
1238    #[test]
1239    fn test_find_all_in_scope() {
1240        let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1241        let expr = &ast[0];
1242
1243        // Find all columns
1244        let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1245        assert_eq!(found.len(), 3);
1246    }
1247
1248    #[test]
1249    fn test_traverse_scope() {
1250        let ast =
1251            Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1252        let expr = &ast[0];
1253
1254        let scopes = traverse_scope(expr);
1255        // traverse_scope returns all scopes via Scope::traverse
1256        // which includes derived table and root scopes
1257        assert!(!scopes.is_empty());
1258        // The root scope is always included
1259        assert!(scopes.iter().any(|s| s.is_root()));
1260    }
1261
1262    #[test]
1263    fn test_branch_with_options() {
1264        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1265        let scope = build_scope(&ast[0]);
1266
1267        let child = scope.branch_with_options(
1268            ast[0].clone(),
1269            ScopeType::Subquery, // Use Subquery to test can_be_correlated
1270            None,
1271            None,
1272            Some(vec!["col1".to_string(), "col2".to_string()]),
1273        );
1274
1275        assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1276        assert!(child.can_be_correlated); // Subqueries are correlated
1277    }
1278
1279    #[test]
1280    fn test_is_udtf() {
1281        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1282        let scope = Scope::new(ast[0].clone());
1283        assert!(!scope.is_udtf());
1284
1285        let root = build_scope(&ast[0]);
1286        let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1287        assert!(udtf_scope.is_udtf());
1288    }
1289
1290    #[test]
1291    fn test_is_union() {
1292        let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1293
1294        assert!(scope.is_root());
1295        assert_eq!(scope.union_scopes.len(), 2);
1296        // The children are set operation scopes
1297        assert!(scope.union_scopes[0].is_union());
1298        assert!(scope.union_scopes[1].is_union());
1299    }
1300
1301    #[test]
1302    fn test_clear_cache() {
1303        let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1304
1305        // First call populates cache
1306        let _ = scope.columns();
1307        assert!(scope.columns_cache.is_some());
1308
1309        // Clear cache
1310        scope.clear_cache();
1311        assert!(scope.columns_cache.is_none());
1312        assert!(scope.external_columns_cache.is_none());
1313    }
1314
1315    #[test]
1316    fn test_scope_traverse() {
1317        let scope = parse_and_build_scope(
1318            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1319        );
1320
1321        let traversed = scope.traverse();
1322        // Should include: CTE scope, subquery scope, root scope
1323        assert!(traversed.len() >= 3);
1324    }
1325}