1use crate::expressions::Expression;
9use crate::traversal::ExpressionWalk;
10use serde::{Deserialize, Serialize};
11use std::collections::{HashMap, HashSet, VecDeque};
12#[cfg(feature = "bindings")]
13use ts_rs::TS;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
17#[cfg_attr(feature = "bindings", derive(TS))]
18#[cfg_attr(feature = "bindings", ts(export))]
19pub enum ScopeType {
20 Root,
22 Subquery,
24 DerivedTable,
26 Cte,
28 SetOperation,
30 Udtf,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
36#[serde(rename_all = "snake_case")]
37pub enum SourceKind {
38 Root,
40 Table,
42 DerivedTable,
44 Cte,
46 Virtual,
48 Unknown,
50}
51
52impl Default for SourceKind {
53 fn default() -> Self {
54 Self::Unknown
55 }
56}
57
58#[derive(Debug, Clone)]
60pub struct SourceInfo {
61 pub expression: Expression,
63 pub is_scope: bool,
65 pub kind: SourceKind,
67 pub alias: Option<String>,
69 pub lineage_name: Option<String>,
71}
72
73impl SourceInfo {
74 pub fn new(expression: Expression, is_scope: bool, kind: SourceKind) -> Self {
75 Self {
76 expression,
77 is_scope,
78 kind,
79 alias: None,
80 lineage_name: None,
81 }
82 }
83
84 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
85 self.alias = Some(alias.into());
86 self
87 }
88
89 pub fn with_lineage_name(mut self, lineage_name: impl Into<String>) -> Self {
90 self.lineage_name = Some(lineage_name.into());
91 self
92 }
93}
94
95#[derive(Debug, Clone, PartialEq, Eq, Hash)]
97pub struct ColumnRef {
98 pub table: Option<String>,
100 pub name: String,
102}
103
104#[derive(Debug, Clone)]
109pub struct Scope {
110 pub expression: Expression,
112
113 pub scope_type: ScopeType,
115
116 pub sources: HashMap<String, SourceInfo>,
118
119 pub lateral_sources: HashMap<String, SourceInfo>,
121
122 pub cte_sources: HashMap<String, SourceInfo>,
124
125 pub outer_columns: Vec<String>,
128
129 pub can_be_correlated: bool,
132
133 pub subquery_scopes: Vec<Scope>,
135
136 pub derived_table_scopes: Vec<Scope>,
138
139 pub cte_scopes: Vec<Scope>,
141
142 pub udtf_scopes: Vec<Scope>,
144
145 pub table_scopes: Vec<Scope>,
147
148 pub union_scopes: Vec<Scope>,
150
151 columns_cache: Option<Vec<ColumnRef>>,
153
154 external_columns_cache: Option<Vec<ColumnRef>>,
156}
157
158impl Scope {
159 pub fn new(expression: Expression) -> Self {
161 Self {
162 expression,
163 scope_type: ScopeType::Root,
164 sources: HashMap::new(),
165 lateral_sources: HashMap::new(),
166 cte_sources: HashMap::new(),
167 outer_columns: Vec::new(),
168 can_be_correlated: false,
169 subquery_scopes: Vec::new(),
170 derived_table_scopes: Vec::new(),
171 cte_scopes: Vec::new(),
172 udtf_scopes: Vec::new(),
173 table_scopes: Vec::new(),
174 union_scopes: Vec::new(),
175 columns_cache: None,
176 external_columns_cache: None,
177 }
178 }
179
180 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
182 self.branch_with_options(expression, scope_type, None, None, None)
183 }
184
185 pub fn branch_with_options(
187 &self,
188 expression: Expression,
189 scope_type: ScopeType,
190 sources: Option<HashMap<String, SourceInfo>>,
191 lateral_sources: Option<HashMap<String, SourceInfo>>,
192 outer_columns: Option<Vec<String>>,
193 ) -> Self {
194 let can_be_correlated = self.can_be_correlated
195 || scope_type == ScopeType::Subquery
196 || scope_type == ScopeType::Udtf;
197
198 Self {
199 expression,
200 scope_type,
201 sources: sources.unwrap_or_default(),
202 lateral_sources: lateral_sources.unwrap_or_default(),
203 cte_sources: self.cte_sources.clone(),
204 outer_columns: outer_columns.unwrap_or_default(),
205 can_be_correlated,
206 subquery_scopes: Vec::new(),
207 derived_table_scopes: Vec::new(),
208 cte_scopes: Vec::new(),
209 udtf_scopes: Vec::new(),
210 table_scopes: Vec::new(),
211 union_scopes: Vec::new(),
212 columns_cache: None,
213 external_columns_cache: None,
214 }
215 }
216
217 pub fn clear_cache(&mut self) {
219 self.columns_cache = None;
220 self.external_columns_cache = None;
221 }
222
223 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
225 let kind = if is_scope {
226 SourceKind::DerivedTable
227 } else {
228 SourceKind::Table
229 };
230 self.add_source_info(name, SourceInfo::new(expression, is_scope, kind));
231 }
232
233 pub fn add_source_info(&mut self, name: String, info: SourceInfo) {
235 self.sources.insert(name, info);
236 self.clear_cache();
237 }
238
239 pub fn add_virtual_source(&mut self, alias: String, expression: Expression) {
241 let lineage_name = self.next_virtual_source_name();
242 let info = SourceInfo::new(expression, false, SourceKind::Virtual)
243 .with_alias(alias.clone())
244 .with_lineage_name(lineage_name);
245 self.add_source_info(alias, info);
246 }
247
248 fn next_virtual_source_name(&self) -> String {
249 let count = self
250 .sources
251 .values()
252 .filter(|source| source.kind == SourceKind::Virtual)
253 .count();
254 format!("_{}", count)
255 }
256
257 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
259 let kind = if is_scope {
260 SourceKind::DerivedTable
261 } else {
262 SourceKind::Table
263 };
264 let info = SourceInfo::new(expression.clone(), is_scope, kind);
265 self.sources.insert(name.clone(), info.clone());
266 self.lateral_sources.insert(name, info);
267 self.clear_cache();
268 }
269
270 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
272 let info = SourceInfo::new(expression, true, SourceKind::Cte);
273 self.cte_sources.insert(name.clone(), info.clone());
274 self.sources.insert(name, info);
275 self.clear_cache();
276 }
277
278 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
280 if let Some(source) = self.sources.remove(old_name) {
281 self.sources.insert(new_name, source);
282 }
283 self.clear_cache();
284 }
285
286 pub fn remove_source(&mut self, name: &str) {
288 self.sources.remove(name);
289 self.clear_cache();
290 }
291
292 pub fn columns(&mut self) -> &[ColumnRef] {
294 if self.columns_cache.is_none() {
295 let mut columns = Vec::new();
296 collect_columns(&self.expression, &mut columns);
297 self.columns_cache = Some(columns);
298 }
299 self.columns_cache.as_ref().unwrap()
300 }
301
302 pub fn output_columns(&self) -> Vec<String> {
307 crate::ast_transforms::get_output_column_names(&self.expression)
308 }
309
310 pub fn source_names(&self) -> HashSet<String> {
312 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
313 names.extend(self.cte_sources.keys().cloned());
314 names
315 }
316
317 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
319 if self.external_columns_cache.is_some() {
320 return self.external_columns_cache.clone().unwrap();
321 }
322
323 let source_names = self.source_names();
324 let columns = self.columns().to_vec();
325
326 let external: Vec<ColumnRef> = columns
327 .into_iter()
328 .filter(|col| {
329 match &col.table {
331 Some(table) => !source_names.contains(table),
332 None => false, }
334 })
335 .collect();
336
337 self.external_columns_cache = Some(external.clone());
338 external
339 }
340
341 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
343 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
344 let columns = self.columns().to_vec();
345
346 columns
347 .into_iter()
348 .filter(|col| !external_set.contains(col))
349 .collect()
350 }
351
352 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
354 self.columns()
355 .iter()
356 .filter(|c| c.table.is_none())
357 .cloned()
358 .collect()
359 }
360
361 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
363 self.columns()
364 .iter()
365 .filter(|col| col.table.as_deref() == Some(source_name))
366 .cloned()
367 .collect()
368 }
369
370 pub fn is_correlated_subquery(&mut self) -> bool {
376 self.can_be_correlated && !self.external_columns().is_empty()
377 }
378
379 pub fn is_subquery(&self) -> bool {
381 self.scope_type == ScopeType::Subquery
382 }
383
384 pub fn is_derived_table(&self) -> bool {
386 self.scope_type == ScopeType::DerivedTable
387 }
388
389 pub fn is_cte(&self) -> bool {
391 self.scope_type == ScopeType::Cte
392 }
393
394 pub fn is_root(&self) -> bool {
396 self.scope_type == ScopeType::Root
397 }
398
399 pub fn is_udtf(&self) -> bool {
401 self.scope_type == ScopeType::Udtf
402 }
403
404 pub fn is_union(&self) -> bool {
406 self.scope_type == ScopeType::SetOperation
407 }
408
409 pub fn traverse(&self) -> Vec<&Scope> {
411 let mut result = Vec::new();
412 self.traverse_impl(&mut result);
413 result
414 }
415
416 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
417 for scope in &self.cte_scopes {
419 scope.traverse_impl(result);
420 }
421 for scope in &self.union_scopes {
422 scope.traverse_impl(result);
423 }
424 for scope in &self.table_scopes {
425 scope.traverse_impl(result);
426 }
427 for scope in &self.subquery_scopes {
428 scope.traverse_impl(result);
429 }
430 result.push(self);
432 }
433
434 pub fn ref_count(&self) -> HashMap<usize, usize> {
436 let mut counts: HashMap<usize, usize> = HashMap::new();
437
438 for scope in self.traverse() {
439 for (_, source_info) in scope.sources.iter() {
440 if source_info.is_scope {
441 let id = &source_info.expression as *const _ as usize;
442 *counts.entry(id).or_insert(0) += 1;
443 }
444 }
445 }
446
447 counts
448 }
449}
450
451fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
453 match expr {
454 Expression::Column(col) => {
455 columns.push(ColumnRef {
456 table: col.table.as_ref().map(|t| t.name.clone()),
457 name: col.name.name.clone(),
458 });
459 }
460 Expression::Select(select) => {
461 for e in &select.expressions {
463 collect_columns(e, columns);
464 }
465 for join in &select.joins {
467 if let Some(on) = &join.on {
468 collect_columns(on, columns);
469 }
470 if let Some(match_condition) = &join.match_condition {
471 collect_columns(match_condition, columns);
472 }
473 }
474 if let Some(where_clause) = &select.where_clause {
476 collect_columns(&where_clause.this, columns);
477 }
478 if let Some(having) = &select.having {
480 collect_columns(&having.this, columns);
481 }
482 if let Some(order_by) = &select.order_by {
484 for ord in &order_by.expressions {
485 collect_columns(&ord.this, columns);
486 }
487 }
488 if let Some(group_by) = &select.group_by {
490 for e in &group_by.expressions {
491 collect_columns(e, columns);
492 }
493 }
494 }
497 Expression::And(bin)
499 | Expression::Or(bin)
500 | Expression::Add(bin)
501 | Expression::Sub(bin)
502 | Expression::Mul(bin)
503 | Expression::Div(bin)
504 | Expression::Mod(bin)
505 | Expression::Eq(bin)
506 | Expression::Neq(bin)
507 | Expression::Lt(bin)
508 | Expression::Lte(bin)
509 | Expression::Gt(bin)
510 | Expression::Gte(bin)
511 | Expression::BitwiseAnd(bin)
512 | Expression::BitwiseOr(bin)
513 | Expression::BitwiseXor(bin)
514 | Expression::Concat(bin) => {
515 collect_columns(&bin.left, columns);
516 collect_columns(&bin.right, columns);
517 }
518 Expression::Like(like) | Expression::ILike(like) => {
520 collect_columns(&like.left, columns);
521 collect_columns(&like.right, columns);
522 if let Some(escape) = &like.escape {
523 collect_columns(escape, columns);
524 }
525 }
526 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
528 collect_columns(&un.this, columns);
529 }
530 Expression::Function(func) => {
531 for arg in &func.args {
532 collect_columns(arg, columns);
533 }
534 }
535 Expression::AggregateFunction(agg) => {
536 for arg in &agg.args {
537 collect_columns(arg, columns);
538 }
539 }
540 Expression::WindowFunction(wf) => {
541 collect_columns(&wf.this, columns);
542 for e in &wf.over.partition_by {
543 collect_columns(e, columns);
544 }
545 for e in &wf.over.order_by {
546 collect_columns(&e.this, columns);
547 }
548 }
549 Expression::Alias(alias) => {
550 collect_columns(&alias.this, columns);
551 }
552 Expression::Case(case) => {
553 if let Some(operand) = &case.operand {
554 collect_columns(operand, columns);
555 }
556 for (when_expr, then_expr) in &case.whens {
557 collect_columns(when_expr, columns);
558 collect_columns(then_expr, columns);
559 }
560 if let Some(else_clause) = &case.else_ {
561 collect_columns(else_clause, columns);
562 }
563 }
564 Expression::Paren(paren) => {
565 collect_columns(&paren.this, columns);
566 }
567 Expression::Ordered(ord) => {
568 collect_columns(&ord.this, columns);
569 }
570 Expression::In(in_expr) => {
571 collect_columns(&in_expr.this, columns);
572 for e in &in_expr.expressions {
573 collect_columns(e, columns);
574 }
575 }
577 Expression::Between(between) => {
578 collect_columns(&between.this, columns);
579 collect_columns(&between.low, columns);
580 collect_columns(&between.high, columns);
581 }
582 Expression::IsNull(is_null) => {
583 collect_columns(&is_null.this, columns);
584 }
585 Expression::Cast(cast) => {
586 collect_columns(&cast.this, columns);
587 }
588 Expression::Extract(extract) => {
589 collect_columns(&extract.this, columns);
590 }
591 Expression::Exists(_) | Expression::Subquery(_) => {
592 }
594 Expression::Prepare(prepare) => {
595 collect_columns(&prepare.statement, columns);
596 }
597 _ => {
598 }
600 }
601}
602
603pub fn build_scope(expression: &Expression) -> Scope {
608 let mut root = Scope::new(expression.clone());
609 build_scope_impl(expression, &mut root);
610 root
611}
612
613fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
614 match expression {
615 Expression::Prepare(prepare) => {
616 build_scope_impl(&prepare.statement, current_scope);
617 }
618 Expression::Select(select) => {
619 if let Some(with) = &select.with {
621 process_ctes(with, current_scope);
622 }
623
624 if let Some(from) = &select.from {
626 for table in &from.expressions {
627 add_table_to_scope(table, current_scope);
628 }
629 }
630
631 for join in &select.joins {
633 add_table_to_scope(&join.this, current_scope);
634 }
635
636 for lateral_view in &select.lateral_views {
638 add_lateral_view_to_scope(lateral_view, current_scope);
639 }
640
641 collect_subqueries(expression, current_scope);
643 }
644 Expression::Union(union) => {
645 if let Some(with) = &union.with {
646 process_ctes(with, current_scope);
647 }
648
649 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
650 build_scope_impl(&union.left, &mut left_scope);
651
652 let mut right_scope =
653 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
654 build_scope_impl(&union.right, &mut right_scope);
655
656 current_scope.union_scopes.push(left_scope);
657 current_scope.union_scopes.push(right_scope);
658 }
659 Expression::Intersect(intersect) => {
660 if let Some(with) = &intersect.with {
661 process_ctes(with, current_scope);
662 }
663
664 let mut left_scope =
665 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
666 build_scope_impl(&intersect.left, &mut left_scope);
667
668 let mut right_scope =
669 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
670 build_scope_impl(&intersect.right, &mut right_scope);
671
672 current_scope.union_scopes.push(left_scope);
673 current_scope.union_scopes.push(right_scope);
674 }
675 Expression::Except(except) => {
676 if let Some(with) = &except.with {
677 process_ctes(with, current_scope);
678 }
679
680 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
681 build_scope_impl(&except.left, &mut left_scope);
682
683 let mut right_scope =
684 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
685 build_scope_impl(&except.right, &mut right_scope);
686
687 current_scope.union_scopes.push(left_scope);
688 current_scope.union_scopes.push(right_scope);
689 }
690 Expression::CreateTable(create) => {
691 if let Some(with) = &create.with_cte {
694 process_ctes(with, current_scope);
695 }
696 if let Some(as_select) = &create.as_select {
698 build_scope_impl(as_select, current_scope);
699 }
700 }
701 Expression::Subquery(subquery) => {
702 build_scope_impl(&subquery.this, current_scope);
703 }
704 Expression::Paren(paren) => {
705 build_scope_impl(&paren.this, current_scope);
706 }
707 _ => {}
708 }
709}
710
711fn process_ctes(with: &crate::expressions::With, current_scope: &mut Scope) {
712 for cte in &with.ctes {
713 let cte_name = cte.alias.name.clone();
714 let cte_expr = Expression::Cte(Box::new(cte.clone()));
715 let mut cte_scope = current_scope.branch(cte_expr.clone(), ScopeType::Cte);
716
717 if with.recursive && cte_body_self_references(cte) {
718 cte_scope.add_cte_source(cte_name.clone(), cte_expr.clone());
719 }
720
721 build_scope_impl(&cte.this, &mut cte_scope);
722 current_scope.add_cte_source(cte_name, cte_expr);
723 current_scope.cte_scopes.push(cte_scope);
724 }
725}
726
727fn cte_body_self_references(cte: &crate::expressions::Cte) -> bool {
728 let cte_name = cte.alias.name.as_str();
729 !cte.this
730 .find_all(|expr| match expr {
731 Expression::Table(table) if table.schema.is_none() && table.catalog.is_none() => {
732 table.name.name.eq_ignore_ascii_case(cte_name)
733 }
734 _ => false,
735 })
736 .is_empty()
737}
738
739fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
740 match expr {
741 Expression::Table(table) => {
742 let name = table
743 .alias
744 .as_ref()
745 .map(|a| a.name.clone())
746 .unwrap_or_else(|| table.name.name.clone());
747 let cte_source = if table.schema.is_none() && table.catalog.is_none() {
748 scope.cte_sources.get(&table.name.name).or_else(|| {
749 scope
750 .cte_sources
751 .iter()
752 .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
753 .map(|(_, source)| source)
754 })
755 } else {
756 None
757 };
758
759 if let Some(source) = cte_source {
760 scope.add_source_info(name, source.clone());
761 } else {
762 let mut source = SourceInfo::new(expr.clone(), false, SourceKind::Table);
763 if let Some(alias) = &table.alias {
764 source = source.with_alias(alias.name.clone());
765 }
766 scope.add_source_info(name, source);
767 }
768 }
769 Expression::Subquery(subquery) => {
770 let name = subquery
771 .alias
772 .as_ref()
773 .map(|a| a.name.clone())
774 .unwrap_or_default();
775
776 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
777 build_scope_impl(&subquery.this, &mut derived_scope);
778
779 scope.add_source(name.clone(), expr.clone(), true);
780 scope.derived_table_scopes.push(derived_scope);
781 }
782 Expression::Unnest(unnest) => {
783 if let Some(alias) = &unnest.alias {
784 scope.add_virtual_source(alias.name.clone(), expr.clone());
785 }
786 }
787 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
788 scope.add_virtual_source(alias.alias.name.clone(), expr.clone());
789 }
790 Expression::Alias(alias) if is_query_like_relation(&alias.this) => {
791 let outer_columns = alias
792 .column_aliases
793 .iter()
794 .map(|column| column.name.clone())
795 .collect::<Vec<_>>();
796 let mut derived_scope = scope.branch_with_options(
797 alias.this.clone(),
798 ScopeType::DerivedTable,
799 None,
800 None,
801 Some(outer_columns),
802 );
803 build_scope_impl(&alias.this, &mut derived_scope);
804
805 scope.add_source(alias.alias.name.clone(), expr.clone(), true);
806 scope.derived_table_scopes.push(derived_scope);
807 }
808 Expression::Lateral(lateral) => {
809 if let Some(alias) = &lateral.alias {
810 scope.add_virtual_source(alias.clone(), expr.clone());
811 }
812 }
813 Expression::LateralView(lateral_view) => {
814 add_lateral_view_to_scope(lateral_view, scope);
815 }
816 Expression::Pivot(pivot) => {
817 let name =
818 pivot_source_name(&pivot.this, pivot.alias.as_ref().map(|a| a.name.as_str()));
819 scope.add_source_info(
820 name,
821 SourceInfo::new(expr.clone(), false, SourceKind::DerivedTable),
822 );
823 add_pivot_inner_scope(&pivot.this, scope);
824 }
825 Expression::Unpivot(unpivot) => {
826 let name = pivot_source_name(
827 &unpivot.this,
828 unpivot.alias.as_ref().map(|a| a.name.as_str()),
829 );
830 scope.add_source_info(
831 name,
832 SourceInfo::new(expr.clone(), false, SourceKind::DerivedTable),
833 );
834 add_pivot_inner_scope(&unpivot.this, scope);
835 }
836 Expression::Paren(paren) => {
837 add_table_to_scope(&paren.this, scope);
838 }
839 _ => {}
840 }
841}
842
843fn is_query_like_relation(expr: &Expression) -> bool {
844 match expr {
845 Expression::Select(_)
846 | Expression::Subquery(_)
847 | Expression::Union(_)
848 | Expression::Intersect(_)
849 | Expression::Except(_) => true,
850 Expression::Paren(paren) => is_query_like_relation(&paren.this),
851 _ => false,
852 }
853}
854
855fn pivot_source_name(source: &Expression, explicit_alias: Option<&str>) -> String {
856 if let Some(alias) = explicit_alias {
857 return alias.to_string();
858 }
859
860 match source {
861 Expression::Table(table) => table
862 .alias
863 .as_ref()
864 .map(|alias| alias.name.clone())
865 .unwrap_or_else(|| table.name.name.clone()),
866 Expression::Subquery(subquery) => subquery
867 .alias
868 .as_ref()
869 .map(|alias| alias.name.clone())
870 .unwrap_or_else(|| "_0".to_string()),
871 Expression::Paren(paren) => pivot_source_name(&paren.this, explicit_alias),
872 _ => "_0".to_string(),
873 }
874}
875
876fn add_pivot_inner_scope(source: &Expression, scope: &mut Scope) {
877 match source {
878 Expression::Subquery(subquery) => {
879 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
880 build_scope_impl(&subquery.this, &mut derived_scope);
881 scope.derived_table_scopes.push(derived_scope);
882 }
883 Expression::Paren(paren) => add_pivot_inner_scope(&paren.this, scope),
884 _ => {}
885 }
886}
887
888fn add_lateral_view_to_scope(lateral_view: &crate::expressions::LateralView, scope: &mut Scope) {
889 let alias = lateral_view
890 .table_alias
891 .as_ref()
892 .or_else(|| lateral_view.column_aliases.first())
893 .map(|alias| alias.name.clone());
894
895 if let Some(alias) = alias {
896 scope.add_virtual_source(
897 alias,
898 Expression::LateralView(Box::new(lateral_view.clone())),
899 );
900 }
901}
902
903fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
904 match expr {
905 Expression::Select(select) => {
906 if let Some(where_clause) = &select.where_clause {
908 collect_subqueries_in_expr(&where_clause.this, parent_scope);
909 }
910 for e in &select.expressions {
912 collect_subqueries_in_expr(e, parent_scope);
913 }
914 if let Some(having) = &select.having {
916 collect_subqueries_in_expr(&having.this, parent_scope);
917 }
918 }
919 _ => {}
920 }
921}
922
923fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
924 let mut seen = HashSet::new();
925 for node in walk_in_scope(expr, false) {
926 let query = match node {
927 Expression::Subquery(subquery) if subquery.alias.is_none() => Some(&subquery.this),
928 Expression::Exists(exists) => Some(&exists.this),
929 Expression::In(in_expr) => in_expr.query.as_ref(),
930 Expression::Any(quantified) | Expression::All(quantified) => Some(&quantified.subquery),
931 _ => None,
932 };
933
934 let Some(query) = query else {
935 continue;
936 };
937
938 let key = query as *const Expression as usize;
939 if !seen.insert(key) {
940 continue;
941 }
942
943 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
944 build_scope_impl(query, &mut sub_scope);
945 parent_scope.subquery_scopes.push(sub_scope);
946 }
947}
948
949pub fn walk_in_scope<'a>(
961 expression: &'a Expression,
962 bfs: bool,
963) -> impl Iterator<Item = &'a Expression> {
964 WalkInScopeIter::new(expression, bfs)
965}
966
967struct WalkInScopeIter<'a> {
969 queue: VecDeque<&'a Expression>,
970 bfs: bool,
971}
972
973impl<'a> WalkInScopeIter<'a> {
974 fn new(expression: &'a Expression, bfs: bool) -> Self {
975 let mut queue = VecDeque::new();
976 queue.push_back(expression);
977 Self { queue, bfs }
978 }
979
980 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
981 if is_root {
982 return false;
983 }
984
985 if matches!(expr, Expression::Cte(_)) {
987 return true;
988 }
989
990 if let Expression::Subquery(subquery) = expr {
992 if subquery.alias.is_some() {
993 return true;
994 }
995 }
996
997 if matches!(
999 expr,
1000 Expression::Select(_)
1001 | Expression::Union(_)
1002 | Expression::Intersect(_)
1003 | Expression::Except(_)
1004 ) {
1005 return true;
1006 }
1007
1008 false
1009 }
1010
1011 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
1012 let mut children = Vec::new();
1013
1014 match expr {
1015 Expression::Prepare(prepare) => {
1016 children.push(&prepare.statement);
1017 }
1018 Expression::Select(select) => {
1019 for e in &select.expressions {
1021 children.push(e);
1022 }
1023 if let Some(from) = &select.from {
1025 for table in &from.expressions {
1026 if !self.should_stop_at(table, false) {
1027 children.push(table);
1028 }
1029 }
1030 }
1031 for join in &select.joins {
1033 if let Some(on) = &join.on {
1034 children.push(on);
1035 }
1036 }
1038 if let Some(where_clause) = &select.where_clause {
1040 children.push(&where_clause.this);
1041 }
1042 if let Some(group_by) = &select.group_by {
1044 for e in &group_by.expressions {
1045 children.push(e);
1046 }
1047 }
1048 if let Some(having) = &select.having {
1050 children.push(&having.this);
1051 }
1052 if let Some(order_by) = &select.order_by {
1054 for ord in &order_by.expressions {
1055 children.push(&ord.this);
1056 }
1057 }
1058 if let Some(limit) = &select.limit {
1060 children.push(&limit.this);
1061 }
1062 if let Some(offset) = &select.offset {
1064 children.push(&offset.this);
1065 }
1066 }
1067 Expression::And(bin)
1068 | Expression::Or(bin)
1069 | Expression::Add(bin)
1070 | Expression::Sub(bin)
1071 | Expression::Mul(bin)
1072 | Expression::Div(bin)
1073 | Expression::Mod(bin)
1074 | Expression::Eq(bin)
1075 | Expression::Neq(bin)
1076 | Expression::Lt(bin)
1077 | Expression::Lte(bin)
1078 | Expression::Gt(bin)
1079 | Expression::Gte(bin)
1080 | Expression::BitwiseAnd(bin)
1081 | Expression::BitwiseOr(bin)
1082 | Expression::BitwiseXor(bin)
1083 | Expression::Concat(bin) => {
1084 children.push(&bin.left);
1085 children.push(&bin.right);
1086 }
1087 Expression::Like(like) | Expression::ILike(like) => {
1088 children.push(&like.left);
1089 children.push(&like.right);
1090 if let Some(escape) = &like.escape {
1091 children.push(escape);
1092 }
1093 }
1094 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
1095 children.push(&un.this);
1096 }
1097 Expression::Function(func) => {
1098 for arg in &func.args {
1099 children.push(arg);
1100 }
1101 }
1102 Expression::AggregateFunction(agg) => {
1103 for arg in &agg.args {
1104 children.push(arg);
1105 }
1106 }
1107 Expression::WindowFunction(wf) => {
1108 children.push(&wf.this);
1109 for e in &wf.over.partition_by {
1110 children.push(e);
1111 }
1112 for e in &wf.over.order_by {
1113 children.push(&e.this);
1114 }
1115 }
1116 Expression::Alias(alias) => {
1117 children.push(&alias.this);
1118 }
1119 Expression::Case(case) => {
1120 if let Some(operand) = &case.operand {
1121 children.push(operand);
1122 }
1123 for (when_expr, then_expr) in &case.whens {
1124 children.push(when_expr);
1125 children.push(then_expr);
1126 }
1127 if let Some(else_clause) = &case.else_ {
1128 children.push(else_clause);
1129 }
1130 }
1131 Expression::Paren(paren) => {
1132 children.push(&paren.this);
1133 }
1134 Expression::Ordered(ord) => {
1135 children.push(&ord.this);
1136 }
1137 Expression::In(in_expr) => {
1138 children.push(&in_expr.this);
1139 for e in &in_expr.expressions {
1140 children.push(e);
1141 }
1142 }
1144 Expression::Between(between) => {
1145 children.push(&between.this);
1146 children.push(&between.low);
1147 children.push(&between.high);
1148 }
1149 Expression::IsNull(is_null) => {
1150 children.push(&is_null.this);
1151 }
1152 Expression::Cast(cast) => {
1153 children.push(&cast.this);
1154 }
1155 Expression::Extract(extract) => {
1156 children.push(&extract.this);
1157 }
1158 Expression::Coalesce(coalesce) => {
1159 for e in &coalesce.expressions {
1160 children.push(e);
1161 }
1162 }
1163 Expression::NullIf(nullif) => {
1164 children.push(&nullif.this);
1165 children.push(&nullif.expression);
1166 }
1167 Expression::Table(_table) => {
1168 }
1171 Expression::TryCatch(try_catch) => {
1172 for stmt in &try_catch.try_body {
1173 children.push(stmt);
1174 }
1175 if let Some(catch_body) = &try_catch.catch_body {
1176 for stmt in catch_body {
1177 children.push(stmt);
1178 }
1179 }
1180 }
1181 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1182 }
1184 Expression::Subquery(_) | Expression::Exists(_) => {}
1186 _ => {
1187 }
1189 }
1190
1191 children
1192 }
1193}
1194
1195impl<'a> Iterator for WalkInScopeIter<'a> {
1196 type Item = &'a Expression;
1197
1198 fn next(&mut self) -> Option<Self::Item> {
1199 let expr = if self.bfs {
1200 self.queue.pop_front()?
1201 } else {
1202 self.queue.pop_back()?
1203 };
1204
1205 let children = self.get_children(expr);
1207
1208 if self.bfs {
1209 for child in children {
1210 if !self.should_stop_at(child, false) {
1211 self.queue.push_back(child);
1212 }
1213 }
1214 } else {
1215 for child in children.into_iter().rev() {
1216 if !self.should_stop_at(child, false) {
1217 self.queue.push_back(child);
1218 }
1219 }
1220 }
1221
1222 Some(expr)
1223 }
1224}
1225
1226pub fn find_in_scope<'a, F>(
1238 expression: &'a Expression,
1239 predicate: F,
1240 bfs: bool,
1241) -> Option<&'a Expression>
1242where
1243 F: Fn(&Expression) -> bool,
1244{
1245 walk_in_scope(expression, bfs).find(|e| predicate(e))
1246}
1247
1248pub fn find_all_in_scope<'a, F>(
1260 expression: &'a Expression,
1261 predicate: F,
1262 bfs: bool,
1263) -> Vec<&'a Expression>
1264where
1265 F: Fn(&Expression) -> bool,
1266{
1267 walk_in_scope(expression, bfs)
1268 .filter(|e| predicate(e))
1269 .collect()
1270}
1271
1272pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1282 match expression {
1283 Expression::Select(_)
1284 | Expression::Union(_)
1285 | Expression::Intersect(_)
1286 | Expression::Except(_)
1287 | Expression::Prepare(_)
1288 | Expression::CreateTable(_) => {
1289 let root = build_scope(expression);
1290 root.traverse().into_iter().cloned().collect()
1291 }
1292 _ => Vec::new(),
1293 }
1294}
1295
1296#[cfg(test)]
1297mod tests {
1298 use super::*;
1299 use crate::parser::Parser;
1300
1301 fn parse_and_build_scope(sql: &str) -> Scope {
1302 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1303 build_scope(&ast[0])
1304 }
1305
1306 #[test]
1307 fn test_simple_select_scope() {
1308 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1309
1310 assert!(scope.is_root());
1311 assert!(!scope.can_be_correlated);
1312 assert!(scope.sources.contains_key("t"));
1313
1314 let columns = scope.columns();
1315 assert_eq!(columns.len(), 2);
1316 }
1317
1318 #[test]
1319 fn test_derived_table_scope() {
1320 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1321
1322 assert!(scope.sources.contains_key("x"));
1323 assert_eq!(scope.derived_table_scopes.len(), 1);
1324
1325 let derived = &mut scope.derived_table_scopes[0];
1326 assert!(derived.is_derived_table());
1327 assert!(derived.sources.contains_key("t"));
1328 }
1329
1330 #[test]
1331 fn test_non_correlated_subquery() {
1332 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1333
1334 assert_eq!(scope.subquery_scopes.len(), 1);
1335
1336 let subquery = &mut scope.subquery_scopes[0];
1337 assert!(subquery.is_subquery());
1338 assert!(subquery.can_be_correlated);
1339
1340 assert!(subquery.sources.contains_key("s"));
1342 assert!(!subquery.is_correlated_subquery());
1343 }
1344
1345 #[test]
1346 fn test_correlated_subquery() {
1347 let mut scope =
1348 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1349
1350 assert_eq!(scope.subquery_scopes.len(), 1);
1351
1352 let subquery = &mut scope.subquery_scopes[0];
1353 assert!(subquery.is_subquery());
1354 assert!(subquery.can_be_correlated);
1355
1356 let external = subquery.external_columns();
1358 assert!(!external.is_empty());
1359 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1360 assert!(subquery.is_correlated_subquery());
1361 }
1362
1363 #[test]
1364 fn test_cte_scope() {
1365 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1366
1367 assert_eq!(scope.cte_scopes.len(), 1);
1368 assert!(scope.cte_sources.contains_key("cte"));
1369
1370 let cte = &scope.cte_scopes[0];
1371 assert!(cte.is_cte());
1372 }
1373
1374 #[test]
1375 fn test_multiple_sources() {
1376 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1377
1378 assert!(scope.sources.contains_key("t"));
1379 assert!(scope.sources.contains_key("s"));
1380 assert_eq!(scope.sources.len(), 2);
1381 }
1382
1383 #[test]
1384 fn test_aliased_table() {
1385 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1386
1387 assert!(scope.sources.contains_key("x"));
1389 assert!(!scope.sources.contains_key("t"));
1390 }
1391
1392 #[test]
1393 fn test_local_columns() {
1394 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1395
1396 let local = scope.local_columns();
1397 assert_eq!(local.len(), 5);
1400 assert!(local.iter().all(|c| c.table.is_some()));
1401 }
1402
1403 #[test]
1404 fn test_columns_include_join_on_clause_references() {
1405 let mut scope = parse_and_build_scope(
1406 "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1407 );
1408
1409 let cols: Vec<String> = scope
1410 .columns()
1411 .iter()
1412 .map(|c| match &c.table {
1413 Some(t) => format!("{}.{}", t, c.name),
1414 None => c.name.clone(),
1415 })
1416 .collect();
1417
1418 assert!(cols.contains(&"o.total".to_string()));
1419 assert!(cols.contains(&"c.id".to_string()));
1420 assert!(cols.contains(&"o.customer_id".to_string()));
1421 }
1422
1423 #[test]
1424 fn test_unqualified_columns() {
1425 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1426
1427 let unqualified = scope.unqualified_columns();
1428 assert_eq!(unqualified.len(), 2);
1430 assert!(unqualified.iter().all(|c| c.table.is_none()));
1431 }
1432
1433 #[test]
1434 fn test_source_columns() {
1435 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1436
1437 let t_cols = scope.source_columns("t");
1438 assert!(t_cols.len() >= 2);
1440 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1441
1442 let s_cols = scope.source_columns("s");
1443 assert!(s_cols.len() >= 1);
1445 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1446 }
1447
1448 #[test]
1449 fn test_rename_source() {
1450 let mut scope = parse_and_build_scope("SELECT a FROM t");
1451
1452 assert!(scope.sources.contains_key("t"));
1453 scope.rename_source("t", "new_name".to_string());
1454 assert!(!scope.sources.contains_key("t"));
1455 assert!(scope.sources.contains_key("new_name"));
1456 }
1457
1458 #[test]
1459 fn test_remove_source() {
1460 let mut scope = parse_and_build_scope("SELECT a FROM t");
1461
1462 assert!(scope.sources.contains_key("t"));
1463 scope.remove_source("t");
1464 assert!(!scope.sources.contains_key("t"));
1465 }
1466
1467 #[test]
1468 fn test_walk_in_scope() {
1469 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1470 let expr = &ast[0];
1471
1472 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1474 assert!(!walked.is_empty());
1475
1476 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1478 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1480 }
1481
1482 #[test]
1483 fn test_find_in_scope() {
1484 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1485 let expr = &ast[0];
1486
1487 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1489 assert!(found.is_some());
1490 assert!(matches!(found.unwrap(), Expression::Column(_)));
1491 }
1492
1493 #[test]
1494 fn test_find_all_in_scope() {
1495 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1496 let expr = &ast[0];
1497
1498 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1500 assert_eq!(found.len(), 3);
1501 }
1502
1503 #[test]
1504 fn test_traverse_scope() {
1505 let ast =
1506 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1507 let expr = &ast[0];
1508
1509 let scopes = traverse_scope(expr);
1510 assert!(!scopes.is_empty());
1513 assert!(scopes.iter().any(|s| s.is_root()));
1515 }
1516
1517 #[test]
1518 fn test_branch_with_options() {
1519 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1520 let scope = build_scope(&ast[0]);
1521
1522 let child = scope.branch_with_options(
1523 ast[0].clone(),
1524 ScopeType::Subquery, None,
1526 None,
1527 Some(vec!["col1".to_string(), "col2".to_string()]),
1528 );
1529
1530 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1531 assert!(child.can_be_correlated); }
1533
1534 #[test]
1535 fn test_is_udtf() {
1536 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1537 let scope = Scope::new(ast[0].clone());
1538 assert!(!scope.is_udtf());
1539
1540 let root = build_scope(&ast[0]);
1541 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1542 assert!(udtf_scope.is_udtf());
1543 }
1544
1545 #[test]
1546 fn test_is_union() {
1547 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1548
1549 assert!(scope.is_root());
1550 assert_eq!(scope.union_scopes.len(), 2);
1551 assert!(scope.union_scopes[0].is_union());
1553 assert!(scope.union_scopes[1].is_union());
1554 }
1555
1556 #[test]
1557 fn test_union_output_columns() {
1558 let scope = parse_and_build_scope(
1559 "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1560 );
1561 assert_eq!(scope.output_columns(), vec!["id", "name"]);
1562 }
1563
1564 #[test]
1565 fn test_clear_cache() {
1566 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1567
1568 let _ = scope.columns();
1570 assert!(scope.columns_cache.is_some());
1571
1572 scope.clear_cache();
1574 assert!(scope.columns_cache.is_none());
1575 assert!(scope.external_columns_cache.is_none());
1576 }
1577
1578 #[test]
1579 fn test_scope_traverse() {
1580 let scope = parse_and_build_scope(
1581 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1582 );
1583
1584 let traversed = scope.traverse();
1585 assert!(traversed.len() >= 3);
1587 }
1588
1589 #[test]
1590 fn test_create_table_as_select_scope() {
1591 let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1593 assert!(
1594 scope.sources.contains_key("src"),
1595 "CTAS scope should contain the FROM table"
1596 );
1597 assert!(
1598 !scope.sources.contains_key("out_table"),
1599 "CTAS target table should not be treated as a source"
1600 );
1601
1602 let scope = parse_and_build_scope(
1604 "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1605 );
1606 assert!(scope.sources.contains_key("a"));
1607 assert!(scope.sources.contains_key("b"));
1608 assert!(
1609 !scope.sources.contains_key("out_table"),
1610 "CTAS target table should not be treated as a source"
1611 );
1612
1613 let scope = parse_and_build_scope(
1615 "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1616 );
1617 assert!(
1618 scope.sources.contains_key("cte"),
1619 "CTAS with CTE should resolve CTE as source"
1620 );
1621 assert!(
1622 !scope.sources.contains_key("out_table"),
1623 "CTAS target table should not be treated as a source"
1624 );
1625 assert_eq!(scope.cte_scopes.len(), 1);
1626 }
1627
1628 #[test]
1629 fn test_create_table_as_select_traverse() {
1630 let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1631 let scopes = traverse_scope(&ast[0]);
1632 assert!(
1633 !scopes.is_empty(),
1634 "traverse_scope should return scopes for CTAS"
1635 );
1636 }
1637}