1use crate::expressions::Expression;
9use crate::traversal::ExpressionWalk;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12#[cfg(feature = "bindings")]
13use ts_rs::TS;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[cfg_attr(feature = "bindings", derive(TS))]
18#[cfg_attr(feature = "bindings", ts(export))]
19pub enum ScopeType {
20 Root,
22 Subquery,
24 DerivedTable,
26 Cte,
28 SetOperation,
30 Udtf,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum SourceKind {
38 Root,
40 Table,
42 DerivedTable,
44 Cte,
46 Virtual,
48 Unknown,
50}
51
52impl Default for SourceKind {
53 fn default() -> Self {
54 Self::Unknown
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct SourceInfo {
61 pub expression: Expression,
63 pub is_scope: bool,
65 pub kind: SourceKind,
67 pub alias: Option<String>,
69 pub lineage_name: Option<String>,
71}
72
73impl SourceInfo {
74 pub fn new(expression: Expression, is_scope: bool, kind: SourceKind) -> Self {
75 Self {
76 expression,
77 is_scope,
78 kind,
79 alias: None,
80 lineage_name: None,
81 }
82 }
83
84 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
85 self.alias = Some(alias.into());
86 self
87 }
88
89 pub fn with_lineage_name(mut self, lineage_name: impl Into<String>) -> Self {
90 self.lineage_name = Some(lineage_name.into());
91 self
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct ColumnRef {
98 pub table: Option<String>,
100 pub name: String,
102}
103
104#[derive(Debug, Clone)]
109pub struct Scope {
110 pub expression: Expression,
112
113 pub scope_type: ScopeType,
115
116 pub sources: HashMap<String, SourceInfo>,
118
119 pub lateral_sources: HashMap<String, SourceInfo>,
121
122 pub cte_sources: HashMap<String, SourceInfo>,
124
125 pub outer_columns: Vec<String>,
128
129 pub can_be_correlated: bool,
132
133 pub subquery_scopes: Vec<Scope>,
135
136 pub derived_table_scopes: Vec<Scope>,
138
139 pub cte_scopes: Vec<Scope>,
141
142 pub udtf_scopes: Vec<Scope>,
144
145 pub table_scopes: Vec<Scope>,
147
148 pub union_scopes: Vec<Scope>,
150
151 columns_cache: Option<Vec<ColumnRef>>,
153
154 external_columns_cache: Option<Vec<ColumnRef>>,
156}
157
158impl Scope {
159 pub fn new(expression: Expression) -> Self {
161 Self {
162 expression,
163 scope_type: ScopeType::Root,
164 sources: HashMap::new(),
165 lateral_sources: HashMap::new(),
166 cte_sources: HashMap::new(),
167 outer_columns: Vec::new(),
168 can_be_correlated: false,
169 subquery_scopes: Vec::new(),
170 derived_table_scopes: Vec::new(),
171 cte_scopes: Vec::new(),
172 udtf_scopes: Vec::new(),
173 table_scopes: Vec::new(),
174 union_scopes: Vec::new(),
175 columns_cache: None,
176 external_columns_cache: None,
177 }
178 }
179
180 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
182 self.branch_with_options(expression, scope_type, None, None, None)
183 }
184
185 pub fn branch_with_options(
187 &self,
188 expression: Expression,
189 scope_type: ScopeType,
190 sources: Option<HashMap<String, SourceInfo>>,
191 lateral_sources: Option<HashMap<String, SourceInfo>>,
192 outer_columns: Option<Vec<String>>,
193 ) -> Self {
194 let can_be_correlated = self.can_be_correlated
195 || scope_type == ScopeType::Subquery
196 || scope_type == ScopeType::Udtf;
197
198 Self {
199 expression,
200 scope_type,
201 sources: sources.unwrap_or_default(),
202 lateral_sources: lateral_sources.unwrap_or_default(),
203 cte_sources: self.cte_sources.clone(),
204 outer_columns: outer_columns.unwrap_or_default(),
205 can_be_correlated,
206 subquery_scopes: Vec::new(),
207 derived_table_scopes: Vec::new(),
208 cte_scopes: Vec::new(),
209 udtf_scopes: Vec::new(),
210 table_scopes: Vec::new(),
211 union_scopes: Vec::new(),
212 columns_cache: None,
213 external_columns_cache: None,
214 }
215 }
216
217 pub fn clear_cache(&mut self) {
219 self.columns_cache = None;
220 self.external_columns_cache = None;
221 }
222
223 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
225 let kind = if is_scope {
226 SourceKind::DerivedTable
227 } else {
228 SourceKind::Table
229 };
230 self.add_source_info(name, SourceInfo::new(expression, is_scope, kind));
231 }
232
233 pub fn add_source_info(&mut self, name: String, info: SourceInfo) {
235 self.sources.insert(name, info);
236 self.clear_cache();
237 }
238
239 pub fn add_virtual_source(&mut self, alias: String, expression: Expression) {
241 let lineage_name = self.next_virtual_source_name();
242 let info = SourceInfo::new(expression, false, SourceKind::Virtual)
243 .with_alias(alias.clone())
244 .with_lineage_name(lineage_name);
245 self.add_source_info(alias, info);
246 }
247
248 fn next_virtual_source_name(&self) -> String {
249 let count = self
250 .sources
251 .values()
252 .filter(|source| source.kind == SourceKind::Virtual)
253 .count();
254 format!("_{}", count)
255 }
256
257 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
259 let kind = if is_scope {
260 SourceKind::DerivedTable
261 } else {
262 SourceKind::Table
263 };
264 let info = SourceInfo::new(expression.clone(), is_scope, kind);
265 self.sources.insert(name.clone(), info.clone());
266 self.lateral_sources.insert(name, info);
267 self.clear_cache();
268 }
269
270 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
272 let info = SourceInfo::new(expression, true, SourceKind::Cte);
273 self.cte_sources.insert(name.clone(), info.clone());
274 self.sources.insert(name, info);
275 self.clear_cache();
276 }
277
278 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
280 if let Some(source) = self.sources.remove(old_name) {
281 self.sources.insert(new_name, source);
282 }
283 self.clear_cache();
284 }
285
286 pub fn remove_source(&mut self, name: &str) {
288 self.sources.remove(name);
289 self.clear_cache();
290 }
291
292 pub fn columns(&mut self) -> &[ColumnRef] {
294 if self.columns_cache.is_none() {
295 let mut columns = Vec::new();
296 collect_columns(&self.expression, &mut columns);
297 self.columns_cache = Some(columns);
298 }
299 self.columns_cache.as_ref().unwrap()
300 }
301
302 pub fn output_columns(&self) -> Vec<String> {
307 crate::ast_transforms::get_output_column_names(&self.expression)
308 }
309
310 pub fn source_names(&self) -> HashSet<String> {
312 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
313 names.extend(self.cte_sources.keys().cloned());
314 names
315 }
316
317 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
319 if self.external_columns_cache.is_some() {
320 return self.external_columns_cache.clone().unwrap();
321 }
322
323 let source_names = self.source_names();
324 let columns = self.columns().to_vec();
325
326 let external: Vec<ColumnRef> = columns
327 .into_iter()
328 .filter(|col| {
329 match &col.table {
331 Some(table) => !source_names.contains(table),
332 None => false, }
334 })
335 .collect();
336
337 self.external_columns_cache = Some(external.clone());
338 external
339 }
340
341 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
343 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
344 let columns = self.columns().to_vec();
345
346 columns
347 .into_iter()
348 .filter(|col| !external_set.contains(col))
349 .collect()
350 }
351
352 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
354 self.columns()
355 .iter()
356 .filter(|c| c.table.is_none())
357 .cloned()
358 .collect()
359 }
360
361 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
363 self.columns()
364 .iter()
365 .filter(|col| col.table.as_deref() == Some(source_name))
366 .cloned()
367 .collect()
368 }
369
370 pub fn is_correlated_subquery(&mut self) -> bool {
376 self.can_be_correlated && !self.external_columns().is_empty()
377 }
378
379 pub fn is_subquery(&self) -> bool {
381 self.scope_type == ScopeType::Subquery
382 }
383
384 pub fn is_derived_table(&self) -> bool {
386 self.scope_type == ScopeType::DerivedTable
387 }
388
389 pub fn is_cte(&self) -> bool {
391 self.scope_type == ScopeType::Cte
392 }
393
394 pub fn is_root(&self) -> bool {
396 self.scope_type == ScopeType::Root
397 }
398
399 pub fn is_udtf(&self) -> bool {
401 self.scope_type == ScopeType::Udtf
402 }
403
404 pub fn is_union(&self) -> bool {
406 self.scope_type == ScopeType::SetOperation
407 }
408
409 pub fn traverse(&self) -> Vec<&Scope> {
411 let mut result = Vec::new();
412 self.traverse_impl(&mut result);
413 result
414 }
415
416 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
417 for scope in &self.cte_scopes {
419 scope.traverse_impl(result);
420 }
421 for scope in &self.union_scopes {
422 scope.traverse_impl(result);
423 }
424 for scope in &self.table_scopes {
425 scope.traverse_impl(result);
426 }
427 for scope in &self.subquery_scopes {
428 scope.traverse_impl(result);
429 }
430 result.push(self);
432 }
433
434 pub fn ref_count(&self) -> HashMap<usize, usize> {
436 let mut counts: HashMap<usize, usize> = HashMap::new();
437
438 for scope in self.traverse() {
439 for (_, source_info) in scope.sources.iter() {
440 if source_info.is_scope {
441 let id = &source_info.expression as *const _ as usize;
442 *counts.entry(id).or_insert(0) += 1;
443 }
444 }
445 }
446
447 counts
448 }
449}
450
451fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
453 match expr {
454 Expression::Column(col) => {
455 columns.push(ColumnRef {
456 table: col.table.as_ref().map(|t| t.name.clone()),
457 name: col.name.name.clone(),
458 });
459 }
460 Expression::Select(select) => {
461 for e in &select.expressions {
463 collect_columns(e, columns);
464 }
465 for join in &select.joins {
467 if let Some(on) = &join.on {
468 collect_columns(on, columns);
469 }
470 if let Some(match_condition) = &join.match_condition {
471 collect_columns(match_condition, columns);
472 }
473 }
474 if let Some(where_clause) = &select.where_clause {
476 collect_columns(&where_clause.this, columns);
477 }
478 if let Some(having) = &select.having {
480 collect_columns(&having.this, columns);
481 }
482 if let Some(order_by) = &select.order_by {
484 for ord in &order_by.expressions {
485 collect_columns(&ord.this, columns);
486 }
487 }
488 if let Some(group_by) = &select.group_by {
490 for e in &group_by.expressions {
491 collect_columns(e, columns);
492 }
493 }
494 }
497 Expression::And(bin)
499 | Expression::Or(bin)
500 | Expression::Add(bin)
501 | Expression::Sub(bin)
502 | Expression::Mul(bin)
503 | Expression::Div(bin)
504 | Expression::Mod(bin)
505 | Expression::Eq(bin)
506 | Expression::Neq(bin)
507 | Expression::Lt(bin)
508 | Expression::Lte(bin)
509 | Expression::Gt(bin)
510 | Expression::Gte(bin)
511 | Expression::BitwiseAnd(bin)
512 | Expression::BitwiseOr(bin)
513 | Expression::BitwiseXor(bin)
514 | Expression::Concat(bin) => {
515 collect_columns(&bin.left, columns);
516 collect_columns(&bin.right, columns);
517 }
518 Expression::Like(like) | Expression::ILike(like) => {
520 collect_columns(&like.left, columns);
521 collect_columns(&like.right, columns);
522 if let Some(escape) = &like.escape {
523 collect_columns(escape, columns);
524 }
525 }
526 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
528 collect_columns(&un.this, columns);
529 }
530 Expression::Function(func) => {
531 for arg in &func.args {
532 collect_columns(arg, columns);
533 }
534 }
535 Expression::AggregateFunction(agg) => {
536 for arg in &agg.args {
537 collect_columns(arg, columns);
538 }
539 }
540 Expression::WindowFunction(wf) => {
541 collect_columns(&wf.this, columns);
542 for e in &wf.over.partition_by {
543 collect_columns(e, columns);
544 }
545 for e in &wf.over.order_by {
546 collect_columns(&e.this, columns);
547 }
548 }
549 Expression::Alias(alias) => {
550 collect_columns(&alias.this, columns);
551 }
552 Expression::Case(case) => {
553 if let Some(operand) = &case.operand {
554 collect_columns(operand, columns);
555 }
556 for (when_expr, then_expr) in &case.whens {
557 collect_columns(when_expr, columns);
558 collect_columns(then_expr, columns);
559 }
560 if let Some(else_clause) = &case.else_ {
561 collect_columns(else_clause, columns);
562 }
563 }
564 Expression::Paren(paren) => {
565 collect_columns(&paren.this, columns);
566 }
567 Expression::Ordered(ord) => {
568 collect_columns(&ord.this, columns);
569 }
570 Expression::In(in_expr) => {
571 collect_columns(&in_expr.this, columns);
572 for e in &in_expr.expressions {
573 collect_columns(e, columns);
574 }
575 }
577 Expression::Between(between) => {
578 collect_columns(&between.this, columns);
579 collect_columns(&between.low, columns);
580 collect_columns(&between.high, columns);
581 }
582 Expression::IsNull(is_null) => {
583 collect_columns(&is_null.this, columns);
584 }
585 Expression::Cast(cast) => {
586 collect_columns(&cast.this, columns);
587 }
588 Expression::Extract(extract) => {
589 collect_columns(&extract.this, columns);
590 }
591 Expression::Exists(_) | Expression::Subquery(_) => {
592 }
594 Expression::Prepare(prepare) => {
595 collect_columns(&prepare.statement, columns);
596 }
597 _ => {
598 }
600 }
601}
602
603pub fn build_scope(expression: &Expression) -> Scope {
608 let mut root = Scope::new(expression.clone());
609 build_scope_impl(expression, &mut root);
610 root
611}
612
613fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
614 match expression {
615 Expression::Prepare(prepare) => {
616 build_scope_impl(&prepare.statement, current_scope);
617 }
618 Expression::Select(select) => {
619 if let Some(with) = &select.with {
621 process_ctes(with, current_scope);
622 }
623
624 if let Some(from) = &select.from {
626 for table in &from.expressions {
627 add_table_to_scope(table, current_scope);
628 }
629 }
630
631 for join in &select.joins {
633 add_table_to_scope(&join.this, current_scope);
634 }
635
636 for lateral_view in &select.lateral_views {
638 add_lateral_view_to_scope(lateral_view, current_scope);
639 }
640
641 collect_subqueries(expression, current_scope);
643 }
644 Expression::Union(union) => {
645 if let Some(with) = &union.with {
646 process_ctes(with, current_scope);
647 }
648
649 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
650 build_scope_impl(&union.left, &mut left_scope);
651
652 let mut right_scope =
653 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
654 build_scope_impl(&union.right, &mut right_scope);
655
656 current_scope.union_scopes.push(left_scope);
657 current_scope.union_scopes.push(right_scope);
658 }
659 Expression::Intersect(intersect) => {
660 if let Some(with) = &intersect.with {
661 process_ctes(with, current_scope);
662 }
663
664 let mut left_scope =
665 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
666 build_scope_impl(&intersect.left, &mut left_scope);
667
668 let mut right_scope =
669 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
670 build_scope_impl(&intersect.right, &mut right_scope);
671
672 current_scope.union_scopes.push(left_scope);
673 current_scope.union_scopes.push(right_scope);
674 }
675 Expression::Except(except) => {
676 if let Some(with) = &except.with {
677 process_ctes(with, current_scope);
678 }
679
680 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
681 build_scope_impl(&except.left, &mut left_scope);
682
683 let mut right_scope =
684 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
685 build_scope_impl(&except.right, &mut right_scope);
686
687 current_scope.union_scopes.push(left_scope);
688 current_scope.union_scopes.push(right_scope);
689 }
690 Expression::CreateTable(create) => {
691 if let Some(with) = &create.with_cte {
694 process_ctes(with, current_scope);
695 }
696 if let Some(as_select) = &create.as_select {
698 build_scope_impl(as_select, current_scope);
699 }
700 }
701 _ => {}
702 }
703}
704
705fn process_ctes(with: &crate::expressions::With, current_scope: &mut Scope) {
706 for cte in &with.ctes {
707 let cte_name = cte.alias.name.clone();
708 let cte_expr = Expression::Cte(Box::new(cte.clone()));
709 let mut cte_scope = current_scope.branch(cte_expr.clone(), ScopeType::Cte);
710
711 if with.recursive && cte_body_self_references(cte) {
712 cte_scope.add_cte_source(cte_name.clone(), cte_expr.clone());
713 }
714
715 build_scope_impl(&cte.this, &mut cte_scope);
716 current_scope.add_cte_source(cte_name, cte_expr);
717 current_scope.cte_scopes.push(cte_scope);
718 }
719}
720
721fn cte_body_self_references(cte: &crate::expressions::Cte) -> bool {
722 let cte_name = cte.alias.name.as_str();
723 !cte.this
724 .find_all(|expr| match expr {
725 Expression::Table(table) if table.schema.is_none() && table.catalog.is_none() => {
726 table.name.name.eq_ignore_ascii_case(cte_name)
727 }
728 _ => false,
729 })
730 .is_empty()
731}
732
733fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
734 match expr {
735 Expression::Table(table) => {
736 let name = table
737 .alias
738 .as_ref()
739 .map(|a| a.name.clone())
740 .unwrap_or_else(|| table.name.name.clone());
741 let cte_source = if table.schema.is_none() && table.catalog.is_none() {
742 scope.cte_sources.get(&table.name.name).or_else(|| {
743 scope
744 .cte_sources
745 .iter()
746 .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
747 .map(|(_, source)| source)
748 })
749 } else {
750 None
751 };
752
753 if let Some(source) = cte_source {
754 scope.add_source_info(name, source.clone());
755 } else {
756 let mut source = SourceInfo::new(expr.clone(), false, SourceKind::Table);
757 if let Some(alias) = &table.alias {
758 source = source.with_alias(alias.name.clone());
759 }
760 scope.add_source_info(name, source);
761 }
762 }
763 Expression::Subquery(subquery) => {
764 let name = subquery
765 .alias
766 .as_ref()
767 .map(|a| a.name.clone())
768 .unwrap_or_default();
769
770 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
771 build_scope_impl(&subquery.this, &mut derived_scope);
772
773 scope.add_source(name.clone(), expr.clone(), true);
774 scope.derived_table_scopes.push(derived_scope);
775 }
776 Expression::Unnest(unnest) => {
777 if let Some(alias) = &unnest.alias {
778 scope.add_virtual_source(alias.name.clone(), expr.clone());
779 }
780 }
781 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
782 scope.add_virtual_source(alias.alias.name.clone(), expr.clone());
783 }
784 Expression::Lateral(lateral) => {
785 if let Some(alias) = &lateral.alias {
786 scope.add_virtual_source(alias.clone(), expr.clone());
787 }
788 }
789 Expression::LateralView(lateral_view) => {
790 add_lateral_view_to_scope(lateral_view, scope);
791 }
792 Expression::Pivot(pivot) => {
793 let name =
794 pivot_source_name(&pivot.this, pivot.alias.as_ref().map(|a| a.name.as_str()));
795 scope.add_source_info(
796 name,
797 SourceInfo::new(expr.clone(), false, SourceKind::DerivedTable),
798 );
799 add_pivot_inner_scope(&pivot.this, scope);
800 }
801 Expression::Unpivot(unpivot) => {
802 let name = pivot_source_name(
803 &unpivot.this,
804 unpivot.alias.as_ref().map(|a| a.name.as_str()),
805 );
806 scope.add_source_info(
807 name,
808 SourceInfo::new(expr.clone(), false, SourceKind::DerivedTable),
809 );
810 add_pivot_inner_scope(&unpivot.this, scope);
811 }
812 Expression::Paren(paren) => {
813 add_table_to_scope(&paren.this, scope);
814 }
815 _ => {}
816 }
817}
818
819fn pivot_source_name(source: &Expression, explicit_alias: Option<&str>) -> String {
820 if let Some(alias) = explicit_alias {
821 return alias.to_string();
822 }
823
824 match source {
825 Expression::Table(table) => table
826 .alias
827 .as_ref()
828 .map(|alias| alias.name.clone())
829 .unwrap_or_else(|| table.name.name.clone()),
830 Expression::Subquery(subquery) => subquery
831 .alias
832 .as_ref()
833 .map(|alias| alias.name.clone())
834 .unwrap_or_else(|| "_0".to_string()),
835 Expression::Paren(paren) => pivot_source_name(&paren.this, explicit_alias),
836 _ => "_0".to_string(),
837 }
838}
839
840fn add_pivot_inner_scope(source: &Expression, scope: &mut Scope) {
841 match source {
842 Expression::Subquery(subquery) => {
843 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
844 build_scope_impl(&subquery.this, &mut derived_scope);
845 scope.derived_table_scopes.push(derived_scope);
846 }
847 Expression::Paren(paren) => add_pivot_inner_scope(&paren.this, scope),
848 _ => {}
849 }
850}
851
852fn add_lateral_view_to_scope(lateral_view: &crate::expressions::LateralView, scope: &mut Scope) {
853 let alias = lateral_view
854 .table_alias
855 .as_ref()
856 .or_else(|| lateral_view.column_aliases.first())
857 .map(|alias| alias.name.clone());
858
859 if let Some(alias) = alias {
860 scope.add_virtual_source(
861 alias,
862 Expression::LateralView(Box::new(lateral_view.clone())),
863 );
864 }
865}
866
867fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
868 match expr {
869 Expression::Select(select) => {
870 if let Some(where_clause) = &select.where_clause {
872 collect_subqueries_in_expr(&where_clause.this, parent_scope);
873 }
874 for e in &select.expressions {
876 collect_subqueries_in_expr(e, parent_scope);
877 }
878 if let Some(having) = &select.having {
880 collect_subqueries_in_expr(&having.this, parent_scope);
881 }
882 }
883 _ => {}
884 }
885}
886
887fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
888 match expr {
889 Expression::Subquery(subquery) if subquery.alias.is_none() => {
890 let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
892 build_scope_impl(&subquery.this, &mut sub_scope);
893 parent_scope.subquery_scopes.push(sub_scope);
894 }
895 Expression::In(in_expr) => {
896 collect_subqueries_in_expr(&in_expr.this, parent_scope);
897 if let Some(query) = &in_expr.query {
898 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
899 build_scope_impl(query, &mut sub_scope);
900 parent_scope.subquery_scopes.push(sub_scope);
901 }
902 }
903 Expression::Exists(exists) => {
904 let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
905 build_scope_impl(&exists.this, &mut sub_scope);
906 parent_scope.subquery_scopes.push(sub_scope);
907 }
908 Expression::And(bin)
910 | Expression::Or(bin)
911 | Expression::Add(bin)
912 | Expression::Sub(bin)
913 | Expression::Mul(bin)
914 | Expression::Div(bin)
915 | Expression::Mod(bin)
916 | Expression::Eq(bin)
917 | Expression::Neq(bin)
918 | Expression::Lt(bin)
919 | Expression::Lte(bin)
920 | Expression::Gt(bin)
921 | Expression::Gte(bin)
922 | Expression::BitwiseAnd(bin)
923 | Expression::BitwiseOr(bin)
924 | Expression::BitwiseXor(bin)
925 | Expression::Concat(bin) => {
926 collect_subqueries_in_expr(&bin.left, parent_scope);
927 collect_subqueries_in_expr(&bin.right, parent_scope);
928 }
929 Expression::Like(like) | Expression::ILike(like) => {
931 collect_subqueries_in_expr(&like.left, parent_scope);
932 collect_subqueries_in_expr(&like.right, parent_scope);
933 if let Some(escape) = &like.escape {
934 collect_subqueries_in_expr(escape, parent_scope);
935 }
936 }
937 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
939 collect_subqueries_in_expr(&un.this, parent_scope);
940 }
941 Expression::Function(func) => {
942 for arg in &func.args {
943 collect_subqueries_in_expr(arg, parent_scope);
944 }
945 }
946 Expression::Case(case) => {
947 if let Some(operand) = &case.operand {
948 collect_subqueries_in_expr(operand, parent_scope);
949 }
950 for (when_expr, then_expr) in &case.whens {
951 collect_subqueries_in_expr(when_expr, parent_scope);
952 collect_subqueries_in_expr(then_expr, parent_scope);
953 }
954 if let Some(else_clause) = &case.else_ {
955 collect_subqueries_in_expr(else_clause, parent_scope);
956 }
957 }
958 Expression::Paren(paren) => {
959 collect_subqueries_in_expr(&paren.this, parent_scope);
960 }
961 Expression::Alias(alias) => {
962 collect_subqueries_in_expr(&alias.this, parent_scope);
963 }
964 _ => {}
965 }
966}
967
968pub fn walk_in_scope<'a>(
980 expression: &'a Expression,
981 bfs: bool,
982) -> impl Iterator<Item = &'a Expression> {
983 WalkInScopeIter::new(expression, bfs)
984}
985
986struct WalkInScopeIter<'a> {
988 queue: VecDeque<&'a Expression>,
989 bfs: bool,
990}
991
992impl<'a> WalkInScopeIter<'a> {
993 fn new(expression: &'a Expression, bfs: bool) -> Self {
994 let mut queue = VecDeque::new();
995 queue.push_back(expression);
996 Self { queue, bfs }
997 }
998
999 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
1000 if is_root {
1001 return false;
1002 }
1003
1004 if matches!(expr, Expression::Cte(_)) {
1006 return true;
1007 }
1008
1009 if let Expression::Subquery(subquery) = expr {
1011 if subquery.alias.is_some() {
1012 return true;
1013 }
1014 }
1015
1016 if matches!(
1018 expr,
1019 Expression::Select(_)
1020 | Expression::Union(_)
1021 | Expression::Intersect(_)
1022 | Expression::Except(_)
1023 ) {
1024 return true;
1025 }
1026
1027 false
1028 }
1029
1030 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
1031 let mut children = Vec::new();
1032
1033 match expr {
1034 Expression::Prepare(prepare) => {
1035 children.push(&prepare.statement);
1036 }
1037 Expression::Select(select) => {
1038 for e in &select.expressions {
1040 children.push(e);
1041 }
1042 if let Some(from) = &select.from {
1044 for table in &from.expressions {
1045 if !self.should_stop_at(table, false) {
1046 children.push(table);
1047 }
1048 }
1049 }
1050 for join in &select.joins {
1052 if let Some(on) = &join.on {
1053 children.push(on);
1054 }
1055 }
1057 if let Some(where_clause) = &select.where_clause {
1059 children.push(&where_clause.this);
1060 }
1061 if let Some(group_by) = &select.group_by {
1063 for e in &group_by.expressions {
1064 children.push(e);
1065 }
1066 }
1067 if let Some(having) = &select.having {
1069 children.push(&having.this);
1070 }
1071 if let Some(order_by) = &select.order_by {
1073 for ord in &order_by.expressions {
1074 children.push(&ord.this);
1075 }
1076 }
1077 if let Some(limit) = &select.limit {
1079 children.push(&limit.this);
1080 }
1081 if let Some(offset) = &select.offset {
1083 children.push(&offset.this);
1084 }
1085 }
1086 Expression::And(bin)
1087 | Expression::Or(bin)
1088 | Expression::Add(bin)
1089 | Expression::Sub(bin)
1090 | Expression::Mul(bin)
1091 | Expression::Div(bin)
1092 | Expression::Mod(bin)
1093 | Expression::Eq(bin)
1094 | Expression::Neq(bin)
1095 | Expression::Lt(bin)
1096 | Expression::Lte(bin)
1097 | Expression::Gt(bin)
1098 | Expression::Gte(bin)
1099 | Expression::BitwiseAnd(bin)
1100 | Expression::BitwiseOr(bin)
1101 | Expression::BitwiseXor(bin)
1102 | Expression::Concat(bin) => {
1103 children.push(&bin.left);
1104 children.push(&bin.right);
1105 }
1106 Expression::Like(like) | Expression::ILike(like) => {
1107 children.push(&like.left);
1108 children.push(&like.right);
1109 if let Some(escape) = &like.escape {
1110 children.push(escape);
1111 }
1112 }
1113 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1114 children.push(&un.this);
1115 }
1116 Expression::Function(func) => {
1117 for arg in &func.args {
1118 children.push(arg);
1119 }
1120 }
1121 Expression::AggregateFunction(agg) => {
1122 for arg in &agg.args {
1123 children.push(arg);
1124 }
1125 }
1126 Expression::WindowFunction(wf) => {
1127 children.push(&wf.this);
1128 for e in &wf.over.partition_by {
1129 children.push(e);
1130 }
1131 for e in &wf.over.order_by {
1132 children.push(&e.this);
1133 }
1134 }
1135 Expression::Alias(alias) => {
1136 children.push(&alias.this);
1137 }
1138 Expression::Case(case) => {
1139 if let Some(operand) = &case.operand {
1140 children.push(operand);
1141 }
1142 for (when_expr, then_expr) in &case.whens {
1143 children.push(when_expr);
1144 children.push(then_expr);
1145 }
1146 if let Some(else_clause) = &case.else_ {
1147 children.push(else_clause);
1148 }
1149 }
1150 Expression::Paren(paren) => {
1151 children.push(&paren.this);
1152 }
1153 Expression::Ordered(ord) => {
1154 children.push(&ord.this);
1155 }
1156 Expression::In(in_expr) => {
1157 children.push(&in_expr.this);
1158 for e in &in_expr.expressions {
1159 children.push(e);
1160 }
1161 }
1163 Expression::Between(between) => {
1164 children.push(&between.this);
1165 children.push(&between.low);
1166 children.push(&between.high);
1167 }
1168 Expression::IsNull(is_null) => {
1169 children.push(&is_null.this);
1170 }
1171 Expression::Cast(cast) => {
1172 children.push(&cast.this);
1173 }
1174 Expression::Extract(extract) => {
1175 children.push(&extract.this);
1176 }
1177 Expression::Coalesce(coalesce) => {
1178 for e in &coalesce.expressions {
1179 children.push(e);
1180 }
1181 }
1182 Expression::NullIf(nullif) => {
1183 children.push(&nullif.this);
1184 children.push(&nullif.expression);
1185 }
1186 Expression::Table(_table) => {
1187 }
1190 Expression::TryCatch(try_catch) => {
1191 for stmt in &try_catch.try_body {
1192 children.push(stmt);
1193 }
1194 if let Some(catch_body) = &try_catch.catch_body {
1195 for stmt in catch_body {
1196 children.push(stmt);
1197 }
1198 }
1199 }
1200 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1201 }
1203 Expression::Subquery(_) | Expression::Exists(_) => {}
1205 _ => {
1206 }
1208 }
1209
1210 children
1211 }
1212}
1213
1214impl<'a> Iterator for WalkInScopeIter<'a> {
1215 type Item = &'a Expression;
1216
1217 fn next(&mut self) -> Option<Self::Item> {
1218 let expr = if self.bfs {
1219 self.queue.pop_front()?
1220 } else {
1221 self.queue.pop_back()?
1222 };
1223
1224 let children = self.get_children(expr);
1226
1227 if self.bfs {
1228 for child in children {
1229 if !self.should_stop_at(child, false) {
1230 self.queue.push_back(child);
1231 }
1232 }
1233 } else {
1234 for child in children.into_iter().rev() {
1235 if !self.should_stop_at(child, false) {
1236 self.queue.push_back(child);
1237 }
1238 }
1239 }
1240
1241 Some(expr)
1242 }
1243}
1244
1245pub fn find_in_scope<'a, F>(
1257 expression: &'a Expression,
1258 predicate: F,
1259 bfs: bool,
1260) -> Option<&'a Expression>
1261where
1262 F: Fn(&Expression) -> bool,
1263{
1264 walk_in_scope(expression, bfs).find(|e| predicate(e))
1265}
1266
1267pub fn find_all_in_scope<'a, F>(
1279 expression: &'a Expression,
1280 predicate: F,
1281 bfs: bool,
1282) -> Vec<&'a Expression>
1283where
1284 F: Fn(&Expression) -> bool,
1285{
1286 walk_in_scope(expression, bfs)
1287 .filter(|e| predicate(e))
1288 .collect()
1289}
1290
1291pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1301 match expression {
1302 Expression::Select(_)
1303 | Expression::Union(_)
1304 | Expression::Intersect(_)
1305 | Expression::Except(_)
1306 | Expression::Prepare(_)
1307 | Expression::CreateTable(_) => {
1308 let root = build_scope(expression);
1309 root.traverse().into_iter().cloned().collect()
1310 }
1311 _ => Vec::new(),
1312 }
1313}
1314
1315#[cfg(test)]
1316mod tests {
1317 use super::*;
1318 use crate::parser::Parser;
1319
1320 fn parse_and_build_scope(sql: &str) -> Scope {
1321 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1322 build_scope(&ast[0])
1323 }
1324
1325 #[test]
1326 fn test_simple_select_scope() {
1327 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1328
1329 assert!(scope.is_root());
1330 assert!(!scope.can_be_correlated);
1331 assert!(scope.sources.contains_key("t"));
1332
1333 let columns = scope.columns();
1334 assert_eq!(columns.len(), 2);
1335 }
1336
1337 #[test]
1338 fn test_derived_table_scope() {
1339 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1340
1341 assert!(scope.sources.contains_key("x"));
1342 assert_eq!(scope.derived_table_scopes.len(), 1);
1343
1344 let derived = &mut scope.derived_table_scopes[0];
1345 assert!(derived.is_derived_table());
1346 assert!(derived.sources.contains_key("t"));
1347 }
1348
1349 #[test]
1350 fn test_non_correlated_subquery() {
1351 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1352
1353 assert_eq!(scope.subquery_scopes.len(), 1);
1354
1355 let subquery = &mut scope.subquery_scopes[0];
1356 assert!(subquery.is_subquery());
1357 assert!(subquery.can_be_correlated);
1358
1359 assert!(subquery.sources.contains_key("s"));
1361 assert!(!subquery.is_correlated_subquery());
1362 }
1363
1364 #[test]
1365 fn test_correlated_subquery() {
1366 let mut scope =
1367 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1368
1369 assert_eq!(scope.subquery_scopes.len(), 1);
1370
1371 let subquery = &mut scope.subquery_scopes[0];
1372 assert!(subquery.is_subquery());
1373 assert!(subquery.can_be_correlated);
1374
1375 let external = subquery.external_columns();
1377 assert!(!external.is_empty());
1378 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1379 assert!(subquery.is_correlated_subquery());
1380 }
1381
1382 #[test]
1383 fn test_cte_scope() {
1384 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1385
1386 assert_eq!(scope.cte_scopes.len(), 1);
1387 assert!(scope.cte_sources.contains_key("cte"));
1388
1389 let cte = &scope.cte_scopes[0];
1390 assert!(cte.is_cte());
1391 }
1392
1393 #[test]
1394 fn test_multiple_sources() {
1395 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1396
1397 assert!(scope.sources.contains_key("t"));
1398 assert!(scope.sources.contains_key("s"));
1399 assert_eq!(scope.sources.len(), 2);
1400 }
1401
1402 #[test]
1403 fn test_aliased_table() {
1404 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1405
1406 assert!(scope.sources.contains_key("x"));
1408 assert!(!scope.sources.contains_key("t"));
1409 }
1410
1411 #[test]
1412 fn test_local_columns() {
1413 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1414
1415 let local = scope.local_columns();
1416 assert_eq!(local.len(), 5);
1419 assert!(local.iter().all(|c| c.table.is_some()));
1420 }
1421
1422 #[test]
1423 fn test_columns_include_join_on_clause_references() {
1424 let mut scope = parse_and_build_scope(
1425 "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1426 );
1427
1428 let cols: Vec<String> = scope
1429 .columns()
1430 .iter()
1431 .map(|c| match &c.table {
1432 Some(t) => format!("{}.{}", t, c.name),
1433 None => c.name.clone(),
1434 })
1435 .collect();
1436
1437 assert!(cols.contains(&"o.total".to_string()));
1438 assert!(cols.contains(&"c.id".to_string()));
1439 assert!(cols.contains(&"o.customer_id".to_string()));
1440 }
1441
1442 #[test]
1443 fn test_unqualified_columns() {
1444 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1445
1446 let unqualified = scope.unqualified_columns();
1447 assert_eq!(unqualified.len(), 2);
1449 assert!(unqualified.iter().all(|c| c.table.is_none()));
1450 }
1451
1452 #[test]
1453 fn test_source_columns() {
1454 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1455
1456 let t_cols = scope.source_columns("t");
1457 assert!(t_cols.len() >= 2);
1459 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1460
1461 let s_cols = scope.source_columns("s");
1462 assert!(s_cols.len() >= 1);
1464 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1465 }
1466
1467 #[test]
1468 fn test_rename_source() {
1469 let mut scope = parse_and_build_scope("SELECT a FROM t");
1470
1471 assert!(scope.sources.contains_key("t"));
1472 scope.rename_source("t", "new_name".to_string());
1473 assert!(!scope.sources.contains_key("t"));
1474 assert!(scope.sources.contains_key("new_name"));
1475 }
1476
1477 #[test]
1478 fn test_remove_source() {
1479 let mut scope = parse_and_build_scope("SELECT a FROM t");
1480
1481 assert!(scope.sources.contains_key("t"));
1482 scope.remove_source("t");
1483 assert!(!scope.sources.contains_key("t"));
1484 }
1485
1486 #[test]
1487 fn test_walk_in_scope() {
1488 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1489 let expr = &ast[0];
1490
1491 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1493 assert!(!walked.is_empty());
1494
1495 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1497 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1499 }
1500
1501 #[test]
1502 fn test_find_in_scope() {
1503 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1504 let expr = &ast[0];
1505
1506 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1508 assert!(found.is_some());
1509 assert!(matches!(found.unwrap(), Expression::Column(_)));
1510 }
1511
1512 #[test]
1513 fn test_find_all_in_scope() {
1514 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1515 let expr = &ast[0];
1516
1517 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1519 assert_eq!(found.len(), 3);
1520 }
1521
1522 #[test]
1523 fn test_traverse_scope() {
1524 let ast =
1525 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1526 let expr = &ast[0];
1527
1528 let scopes = traverse_scope(expr);
1529 assert!(!scopes.is_empty());
1532 assert!(scopes.iter().any(|s| s.is_root()));
1534 }
1535
1536 #[test]
1537 fn test_branch_with_options() {
1538 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1539 let scope = build_scope(&ast[0]);
1540
1541 let child = scope.branch_with_options(
1542 ast[0].clone(),
1543 ScopeType::Subquery, None,
1545 None,
1546 Some(vec!["col1".to_string(), "col2".to_string()]),
1547 );
1548
1549 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1550 assert!(child.can_be_correlated); }
1552
1553 #[test]
1554 fn test_is_udtf() {
1555 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1556 let scope = Scope::new(ast[0].clone());
1557 assert!(!scope.is_udtf());
1558
1559 let root = build_scope(&ast[0]);
1560 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1561 assert!(udtf_scope.is_udtf());
1562 }
1563
1564 #[test]
1565 fn test_is_union() {
1566 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1567
1568 assert!(scope.is_root());
1569 assert_eq!(scope.union_scopes.len(), 2);
1570 assert!(scope.union_scopes[0].is_union());
1572 assert!(scope.union_scopes[1].is_union());
1573 }
1574
1575 #[test]
1576 fn test_union_output_columns() {
1577 let scope = parse_and_build_scope(
1578 "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1579 );
1580 assert_eq!(scope.output_columns(), vec!["id", "name"]);
1581 }
1582
1583 #[test]
1584 fn test_clear_cache() {
1585 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1586
1587 let _ = scope.columns();
1589 assert!(scope.columns_cache.is_some());
1590
1591 scope.clear_cache();
1593 assert!(scope.columns_cache.is_none());
1594 assert!(scope.external_columns_cache.is_none());
1595 }
1596
1597 #[test]
1598 fn test_scope_traverse() {
1599 let scope = parse_and_build_scope(
1600 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1601 );
1602
1603 let traversed = scope.traverse();
1604 assert!(traversed.len() >= 3);
1606 }
1607
1608 #[test]
1609 fn test_create_table_as_select_scope() {
1610 let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1612 assert!(
1613 scope.sources.contains_key("src"),
1614 "CTAS scope should contain the FROM table"
1615 );
1616 assert!(
1617 !scope.sources.contains_key("out_table"),
1618 "CTAS target table should not be treated as a source"
1619 );
1620
1621 let scope = parse_and_build_scope(
1623 "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1624 );
1625 assert!(scope.sources.contains_key("a"));
1626 assert!(scope.sources.contains_key("b"));
1627 assert!(
1628 !scope.sources.contains_key("out_table"),
1629 "CTAS target table should not be treated as a source"
1630 );
1631
1632 let scope = parse_and_build_scope(
1634 "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1635 );
1636 assert!(
1637 scope.sources.contains_key("cte"),
1638 "CTAS with CTE should resolve CTE as source"
1639 );
1640 assert!(
1641 !scope.sources.contains_key("out_table"),
1642 "CTAS target table should not be treated as a source"
1643 );
1644 assert_eq!(scope.cte_scopes.len(), 1);
1645 }
1646
1647 #[test]
1648 fn test_create_table_as_select_traverse() {
1649 let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1650 let scopes = traverse_scope(&ast[0]);
1651 assert!(
1652 !scopes.is_empty(),
1653 "traverse_scope should return scopes for CTAS"
1654 );
1655 }
1656}