Skip to main content

polyglot_sql/
scope.rs

1//! Scope Analysis Module
2//!
3//! This module provides scope analysis for SQL queries, enabling detection of
4//! correlated subqueries, column references, and scope relationships.
5//!
6//! Ported from sqlglot's optimizer/scope.py
7
8use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14/// Type of scope in a SQL query
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19    /// Root scope of the query
20    Root,
21    /// Subquery scope (e.g., WHERE x IN (SELECT ...))
22    Subquery,
23    /// Derived table scope (e.g., FROM (SELECT ...) AS t)
24    DerivedTable,
25    /// Common Table Expression scope
26    Cte,
27    /// Union/Intersect/Except scope
28    SetOperation,
29    /// User-Defined Table Function scope
30    Udtf,
31}
32
33/// Information about a source (table or subquery) in a scope
34#[derive(Debug, Clone)]
35pub struct SourceInfo {
36    /// The source expression (Table or subquery)
37    pub expression: Expression,
38    /// Whether this source is a scope (vs. a plain table)
39    pub is_scope: bool,
40}
41
42/// A column reference found in a scope
43#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45    /// The table/alias qualifier (if any)
46    pub table: Option<String>,
47    /// The column name
48    pub name: String,
49}
50
51/// Represents a scope in a SQL query
52///
53/// A scope is the context of a SELECT statement and its sources.
54/// Scopes can be nested (subqueries, CTEs, derived tables) and form a tree.
55#[derive(Debug, Clone)]
56pub struct Scope {
57    /// The expression at the root of this scope
58    pub expression: Expression,
59
60    /// Type of this scope relative to its parent
61    pub scope_type: ScopeType,
62
63    /// Mapping of source names to their info
64    pub sources: HashMap<String, SourceInfo>,
65
66    /// Sources from LATERAL views (have access to preceding sources)
67    pub lateral_sources: HashMap<String, SourceInfo>,
68
69    /// CTE sources available to this scope
70    pub cte_sources: HashMap<String, SourceInfo>,
71
72    /// If this is a derived table or CTE with alias columns, this is that list
73    /// e.g., `SELECT * FROM (SELECT ...) AS y(col1, col2)` => ["col1", "col2"]
74    pub outer_columns: Vec<String>,
75
76    /// Whether this scope can potentially be correlated
77    /// (true for subqueries and UDTFs)
78    pub can_be_correlated: bool,
79
80    /// Child subquery scopes
81    pub subquery_scopes: Vec<Scope>,
82
83    /// Child derived table scopes
84    pub derived_table_scopes: Vec<Scope>,
85
86    /// Child CTE scopes
87    pub cte_scopes: Vec<Scope>,
88
89    /// Child UDTF (User Defined Table Function) scopes
90    pub udtf_scopes: Vec<Scope>,
91
92    /// Combined derived_table_scopes + udtf_scopes in definition order
93    pub table_scopes: Vec<Scope>,
94
95    /// Union/set operation scopes (left and right)
96    pub union_scopes: Vec<Scope>,
97
98    /// Cached columns
99    columns_cache: Option<Vec<ColumnRef>>,
100
101    /// Cached external columns
102    external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106    /// Create a new root scope
107    pub fn new(expression: Expression) -> Self {
108        Self {
109            expression,
110            scope_type: ScopeType::Root,
111            sources: HashMap::new(),
112            lateral_sources: HashMap::new(),
113            cte_sources: HashMap::new(),
114            outer_columns: Vec::new(),
115            can_be_correlated: false,
116            subquery_scopes: Vec::new(),
117            derived_table_scopes: Vec::new(),
118            cte_scopes: Vec::new(),
119            udtf_scopes: Vec::new(),
120            table_scopes: Vec::new(),
121            union_scopes: Vec::new(),
122            columns_cache: None,
123            external_columns_cache: None,
124        }
125    }
126
127    /// Create a child scope branching from this one
128    pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129        self.branch_with_options(expression, scope_type, None, None, None)
130    }
131
132    /// Create a child scope with additional options
133    pub fn branch_with_options(
134        &self,
135        expression: Expression,
136        scope_type: ScopeType,
137        sources: Option<HashMap<String, SourceInfo>>,
138        lateral_sources: Option<HashMap<String, SourceInfo>>,
139        outer_columns: Option<Vec<String>>,
140    ) -> Self {
141        let can_be_correlated = self.can_be_correlated
142            || scope_type == ScopeType::Subquery
143            || scope_type == ScopeType::Udtf;
144
145        Self {
146            expression,
147            scope_type,
148            sources: sources.unwrap_or_default(),
149            lateral_sources: lateral_sources.unwrap_or_default(),
150            cte_sources: self.cte_sources.clone(),
151            outer_columns: outer_columns.unwrap_or_default(),
152            can_be_correlated,
153            subquery_scopes: Vec::new(),
154            derived_table_scopes: Vec::new(),
155            cte_scopes: Vec::new(),
156            udtf_scopes: Vec::new(),
157            table_scopes: Vec::new(),
158            union_scopes: Vec::new(),
159            columns_cache: None,
160            external_columns_cache: None,
161        }
162    }
163
164    /// Clear all cached properties
165    pub fn clear_cache(&mut self) {
166        self.columns_cache = None;
167        self.external_columns_cache = None;
168    }
169
170    /// Add a source to this scope
171    pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172        self.sources.insert(
173            name,
174            SourceInfo {
175                expression,
176                is_scope,
177            },
178        );
179        self.clear_cache();
180    }
181
182    /// Add a lateral source to this scope
183    pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184        self.lateral_sources.insert(
185            name.clone(),
186            SourceInfo {
187                expression: expression.clone(),
188                is_scope,
189            },
190        );
191        self.sources.insert(
192            name,
193            SourceInfo {
194                expression,
195                is_scope,
196            },
197        );
198        self.clear_cache();
199    }
200
201    /// Add a CTE source to this scope
202    pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203        self.cte_sources.insert(
204            name.clone(),
205            SourceInfo {
206                expression: expression.clone(),
207                is_scope: true,
208            },
209        );
210        self.sources.insert(
211            name,
212            SourceInfo {
213                expression,
214                is_scope: true,
215            },
216        );
217        self.clear_cache();
218    }
219
220    /// Rename a source
221    pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222        if let Some(source) = self.sources.remove(old_name) {
223            self.sources.insert(new_name, source);
224        }
225        self.clear_cache();
226    }
227
228    /// Remove a source
229    pub fn remove_source(&mut self, name: &str) {
230        self.sources.remove(name);
231        self.clear_cache();
232    }
233
234    /// Collect all column references in this scope
235    pub fn columns(&mut self) -> &[ColumnRef] {
236        if self.columns_cache.is_none() {
237            let mut columns = Vec::new();
238            collect_columns(&self.expression, &mut columns);
239            self.columns_cache = Some(columns);
240        }
241        self.columns_cache.as_ref().unwrap()
242    }
243
244    /// Get all source names in this scope
245    pub fn source_names(&self) -> HashSet<String> {
246        let mut names: HashSet<String> = self.sources.keys().cloned().collect();
247        names.extend(self.cte_sources.keys().cloned());
248        names
249    }
250
251    /// Get columns that reference sources outside this scope
252    pub fn external_columns(&mut self) -> Vec<ColumnRef> {
253        if self.external_columns_cache.is_some() {
254            return self.external_columns_cache.clone().unwrap();
255        }
256
257        let source_names = self.source_names();
258        let columns = self.columns().to_vec();
259
260        let external: Vec<ColumnRef> = columns
261            .into_iter()
262            .filter(|col| {
263                // A column is external if it has a table qualifier that's not in our sources
264                match &col.table {
265                    Some(table) => !source_names.contains(table),
266                    None => false, // Unqualified columns might be local
267                }
268            })
269            .collect();
270
271        self.external_columns_cache = Some(external.clone());
272        external
273    }
274
275    /// Get columns that reference sources in this scope (not external)
276    pub fn local_columns(&mut self) -> Vec<ColumnRef> {
277        let external_set: HashSet<_> = self.external_columns().into_iter().collect();
278        let columns = self.columns().to_vec();
279
280        columns
281            .into_iter()
282            .filter(|col| !external_set.contains(col))
283            .collect()
284    }
285
286    /// Get unqualified columns (columns without table qualifier)
287    pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
288        self.columns()
289            .iter()
290            .filter(|c| c.table.is_none())
291            .cloned()
292            .collect()
293    }
294
295    /// Get columns for a specific source
296    pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
297        self.columns()
298            .iter()
299            .filter(|col| col.table.as_deref() == Some(source_name))
300            .cloned()
301            .collect()
302    }
303
304    /// Determine if this scope is a correlated subquery
305    ///
306    /// A subquery is correlated if:
307    /// 1. It can be correlated (is a subquery or UDTF), AND
308    /// 2. It references columns from outer scopes
309    pub fn is_correlated_subquery(&mut self) -> bool {
310        self.can_be_correlated && !self.external_columns().is_empty()
311    }
312
313    /// Check if this is a subquery scope
314    pub fn is_subquery(&self) -> bool {
315        self.scope_type == ScopeType::Subquery
316    }
317
318    /// Check if this is a derived table scope
319    pub fn is_derived_table(&self) -> bool {
320        self.scope_type == ScopeType::DerivedTable
321    }
322
323    /// Check if this is a CTE scope
324    pub fn is_cte(&self) -> bool {
325        self.scope_type == ScopeType::Cte
326    }
327
328    /// Check if this is the root scope
329    pub fn is_root(&self) -> bool {
330        self.scope_type == ScopeType::Root
331    }
332
333    /// Check if this is a UDTF scope
334    pub fn is_udtf(&self) -> bool {
335        self.scope_type == ScopeType::Udtf
336    }
337
338    /// Check if this is a union/set operation scope
339    pub fn is_union(&self) -> bool {
340        self.scope_type == ScopeType::SetOperation
341    }
342
343    /// Traverse all scopes in this tree (depth-first post-order)
344    pub fn traverse(&self) -> Vec<&Scope> {
345        let mut result = Vec::new();
346        self.traverse_impl(&mut result);
347        result
348    }
349
350    fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
351        // First traverse children
352        for scope in &self.cte_scopes {
353            scope.traverse_impl(result);
354        }
355        for scope in &self.union_scopes {
356            scope.traverse_impl(result);
357        }
358        for scope in &self.table_scopes {
359            scope.traverse_impl(result);
360        }
361        for scope in &self.subquery_scopes {
362            scope.traverse_impl(result);
363        }
364        // Then add self
365        result.push(self);
366    }
367
368    /// Count references to each scope in this tree
369    pub fn ref_count(&self) -> HashMap<usize, usize> {
370        let mut counts: HashMap<usize, usize> = HashMap::new();
371
372        for scope in self.traverse() {
373            for (_, source_info) in scope.sources.iter() {
374                if source_info.is_scope {
375                    let id = &source_info.expression as *const _ as usize;
376                    *counts.entry(id).or_insert(0) += 1;
377                }
378            }
379        }
380
381        counts
382    }
383}
384
385/// Collect all column references from an expression tree
386fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
387    match expr {
388        Expression::Column(col) => {
389            columns.push(ColumnRef {
390                table: col.table.as_ref().map(|t| t.name.clone()),
391                name: col.name.name.clone(),
392            });
393        }
394        Expression::Select(select) => {
395            // Collect from SELECT expressions
396            for e in &select.expressions {
397                collect_columns(e, columns);
398            }
399            // Collect from WHERE
400            if let Some(where_clause) = &select.where_clause {
401                collect_columns(&where_clause.this, columns);
402            }
403            // Collect from HAVING
404            if let Some(having) = &select.having {
405                collect_columns(&having.this, columns);
406            }
407            // Collect from ORDER BY
408            if let Some(order_by) = &select.order_by {
409                for ord in &order_by.expressions {
410                    collect_columns(&ord.this, columns);
411                }
412            }
413            // Collect from GROUP BY
414            if let Some(group_by) = &select.group_by {
415                for e in &group_by.expressions {
416                    collect_columns(e, columns);
417                }
418            }
419            // Note: We don't recurse into FROM/JOIN subqueries here
420            // as those create their own scopes
421        }
422        // Binary operations
423        Expression::And(bin)
424        | Expression::Or(bin)
425        | Expression::Add(bin)
426        | Expression::Sub(bin)
427        | Expression::Mul(bin)
428        | Expression::Div(bin)
429        | Expression::Mod(bin)
430        | Expression::Eq(bin)
431        | Expression::Neq(bin)
432        | Expression::Lt(bin)
433        | Expression::Lte(bin)
434        | Expression::Gt(bin)
435        | Expression::Gte(bin)
436        | Expression::BitwiseAnd(bin)
437        | Expression::BitwiseOr(bin)
438        | Expression::BitwiseXor(bin)
439        | Expression::Concat(bin) => {
440            collect_columns(&bin.left, columns);
441            collect_columns(&bin.right, columns);
442        }
443        // LIKE/ILIKE operations
444        Expression::Like(like) | Expression::ILike(like) => {
445            collect_columns(&like.left, columns);
446            collect_columns(&like.right, columns);
447            if let Some(escape) = &like.escape {
448                collect_columns(escape, columns);
449            }
450        }
451        // Unary operations
452        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
453            collect_columns(&un.this, columns);
454        }
455        Expression::Function(func) => {
456            for arg in &func.args {
457                collect_columns(arg, columns);
458            }
459        }
460        Expression::AggregateFunction(agg) => {
461            for arg in &agg.args {
462                collect_columns(arg, columns);
463            }
464        }
465        Expression::WindowFunction(wf) => {
466            collect_columns(&wf.this, columns);
467            for e in &wf.over.partition_by {
468                collect_columns(e, columns);
469            }
470            for e in &wf.over.order_by {
471                collect_columns(&e.this, columns);
472            }
473        }
474        Expression::Alias(alias) => {
475            collect_columns(&alias.this, columns);
476        }
477        Expression::Case(case) => {
478            if let Some(operand) = &case.operand {
479                collect_columns(operand, columns);
480            }
481            for (when_expr, then_expr) in &case.whens {
482                collect_columns(when_expr, columns);
483                collect_columns(then_expr, columns);
484            }
485            if let Some(else_clause) = &case.else_ {
486                collect_columns(else_clause, columns);
487            }
488        }
489        Expression::Paren(paren) => {
490            collect_columns(&paren.this, columns);
491        }
492        Expression::Ordered(ord) => {
493            collect_columns(&ord.this, columns);
494        }
495        Expression::In(in_expr) => {
496            collect_columns(&in_expr.this, columns);
497            for e in &in_expr.expressions {
498                collect_columns(e, columns);
499            }
500            // Note: in_expr.query is a subquery - creates its own scope
501        }
502        Expression::Between(between) => {
503            collect_columns(&between.this, columns);
504            collect_columns(&between.low, columns);
505            collect_columns(&between.high, columns);
506        }
507        Expression::IsNull(is_null) => {
508            collect_columns(&is_null.this, columns);
509        }
510        Expression::Cast(cast) => {
511            collect_columns(&cast.this, columns);
512        }
513        Expression::Extract(extract) => {
514            collect_columns(&extract.this, columns);
515        }
516        Expression::Exists(_) | Expression::Subquery(_) => {
517            // These create their own scopes - don't collect from here
518        }
519        _ => {
520            // For other expressions, we might need to add more cases
521        }
522    }
523}
524
525/// Build scope tree from an expression
526///
527/// This traverses the expression tree and builds a hierarchy of Scope objects
528/// that track sources and column references at each level.
529pub fn build_scope(expression: &Expression) -> Scope {
530    let mut root = Scope::new(expression.clone());
531    build_scope_impl(expression, &mut root);
532    root
533}
534
535fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
536    match expression {
537        Expression::Select(select) => {
538            // Process CTEs first
539            if let Some(with) = &select.with {
540                for cte in &with.ctes {
541                    let cte_name = cte.alias.name.clone();
542                    let mut cte_scope = current_scope
543                        .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
544                    build_scope_impl(&cte.this, &mut cte_scope);
545                    current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
546                    current_scope.cte_scopes.push(cte_scope);
547                }
548            }
549
550            // Process FROM clause
551            if let Some(from) = &select.from {
552                for table in &from.expressions {
553                    add_table_to_scope(table, current_scope);
554                }
555            }
556
557            // Process JOINs
558            for join in &select.joins {
559                add_table_to_scope(&join.this, current_scope);
560            }
561
562            // Process subqueries in WHERE, SELECT expressions, etc.
563            collect_subqueries(expression, current_scope);
564        }
565        Expression::Union(union) => {
566            let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
567            build_scope_impl(&union.left, &mut left_scope);
568
569            let mut right_scope =
570                current_scope.branch(union.right.clone(), ScopeType::SetOperation);
571            build_scope_impl(&union.right, &mut right_scope);
572
573            current_scope.union_scopes.push(left_scope);
574            current_scope.union_scopes.push(right_scope);
575        }
576        Expression::Intersect(intersect) => {
577            let mut left_scope =
578                current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
579            build_scope_impl(&intersect.left, &mut left_scope);
580
581            let mut right_scope =
582                current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
583            build_scope_impl(&intersect.right, &mut right_scope);
584
585            current_scope.union_scopes.push(left_scope);
586            current_scope.union_scopes.push(right_scope);
587        }
588        Expression::Except(except) => {
589            let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
590            build_scope_impl(&except.left, &mut left_scope);
591
592            let mut right_scope =
593                current_scope.branch(except.right.clone(), ScopeType::SetOperation);
594            build_scope_impl(&except.right, &mut right_scope);
595
596            current_scope.union_scopes.push(left_scope);
597            current_scope.union_scopes.push(right_scope);
598        }
599        _ => {}
600    }
601}
602
603fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
604    match expr {
605        Expression::Table(table) => {
606            let name = table
607                .alias
608                .as_ref()
609                .map(|a| a.name.clone())
610                .unwrap_or_else(|| table.name.name.clone());
611            scope.add_source(name, expr.clone(), false);
612        }
613        Expression::Subquery(subquery) => {
614            let name = subquery
615                .alias
616                .as_ref()
617                .map(|a| a.name.clone())
618                .unwrap_or_default();
619
620            let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
621            build_scope_impl(&subquery.this, &mut derived_scope);
622
623            scope.add_source(name.clone(), expr.clone(), true);
624            scope.derived_table_scopes.push(derived_scope);
625        }
626        Expression::Paren(paren) => {
627            add_table_to_scope(&paren.this, scope);
628        }
629        _ => {}
630    }
631}
632
633fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
634    match expr {
635        Expression::Select(select) => {
636            // Check WHERE for subqueries
637            if let Some(where_clause) = &select.where_clause {
638                collect_subqueries_in_expr(&where_clause.this, parent_scope);
639            }
640            // Check SELECT expressions for subqueries
641            for e in &select.expressions {
642                collect_subqueries_in_expr(e, parent_scope);
643            }
644            // Check HAVING for subqueries
645            if let Some(having) = &select.having {
646                collect_subqueries_in_expr(&having.this, parent_scope);
647            }
648        }
649        _ => {}
650    }
651}
652
653fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
654    match expr {
655        Expression::Subquery(subquery) if subquery.alias.is_none() => {
656            // This is a scalar subquery or IN subquery (not a derived table)
657            let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
658            build_scope_impl(&subquery.this, &mut sub_scope);
659            parent_scope.subquery_scopes.push(sub_scope);
660        }
661        Expression::In(in_expr) => {
662            collect_subqueries_in_expr(&in_expr.this, parent_scope);
663            if let Some(query) = &in_expr.query {
664                let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
665                build_scope_impl(query, &mut sub_scope);
666                parent_scope.subquery_scopes.push(sub_scope);
667            }
668        }
669        Expression::Exists(exists) => {
670            let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
671            build_scope_impl(&exists.this, &mut sub_scope);
672            parent_scope.subquery_scopes.push(sub_scope);
673        }
674        // Binary operations
675        Expression::And(bin)
676        | Expression::Or(bin)
677        | Expression::Add(bin)
678        | Expression::Sub(bin)
679        | Expression::Mul(bin)
680        | Expression::Div(bin)
681        | Expression::Mod(bin)
682        | Expression::Eq(bin)
683        | Expression::Neq(bin)
684        | Expression::Lt(bin)
685        | Expression::Lte(bin)
686        | Expression::Gt(bin)
687        | Expression::Gte(bin)
688        | Expression::BitwiseAnd(bin)
689        | Expression::BitwiseOr(bin)
690        | Expression::BitwiseXor(bin)
691        | Expression::Concat(bin) => {
692            collect_subqueries_in_expr(&bin.left, parent_scope);
693            collect_subqueries_in_expr(&bin.right, parent_scope);
694        }
695        // LIKE/ILIKE operations (have different structure with escape)
696        Expression::Like(like) | Expression::ILike(like) => {
697            collect_subqueries_in_expr(&like.left, parent_scope);
698            collect_subqueries_in_expr(&like.right, parent_scope);
699            if let Some(escape) = &like.escape {
700                collect_subqueries_in_expr(escape, parent_scope);
701            }
702        }
703        // Unary operations
704        Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
705            collect_subqueries_in_expr(&un.this, parent_scope);
706        }
707        Expression::Function(func) => {
708            for arg in &func.args {
709                collect_subqueries_in_expr(arg, parent_scope);
710            }
711        }
712        Expression::Case(case) => {
713            if let Some(operand) = &case.operand {
714                collect_subqueries_in_expr(operand, parent_scope);
715            }
716            for (when_expr, then_expr) in &case.whens {
717                collect_subqueries_in_expr(when_expr, parent_scope);
718                collect_subqueries_in_expr(then_expr, parent_scope);
719            }
720            if let Some(else_clause) = &case.else_ {
721                collect_subqueries_in_expr(else_clause, parent_scope);
722            }
723        }
724        Expression::Paren(paren) => {
725            collect_subqueries_in_expr(&paren.this, parent_scope);
726        }
727        Expression::Alias(alias) => {
728            collect_subqueries_in_expr(&alias.this, parent_scope);
729        }
730        _ => {}
731    }
732}
733
734/// Walk within a scope, yielding expressions without crossing scope boundaries.
735///
736/// This iterator visits all nodes in the syntax tree, stopping at nodes that
737/// start child scopes (CTEs, derived tables, subqueries in FROM/JOIN).
738///
739/// # Arguments
740/// * `expression` - The expression to walk
741/// * `bfs` - If true, uses breadth-first search; otherwise uses depth-first search
742///
743/// # Returns
744/// An iterator over expressions within the scope
745pub fn walk_in_scope<'a>(
746    expression: &'a Expression,
747    bfs: bool,
748) -> impl Iterator<Item = &'a Expression> {
749    WalkInScopeIter::new(expression, bfs)
750}
751
752/// Iterator for walking within a scope
753struct WalkInScopeIter<'a> {
754    queue: VecDeque<&'a Expression>,
755    bfs: bool,
756}
757
758impl<'a> WalkInScopeIter<'a> {
759    fn new(expression: &'a Expression, bfs: bool) -> Self {
760        let mut queue = VecDeque::new();
761        queue.push_back(expression);
762        Self { queue, bfs }
763    }
764
765    fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
766        if is_root {
767            return false;
768        }
769
770        // Stop at CTE definitions
771        if matches!(expr, Expression::Cte(_)) {
772            return true;
773        }
774
775        // Stop at subqueries that are derived tables (in FROM/JOIN)
776        if let Expression::Subquery(subquery) = expr {
777            if subquery.alias.is_some() {
778                return true;
779            }
780        }
781
782        // Stop at standalone SELECT/UNION/etc that would be subqueries
783        if matches!(
784            expr,
785            Expression::Select(_)
786                | Expression::Union(_)
787                | Expression::Intersect(_)
788                | Expression::Except(_)
789        ) {
790            return true;
791        }
792
793        false
794    }
795
796    fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
797        let mut children = Vec::new();
798
799        match expr {
800            Expression::Select(select) => {
801                // Walk SELECT expressions
802                for e in &select.expressions {
803                    children.push(e);
804                }
805                // Walk FROM (but tables/subqueries create new scopes)
806                if let Some(from) = &select.from {
807                    for table in &from.expressions {
808                        if !self.should_stop_at(table, false) {
809                            children.push(table);
810                        }
811                    }
812                }
813                // Walk JOINs (but their sources create new scopes)
814                for join in &select.joins {
815                    if let Some(on) = &join.on {
816                        children.push(on);
817                    }
818                    // Don't traverse join.this as it's a source (table or subquery)
819                }
820                // Walk WHERE
821                if let Some(where_clause) = &select.where_clause {
822                    children.push(&where_clause.this);
823                }
824                // Walk GROUP BY
825                if let Some(group_by) = &select.group_by {
826                    for e in &group_by.expressions {
827                        children.push(e);
828                    }
829                }
830                // Walk HAVING
831                if let Some(having) = &select.having {
832                    children.push(&having.this);
833                }
834                // Walk ORDER BY
835                if let Some(order_by) = &select.order_by {
836                    for ord in &order_by.expressions {
837                        children.push(&ord.this);
838                    }
839                }
840                // Walk LIMIT
841                if let Some(limit) = &select.limit {
842                    children.push(&limit.this);
843                }
844                // Walk OFFSET
845                if let Some(offset) = &select.offset {
846                    children.push(&offset.this);
847                }
848            }
849            Expression::And(bin)
850            | Expression::Or(bin)
851            | Expression::Add(bin)
852            | Expression::Sub(bin)
853            | Expression::Mul(bin)
854            | Expression::Div(bin)
855            | Expression::Mod(bin)
856            | Expression::Eq(bin)
857            | Expression::Neq(bin)
858            | Expression::Lt(bin)
859            | Expression::Lte(bin)
860            | Expression::Gt(bin)
861            | Expression::Gte(bin)
862            | Expression::BitwiseAnd(bin)
863            | Expression::BitwiseOr(bin)
864            | Expression::BitwiseXor(bin)
865            | Expression::Concat(bin) => {
866                children.push(&bin.left);
867                children.push(&bin.right);
868            }
869            Expression::Like(like) | Expression::ILike(like) => {
870                children.push(&like.left);
871                children.push(&like.right);
872                if let Some(escape) = &like.escape {
873                    children.push(escape);
874                }
875            }
876            Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
877                children.push(&un.this);
878            }
879            Expression::Function(func) => {
880                for arg in &func.args {
881                    children.push(arg);
882                }
883            }
884            Expression::AggregateFunction(agg) => {
885                for arg in &agg.args {
886                    children.push(arg);
887                }
888            }
889            Expression::WindowFunction(wf) => {
890                children.push(&wf.this);
891                for e in &wf.over.partition_by {
892                    children.push(e);
893                }
894                for e in &wf.over.order_by {
895                    children.push(&e.this);
896                }
897            }
898            Expression::Alias(alias) => {
899                children.push(&alias.this);
900            }
901            Expression::Case(case) => {
902                if let Some(operand) = &case.operand {
903                    children.push(operand);
904                }
905                for (when_expr, then_expr) in &case.whens {
906                    children.push(when_expr);
907                    children.push(then_expr);
908                }
909                if let Some(else_clause) = &case.else_ {
910                    children.push(else_clause);
911                }
912            }
913            Expression::Paren(paren) => {
914                children.push(&paren.this);
915            }
916            Expression::Ordered(ord) => {
917                children.push(&ord.this);
918            }
919            Expression::In(in_expr) => {
920                children.push(&in_expr.this);
921                for e in &in_expr.expressions {
922                    children.push(e);
923                }
924                // Note: in_expr.query creates a new scope - don't traverse
925            }
926            Expression::Between(between) => {
927                children.push(&between.this);
928                children.push(&between.low);
929                children.push(&between.high);
930            }
931            Expression::IsNull(is_null) => {
932                children.push(&is_null.this);
933            }
934            Expression::Cast(cast) => {
935                children.push(&cast.this);
936            }
937            Expression::Extract(extract) => {
938                children.push(&extract.this);
939            }
940            Expression::Coalesce(coalesce) => {
941                for e in &coalesce.expressions {
942                    children.push(e);
943                }
944            }
945            Expression::NullIf(nullif) => {
946                children.push(&nullif.this);
947                children.push(&nullif.expression);
948            }
949            Expression::Table(_table) => {
950                // Tables don't have child expressions to traverse within scope
951                // (joins are handled at the Select level)
952            }
953            Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
954                // Leaf nodes - no children
955            }
956            // Subqueries and Exists create new scopes - don't traverse into them
957            Expression::Subquery(_) | Expression::Exists(_) => {}
958            _ => {
959                // For other expressions, we could add more cases as needed
960            }
961        }
962
963        children
964    }
965}
966
967impl<'a> Iterator for WalkInScopeIter<'a> {
968    type Item = &'a Expression;
969
970    fn next(&mut self) -> Option<Self::Item> {
971        let expr = if self.bfs {
972            self.queue.pop_front()?
973        } else {
974            self.queue.pop_back()?
975        };
976
977        // Get children that don't cross scope boundaries
978        let children = self.get_children(expr);
979
980        if self.bfs {
981            for child in children {
982                if !self.should_stop_at(child, false) {
983                    self.queue.push_back(child);
984                }
985            }
986        } else {
987            for child in children.into_iter().rev() {
988                if !self.should_stop_at(child, false) {
989                    self.queue.push_back(child);
990                }
991            }
992        }
993
994        Some(expr)
995    }
996}
997
998/// Find the first expression matching the predicate within this scope.
999///
1000/// This does NOT traverse into subscopes.
1001///
1002/// # Arguments
1003/// * `expression` - The root expression
1004/// * `predicate` - Function that returns true for matching expressions
1005/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1006///
1007/// # Returns
1008/// The first matching expression, or None
1009pub fn find_in_scope<'a, F>(
1010    expression: &'a Expression,
1011    predicate: F,
1012    bfs: bool,
1013) -> Option<&'a Expression>
1014where
1015    F: Fn(&Expression) -> bool,
1016{
1017    walk_in_scope(expression, bfs).find(|e| predicate(e))
1018}
1019
1020/// Find all expressions matching the predicate within this scope.
1021///
1022/// This does NOT traverse into subscopes.
1023///
1024/// # Arguments
1025/// * `expression` - The root expression
1026/// * `predicate` - Function that returns true for matching expressions
1027/// * `bfs` - If true, uses breadth-first search; otherwise depth-first
1028///
1029/// # Returns
1030/// A vector of matching expressions
1031pub fn find_all_in_scope<'a, F>(
1032    expression: &'a Expression,
1033    predicate: F,
1034    bfs: bool,
1035) -> Vec<&'a Expression>
1036where
1037    F: Fn(&Expression) -> bool,
1038{
1039    walk_in_scope(expression, bfs)
1040        .filter(|e| predicate(e))
1041        .collect()
1042}
1043
1044/// Traverse an expression by its "scopes".
1045///
1046/// Returns a list of all scopes in depth-first post-order.
1047///
1048/// # Arguments
1049/// * `expression` - The expression to traverse
1050///
1051/// # Returns
1052/// A vector of all scopes found
1053pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1054    match expression {
1055        Expression::Select(_)
1056        | Expression::Union(_)
1057        | Expression::Intersect(_)
1058        | Expression::Except(_) => {
1059            let root = build_scope(expression);
1060            root.traverse().into_iter().cloned().collect()
1061        }
1062        _ => Vec::new(),
1063    }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069    use crate::parser::Parser;
1070
1071    fn parse_and_build_scope(sql: &str) -> Scope {
1072        let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1073        build_scope(&ast[0])
1074    }
1075
1076    #[test]
1077    fn test_simple_select_scope() {
1078        let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1079
1080        assert!(scope.is_root());
1081        assert!(!scope.can_be_correlated);
1082        assert!(scope.sources.contains_key("t"));
1083
1084        let columns = scope.columns();
1085        assert_eq!(columns.len(), 2);
1086    }
1087
1088    #[test]
1089    fn test_derived_table_scope() {
1090        let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1091
1092        assert!(scope.sources.contains_key("x"));
1093        assert_eq!(scope.derived_table_scopes.len(), 1);
1094
1095        let derived = &mut scope.derived_table_scopes[0];
1096        assert!(derived.is_derived_table());
1097        assert!(derived.sources.contains_key("t"));
1098    }
1099
1100    #[test]
1101    fn test_non_correlated_subquery() {
1102        let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1103
1104        assert_eq!(scope.subquery_scopes.len(), 1);
1105
1106        let subquery = &mut scope.subquery_scopes[0];
1107        assert!(subquery.is_subquery());
1108        assert!(subquery.can_be_correlated);
1109
1110        // The subquery references only 's', which is in its own sources
1111        assert!(subquery.sources.contains_key("s"));
1112        assert!(!subquery.is_correlated_subquery());
1113    }
1114
1115    #[test]
1116    fn test_correlated_subquery() {
1117        let mut scope =
1118            parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1119
1120        assert_eq!(scope.subquery_scopes.len(), 1);
1121
1122        let subquery = &mut scope.subquery_scopes[0];
1123        assert!(subquery.is_subquery());
1124        assert!(subquery.can_be_correlated);
1125
1126        // The subquery references 't.y' which is external
1127        let external = subquery.external_columns();
1128        assert!(!external.is_empty());
1129        assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1130        assert!(subquery.is_correlated_subquery());
1131    }
1132
1133    #[test]
1134    fn test_cte_scope() {
1135        let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1136
1137        assert_eq!(scope.cte_scopes.len(), 1);
1138        assert!(scope.cte_sources.contains_key("cte"));
1139
1140        let cte = &scope.cte_scopes[0];
1141        assert!(cte.is_cte());
1142    }
1143
1144    #[test]
1145    fn test_multiple_sources() {
1146        let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1147
1148        assert!(scope.sources.contains_key("t"));
1149        assert!(scope.sources.contains_key("s"));
1150        assert_eq!(scope.sources.len(), 2);
1151    }
1152
1153    #[test]
1154    fn test_aliased_table() {
1155        let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1156
1157        // Should be indexed by alias, not original name
1158        assert!(scope.sources.contains_key("x"));
1159        assert!(!scope.sources.contains_key("t"));
1160    }
1161
1162    #[test]
1163    fn test_local_columns() {
1164        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1165
1166        let local = scope.local_columns();
1167        // All columns are local since both t and s are in scope
1168        assert_eq!(local.len(), 3);
1169        assert!(local.iter().all(|c| c.table.is_some()));
1170    }
1171
1172    #[test]
1173    fn test_unqualified_columns() {
1174        let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1175
1176        let unqualified = scope.unqualified_columns();
1177        // Only a and b are unqualified
1178        assert_eq!(unqualified.len(), 2);
1179        assert!(unqualified.iter().all(|c| c.table.is_none()));
1180    }
1181
1182    #[test]
1183    fn test_source_columns() {
1184        let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1185
1186        let t_cols = scope.source_columns("t");
1187        // t.a, t.b, and t.id from JOIN condition
1188        assert!(t_cols.len() >= 2);
1189        assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1190
1191        let s_cols = scope.source_columns("s");
1192        // s.c and s.id from JOIN condition
1193        assert!(s_cols.len() >= 1);
1194        assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1195    }
1196
1197    #[test]
1198    fn test_rename_source() {
1199        let mut scope = parse_and_build_scope("SELECT a FROM t");
1200
1201        assert!(scope.sources.contains_key("t"));
1202        scope.rename_source("t", "new_name".to_string());
1203        assert!(!scope.sources.contains_key("t"));
1204        assert!(scope.sources.contains_key("new_name"));
1205    }
1206
1207    #[test]
1208    fn test_remove_source() {
1209        let mut scope = parse_and_build_scope("SELECT a FROM t");
1210
1211        assert!(scope.sources.contains_key("t"));
1212        scope.remove_source("t");
1213        assert!(!scope.sources.contains_key("t"));
1214    }
1215
1216    #[test]
1217    fn test_walk_in_scope() {
1218        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1219        let expr = &ast[0];
1220
1221        // Walk should visit all expressions within the scope
1222        let walked: Vec<_> = walk_in_scope(expr, true).collect();
1223        assert!(!walked.is_empty());
1224
1225        // Should include the root SELECT
1226        assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1227        // Should include columns
1228        assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1229    }
1230
1231    #[test]
1232    fn test_find_in_scope() {
1233        let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1234        let expr = &ast[0];
1235
1236        // Find the first column
1237        let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1238        assert!(found.is_some());
1239        assert!(matches!(found.unwrap(), Expression::Column(_)));
1240    }
1241
1242    #[test]
1243    fn test_find_all_in_scope() {
1244        let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1245        let expr = &ast[0];
1246
1247        // Find all columns
1248        let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1249        assert_eq!(found.len(), 3);
1250    }
1251
1252    #[test]
1253    fn test_traverse_scope() {
1254        let ast =
1255            Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1256        let expr = &ast[0];
1257
1258        let scopes = traverse_scope(expr);
1259        // traverse_scope returns all scopes via Scope::traverse
1260        // which includes derived table and root scopes
1261        assert!(!scopes.is_empty());
1262        // The root scope is always included
1263        assert!(scopes.iter().any(|s| s.is_root()));
1264    }
1265
1266    #[test]
1267    fn test_branch_with_options() {
1268        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1269        let scope = build_scope(&ast[0]);
1270
1271        let child = scope.branch_with_options(
1272            ast[0].clone(),
1273            ScopeType::Subquery, // Use Subquery to test can_be_correlated
1274            None,
1275            None,
1276            Some(vec!["col1".to_string(), "col2".to_string()]),
1277        );
1278
1279        assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1280        assert!(child.can_be_correlated); // Subqueries are correlated
1281    }
1282
1283    #[test]
1284    fn test_is_udtf() {
1285        let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1286        let scope = Scope::new(ast[0].clone());
1287        assert!(!scope.is_udtf());
1288
1289        let root = build_scope(&ast[0]);
1290        let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1291        assert!(udtf_scope.is_udtf());
1292    }
1293
1294    #[test]
1295    fn test_is_union() {
1296        let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1297
1298        assert!(scope.is_root());
1299        assert_eq!(scope.union_scopes.len(), 2);
1300        // The children are set operation scopes
1301        assert!(scope.union_scopes[0].is_union());
1302        assert!(scope.union_scopes[1].is_union());
1303    }
1304
1305    #[test]
1306    fn test_clear_cache() {
1307        let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1308
1309        // First call populates cache
1310        let _ = scope.columns();
1311        assert!(scope.columns_cache.is_some());
1312
1313        // Clear cache
1314        scope.clear_cache();
1315        assert!(scope.columns_cache.is_none());
1316        assert!(scope.external_columns_cache.is_none());
1317    }
1318
1319    #[test]
1320    fn test_scope_traverse() {
1321        let scope = parse_and_build_scope(
1322            "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1323        );
1324
1325        let traversed = scope.traverse();
1326        // Should include: CTE scope, subquery scope, root scope
1327        assert!(traversed.len() >= 3);
1328    }
1329}