1use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19 Root,
21 Subquery,
23 DerivedTable,
25 Cte,
27 SetOperation,
29 Udtf,
31}
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
35#[serde(rename_all = "snake_case")]
36pub enum SourceKind {
37 Root,
39 Table,
41 DerivedTable,
43 Cte,
45 Virtual,
47 Unknown,
49}
50
51impl Default for SourceKind {
52 fn default() -> Self {
53 Self::Unknown
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct SourceInfo {
60 pub expression: Expression,
62 pub is_scope: bool,
64 pub kind: SourceKind,
66 pub alias: Option<String>,
68 pub lineage_name: Option<String>,
70}
71
72impl SourceInfo {
73 pub fn new(expression: Expression, is_scope: bool, kind: SourceKind) -> Self {
74 Self {
75 expression,
76 is_scope,
77 kind,
78 alias: None,
79 lineage_name: None,
80 }
81 }
82
83 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
84 self.alias = Some(alias.into());
85 self
86 }
87
88 pub fn with_lineage_name(mut self, lineage_name: impl Into<String>) -> Self {
89 self.lineage_name = Some(lineage_name.into());
90 self
91 }
92}
93
94#[derive(Debug, Clone, PartialEq, Eq, Hash)]
96pub struct ColumnRef {
97 pub table: Option<String>,
99 pub name: String,
101}
102
103#[derive(Debug, Clone)]
108pub struct Scope {
109 pub expression: Expression,
111
112 pub scope_type: ScopeType,
114
115 pub sources: HashMap<String, SourceInfo>,
117
118 pub lateral_sources: HashMap<String, SourceInfo>,
120
121 pub cte_sources: HashMap<String, SourceInfo>,
123
124 pub outer_columns: Vec<String>,
127
128 pub can_be_correlated: bool,
131
132 pub subquery_scopes: Vec<Scope>,
134
135 pub derived_table_scopes: Vec<Scope>,
137
138 pub cte_scopes: Vec<Scope>,
140
141 pub udtf_scopes: Vec<Scope>,
143
144 pub table_scopes: Vec<Scope>,
146
147 pub union_scopes: Vec<Scope>,
149
150 columns_cache: Option<Vec<ColumnRef>>,
152
153 external_columns_cache: Option<Vec<ColumnRef>>,
155}
156
157impl Scope {
158 pub fn new(expression: Expression) -> Self {
160 Self {
161 expression,
162 scope_type: ScopeType::Root,
163 sources: HashMap::new(),
164 lateral_sources: HashMap::new(),
165 cte_sources: HashMap::new(),
166 outer_columns: Vec::new(),
167 can_be_correlated: false,
168 subquery_scopes: Vec::new(),
169 derived_table_scopes: Vec::new(),
170 cte_scopes: Vec::new(),
171 udtf_scopes: Vec::new(),
172 table_scopes: Vec::new(),
173 union_scopes: Vec::new(),
174 columns_cache: None,
175 external_columns_cache: None,
176 }
177 }
178
179 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
181 self.branch_with_options(expression, scope_type, None, None, None)
182 }
183
184 pub fn branch_with_options(
186 &self,
187 expression: Expression,
188 scope_type: ScopeType,
189 sources: Option<HashMap<String, SourceInfo>>,
190 lateral_sources: Option<HashMap<String, SourceInfo>>,
191 outer_columns: Option<Vec<String>>,
192 ) -> Self {
193 let can_be_correlated = self.can_be_correlated
194 || scope_type == ScopeType::Subquery
195 || scope_type == ScopeType::Udtf;
196
197 Self {
198 expression,
199 scope_type,
200 sources: sources.unwrap_or_default(),
201 lateral_sources: lateral_sources.unwrap_or_default(),
202 cte_sources: self.cte_sources.clone(),
203 outer_columns: outer_columns.unwrap_or_default(),
204 can_be_correlated,
205 subquery_scopes: Vec::new(),
206 derived_table_scopes: Vec::new(),
207 cte_scopes: Vec::new(),
208 udtf_scopes: Vec::new(),
209 table_scopes: Vec::new(),
210 union_scopes: Vec::new(),
211 columns_cache: None,
212 external_columns_cache: None,
213 }
214 }
215
216 pub fn clear_cache(&mut self) {
218 self.columns_cache = None;
219 self.external_columns_cache = None;
220 }
221
222 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
224 let kind = if is_scope {
225 SourceKind::DerivedTable
226 } else {
227 SourceKind::Table
228 };
229 self.add_source_info(name, SourceInfo::new(expression, is_scope, kind));
230 }
231
232 pub fn add_source_info(&mut self, name: String, info: SourceInfo) {
234 self.sources.insert(name, info);
235 self.clear_cache();
236 }
237
238 pub fn add_virtual_source(&mut self, alias: String, expression: Expression) {
240 let lineage_name = self.next_virtual_source_name();
241 let info = SourceInfo::new(expression, false, SourceKind::Virtual)
242 .with_alias(alias.clone())
243 .with_lineage_name(lineage_name);
244 self.add_source_info(alias, info);
245 }
246
247 fn next_virtual_source_name(&self) -> String {
248 let count = self
249 .sources
250 .values()
251 .filter(|source| source.kind == SourceKind::Virtual)
252 .count();
253 format!("_{}", count)
254 }
255
256 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
258 let kind = if is_scope {
259 SourceKind::DerivedTable
260 } else {
261 SourceKind::Table
262 };
263 let info = SourceInfo::new(expression.clone(), is_scope, kind);
264 self.sources.insert(name.clone(), info.clone());
265 self.lateral_sources.insert(name, info);
266 self.clear_cache();
267 }
268
269 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
271 let info = SourceInfo::new(expression, true, SourceKind::Cte);
272 self.cte_sources.insert(name.clone(), info.clone());
273 self.sources.insert(name, info);
274 self.clear_cache();
275 }
276
277 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
279 if let Some(source) = self.sources.remove(old_name) {
280 self.sources.insert(new_name, source);
281 }
282 self.clear_cache();
283 }
284
285 pub fn remove_source(&mut self, name: &str) {
287 self.sources.remove(name);
288 self.clear_cache();
289 }
290
291 pub fn columns(&mut self) -> &[ColumnRef] {
293 if self.columns_cache.is_none() {
294 let mut columns = Vec::new();
295 collect_columns(&self.expression, &mut columns);
296 self.columns_cache = Some(columns);
297 }
298 self.columns_cache.as_ref().unwrap()
299 }
300
301 pub fn output_columns(&self) -> Vec<String> {
306 crate::ast_transforms::get_output_column_names(&self.expression)
307 }
308
309 pub fn source_names(&self) -> HashSet<String> {
311 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
312 names.extend(self.cte_sources.keys().cloned());
313 names
314 }
315
316 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
318 if self.external_columns_cache.is_some() {
319 return self.external_columns_cache.clone().unwrap();
320 }
321
322 let source_names = self.source_names();
323 let columns = self.columns().to_vec();
324
325 let external: Vec<ColumnRef> = columns
326 .into_iter()
327 .filter(|col| {
328 match &col.table {
330 Some(table) => !source_names.contains(table),
331 None => false, }
333 })
334 .collect();
335
336 self.external_columns_cache = Some(external.clone());
337 external
338 }
339
340 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
342 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
343 let columns = self.columns().to_vec();
344
345 columns
346 .into_iter()
347 .filter(|col| !external_set.contains(col))
348 .collect()
349 }
350
351 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
353 self.columns()
354 .iter()
355 .filter(|c| c.table.is_none())
356 .cloned()
357 .collect()
358 }
359
360 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
362 self.columns()
363 .iter()
364 .filter(|col| col.table.as_deref() == Some(source_name))
365 .cloned()
366 .collect()
367 }
368
369 pub fn is_correlated_subquery(&mut self) -> bool {
375 self.can_be_correlated && !self.external_columns().is_empty()
376 }
377
378 pub fn is_subquery(&self) -> bool {
380 self.scope_type == ScopeType::Subquery
381 }
382
383 pub fn is_derived_table(&self) -> bool {
385 self.scope_type == ScopeType::DerivedTable
386 }
387
388 pub fn is_cte(&self) -> bool {
390 self.scope_type == ScopeType::Cte
391 }
392
393 pub fn is_root(&self) -> bool {
395 self.scope_type == ScopeType::Root
396 }
397
398 pub fn is_udtf(&self) -> bool {
400 self.scope_type == ScopeType::Udtf
401 }
402
403 pub fn is_union(&self) -> bool {
405 self.scope_type == ScopeType::SetOperation
406 }
407
408 pub fn traverse(&self) -> Vec<&Scope> {
410 let mut result = Vec::new();
411 self.traverse_impl(&mut result);
412 result
413 }
414
415 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
416 for scope in &self.cte_scopes {
418 scope.traverse_impl(result);
419 }
420 for scope in &self.union_scopes {
421 scope.traverse_impl(result);
422 }
423 for scope in &self.table_scopes {
424 scope.traverse_impl(result);
425 }
426 for scope in &self.subquery_scopes {
427 scope.traverse_impl(result);
428 }
429 result.push(self);
431 }
432
433 pub fn ref_count(&self) -> HashMap<usize, usize> {
435 let mut counts: HashMap<usize, usize> = HashMap::new();
436
437 for scope in self.traverse() {
438 for (_, source_info) in scope.sources.iter() {
439 if source_info.is_scope {
440 let id = &source_info.expression as *const _ as usize;
441 *counts.entry(id).or_insert(0) += 1;
442 }
443 }
444 }
445
446 counts
447 }
448}
449
450fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
452 match expr {
453 Expression::Column(col) => {
454 columns.push(ColumnRef {
455 table: col.table.as_ref().map(|t| t.name.clone()),
456 name: col.name.name.clone(),
457 });
458 }
459 Expression::Select(select) => {
460 for e in &select.expressions {
462 collect_columns(e, columns);
463 }
464 for join in &select.joins {
466 if let Some(on) = &join.on {
467 collect_columns(on, columns);
468 }
469 if let Some(match_condition) = &join.match_condition {
470 collect_columns(match_condition, columns);
471 }
472 }
473 if let Some(where_clause) = &select.where_clause {
475 collect_columns(&where_clause.this, columns);
476 }
477 if let Some(having) = &select.having {
479 collect_columns(&having.this, columns);
480 }
481 if let Some(order_by) = &select.order_by {
483 for ord in &order_by.expressions {
484 collect_columns(&ord.this, columns);
485 }
486 }
487 if let Some(group_by) = &select.group_by {
489 for e in &group_by.expressions {
490 collect_columns(e, columns);
491 }
492 }
493 }
496 Expression::And(bin)
498 | Expression::Or(bin)
499 | Expression::Add(bin)
500 | Expression::Sub(bin)
501 | Expression::Mul(bin)
502 | Expression::Div(bin)
503 | Expression::Mod(bin)
504 | Expression::Eq(bin)
505 | Expression::Neq(bin)
506 | Expression::Lt(bin)
507 | Expression::Lte(bin)
508 | Expression::Gt(bin)
509 | Expression::Gte(bin)
510 | Expression::BitwiseAnd(bin)
511 | Expression::BitwiseOr(bin)
512 | Expression::BitwiseXor(bin)
513 | Expression::Concat(bin) => {
514 collect_columns(&bin.left, columns);
515 collect_columns(&bin.right, columns);
516 }
517 Expression::Like(like) | Expression::ILike(like) => {
519 collect_columns(&like.left, columns);
520 collect_columns(&like.right, columns);
521 if let Some(escape) = &like.escape {
522 collect_columns(escape, columns);
523 }
524 }
525 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
527 collect_columns(&un.this, columns);
528 }
529 Expression::Function(func) => {
530 for arg in &func.args {
531 collect_columns(arg, columns);
532 }
533 }
534 Expression::AggregateFunction(agg) => {
535 for arg in &agg.args {
536 collect_columns(arg, columns);
537 }
538 }
539 Expression::WindowFunction(wf) => {
540 collect_columns(&wf.this, columns);
541 for e in &wf.over.partition_by {
542 collect_columns(e, columns);
543 }
544 for e in &wf.over.order_by {
545 collect_columns(&e.this, columns);
546 }
547 }
548 Expression::Alias(alias) => {
549 collect_columns(&alias.this, columns);
550 }
551 Expression::Case(case) => {
552 if let Some(operand) = &case.operand {
553 collect_columns(operand, columns);
554 }
555 for (when_expr, then_expr) in &case.whens {
556 collect_columns(when_expr, columns);
557 collect_columns(then_expr, columns);
558 }
559 if let Some(else_clause) = &case.else_ {
560 collect_columns(else_clause, columns);
561 }
562 }
563 Expression::Paren(paren) => {
564 collect_columns(&paren.this, columns);
565 }
566 Expression::Ordered(ord) => {
567 collect_columns(&ord.this, columns);
568 }
569 Expression::In(in_expr) => {
570 collect_columns(&in_expr.this, columns);
571 for e in &in_expr.expressions {
572 collect_columns(e, columns);
573 }
574 }
576 Expression::Between(between) => {
577 collect_columns(&between.this, columns);
578 collect_columns(&between.low, columns);
579 collect_columns(&between.high, columns);
580 }
581 Expression::IsNull(is_null) => {
582 collect_columns(&is_null.this, columns);
583 }
584 Expression::Cast(cast) => {
585 collect_columns(&cast.this, columns);
586 }
587 Expression::Extract(extract) => {
588 collect_columns(&extract.this, columns);
589 }
590 Expression::Exists(_) | Expression::Subquery(_) => {
591 }
593 Expression::Prepare(prepare) => {
594 collect_columns(&prepare.statement, columns);
595 }
596 _ => {
597 }
599 }
600}
601
602pub fn build_scope(expression: &Expression) -> Scope {
607 let mut root = Scope::new(expression.clone());
608 build_scope_impl(expression, &mut root);
609 root
610}
611
612fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
613 match expression {
614 Expression::Prepare(prepare) => {
615 build_scope_impl(&prepare.statement, current_scope);
616 }
617 Expression::Select(select) => {
618 if let Some(with) = &select.with {
620 for cte in &with.ctes {
621 let cte_name = cte.alias.name.clone();
622 let mut cte_scope = current_scope
623 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
624 build_scope_impl(&cte.this, &mut cte_scope);
625 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
626 current_scope.cte_scopes.push(cte_scope);
627 }
628 }
629
630 if let Some(from) = &select.from {
632 for table in &from.expressions {
633 add_table_to_scope(table, current_scope);
634 }
635 }
636
637 for join in &select.joins {
639 add_table_to_scope(&join.this, current_scope);
640 }
641
642 for lateral_view in &select.lateral_views {
644 add_lateral_view_to_scope(lateral_view, current_scope);
645 }
646
647 collect_subqueries(expression, current_scope);
649 }
650 Expression::Union(union) => {
651 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
652 build_scope_impl(&union.left, &mut left_scope);
653
654 let mut right_scope =
655 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
656 build_scope_impl(&union.right, &mut right_scope);
657
658 current_scope.union_scopes.push(left_scope);
659 current_scope.union_scopes.push(right_scope);
660 }
661 Expression::Intersect(intersect) => {
662 let mut left_scope =
663 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
664 build_scope_impl(&intersect.left, &mut left_scope);
665
666 let mut right_scope =
667 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
668 build_scope_impl(&intersect.right, &mut right_scope);
669
670 current_scope.union_scopes.push(left_scope);
671 current_scope.union_scopes.push(right_scope);
672 }
673 Expression::Except(except) => {
674 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
675 build_scope_impl(&except.left, &mut left_scope);
676
677 let mut right_scope =
678 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
679 build_scope_impl(&except.right, &mut right_scope);
680
681 current_scope.union_scopes.push(left_scope);
682 current_scope.union_scopes.push(right_scope);
683 }
684 Expression::CreateTable(create) => {
685 if let Some(with) = &create.with_cte {
688 for cte in &with.ctes {
689 let cte_name = cte.alias.name.clone();
690 let mut cte_scope = current_scope
691 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
692 build_scope_impl(&cte.this, &mut cte_scope);
693 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
694 current_scope.cte_scopes.push(cte_scope);
695 }
696 }
697 if let Some(as_select) = &create.as_select {
699 build_scope_impl(as_select, current_scope);
700 }
701 }
702 _ => {}
703 }
704}
705
706fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
707 match expr {
708 Expression::Table(table) => {
709 let name = table
710 .alias
711 .as_ref()
712 .map(|a| a.name.clone())
713 .unwrap_or_else(|| table.name.name.clone());
714 let cte_source = if table.schema.is_none() && table.catalog.is_none() {
715 scope.cte_sources.get(&table.name.name).or_else(|| {
716 scope
717 .cte_sources
718 .iter()
719 .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
720 .map(|(_, source)| source)
721 })
722 } else {
723 None
724 };
725
726 if let Some(source) = cte_source {
727 scope.add_source_info(name, source.clone());
728 } else {
729 let mut source = SourceInfo::new(expr.clone(), false, SourceKind::Table);
730 if let Some(alias) = &table.alias {
731 source = source.with_alias(alias.name.clone());
732 }
733 scope.add_source_info(name, source);
734 }
735 }
736 Expression::Subquery(subquery) => {
737 let name = subquery
738 .alias
739 .as_ref()
740 .map(|a| a.name.clone())
741 .unwrap_or_default();
742
743 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
744 build_scope_impl(&subquery.this, &mut derived_scope);
745
746 scope.add_source(name.clone(), expr.clone(), true);
747 scope.derived_table_scopes.push(derived_scope);
748 }
749 Expression::Unnest(unnest) => {
750 if let Some(alias) = &unnest.alias {
751 scope.add_virtual_source(alias.name.clone(), expr.clone());
752 }
753 }
754 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
755 scope.add_virtual_source(alias.alias.name.clone(), expr.clone());
756 }
757 Expression::Lateral(lateral) => {
758 if let Some(alias) = &lateral.alias {
759 scope.add_virtual_source(alias.clone(), expr.clone());
760 }
761 }
762 Expression::LateralView(lateral_view) => {
763 add_lateral_view_to_scope(lateral_view, scope);
764 }
765 Expression::Paren(paren) => {
766 add_table_to_scope(&paren.this, scope);
767 }
768 _ => {}
769 }
770}
771
772fn add_lateral_view_to_scope(lateral_view: &crate::expressions::LateralView, scope: &mut Scope) {
773 let alias = lateral_view
774 .table_alias
775 .as_ref()
776 .or_else(|| lateral_view.column_aliases.first())
777 .map(|alias| alias.name.clone());
778
779 if let Some(alias) = alias {
780 scope.add_virtual_source(
781 alias,
782 Expression::LateralView(Box::new(lateral_view.clone())),
783 );
784 }
785}
786
787fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
788 match expr {
789 Expression::Select(select) => {
790 if let Some(where_clause) = &select.where_clause {
792 collect_subqueries_in_expr(&where_clause.this, parent_scope);
793 }
794 for e in &select.expressions {
796 collect_subqueries_in_expr(e, parent_scope);
797 }
798 if let Some(having) = &select.having {
800 collect_subqueries_in_expr(&having.this, parent_scope);
801 }
802 }
803 _ => {}
804 }
805}
806
807fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
808 match expr {
809 Expression::Subquery(subquery) if subquery.alias.is_none() => {
810 let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
812 build_scope_impl(&subquery.this, &mut sub_scope);
813 parent_scope.subquery_scopes.push(sub_scope);
814 }
815 Expression::In(in_expr) => {
816 collect_subqueries_in_expr(&in_expr.this, parent_scope);
817 if let Some(query) = &in_expr.query {
818 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
819 build_scope_impl(query, &mut sub_scope);
820 parent_scope.subquery_scopes.push(sub_scope);
821 }
822 }
823 Expression::Exists(exists) => {
824 let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
825 build_scope_impl(&exists.this, &mut sub_scope);
826 parent_scope.subquery_scopes.push(sub_scope);
827 }
828 Expression::And(bin)
830 | Expression::Or(bin)
831 | Expression::Add(bin)
832 | Expression::Sub(bin)
833 | Expression::Mul(bin)
834 | Expression::Div(bin)
835 | Expression::Mod(bin)
836 | Expression::Eq(bin)
837 | Expression::Neq(bin)
838 | Expression::Lt(bin)
839 | Expression::Lte(bin)
840 | Expression::Gt(bin)
841 | Expression::Gte(bin)
842 | Expression::BitwiseAnd(bin)
843 | Expression::BitwiseOr(bin)
844 | Expression::BitwiseXor(bin)
845 | Expression::Concat(bin) => {
846 collect_subqueries_in_expr(&bin.left, parent_scope);
847 collect_subqueries_in_expr(&bin.right, parent_scope);
848 }
849 Expression::Like(like) | Expression::ILike(like) => {
851 collect_subqueries_in_expr(&like.left, parent_scope);
852 collect_subqueries_in_expr(&like.right, parent_scope);
853 if let Some(escape) = &like.escape {
854 collect_subqueries_in_expr(escape, parent_scope);
855 }
856 }
857 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
859 collect_subqueries_in_expr(&un.this, parent_scope);
860 }
861 Expression::Function(func) => {
862 for arg in &func.args {
863 collect_subqueries_in_expr(arg, parent_scope);
864 }
865 }
866 Expression::Case(case) => {
867 if let Some(operand) = &case.operand {
868 collect_subqueries_in_expr(operand, parent_scope);
869 }
870 for (when_expr, then_expr) in &case.whens {
871 collect_subqueries_in_expr(when_expr, parent_scope);
872 collect_subqueries_in_expr(then_expr, parent_scope);
873 }
874 if let Some(else_clause) = &case.else_ {
875 collect_subqueries_in_expr(else_clause, parent_scope);
876 }
877 }
878 Expression::Paren(paren) => {
879 collect_subqueries_in_expr(&paren.this, parent_scope);
880 }
881 Expression::Alias(alias) => {
882 collect_subqueries_in_expr(&alias.this, parent_scope);
883 }
884 _ => {}
885 }
886}
887
888pub fn walk_in_scope<'a>(
900 expression: &'a Expression,
901 bfs: bool,
902) -> impl Iterator<Item = &'a Expression> {
903 WalkInScopeIter::new(expression, bfs)
904}
905
906struct WalkInScopeIter<'a> {
908 queue: VecDeque<&'a Expression>,
909 bfs: bool,
910}
911
912impl<'a> WalkInScopeIter<'a> {
913 fn new(expression: &'a Expression, bfs: bool) -> Self {
914 let mut queue = VecDeque::new();
915 queue.push_back(expression);
916 Self { queue, bfs }
917 }
918
919 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
920 if is_root {
921 return false;
922 }
923
924 if matches!(expr, Expression::Cte(_)) {
926 return true;
927 }
928
929 if let Expression::Subquery(subquery) = expr {
931 if subquery.alias.is_some() {
932 return true;
933 }
934 }
935
936 if matches!(
938 expr,
939 Expression::Select(_)
940 | Expression::Union(_)
941 | Expression::Intersect(_)
942 | Expression::Except(_)
943 ) {
944 return true;
945 }
946
947 false
948 }
949
950 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
951 let mut children = Vec::new();
952
953 match expr {
954 Expression::Prepare(prepare) => {
955 children.push(&prepare.statement);
956 }
957 Expression::Select(select) => {
958 for e in &select.expressions {
960 children.push(e);
961 }
962 if let Some(from) = &select.from {
964 for table in &from.expressions {
965 if !self.should_stop_at(table, false) {
966 children.push(table);
967 }
968 }
969 }
970 for join in &select.joins {
972 if let Some(on) = &join.on {
973 children.push(on);
974 }
975 }
977 if let Some(where_clause) = &select.where_clause {
979 children.push(&where_clause.this);
980 }
981 if let Some(group_by) = &select.group_by {
983 for e in &group_by.expressions {
984 children.push(e);
985 }
986 }
987 if let Some(having) = &select.having {
989 children.push(&having.this);
990 }
991 if let Some(order_by) = &select.order_by {
993 for ord in &order_by.expressions {
994 children.push(&ord.this);
995 }
996 }
997 if let Some(limit) = &select.limit {
999 children.push(&limit.this);
1000 }
1001 if let Some(offset) = &select.offset {
1003 children.push(&offset.this);
1004 }
1005 }
1006 Expression::And(bin)
1007 | Expression::Or(bin)
1008 | Expression::Add(bin)
1009 | Expression::Sub(bin)
1010 | Expression::Mul(bin)
1011 | Expression::Div(bin)
1012 | Expression::Mod(bin)
1013 | Expression::Eq(bin)
1014 | Expression::Neq(bin)
1015 | Expression::Lt(bin)
1016 | Expression::Lte(bin)
1017 | Expression::Gt(bin)
1018 | Expression::Gte(bin)
1019 | Expression::BitwiseAnd(bin)
1020 | Expression::BitwiseOr(bin)
1021 | Expression::BitwiseXor(bin)
1022 | Expression::Concat(bin) => {
1023 children.push(&bin.left);
1024 children.push(&bin.right);
1025 }
1026 Expression::Like(like) | Expression::ILike(like) => {
1027 children.push(&like.left);
1028 children.push(&like.right);
1029 if let Some(escape) = &like.escape {
1030 children.push(escape);
1031 }
1032 }
1033 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1034 children.push(&un.this);
1035 }
1036 Expression::Function(func) => {
1037 for arg in &func.args {
1038 children.push(arg);
1039 }
1040 }
1041 Expression::AggregateFunction(agg) => {
1042 for arg in &agg.args {
1043 children.push(arg);
1044 }
1045 }
1046 Expression::WindowFunction(wf) => {
1047 children.push(&wf.this);
1048 for e in &wf.over.partition_by {
1049 children.push(e);
1050 }
1051 for e in &wf.over.order_by {
1052 children.push(&e.this);
1053 }
1054 }
1055 Expression::Alias(alias) => {
1056 children.push(&alias.this);
1057 }
1058 Expression::Case(case) => {
1059 if let Some(operand) = &case.operand {
1060 children.push(operand);
1061 }
1062 for (when_expr, then_expr) in &case.whens {
1063 children.push(when_expr);
1064 children.push(then_expr);
1065 }
1066 if let Some(else_clause) = &case.else_ {
1067 children.push(else_clause);
1068 }
1069 }
1070 Expression::Paren(paren) => {
1071 children.push(&paren.this);
1072 }
1073 Expression::Ordered(ord) => {
1074 children.push(&ord.this);
1075 }
1076 Expression::In(in_expr) => {
1077 children.push(&in_expr.this);
1078 for e in &in_expr.expressions {
1079 children.push(e);
1080 }
1081 }
1083 Expression::Between(between) => {
1084 children.push(&between.this);
1085 children.push(&between.low);
1086 children.push(&between.high);
1087 }
1088 Expression::IsNull(is_null) => {
1089 children.push(&is_null.this);
1090 }
1091 Expression::Cast(cast) => {
1092 children.push(&cast.this);
1093 }
1094 Expression::Extract(extract) => {
1095 children.push(&extract.this);
1096 }
1097 Expression::Coalesce(coalesce) => {
1098 for e in &coalesce.expressions {
1099 children.push(e);
1100 }
1101 }
1102 Expression::NullIf(nullif) => {
1103 children.push(&nullif.this);
1104 children.push(&nullif.expression);
1105 }
1106 Expression::Table(_table) => {
1107 }
1110 Expression::TryCatch(try_catch) => {
1111 for stmt in &try_catch.try_body {
1112 children.push(stmt);
1113 }
1114 if let Some(catch_body) = &try_catch.catch_body {
1115 for stmt in catch_body {
1116 children.push(stmt);
1117 }
1118 }
1119 }
1120 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1121 }
1123 Expression::Subquery(_) | Expression::Exists(_) => {}
1125 _ => {
1126 }
1128 }
1129
1130 children
1131 }
1132}
1133
1134impl<'a> Iterator for WalkInScopeIter<'a> {
1135 type Item = &'a Expression;
1136
1137 fn next(&mut self) -> Option<Self::Item> {
1138 let expr = if self.bfs {
1139 self.queue.pop_front()?
1140 } else {
1141 self.queue.pop_back()?
1142 };
1143
1144 let children = self.get_children(expr);
1146
1147 if self.bfs {
1148 for child in children {
1149 if !self.should_stop_at(child, false) {
1150 self.queue.push_back(child);
1151 }
1152 }
1153 } else {
1154 for child in children.into_iter().rev() {
1155 if !self.should_stop_at(child, false) {
1156 self.queue.push_back(child);
1157 }
1158 }
1159 }
1160
1161 Some(expr)
1162 }
1163}
1164
1165pub fn find_in_scope<'a, F>(
1177 expression: &'a Expression,
1178 predicate: F,
1179 bfs: bool,
1180) -> Option<&'a Expression>
1181where
1182 F: Fn(&Expression) -> bool,
1183{
1184 walk_in_scope(expression, bfs).find(|e| predicate(e))
1185}
1186
1187pub fn find_all_in_scope<'a, F>(
1199 expression: &'a Expression,
1200 predicate: F,
1201 bfs: bool,
1202) -> Vec<&'a Expression>
1203where
1204 F: Fn(&Expression) -> bool,
1205{
1206 walk_in_scope(expression, bfs)
1207 .filter(|e| predicate(e))
1208 .collect()
1209}
1210
1211pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1221 match expression {
1222 Expression::Select(_)
1223 | Expression::Union(_)
1224 | Expression::Intersect(_)
1225 | Expression::Except(_)
1226 | Expression::Prepare(_)
1227 | Expression::CreateTable(_) => {
1228 let root = build_scope(expression);
1229 root.traverse().into_iter().cloned().collect()
1230 }
1231 _ => Vec::new(),
1232 }
1233}
1234
1235#[cfg(test)]
1236mod tests {
1237 use super::*;
1238 use crate::parser::Parser;
1239
1240 fn parse_and_build_scope(sql: &str) -> Scope {
1241 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1242 build_scope(&ast[0])
1243 }
1244
1245 #[test]
1246 fn test_simple_select_scope() {
1247 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1248
1249 assert!(scope.is_root());
1250 assert!(!scope.can_be_correlated);
1251 assert!(scope.sources.contains_key("t"));
1252
1253 let columns = scope.columns();
1254 assert_eq!(columns.len(), 2);
1255 }
1256
1257 #[test]
1258 fn test_derived_table_scope() {
1259 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1260
1261 assert!(scope.sources.contains_key("x"));
1262 assert_eq!(scope.derived_table_scopes.len(), 1);
1263
1264 let derived = &mut scope.derived_table_scopes[0];
1265 assert!(derived.is_derived_table());
1266 assert!(derived.sources.contains_key("t"));
1267 }
1268
1269 #[test]
1270 fn test_non_correlated_subquery() {
1271 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1272
1273 assert_eq!(scope.subquery_scopes.len(), 1);
1274
1275 let subquery = &mut scope.subquery_scopes[0];
1276 assert!(subquery.is_subquery());
1277 assert!(subquery.can_be_correlated);
1278
1279 assert!(subquery.sources.contains_key("s"));
1281 assert!(!subquery.is_correlated_subquery());
1282 }
1283
1284 #[test]
1285 fn test_correlated_subquery() {
1286 let mut scope =
1287 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1288
1289 assert_eq!(scope.subquery_scopes.len(), 1);
1290
1291 let subquery = &mut scope.subquery_scopes[0];
1292 assert!(subquery.is_subquery());
1293 assert!(subquery.can_be_correlated);
1294
1295 let external = subquery.external_columns();
1297 assert!(!external.is_empty());
1298 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1299 assert!(subquery.is_correlated_subquery());
1300 }
1301
1302 #[test]
1303 fn test_cte_scope() {
1304 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1305
1306 assert_eq!(scope.cte_scopes.len(), 1);
1307 assert!(scope.cte_sources.contains_key("cte"));
1308
1309 let cte = &scope.cte_scopes[0];
1310 assert!(cte.is_cte());
1311 }
1312
1313 #[test]
1314 fn test_multiple_sources() {
1315 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1316
1317 assert!(scope.sources.contains_key("t"));
1318 assert!(scope.sources.contains_key("s"));
1319 assert_eq!(scope.sources.len(), 2);
1320 }
1321
1322 #[test]
1323 fn test_aliased_table() {
1324 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1325
1326 assert!(scope.sources.contains_key("x"));
1328 assert!(!scope.sources.contains_key("t"));
1329 }
1330
1331 #[test]
1332 fn test_local_columns() {
1333 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1334
1335 let local = scope.local_columns();
1336 assert_eq!(local.len(), 5);
1339 assert!(local.iter().all(|c| c.table.is_some()));
1340 }
1341
1342 #[test]
1343 fn test_columns_include_join_on_clause_references() {
1344 let mut scope = parse_and_build_scope(
1345 "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1346 );
1347
1348 let cols: Vec<String> = scope
1349 .columns()
1350 .iter()
1351 .map(|c| match &c.table {
1352 Some(t) => format!("{}.{}", t, c.name),
1353 None => c.name.clone(),
1354 })
1355 .collect();
1356
1357 assert!(cols.contains(&"o.total".to_string()));
1358 assert!(cols.contains(&"c.id".to_string()));
1359 assert!(cols.contains(&"o.customer_id".to_string()));
1360 }
1361
1362 #[test]
1363 fn test_unqualified_columns() {
1364 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1365
1366 let unqualified = scope.unqualified_columns();
1367 assert_eq!(unqualified.len(), 2);
1369 assert!(unqualified.iter().all(|c| c.table.is_none()));
1370 }
1371
1372 #[test]
1373 fn test_source_columns() {
1374 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1375
1376 let t_cols = scope.source_columns("t");
1377 assert!(t_cols.len() >= 2);
1379 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1380
1381 let s_cols = scope.source_columns("s");
1382 assert!(s_cols.len() >= 1);
1384 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1385 }
1386
1387 #[test]
1388 fn test_rename_source() {
1389 let mut scope = parse_and_build_scope("SELECT a FROM t");
1390
1391 assert!(scope.sources.contains_key("t"));
1392 scope.rename_source("t", "new_name".to_string());
1393 assert!(!scope.sources.contains_key("t"));
1394 assert!(scope.sources.contains_key("new_name"));
1395 }
1396
1397 #[test]
1398 fn test_remove_source() {
1399 let mut scope = parse_and_build_scope("SELECT a FROM t");
1400
1401 assert!(scope.sources.contains_key("t"));
1402 scope.remove_source("t");
1403 assert!(!scope.sources.contains_key("t"));
1404 }
1405
1406 #[test]
1407 fn test_walk_in_scope() {
1408 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1409 let expr = &ast[0];
1410
1411 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1413 assert!(!walked.is_empty());
1414
1415 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1417 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1419 }
1420
1421 #[test]
1422 fn test_find_in_scope() {
1423 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1424 let expr = &ast[0];
1425
1426 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1428 assert!(found.is_some());
1429 assert!(matches!(found.unwrap(), Expression::Column(_)));
1430 }
1431
1432 #[test]
1433 fn test_find_all_in_scope() {
1434 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1435 let expr = &ast[0];
1436
1437 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1439 assert_eq!(found.len(), 3);
1440 }
1441
1442 #[test]
1443 fn test_traverse_scope() {
1444 let ast =
1445 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1446 let expr = &ast[0];
1447
1448 let scopes = traverse_scope(expr);
1449 assert!(!scopes.is_empty());
1452 assert!(scopes.iter().any(|s| s.is_root()));
1454 }
1455
1456 #[test]
1457 fn test_branch_with_options() {
1458 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1459 let scope = build_scope(&ast[0]);
1460
1461 let child = scope.branch_with_options(
1462 ast[0].clone(),
1463 ScopeType::Subquery, None,
1465 None,
1466 Some(vec!["col1".to_string(), "col2".to_string()]),
1467 );
1468
1469 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1470 assert!(child.can_be_correlated); }
1472
1473 #[test]
1474 fn test_is_udtf() {
1475 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1476 let scope = Scope::new(ast[0].clone());
1477 assert!(!scope.is_udtf());
1478
1479 let root = build_scope(&ast[0]);
1480 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1481 assert!(udtf_scope.is_udtf());
1482 }
1483
1484 #[test]
1485 fn test_is_union() {
1486 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1487
1488 assert!(scope.is_root());
1489 assert_eq!(scope.union_scopes.len(), 2);
1490 assert!(scope.union_scopes[0].is_union());
1492 assert!(scope.union_scopes[1].is_union());
1493 }
1494
1495 #[test]
1496 fn test_union_output_columns() {
1497 let scope = parse_and_build_scope(
1498 "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1499 );
1500 assert_eq!(scope.output_columns(), vec!["id", "name"]);
1501 }
1502
1503 #[test]
1504 fn test_clear_cache() {
1505 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1506
1507 let _ = scope.columns();
1509 assert!(scope.columns_cache.is_some());
1510
1511 scope.clear_cache();
1513 assert!(scope.columns_cache.is_none());
1514 assert!(scope.external_columns_cache.is_none());
1515 }
1516
1517 #[test]
1518 fn test_scope_traverse() {
1519 let scope = parse_and_build_scope(
1520 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1521 );
1522
1523 let traversed = scope.traverse();
1524 assert!(traversed.len() >= 3);
1526 }
1527
1528 #[test]
1529 fn test_create_table_as_select_scope() {
1530 let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1532 assert!(
1533 scope.sources.contains_key("src"),
1534 "CTAS scope should contain the FROM table"
1535 );
1536 assert!(
1537 !scope.sources.contains_key("out_table"),
1538 "CTAS target table should not be treated as a source"
1539 );
1540
1541 let scope = parse_and_build_scope(
1543 "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1544 );
1545 assert!(scope.sources.contains_key("a"));
1546 assert!(scope.sources.contains_key("b"));
1547 assert!(
1548 !scope.sources.contains_key("out_table"),
1549 "CTAS target table should not be treated as a source"
1550 );
1551
1552 let scope = parse_and_build_scope(
1554 "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1555 );
1556 assert!(
1557 scope.sources.contains_key("cte"),
1558 "CTAS with CTE should resolve CTE as source"
1559 );
1560 assert!(
1561 !scope.sources.contains_key("out_table"),
1562 "CTAS target table should not be treated as a source"
1563 );
1564 assert_eq!(scope.cte_scopes.len(), 1);
1565 }
1566
1567 #[test]
1568 fn test_create_table_as_select_traverse() {
1569 let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1570 let scopes = traverse_scope(&ast[0]);
1571 assert!(
1572 !scopes.is_empty(),
1573 "traverse_scope should return scopes for CTAS"
1574 );
1575 }
1576}