1use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19 Root,
21 Subquery,
23 DerivedTable,
25 Cte,
27 SetOperation,
29 Udtf,
31}
32
33#[derive(Debug, Clone)]
35pub struct SourceInfo {
36 pub expression: Expression,
38 pub is_scope: bool,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45 pub table: Option<String>,
47 pub name: String,
49}
50
51#[derive(Debug, Clone)]
56pub struct Scope {
57 pub expression: Expression,
59
60 pub scope_type: ScopeType,
62
63 pub sources: HashMap<String, SourceInfo>,
65
66 pub lateral_sources: HashMap<String, SourceInfo>,
68
69 pub cte_sources: HashMap<String, SourceInfo>,
71
72 pub outer_columns: Vec<String>,
75
76 pub can_be_correlated: bool,
79
80 pub subquery_scopes: Vec<Scope>,
82
83 pub derived_table_scopes: Vec<Scope>,
85
86 pub cte_scopes: Vec<Scope>,
88
89 pub udtf_scopes: Vec<Scope>,
91
92 pub table_scopes: Vec<Scope>,
94
95 pub union_scopes: Vec<Scope>,
97
98 columns_cache: Option<Vec<ColumnRef>>,
100
101 external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106 pub fn new(expression: Expression) -> Self {
108 Self {
109 expression,
110 scope_type: ScopeType::Root,
111 sources: HashMap::new(),
112 lateral_sources: HashMap::new(),
113 cte_sources: HashMap::new(),
114 outer_columns: Vec::new(),
115 can_be_correlated: false,
116 subquery_scopes: Vec::new(),
117 derived_table_scopes: Vec::new(),
118 cte_scopes: Vec::new(),
119 udtf_scopes: Vec::new(),
120 table_scopes: Vec::new(),
121 union_scopes: Vec::new(),
122 columns_cache: None,
123 external_columns_cache: None,
124 }
125 }
126
127 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129 self.branch_with_options(expression, scope_type, None, None, None)
130 }
131
132 pub fn branch_with_options(
134 &self,
135 expression: Expression,
136 scope_type: ScopeType,
137 sources: Option<HashMap<String, SourceInfo>>,
138 lateral_sources: Option<HashMap<String, SourceInfo>>,
139 outer_columns: Option<Vec<String>>,
140 ) -> Self {
141 let can_be_correlated = self.can_be_correlated
142 || scope_type == ScopeType::Subquery
143 || scope_type == ScopeType::Udtf;
144
145 Self {
146 expression,
147 scope_type,
148 sources: sources.unwrap_or_default(),
149 lateral_sources: lateral_sources.unwrap_or_default(),
150 cte_sources: self.cte_sources.clone(),
151 outer_columns: outer_columns.unwrap_or_default(),
152 can_be_correlated,
153 subquery_scopes: Vec::new(),
154 derived_table_scopes: Vec::new(),
155 cte_scopes: Vec::new(),
156 udtf_scopes: Vec::new(),
157 table_scopes: Vec::new(),
158 union_scopes: Vec::new(),
159 columns_cache: None,
160 external_columns_cache: None,
161 }
162 }
163
164 pub fn clear_cache(&mut self) {
166 self.columns_cache = None;
167 self.external_columns_cache = None;
168 }
169
170 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172 self.sources.insert(
173 name,
174 SourceInfo {
175 expression,
176 is_scope,
177 },
178 );
179 self.clear_cache();
180 }
181
182 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184 self.lateral_sources.insert(
185 name.clone(),
186 SourceInfo {
187 expression: expression.clone(),
188 is_scope,
189 },
190 );
191 self.sources.insert(
192 name,
193 SourceInfo {
194 expression,
195 is_scope,
196 },
197 );
198 self.clear_cache();
199 }
200
201 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203 self.cte_sources.insert(
204 name.clone(),
205 SourceInfo {
206 expression: expression.clone(),
207 is_scope: true,
208 },
209 );
210 self.sources.insert(
211 name,
212 SourceInfo {
213 expression,
214 is_scope: true,
215 },
216 );
217 self.clear_cache();
218 }
219
220 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222 if let Some(source) = self.sources.remove(old_name) {
223 self.sources.insert(new_name, source);
224 }
225 self.clear_cache();
226 }
227
228 pub fn remove_source(&mut self, name: &str) {
230 self.sources.remove(name);
231 self.clear_cache();
232 }
233
234 pub fn columns(&mut self) -> &[ColumnRef] {
236 if self.columns_cache.is_none() {
237 let mut columns = Vec::new();
238 collect_columns(&self.expression, &mut columns);
239 self.columns_cache = Some(columns);
240 }
241 self.columns_cache.as_ref().unwrap()
242 }
243
244 pub fn output_columns(&self) -> Vec<String> {
249 crate::ast_transforms::get_output_column_names(&self.expression)
250 }
251
252 pub fn source_names(&self) -> HashSet<String> {
254 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
255 names.extend(self.cte_sources.keys().cloned());
256 names
257 }
258
259 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
261 if self.external_columns_cache.is_some() {
262 return self.external_columns_cache.clone().unwrap();
263 }
264
265 let source_names = self.source_names();
266 let columns = self.columns().to_vec();
267
268 let external: Vec<ColumnRef> = columns
269 .into_iter()
270 .filter(|col| {
271 match &col.table {
273 Some(table) => !source_names.contains(table),
274 None => false, }
276 })
277 .collect();
278
279 self.external_columns_cache = Some(external.clone());
280 external
281 }
282
283 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
285 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
286 let columns = self.columns().to_vec();
287
288 columns
289 .into_iter()
290 .filter(|col| !external_set.contains(col))
291 .collect()
292 }
293
294 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
296 self.columns()
297 .iter()
298 .filter(|c| c.table.is_none())
299 .cloned()
300 .collect()
301 }
302
303 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
305 self.columns()
306 .iter()
307 .filter(|col| col.table.as_deref() == Some(source_name))
308 .cloned()
309 .collect()
310 }
311
312 pub fn is_correlated_subquery(&mut self) -> bool {
318 self.can_be_correlated && !self.external_columns().is_empty()
319 }
320
321 pub fn is_subquery(&self) -> bool {
323 self.scope_type == ScopeType::Subquery
324 }
325
326 pub fn is_derived_table(&self) -> bool {
328 self.scope_type == ScopeType::DerivedTable
329 }
330
331 pub fn is_cte(&self) -> bool {
333 self.scope_type == ScopeType::Cte
334 }
335
336 pub fn is_root(&self) -> bool {
338 self.scope_type == ScopeType::Root
339 }
340
341 pub fn is_udtf(&self) -> bool {
343 self.scope_type == ScopeType::Udtf
344 }
345
346 pub fn is_union(&self) -> bool {
348 self.scope_type == ScopeType::SetOperation
349 }
350
351 pub fn traverse(&self) -> Vec<&Scope> {
353 let mut result = Vec::new();
354 self.traverse_impl(&mut result);
355 result
356 }
357
358 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
359 for scope in &self.cte_scopes {
361 scope.traverse_impl(result);
362 }
363 for scope in &self.union_scopes {
364 scope.traverse_impl(result);
365 }
366 for scope in &self.table_scopes {
367 scope.traverse_impl(result);
368 }
369 for scope in &self.subquery_scopes {
370 scope.traverse_impl(result);
371 }
372 result.push(self);
374 }
375
376 pub fn ref_count(&self) -> HashMap<usize, usize> {
378 let mut counts: HashMap<usize, usize> = HashMap::new();
379
380 for scope in self.traverse() {
381 for (_, source_info) in scope.sources.iter() {
382 if source_info.is_scope {
383 let id = &source_info.expression as *const _ as usize;
384 *counts.entry(id).or_insert(0) += 1;
385 }
386 }
387 }
388
389 counts
390 }
391}
392
393fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
395 match expr {
396 Expression::Column(col) => {
397 columns.push(ColumnRef {
398 table: col.table.as_ref().map(|t| t.name.clone()),
399 name: col.name.name.clone(),
400 });
401 }
402 Expression::Select(select) => {
403 for e in &select.expressions {
405 collect_columns(e, columns);
406 }
407 for join in &select.joins {
409 if let Some(on) = &join.on {
410 collect_columns(on, columns);
411 }
412 if let Some(match_condition) = &join.match_condition {
413 collect_columns(match_condition, columns);
414 }
415 }
416 if let Some(where_clause) = &select.where_clause {
418 collect_columns(&where_clause.this, columns);
419 }
420 if let Some(having) = &select.having {
422 collect_columns(&having.this, columns);
423 }
424 if let Some(order_by) = &select.order_by {
426 for ord in &order_by.expressions {
427 collect_columns(&ord.this, columns);
428 }
429 }
430 if let Some(group_by) = &select.group_by {
432 for e in &group_by.expressions {
433 collect_columns(e, columns);
434 }
435 }
436 }
439 Expression::And(bin)
441 | Expression::Or(bin)
442 | Expression::Add(bin)
443 | Expression::Sub(bin)
444 | Expression::Mul(bin)
445 | Expression::Div(bin)
446 | Expression::Mod(bin)
447 | Expression::Eq(bin)
448 | Expression::Neq(bin)
449 | Expression::Lt(bin)
450 | Expression::Lte(bin)
451 | Expression::Gt(bin)
452 | Expression::Gte(bin)
453 | Expression::BitwiseAnd(bin)
454 | Expression::BitwiseOr(bin)
455 | Expression::BitwiseXor(bin)
456 | Expression::Concat(bin) => {
457 collect_columns(&bin.left, columns);
458 collect_columns(&bin.right, columns);
459 }
460 Expression::Like(like) | Expression::ILike(like) => {
462 collect_columns(&like.left, columns);
463 collect_columns(&like.right, columns);
464 if let Some(escape) = &like.escape {
465 collect_columns(escape, columns);
466 }
467 }
468 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
470 collect_columns(&un.this, columns);
471 }
472 Expression::Function(func) => {
473 for arg in &func.args {
474 collect_columns(arg, columns);
475 }
476 }
477 Expression::AggregateFunction(agg) => {
478 for arg in &agg.args {
479 collect_columns(arg, columns);
480 }
481 }
482 Expression::WindowFunction(wf) => {
483 collect_columns(&wf.this, columns);
484 for e in &wf.over.partition_by {
485 collect_columns(e, columns);
486 }
487 for e in &wf.over.order_by {
488 collect_columns(&e.this, columns);
489 }
490 }
491 Expression::Alias(alias) => {
492 collect_columns(&alias.this, columns);
493 }
494 Expression::Case(case) => {
495 if let Some(operand) = &case.operand {
496 collect_columns(operand, columns);
497 }
498 for (when_expr, then_expr) in &case.whens {
499 collect_columns(when_expr, columns);
500 collect_columns(then_expr, columns);
501 }
502 if let Some(else_clause) = &case.else_ {
503 collect_columns(else_clause, columns);
504 }
505 }
506 Expression::Paren(paren) => {
507 collect_columns(&paren.this, columns);
508 }
509 Expression::Ordered(ord) => {
510 collect_columns(&ord.this, columns);
511 }
512 Expression::In(in_expr) => {
513 collect_columns(&in_expr.this, columns);
514 for e in &in_expr.expressions {
515 collect_columns(e, columns);
516 }
517 }
519 Expression::Between(between) => {
520 collect_columns(&between.this, columns);
521 collect_columns(&between.low, columns);
522 collect_columns(&between.high, columns);
523 }
524 Expression::IsNull(is_null) => {
525 collect_columns(&is_null.this, columns);
526 }
527 Expression::Cast(cast) => {
528 collect_columns(&cast.this, columns);
529 }
530 Expression::Extract(extract) => {
531 collect_columns(&extract.this, columns);
532 }
533 Expression::Exists(_) | Expression::Subquery(_) => {
534 }
536 Expression::Prepare(prepare) => {
537 collect_columns(&prepare.statement, columns);
538 }
539 _ => {
540 }
542 }
543}
544
545pub fn build_scope(expression: &Expression) -> Scope {
550 let mut root = Scope::new(expression.clone());
551 build_scope_impl(expression, &mut root);
552 root
553}
554
555fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
556 match expression {
557 Expression::Prepare(prepare) => {
558 build_scope_impl(&prepare.statement, current_scope);
559 }
560 Expression::Select(select) => {
561 if let Some(with) = &select.with {
563 for cte in &with.ctes {
564 let cte_name = cte.alias.name.clone();
565 let mut cte_scope = current_scope
566 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
567 build_scope_impl(&cte.this, &mut cte_scope);
568 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
569 current_scope.cte_scopes.push(cte_scope);
570 }
571 }
572
573 if let Some(from) = &select.from {
575 for table in &from.expressions {
576 add_table_to_scope(table, current_scope);
577 }
578 }
579
580 for join in &select.joins {
582 add_table_to_scope(&join.this, current_scope);
583 }
584
585 collect_subqueries(expression, current_scope);
587 }
588 Expression::Union(union) => {
589 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
590 build_scope_impl(&union.left, &mut left_scope);
591
592 let mut right_scope =
593 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
594 build_scope_impl(&union.right, &mut right_scope);
595
596 current_scope.union_scopes.push(left_scope);
597 current_scope.union_scopes.push(right_scope);
598 }
599 Expression::Intersect(intersect) => {
600 let mut left_scope =
601 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
602 build_scope_impl(&intersect.left, &mut left_scope);
603
604 let mut right_scope =
605 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
606 build_scope_impl(&intersect.right, &mut right_scope);
607
608 current_scope.union_scopes.push(left_scope);
609 current_scope.union_scopes.push(right_scope);
610 }
611 Expression::Except(except) => {
612 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
613 build_scope_impl(&except.left, &mut left_scope);
614
615 let mut right_scope =
616 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
617 build_scope_impl(&except.right, &mut right_scope);
618
619 current_scope.union_scopes.push(left_scope);
620 current_scope.union_scopes.push(right_scope);
621 }
622 Expression::CreateTable(create) => {
623 if let Some(with) = &create.with_cte {
626 for cte in &with.ctes {
627 let cte_name = cte.alias.name.clone();
628 let mut cte_scope = current_scope
629 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
630 build_scope_impl(&cte.this, &mut cte_scope);
631 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
632 current_scope.cte_scopes.push(cte_scope);
633 }
634 }
635 if let Some(as_select) = &create.as_select {
637 build_scope_impl(as_select, current_scope);
638 }
639 }
640 _ => {}
641 }
642}
643
644fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
645 match expr {
646 Expression::Table(table) => {
647 let name = table
648 .alias
649 .as_ref()
650 .map(|a| a.name.clone())
651 .unwrap_or_else(|| table.name.name.clone());
652 let cte_source = if table.schema.is_none() && table.catalog.is_none() {
653 scope.cte_sources.get(&table.name.name).or_else(|| {
654 scope
655 .cte_sources
656 .iter()
657 .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
658 .map(|(_, source)| source)
659 })
660 } else {
661 None
662 };
663
664 if let Some(source) = cte_source {
665 scope.add_source(name, source.expression.clone(), true);
666 } else {
667 scope.add_source(name, expr.clone(), false);
668 }
669 }
670 Expression::Subquery(subquery) => {
671 let name = subquery
672 .alias
673 .as_ref()
674 .map(|a| a.name.clone())
675 .unwrap_or_default();
676
677 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
678 build_scope_impl(&subquery.this, &mut derived_scope);
679
680 scope.add_source(name.clone(), expr.clone(), true);
681 scope.derived_table_scopes.push(derived_scope);
682 }
683 Expression::Unnest(unnest) => {
684 if let Some(alias) = &unnest.alias {
685 scope.add_source(alias.name.clone(), expr.clone(), false);
686 }
687 }
688 Expression::Alias(alias) if matches!(&alias.this, Expression::Unnest(_)) => {
689 scope.add_source(alias.alias.name.clone(), expr.clone(), false);
690 }
691 Expression::Paren(paren) => {
692 add_table_to_scope(&paren.this, scope);
693 }
694 _ => {}
695 }
696}
697
698fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
699 match expr {
700 Expression::Select(select) => {
701 if let Some(where_clause) = &select.where_clause {
703 collect_subqueries_in_expr(&where_clause.this, parent_scope);
704 }
705 for e in &select.expressions {
707 collect_subqueries_in_expr(e, parent_scope);
708 }
709 if let Some(having) = &select.having {
711 collect_subqueries_in_expr(&having.this, parent_scope);
712 }
713 }
714 _ => {}
715 }
716}
717
718fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
719 match expr {
720 Expression::Subquery(subquery) if subquery.alias.is_none() => {
721 let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
723 build_scope_impl(&subquery.this, &mut sub_scope);
724 parent_scope.subquery_scopes.push(sub_scope);
725 }
726 Expression::In(in_expr) => {
727 collect_subqueries_in_expr(&in_expr.this, parent_scope);
728 if let Some(query) = &in_expr.query {
729 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
730 build_scope_impl(query, &mut sub_scope);
731 parent_scope.subquery_scopes.push(sub_scope);
732 }
733 }
734 Expression::Exists(exists) => {
735 let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
736 build_scope_impl(&exists.this, &mut sub_scope);
737 parent_scope.subquery_scopes.push(sub_scope);
738 }
739 Expression::And(bin)
741 | Expression::Or(bin)
742 | Expression::Add(bin)
743 | Expression::Sub(bin)
744 | Expression::Mul(bin)
745 | Expression::Div(bin)
746 | Expression::Mod(bin)
747 | Expression::Eq(bin)
748 | Expression::Neq(bin)
749 | Expression::Lt(bin)
750 | Expression::Lte(bin)
751 | Expression::Gt(bin)
752 | Expression::Gte(bin)
753 | Expression::BitwiseAnd(bin)
754 | Expression::BitwiseOr(bin)
755 | Expression::BitwiseXor(bin)
756 | Expression::Concat(bin) => {
757 collect_subqueries_in_expr(&bin.left, parent_scope);
758 collect_subqueries_in_expr(&bin.right, parent_scope);
759 }
760 Expression::Like(like) | Expression::ILike(like) => {
762 collect_subqueries_in_expr(&like.left, parent_scope);
763 collect_subqueries_in_expr(&like.right, parent_scope);
764 if let Some(escape) = &like.escape {
765 collect_subqueries_in_expr(escape, parent_scope);
766 }
767 }
768 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
770 collect_subqueries_in_expr(&un.this, parent_scope);
771 }
772 Expression::Function(func) => {
773 for arg in &func.args {
774 collect_subqueries_in_expr(arg, parent_scope);
775 }
776 }
777 Expression::Case(case) => {
778 if let Some(operand) = &case.operand {
779 collect_subqueries_in_expr(operand, parent_scope);
780 }
781 for (when_expr, then_expr) in &case.whens {
782 collect_subqueries_in_expr(when_expr, parent_scope);
783 collect_subqueries_in_expr(then_expr, parent_scope);
784 }
785 if let Some(else_clause) = &case.else_ {
786 collect_subqueries_in_expr(else_clause, parent_scope);
787 }
788 }
789 Expression::Paren(paren) => {
790 collect_subqueries_in_expr(&paren.this, parent_scope);
791 }
792 Expression::Alias(alias) => {
793 collect_subqueries_in_expr(&alias.this, parent_scope);
794 }
795 _ => {}
796 }
797}
798
799pub fn walk_in_scope<'a>(
811 expression: &'a Expression,
812 bfs: bool,
813) -> impl Iterator<Item = &'a Expression> {
814 WalkInScopeIter::new(expression, bfs)
815}
816
817struct WalkInScopeIter<'a> {
819 queue: VecDeque<&'a Expression>,
820 bfs: bool,
821}
822
823impl<'a> WalkInScopeIter<'a> {
824 fn new(expression: &'a Expression, bfs: bool) -> Self {
825 let mut queue = VecDeque::new();
826 queue.push_back(expression);
827 Self { queue, bfs }
828 }
829
830 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
831 if is_root {
832 return false;
833 }
834
835 if matches!(expr, Expression::Cte(_)) {
837 return true;
838 }
839
840 if let Expression::Subquery(subquery) = expr {
842 if subquery.alias.is_some() {
843 return true;
844 }
845 }
846
847 if matches!(
849 expr,
850 Expression::Select(_)
851 | Expression::Union(_)
852 | Expression::Intersect(_)
853 | Expression::Except(_)
854 ) {
855 return true;
856 }
857
858 false
859 }
860
861 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
862 let mut children = Vec::new();
863
864 match expr {
865 Expression::Prepare(prepare) => {
866 children.push(&prepare.statement);
867 }
868 Expression::Select(select) => {
869 for e in &select.expressions {
871 children.push(e);
872 }
873 if let Some(from) = &select.from {
875 for table in &from.expressions {
876 if !self.should_stop_at(table, false) {
877 children.push(table);
878 }
879 }
880 }
881 for join in &select.joins {
883 if let Some(on) = &join.on {
884 children.push(on);
885 }
886 }
888 if let Some(where_clause) = &select.where_clause {
890 children.push(&where_clause.this);
891 }
892 if let Some(group_by) = &select.group_by {
894 for e in &group_by.expressions {
895 children.push(e);
896 }
897 }
898 if let Some(having) = &select.having {
900 children.push(&having.this);
901 }
902 if let Some(order_by) = &select.order_by {
904 for ord in &order_by.expressions {
905 children.push(&ord.this);
906 }
907 }
908 if let Some(limit) = &select.limit {
910 children.push(&limit.this);
911 }
912 if let Some(offset) = &select.offset {
914 children.push(&offset.this);
915 }
916 }
917 Expression::And(bin)
918 | Expression::Or(bin)
919 | Expression::Add(bin)
920 | Expression::Sub(bin)
921 | Expression::Mul(bin)
922 | Expression::Div(bin)
923 | Expression::Mod(bin)
924 | Expression::Eq(bin)
925 | Expression::Neq(bin)
926 | Expression::Lt(bin)
927 | Expression::Lte(bin)
928 | Expression::Gt(bin)
929 | Expression::Gte(bin)
930 | Expression::BitwiseAnd(bin)
931 | Expression::BitwiseOr(bin)
932 | Expression::BitwiseXor(bin)
933 | Expression::Concat(bin) => {
934 children.push(&bin.left);
935 children.push(&bin.right);
936 }
937 Expression::Like(like) | Expression::ILike(like) => {
938 children.push(&like.left);
939 children.push(&like.right);
940 if let Some(escape) = &like.escape {
941 children.push(escape);
942 }
943 }
944 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
945 children.push(&un.this);
946 }
947 Expression::Function(func) => {
948 for arg in &func.args {
949 children.push(arg);
950 }
951 }
952 Expression::AggregateFunction(agg) => {
953 for arg in &agg.args {
954 children.push(arg);
955 }
956 }
957 Expression::WindowFunction(wf) => {
958 children.push(&wf.this);
959 for e in &wf.over.partition_by {
960 children.push(e);
961 }
962 for e in &wf.over.order_by {
963 children.push(&e.this);
964 }
965 }
966 Expression::Alias(alias) => {
967 children.push(&alias.this);
968 }
969 Expression::Case(case) => {
970 if let Some(operand) = &case.operand {
971 children.push(operand);
972 }
973 for (when_expr, then_expr) in &case.whens {
974 children.push(when_expr);
975 children.push(then_expr);
976 }
977 if let Some(else_clause) = &case.else_ {
978 children.push(else_clause);
979 }
980 }
981 Expression::Paren(paren) => {
982 children.push(&paren.this);
983 }
984 Expression::Ordered(ord) => {
985 children.push(&ord.this);
986 }
987 Expression::In(in_expr) => {
988 children.push(&in_expr.this);
989 for e in &in_expr.expressions {
990 children.push(e);
991 }
992 }
994 Expression::Between(between) => {
995 children.push(&between.this);
996 children.push(&between.low);
997 children.push(&between.high);
998 }
999 Expression::IsNull(is_null) => {
1000 children.push(&is_null.this);
1001 }
1002 Expression::Cast(cast) => {
1003 children.push(&cast.this);
1004 }
1005 Expression::Extract(extract) => {
1006 children.push(&extract.this);
1007 }
1008 Expression::Coalesce(coalesce) => {
1009 for e in &coalesce.expressions {
1010 children.push(e);
1011 }
1012 }
1013 Expression::NullIf(nullif) => {
1014 children.push(&nullif.this);
1015 children.push(&nullif.expression);
1016 }
1017 Expression::Table(_table) => {
1018 }
1021 Expression::TryCatch(try_catch) => {
1022 for stmt in &try_catch.try_body {
1023 children.push(stmt);
1024 }
1025 if let Some(catch_body) = &try_catch.catch_body {
1026 for stmt in catch_body {
1027 children.push(stmt);
1028 }
1029 }
1030 }
1031 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
1032 }
1034 Expression::Subquery(_) | Expression::Exists(_) => {}
1036 _ => {
1037 }
1039 }
1040
1041 children
1042 }
1043}
1044
1045impl<'a> Iterator for WalkInScopeIter<'a> {
1046 type Item = &'a Expression;
1047
1048 fn next(&mut self) -> Option<Self::Item> {
1049 let expr = if self.bfs {
1050 self.queue.pop_front()?
1051 } else {
1052 self.queue.pop_back()?
1053 };
1054
1055 let children = self.get_children(expr);
1057
1058 if self.bfs {
1059 for child in children {
1060 if !self.should_stop_at(child, false) {
1061 self.queue.push_back(child);
1062 }
1063 }
1064 } else {
1065 for child in children.into_iter().rev() {
1066 if !self.should_stop_at(child, false) {
1067 self.queue.push_back(child);
1068 }
1069 }
1070 }
1071
1072 Some(expr)
1073 }
1074}
1075
1076pub fn find_in_scope<'a, F>(
1088 expression: &'a Expression,
1089 predicate: F,
1090 bfs: bool,
1091) -> Option<&'a Expression>
1092where
1093 F: Fn(&Expression) -> bool,
1094{
1095 walk_in_scope(expression, bfs).find(|e| predicate(e))
1096}
1097
1098pub fn find_all_in_scope<'a, F>(
1110 expression: &'a Expression,
1111 predicate: F,
1112 bfs: bool,
1113) -> Vec<&'a Expression>
1114where
1115 F: Fn(&Expression) -> bool,
1116{
1117 walk_in_scope(expression, bfs)
1118 .filter(|e| predicate(e))
1119 .collect()
1120}
1121
1122pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1132 match expression {
1133 Expression::Select(_)
1134 | Expression::Union(_)
1135 | Expression::Intersect(_)
1136 | Expression::Except(_)
1137 | Expression::Prepare(_)
1138 | Expression::CreateTable(_) => {
1139 let root = build_scope(expression);
1140 root.traverse().into_iter().cloned().collect()
1141 }
1142 _ => Vec::new(),
1143 }
1144}
1145
1146#[cfg(test)]
1147mod tests {
1148 use super::*;
1149 use crate::parser::Parser;
1150
1151 fn parse_and_build_scope(sql: &str) -> Scope {
1152 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1153 build_scope(&ast[0])
1154 }
1155
1156 #[test]
1157 fn test_simple_select_scope() {
1158 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1159
1160 assert!(scope.is_root());
1161 assert!(!scope.can_be_correlated);
1162 assert!(scope.sources.contains_key("t"));
1163
1164 let columns = scope.columns();
1165 assert_eq!(columns.len(), 2);
1166 }
1167
1168 #[test]
1169 fn test_derived_table_scope() {
1170 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1171
1172 assert!(scope.sources.contains_key("x"));
1173 assert_eq!(scope.derived_table_scopes.len(), 1);
1174
1175 let derived = &mut scope.derived_table_scopes[0];
1176 assert!(derived.is_derived_table());
1177 assert!(derived.sources.contains_key("t"));
1178 }
1179
1180 #[test]
1181 fn test_non_correlated_subquery() {
1182 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1183
1184 assert_eq!(scope.subquery_scopes.len(), 1);
1185
1186 let subquery = &mut scope.subquery_scopes[0];
1187 assert!(subquery.is_subquery());
1188 assert!(subquery.can_be_correlated);
1189
1190 assert!(subquery.sources.contains_key("s"));
1192 assert!(!subquery.is_correlated_subquery());
1193 }
1194
1195 #[test]
1196 fn test_correlated_subquery() {
1197 let mut scope =
1198 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1199
1200 assert_eq!(scope.subquery_scopes.len(), 1);
1201
1202 let subquery = &mut scope.subquery_scopes[0];
1203 assert!(subquery.is_subquery());
1204 assert!(subquery.can_be_correlated);
1205
1206 let external = subquery.external_columns();
1208 assert!(!external.is_empty());
1209 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1210 assert!(subquery.is_correlated_subquery());
1211 }
1212
1213 #[test]
1214 fn test_cte_scope() {
1215 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1216
1217 assert_eq!(scope.cte_scopes.len(), 1);
1218 assert!(scope.cte_sources.contains_key("cte"));
1219
1220 let cte = &scope.cte_scopes[0];
1221 assert!(cte.is_cte());
1222 }
1223
1224 #[test]
1225 fn test_multiple_sources() {
1226 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1227
1228 assert!(scope.sources.contains_key("t"));
1229 assert!(scope.sources.contains_key("s"));
1230 assert_eq!(scope.sources.len(), 2);
1231 }
1232
1233 #[test]
1234 fn test_aliased_table() {
1235 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1236
1237 assert!(scope.sources.contains_key("x"));
1239 assert!(!scope.sources.contains_key("t"));
1240 }
1241
1242 #[test]
1243 fn test_local_columns() {
1244 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1245
1246 let local = scope.local_columns();
1247 assert_eq!(local.len(), 5);
1250 assert!(local.iter().all(|c| c.table.is_some()));
1251 }
1252
1253 #[test]
1254 fn test_columns_include_join_on_clause_references() {
1255 let mut scope = parse_and_build_scope(
1256 "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1257 );
1258
1259 let cols: Vec<String> = scope
1260 .columns()
1261 .iter()
1262 .map(|c| match &c.table {
1263 Some(t) => format!("{}.{}", t, c.name),
1264 None => c.name.clone(),
1265 })
1266 .collect();
1267
1268 assert!(cols.contains(&"o.total".to_string()));
1269 assert!(cols.contains(&"c.id".to_string()));
1270 assert!(cols.contains(&"o.customer_id".to_string()));
1271 }
1272
1273 #[test]
1274 fn test_unqualified_columns() {
1275 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1276
1277 let unqualified = scope.unqualified_columns();
1278 assert_eq!(unqualified.len(), 2);
1280 assert!(unqualified.iter().all(|c| c.table.is_none()));
1281 }
1282
1283 #[test]
1284 fn test_source_columns() {
1285 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1286
1287 let t_cols = scope.source_columns("t");
1288 assert!(t_cols.len() >= 2);
1290 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1291
1292 let s_cols = scope.source_columns("s");
1293 assert!(s_cols.len() >= 1);
1295 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1296 }
1297
1298 #[test]
1299 fn test_rename_source() {
1300 let mut scope = parse_and_build_scope("SELECT a FROM t");
1301
1302 assert!(scope.sources.contains_key("t"));
1303 scope.rename_source("t", "new_name".to_string());
1304 assert!(!scope.sources.contains_key("t"));
1305 assert!(scope.sources.contains_key("new_name"));
1306 }
1307
1308 #[test]
1309 fn test_remove_source() {
1310 let mut scope = parse_and_build_scope("SELECT a FROM t");
1311
1312 assert!(scope.sources.contains_key("t"));
1313 scope.remove_source("t");
1314 assert!(!scope.sources.contains_key("t"));
1315 }
1316
1317 #[test]
1318 fn test_walk_in_scope() {
1319 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1320 let expr = &ast[0];
1321
1322 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1324 assert!(!walked.is_empty());
1325
1326 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1328 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1330 }
1331
1332 #[test]
1333 fn test_find_in_scope() {
1334 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1335 let expr = &ast[0];
1336
1337 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1339 assert!(found.is_some());
1340 assert!(matches!(found.unwrap(), Expression::Column(_)));
1341 }
1342
1343 #[test]
1344 fn test_find_all_in_scope() {
1345 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1346 let expr = &ast[0];
1347
1348 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1350 assert_eq!(found.len(), 3);
1351 }
1352
1353 #[test]
1354 fn test_traverse_scope() {
1355 let ast =
1356 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1357 let expr = &ast[0];
1358
1359 let scopes = traverse_scope(expr);
1360 assert!(!scopes.is_empty());
1363 assert!(scopes.iter().any(|s| s.is_root()));
1365 }
1366
1367 #[test]
1368 fn test_branch_with_options() {
1369 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1370 let scope = build_scope(&ast[0]);
1371
1372 let child = scope.branch_with_options(
1373 ast[0].clone(),
1374 ScopeType::Subquery, None,
1376 None,
1377 Some(vec!["col1".to_string(), "col2".to_string()]),
1378 );
1379
1380 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1381 assert!(child.can_be_correlated); }
1383
1384 #[test]
1385 fn test_is_udtf() {
1386 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1387 let scope = Scope::new(ast[0].clone());
1388 assert!(!scope.is_udtf());
1389
1390 let root = build_scope(&ast[0]);
1391 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1392 assert!(udtf_scope.is_udtf());
1393 }
1394
1395 #[test]
1396 fn test_is_union() {
1397 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1398
1399 assert!(scope.is_root());
1400 assert_eq!(scope.union_scopes.len(), 2);
1401 assert!(scope.union_scopes[0].is_union());
1403 assert!(scope.union_scopes[1].is_union());
1404 }
1405
1406 #[test]
1407 fn test_union_output_columns() {
1408 let scope = parse_and_build_scope(
1409 "SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
1410 );
1411 assert_eq!(scope.output_columns(), vec!["id", "name"]);
1412 }
1413
1414 #[test]
1415 fn test_clear_cache() {
1416 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1417
1418 let _ = scope.columns();
1420 assert!(scope.columns_cache.is_some());
1421
1422 scope.clear_cache();
1424 assert!(scope.columns_cache.is_none());
1425 assert!(scope.external_columns_cache.is_none());
1426 }
1427
1428 #[test]
1429 fn test_scope_traverse() {
1430 let scope = parse_and_build_scope(
1431 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1432 );
1433
1434 let traversed = scope.traverse();
1435 assert!(traversed.len() >= 3);
1437 }
1438
1439 #[test]
1440 fn test_create_table_as_select_scope() {
1441 let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
1443 assert!(
1444 scope.sources.contains_key("src"),
1445 "CTAS scope should contain the FROM table"
1446 );
1447 assert!(
1448 !scope.sources.contains_key("out_table"),
1449 "CTAS target table should not be treated as a source"
1450 );
1451
1452 let scope = parse_and_build_scope(
1454 "CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
1455 );
1456 assert!(scope.sources.contains_key("a"));
1457 assert!(scope.sources.contains_key("b"));
1458 assert!(
1459 !scope.sources.contains_key("out_table"),
1460 "CTAS target table should not be treated as a source"
1461 );
1462
1463 let scope = parse_and_build_scope(
1465 "CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
1466 );
1467 assert!(
1468 scope.sources.contains_key("cte"),
1469 "CTAS with CTE should resolve CTE as source"
1470 );
1471 assert!(
1472 !scope.sources.contains_key("out_table"),
1473 "CTAS target table should not be treated as a source"
1474 );
1475 assert_eq!(scope.cte_scopes.len(), 1);
1476 }
1477
1478 #[test]
1479 fn test_create_table_as_select_traverse() {
1480 let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
1481 let scopes = traverse_scope(&ast[0]);
1482 assert!(
1483 !scopes.is_empty(),
1484 "traverse_scope should return scopes for CTAS"
1485 );
1486 }
1487}