1use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19 Root,
21 Subquery,
23 DerivedTable,
25 Cte,
27 SetOperation,
29 Udtf,
31}
32
33#[derive(Debug, Clone)]
35pub struct SourceInfo {
36 pub expression: Expression,
38 pub is_scope: bool,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45 pub table: Option<String>,
47 pub name: String,
49}
50
51#[derive(Debug, Clone)]
56pub struct Scope {
57 pub expression: Expression,
59
60 pub scope_type: ScopeType,
62
63 pub sources: HashMap<String, SourceInfo>,
65
66 pub lateral_sources: HashMap<String, SourceInfo>,
68
69 pub cte_sources: HashMap<String, SourceInfo>,
71
72 pub outer_columns: Vec<String>,
75
76 pub can_be_correlated: bool,
79
80 pub subquery_scopes: Vec<Scope>,
82
83 pub derived_table_scopes: Vec<Scope>,
85
86 pub cte_scopes: Vec<Scope>,
88
89 pub udtf_scopes: Vec<Scope>,
91
92 pub table_scopes: Vec<Scope>,
94
95 pub union_scopes: Vec<Scope>,
97
98 columns_cache: Option<Vec<ColumnRef>>,
100
101 external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106 pub fn new(expression: Expression) -> Self {
108 Self {
109 expression,
110 scope_type: ScopeType::Root,
111 sources: HashMap::new(),
112 lateral_sources: HashMap::new(),
113 cte_sources: HashMap::new(),
114 outer_columns: Vec::new(),
115 can_be_correlated: false,
116 subquery_scopes: Vec::new(),
117 derived_table_scopes: Vec::new(),
118 cte_scopes: Vec::new(),
119 udtf_scopes: Vec::new(),
120 table_scopes: Vec::new(),
121 union_scopes: Vec::new(),
122 columns_cache: None,
123 external_columns_cache: None,
124 }
125 }
126
127 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129 self.branch_with_options(expression, scope_type, None, None, None)
130 }
131
132 pub fn branch_with_options(
134 &self,
135 expression: Expression,
136 scope_type: ScopeType,
137 sources: Option<HashMap<String, SourceInfo>>,
138 lateral_sources: Option<HashMap<String, SourceInfo>>,
139 outer_columns: Option<Vec<String>>,
140 ) -> Self {
141 let can_be_correlated = self.can_be_correlated
142 || scope_type == ScopeType::Subquery
143 || scope_type == ScopeType::Udtf;
144
145 Self {
146 expression,
147 scope_type,
148 sources: sources.unwrap_or_default(),
149 lateral_sources: lateral_sources.unwrap_or_default(),
150 cte_sources: self.cte_sources.clone(),
151 outer_columns: outer_columns.unwrap_or_default(),
152 can_be_correlated,
153 subquery_scopes: Vec::new(),
154 derived_table_scopes: Vec::new(),
155 cte_scopes: Vec::new(),
156 udtf_scopes: Vec::new(),
157 table_scopes: Vec::new(),
158 union_scopes: Vec::new(),
159 columns_cache: None,
160 external_columns_cache: None,
161 }
162 }
163
164 pub fn clear_cache(&mut self) {
166 self.columns_cache = None;
167 self.external_columns_cache = None;
168 }
169
170 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172 self.sources
173 .insert(name, SourceInfo { expression, is_scope });
174 self.clear_cache();
175 }
176
177 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
179 self.lateral_sources.insert(
180 name.clone(),
181 SourceInfo {
182 expression: expression.clone(),
183 is_scope,
184 },
185 );
186 self.sources.insert(name, SourceInfo { expression, is_scope });
187 self.clear_cache();
188 }
189
190 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
192 self.cte_sources.insert(
193 name.clone(),
194 SourceInfo {
195 expression: expression.clone(),
196 is_scope: true,
197 },
198 );
199 self.sources.insert(
200 name,
201 SourceInfo {
202 expression,
203 is_scope: true,
204 },
205 );
206 self.clear_cache();
207 }
208
209 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
211 if let Some(source) = self.sources.remove(old_name) {
212 self.sources.insert(new_name, source);
213 }
214 self.clear_cache();
215 }
216
217 pub fn remove_source(&mut self, name: &str) {
219 self.sources.remove(name);
220 self.clear_cache();
221 }
222
223 pub fn columns(&mut self) -> &[ColumnRef] {
225 if self.columns_cache.is_none() {
226 let mut columns = Vec::new();
227 collect_columns(&self.expression, &mut columns);
228 self.columns_cache = Some(columns);
229 }
230 self.columns_cache.as_ref().unwrap()
231 }
232
233 pub fn source_names(&self) -> HashSet<String> {
235 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
236 names.extend(self.cte_sources.keys().cloned());
237 names
238 }
239
240 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
242 if self.external_columns_cache.is_some() {
243 return self.external_columns_cache.clone().unwrap();
244 }
245
246 let source_names = self.source_names();
247 let columns = self.columns().to_vec();
248
249 let external: Vec<ColumnRef> = columns
250 .into_iter()
251 .filter(|col| {
252 match &col.table {
254 Some(table) => !source_names.contains(table),
255 None => false, }
257 })
258 .collect();
259
260 self.external_columns_cache = Some(external.clone());
261 external
262 }
263
264 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
266 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
267 let columns = self.columns().to_vec();
268
269 columns
270 .into_iter()
271 .filter(|col| !external_set.contains(col))
272 .collect()
273 }
274
275 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
277 self.columns()
278 .iter()
279 .filter(|c| c.table.is_none())
280 .cloned()
281 .collect()
282 }
283
284 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
286 self.columns()
287 .iter()
288 .filter(|col| col.table.as_deref() == Some(source_name))
289 .cloned()
290 .collect()
291 }
292
293 pub fn is_correlated_subquery(&mut self) -> bool {
299 self.can_be_correlated && !self.external_columns().is_empty()
300 }
301
302 pub fn is_subquery(&self) -> bool {
304 self.scope_type == ScopeType::Subquery
305 }
306
307 pub fn is_derived_table(&self) -> bool {
309 self.scope_type == ScopeType::DerivedTable
310 }
311
312 pub fn is_cte(&self) -> bool {
314 self.scope_type == ScopeType::Cte
315 }
316
317 pub fn is_root(&self) -> bool {
319 self.scope_type == ScopeType::Root
320 }
321
322 pub fn is_udtf(&self) -> bool {
324 self.scope_type == ScopeType::Udtf
325 }
326
327 pub fn is_union(&self) -> bool {
329 self.scope_type == ScopeType::SetOperation
330 }
331
332 pub fn traverse(&self) -> Vec<&Scope> {
334 let mut result = Vec::new();
335 self.traverse_impl(&mut result);
336 result
337 }
338
339 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
340 for scope in &self.cte_scopes {
342 scope.traverse_impl(result);
343 }
344 for scope in &self.union_scopes {
345 scope.traverse_impl(result);
346 }
347 for scope in &self.table_scopes {
348 scope.traverse_impl(result);
349 }
350 for scope in &self.subquery_scopes {
351 scope.traverse_impl(result);
352 }
353 result.push(self);
355 }
356
357 pub fn ref_count(&self) -> HashMap<usize, usize> {
359 let mut counts: HashMap<usize, usize> = HashMap::new();
360
361 for scope in self.traverse() {
362 for (_, source_info) in scope.sources.iter() {
363 if source_info.is_scope {
364 let id = &source_info.expression as *const _ as usize;
365 *counts.entry(id).or_insert(0) += 1;
366 }
367 }
368 }
369
370 counts
371 }
372}
373
374fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
376 match expr {
377 Expression::Column(col) => {
378 columns.push(ColumnRef {
379 table: col.table.as_ref().map(|t| t.name.clone()),
380 name: col.name.name.clone(),
381 });
382 }
383 Expression::Select(select) => {
384 for e in &select.expressions {
386 collect_columns(e, columns);
387 }
388 if let Some(where_clause) = &select.where_clause {
390 collect_columns(&where_clause.this, columns);
391 }
392 if let Some(having) = &select.having {
394 collect_columns(&having.this, columns);
395 }
396 if let Some(order_by) = &select.order_by {
398 for ord in &order_by.expressions {
399 collect_columns(&ord.this, columns);
400 }
401 }
402 if let Some(group_by) = &select.group_by {
404 for e in &group_by.expressions {
405 collect_columns(e, columns);
406 }
407 }
408 }
411 Expression::And(bin) | Expression::Or(bin) |
413 Expression::Add(bin) | Expression::Sub(bin) |
414 Expression::Mul(bin) | Expression::Div(bin) |
415 Expression::Mod(bin) | Expression::Eq(bin) |
416 Expression::Neq(bin) | Expression::Lt(bin) |
417 Expression::Lte(bin) | Expression::Gt(bin) |
418 Expression::Gte(bin) | Expression::BitwiseAnd(bin) |
419 Expression::BitwiseOr(bin) | Expression::BitwiseXor(bin) |
420 Expression::Concat(bin) => {
421 collect_columns(&bin.left, columns);
422 collect_columns(&bin.right, columns);
423 }
424 Expression::Like(like) | Expression::ILike(like) => {
426 collect_columns(&like.left, columns);
427 collect_columns(&like.right, columns);
428 if let Some(escape) = &like.escape {
429 collect_columns(escape, columns);
430 }
431 }
432 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
434 collect_columns(&un.this, columns);
435 }
436 Expression::Function(func) => {
437 for arg in &func.args {
438 collect_columns(arg, columns);
439 }
440 }
441 Expression::AggregateFunction(agg) => {
442 for arg in &agg.args {
443 collect_columns(arg, columns);
444 }
445 }
446 Expression::WindowFunction(wf) => {
447 collect_columns(&wf.this, columns);
448 for e in &wf.over.partition_by {
449 collect_columns(e, columns);
450 }
451 for e in &wf.over.order_by {
452 collect_columns(&e.this, columns);
453 }
454 }
455 Expression::Alias(alias) => {
456 collect_columns(&alias.this, columns);
457 }
458 Expression::Case(case) => {
459 if let Some(operand) = &case.operand {
460 collect_columns(operand, columns);
461 }
462 for (when_expr, then_expr) in &case.whens {
463 collect_columns(when_expr, columns);
464 collect_columns(then_expr, columns);
465 }
466 if let Some(else_clause) = &case.else_ {
467 collect_columns(else_clause, columns);
468 }
469 }
470 Expression::Paren(paren) => {
471 collect_columns(&paren.this, columns);
472 }
473 Expression::Ordered(ord) => {
474 collect_columns(&ord.this, columns);
475 }
476 Expression::In(in_expr) => {
477 collect_columns(&in_expr.this, columns);
478 for e in &in_expr.expressions {
479 collect_columns(e, columns);
480 }
481 }
483 Expression::Between(between) => {
484 collect_columns(&between.this, columns);
485 collect_columns(&between.low, columns);
486 collect_columns(&between.high, columns);
487 }
488 Expression::IsNull(is_null) => {
489 collect_columns(&is_null.this, columns);
490 }
491 Expression::Cast(cast) => {
492 collect_columns(&cast.this, columns);
493 }
494 Expression::Extract(extract) => {
495 collect_columns(&extract.this, columns);
496 }
497 Expression::Exists(_) | Expression::Subquery(_) => {
498 }
500 _ => {
501 }
503 }
504}
505
506pub fn build_scope(expression: &Expression) -> Scope {
511 let mut root = Scope::new(expression.clone());
512 build_scope_impl(expression, &mut root);
513 root
514}
515
516fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
517 match expression {
518 Expression::Select(select) => {
519 if let Some(with) = &select.with {
521 for cte in &with.ctes {
522 let cte_name = cte.alias.name.clone();
523 let mut cte_scope = current_scope.branch(
524 Expression::Cte(Box::new(cte.clone())),
525 ScopeType::Cte,
526 );
527 build_scope_impl(&cte.this, &mut cte_scope);
528 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
529 current_scope.cte_scopes.push(cte_scope);
530 }
531 }
532
533 if let Some(from) = &select.from {
535 for table in &from.expressions {
536 add_table_to_scope(table, current_scope);
537 }
538 }
539
540 for join in &select.joins {
542 add_table_to_scope(&join.this, current_scope);
543 }
544
545 collect_subqueries(expression, current_scope);
547 }
548 Expression::Union(union) => {
549 let mut left_scope = current_scope.branch(
550 union.left.clone(),
551 ScopeType::SetOperation,
552 );
553 build_scope_impl(&union.left, &mut left_scope);
554
555 let mut right_scope = current_scope.branch(
556 union.right.clone(),
557 ScopeType::SetOperation,
558 );
559 build_scope_impl(&union.right, &mut right_scope);
560
561 current_scope.union_scopes.push(left_scope);
562 current_scope.union_scopes.push(right_scope);
563 }
564 Expression::Intersect(intersect) => {
565 let mut left_scope = current_scope.branch(
566 intersect.left.clone(),
567 ScopeType::SetOperation,
568 );
569 build_scope_impl(&intersect.left, &mut left_scope);
570
571 let mut right_scope = current_scope.branch(
572 intersect.right.clone(),
573 ScopeType::SetOperation,
574 );
575 build_scope_impl(&intersect.right, &mut right_scope);
576
577 current_scope.union_scopes.push(left_scope);
578 current_scope.union_scopes.push(right_scope);
579 }
580 Expression::Except(except) => {
581 let mut left_scope = current_scope.branch(
582 except.left.clone(),
583 ScopeType::SetOperation,
584 );
585 build_scope_impl(&except.left, &mut left_scope);
586
587 let mut right_scope = current_scope.branch(
588 except.right.clone(),
589 ScopeType::SetOperation,
590 );
591 build_scope_impl(&except.right, &mut right_scope);
592
593 current_scope.union_scopes.push(left_scope);
594 current_scope.union_scopes.push(right_scope);
595 }
596 _ => {}
597 }
598}
599
600fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
601 match expr {
602 Expression::Table(table) => {
603 let name = table.alias.as_ref()
604 .map(|a| a.name.clone())
605 .unwrap_or_else(|| table.name.name.clone());
606 scope.add_source(name, expr.clone(), false);
607 }
608 Expression::Subquery(subquery) => {
609 let name = subquery.alias.as_ref()
610 .map(|a| a.name.clone())
611 .unwrap_or_default();
612
613 let mut derived_scope = scope.branch(
614 subquery.this.clone(),
615 ScopeType::DerivedTable,
616 );
617 build_scope_impl(&subquery.this, &mut derived_scope);
618
619 scope.add_source(name.clone(), expr.clone(), true);
620 scope.derived_table_scopes.push(derived_scope);
621 }
622 Expression::Paren(paren) => {
623 add_table_to_scope(&paren.this, scope);
624 }
625 _ => {}
626 }
627}
628
629fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
630 match expr {
631 Expression::Select(select) => {
632 if let Some(where_clause) = &select.where_clause {
634 collect_subqueries_in_expr(&where_clause.this, parent_scope);
635 }
636 for e in &select.expressions {
638 collect_subqueries_in_expr(e, parent_scope);
639 }
640 if let Some(having) = &select.having {
642 collect_subqueries_in_expr(&having.this, parent_scope);
643 }
644 }
645 _ => {}
646 }
647}
648
649fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
650 match expr {
651 Expression::Subquery(subquery) if subquery.alias.is_none() => {
652 let mut sub_scope = parent_scope.branch(
654 subquery.this.clone(),
655 ScopeType::Subquery,
656 );
657 build_scope_impl(&subquery.this, &mut sub_scope);
658 parent_scope.subquery_scopes.push(sub_scope);
659 }
660 Expression::In(in_expr) => {
661 collect_subqueries_in_expr(&in_expr.this, parent_scope);
662 if let Some(query) = &in_expr.query {
663 let mut sub_scope = parent_scope.branch(
664 query.clone(),
665 ScopeType::Subquery,
666 );
667 build_scope_impl(query, &mut sub_scope);
668 parent_scope.subquery_scopes.push(sub_scope);
669 }
670 }
671 Expression::Exists(exists) => {
672 let mut sub_scope = parent_scope.branch(
673 exists.this.clone(),
674 ScopeType::Subquery,
675 );
676 build_scope_impl(&exists.this, &mut sub_scope);
677 parent_scope.subquery_scopes.push(sub_scope);
678 }
679 Expression::And(bin) | Expression::Or(bin) |
681 Expression::Add(bin) | Expression::Sub(bin) |
682 Expression::Mul(bin) | Expression::Div(bin) |
683 Expression::Mod(bin) | Expression::Eq(bin) |
684 Expression::Neq(bin) | Expression::Lt(bin) |
685 Expression::Lte(bin) | Expression::Gt(bin) |
686 Expression::Gte(bin) | Expression::BitwiseAnd(bin) |
687 Expression::BitwiseOr(bin) | Expression::BitwiseXor(bin) |
688 Expression::Concat(bin) => {
689 collect_subqueries_in_expr(&bin.left, parent_scope);
690 collect_subqueries_in_expr(&bin.right, parent_scope);
691 }
692 Expression::Like(like) | Expression::ILike(like) => {
694 collect_subqueries_in_expr(&like.left, parent_scope);
695 collect_subqueries_in_expr(&like.right, parent_scope);
696 if let Some(escape) = &like.escape {
697 collect_subqueries_in_expr(escape, parent_scope);
698 }
699 }
700 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
702 collect_subqueries_in_expr(&un.this, parent_scope);
703 }
704 Expression::Function(func) => {
705 for arg in &func.args {
706 collect_subqueries_in_expr(arg, parent_scope);
707 }
708 }
709 Expression::Case(case) => {
710 if let Some(operand) = &case.operand {
711 collect_subqueries_in_expr(operand, parent_scope);
712 }
713 for (when_expr, then_expr) in &case.whens {
714 collect_subqueries_in_expr(when_expr, parent_scope);
715 collect_subqueries_in_expr(then_expr, parent_scope);
716 }
717 if let Some(else_clause) = &case.else_ {
718 collect_subqueries_in_expr(else_clause, parent_scope);
719 }
720 }
721 Expression::Paren(paren) => {
722 collect_subqueries_in_expr(&paren.this, parent_scope);
723 }
724 Expression::Alias(alias) => {
725 collect_subqueries_in_expr(&alias.this, parent_scope);
726 }
727 _ => {}
728 }
729}
730
731pub fn walk_in_scope<'a>(
743 expression: &'a Expression,
744 bfs: bool,
745) -> impl Iterator<Item = &'a Expression> {
746 WalkInScopeIter::new(expression, bfs)
747}
748
749struct WalkInScopeIter<'a> {
751 queue: VecDeque<&'a Expression>,
752 bfs: bool,
753}
754
755impl<'a> WalkInScopeIter<'a> {
756 fn new(expression: &'a Expression, bfs: bool) -> Self {
757 let mut queue = VecDeque::new();
758 queue.push_back(expression);
759 Self { queue, bfs }
760 }
761
762 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
763 if is_root {
764 return false;
765 }
766
767 if matches!(expr, Expression::Cte(_)) {
769 return true;
770 }
771
772 if let Expression::Subquery(subquery) = expr {
774 if subquery.alias.is_some() {
775 return true;
776 }
777 }
778
779 if matches!(
781 expr,
782 Expression::Select(_)
783 | Expression::Union(_)
784 | Expression::Intersect(_)
785 | Expression::Except(_)
786 ) {
787 return true;
788 }
789
790 false
791 }
792
793 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
794 let mut children = Vec::new();
795
796 match expr {
797 Expression::Select(select) => {
798 for e in &select.expressions {
800 children.push(e);
801 }
802 if let Some(from) = &select.from {
804 for table in &from.expressions {
805 if !self.should_stop_at(table, false) {
806 children.push(table);
807 }
808 }
809 }
810 for join in &select.joins {
812 if let Some(on) = &join.on {
813 children.push(on);
814 }
815 }
817 if let Some(where_clause) = &select.where_clause {
819 children.push(&where_clause.this);
820 }
821 if let Some(group_by) = &select.group_by {
823 for e in &group_by.expressions {
824 children.push(e);
825 }
826 }
827 if let Some(having) = &select.having {
829 children.push(&having.this);
830 }
831 if let Some(order_by) = &select.order_by {
833 for ord in &order_by.expressions {
834 children.push(&ord.this);
835 }
836 }
837 if let Some(limit) = &select.limit {
839 children.push(&limit.this);
840 }
841 if let Some(offset) = &select.offset {
843 children.push(&offset.this);
844 }
845 }
846 Expression::And(bin)
847 | Expression::Or(bin)
848 | Expression::Add(bin)
849 | Expression::Sub(bin)
850 | Expression::Mul(bin)
851 | Expression::Div(bin)
852 | Expression::Mod(bin)
853 | Expression::Eq(bin)
854 | Expression::Neq(bin)
855 | Expression::Lt(bin)
856 | Expression::Lte(bin)
857 | Expression::Gt(bin)
858 | Expression::Gte(bin)
859 | Expression::BitwiseAnd(bin)
860 | Expression::BitwiseOr(bin)
861 | Expression::BitwiseXor(bin)
862 | Expression::Concat(bin) => {
863 children.push(&bin.left);
864 children.push(&bin.right);
865 }
866 Expression::Like(like) | Expression::ILike(like) => {
867 children.push(&like.left);
868 children.push(&like.right);
869 if let Some(escape) = &like.escape {
870 children.push(escape);
871 }
872 }
873 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
874 children.push(&un.this);
875 }
876 Expression::Function(func) => {
877 for arg in &func.args {
878 children.push(arg);
879 }
880 }
881 Expression::AggregateFunction(agg) => {
882 for arg in &agg.args {
883 children.push(arg);
884 }
885 }
886 Expression::WindowFunction(wf) => {
887 children.push(&wf.this);
888 for e in &wf.over.partition_by {
889 children.push(e);
890 }
891 for e in &wf.over.order_by {
892 children.push(&e.this);
893 }
894 }
895 Expression::Alias(alias) => {
896 children.push(&alias.this);
897 }
898 Expression::Case(case) => {
899 if let Some(operand) = &case.operand {
900 children.push(operand);
901 }
902 for (when_expr, then_expr) in &case.whens {
903 children.push(when_expr);
904 children.push(then_expr);
905 }
906 if let Some(else_clause) = &case.else_ {
907 children.push(else_clause);
908 }
909 }
910 Expression::Paren(paren) => {
911 children.push(&paren.this);
912 }
913 Expression::Ordered(ord) => {
914 children.push(&ord.this);
915 }
916 Expression::In(in_expr) => {
917 children.push(&in_expr.this);
918 for e in &in_expr.expressions {
919 children.push(e);
920 }
921 }
923 Expression::Between(between) => {
924 children.push(&between.this);
925 children.push(&between.low);
926 children.push(&between.high);
927 }
928 Expression::IsNull(is_null) => {
929 children.push(&is_null.this);
930 }
931 Expression::Cast(cast) => {
932 children.push(&cast.this);
933 }
934 Expression::Extract(extract) => {
935 children.push(&extract.this);
936 }
937 Expression::Coalesce(coalesce) => {
938 for e in &coalesce.expressions {
939 children.push(e);
940 }
941 }
942 Expression::NullIf(nullif) => {
943 children.push(&nullif.this);
944 children.push(&nullif.expression);
945 }
946 Expression::Table(_table) => {
947 }
950 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
951 }
953 Expression::Subquery(_) | Expression::Exists(_) => {}
955 _ => {
956 }
958 }
959
960 children
961 }
962}
963
964impl<'a> Iterator for WalkInScopeIter<'a> {
965 type Item = &'a Expression;
966
967 fn next(&mut self) -> Option<Self::Item> {
968 let expr = if self.bfs {
969 self.queue.pop_front()?
970 } else {
971 self.queue.pop_back()?
972 };
973
974 let children = self.get_children(expr);
976
977 if self.bfs {
978 for child in children {
979 if !self.should_stop_at(child, false) {
980 self.queue.push_back(child);
981 }
982 }
983 } else {
984 for child in children.into_iter().rev() {
985 if !self.should_stop_at(child, false) {
986 self.queue.push_back(child);
987 }
988 }
989 }
990
991 Some(expr)
992 }
993}
994
995pub fn find_in_scope<'a, F>(expression: &'a Expression, predicate: F, bfs: bool) -> Option<&'a Expression>
1007where
1008 F: Fn(&Expression) -> bool,
1009{
1010 walk_in_scope(expression, bfs).find(|e| predicate(e))
1011}
1012
1013pub fn find_all_in_scope<'a, F>(expression: &'a Expression, predicate: F, bfs: bool) -> Vec<&'a Expression>
1025where
1026 F: Fn(&Expression) -> bool,
1027{
1028 walk_in_scope(expression, bfs).filter(|e| predicate(e)).collect()
1029}
1030
1031pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1041 match expression {
1042 Expression::Select(_)
1043 | Expression::Union(_)
1044 | Expression::Intersect(_)
1045 | Expression::Except(_) => {
1046 let root = build_scope(expression);
1047 root.traverse().into_iter().cloned().collect()
1048 }
1049 _ => Vec::new(),
1050 }
1051}
1052
1053#[cfg(test)]
1054mod tests {
1055 use super::*;
1056 use crate::parser::Parser;
1057
1058 fn parse_and_build_scope(sql: &str) -> Scope {
1059 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1060 build_scope(&ast[0])
1061 }
1062
1063 #[test]
1064 fn test_simple_select_scope() {
1065 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1066
1067 assert!(scope.is_root());
1068 assert!(!scope.can_be_correlated);
1069 assert!(scope.sources.contains_key("t"));
1070
1071 let columns = scope.columns();
1072 assert_eq!(columns.len(), 2);
1073 }
1074
1075 #[test]
1076 fn test_derived_table_scope() {
1077 let mut scope = parse_and_build_scope(
1078 "SELECT x.a FROM (SELECT a FROM t) AS x"
1079 );
1080
1081 assert!(scope.sources.contains_key("x"));
1082 assert_eq!(scope.derived_table_scopes.len(), 1);
1083
1084 let derived = &mut scope.derived_table_scopes[0];
1085 assert!(derived.is_derived_table());
1086 assert!(derived.sources.contains_key("t"));
1087 }
1088
1089 #[test]
1090 fn test_non_correlated_subquery() {
1091 let mut scope = parse_and_build_scope(
1092 "SELECT * FROM t WHERE EXISTS (SELECT b FROM s)"
1093 );
1094
1095 assert_eq!(scope.subquery_scopes.len(), 1);
1096
1097 let subquery = &mut scope.subquery_scopes[0];
1098 assert!(subquery.is_subquery());
1099 assert!(subquery.can_be_correlated);
1100
1101 assert!(subquery.sources.contains_key("s"));
1103 assert!(!subquery.is_correlated_subquery());
1104 }
1105
1106 #[test]
1107 fn test_correlated_subquery() {
1108 let mut scope = parse_and_build_scope(
1109 "SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)"
1110 );
1111
1112 assert_eq!(scope.subquery_scopes.len(), 1);
1113
1114 let subquery = &mut scope.subquery_scopes[0];
1115 assert!(subquery.is_subquery());
1116 assert!(subquery.can_be_correlated);
1117
1118 let external = subquery.external_columns();
1120 assert!(!external.is_empty());
1121 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1122 assert!(subquery.is_correlated_subquery());
1123 }
1124
1125 #[test]
1126 fn test_cte_scope() {
1127 let scope = parse_and_build_scope(
1128 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte"
1129 );
1130
1131 assert_eq!(scope.cte_scopes.len(), 1);
1132 assert!(scope.cte_sources.contains_key("cte"));
1133
1134 let cte = &scope.cte_scopes[0];
1135 assert!(cte.is_cte());
1136 }
1137
1138 #[test]
1139 fn test_multiple_sources() {
1140 let scope = parse_and_build_scope(
1141 "SELECT t.a, s.b FROM t JOIN s ON t.id = s.id"
1142 );
1143
1144 assert!(scope.sources.contains_key("t"));
1145 assert!(scope.sources.contains_key("s"));
1146 assert_eq!(scope.sources.len(), 2);
1147 }
1148
1149 #[test]
1150 fn test_aliased_table() {
1151 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1152
1153 assert!(scope.sources.contains_key("x"));
1155 assert!(!scope.sources.contains_key("t"));
1156 }
1157
1158 #[test]
1159 fn test_local_columns() {
1160 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1161
1162 let local = scope.local_columns();
1163 assert_eq!(local.len(), 3);
1165 assert!(local.iter().all(|c| c.table.is_some()));
1166 }
1167
1168 #[test]
1169 fn test_unqualified_columns() {
1170 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1171
1172 let unqualified = scope.unqualified_columns();
1173 assert_eq!(unqualified.len(), 2);
1175 assert!(unqualified.iter().all(|c| c.table.is_none()));
1176 }
1177
1178 #[test]
1179 fn test_source_columns() {
1180 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1181
1182 let t_cols = scope.source_columns("t");
1183 assert!(t_cols.len() >= 2);
1185 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1186
1187 let s_cols = scope.source_columns("s");
1188 assert!(s_cols.len() >= 1);
1190 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1191 }
1192
1193 #[test]
1194 fn test_rename_source() {
1195 let mut scope = parse_and_build_scope("SELECT a FROM t");
1196
1197 assert!(scope.sources.contains_key("t"));
1198 scope.rename_source("t", "new_name".to_string());
1199 assert!(!scope.sources.contains_key("t"));
1200 assert!(scope.sources.contains_key("new_name"));
1201 }
1202
1203 #[test]
1204 fn test_remove_source() {
1205 let mut scope = parse_and_build_scope("SELECT a FROM t");
1206
1207 assert!(scope.sources.contains_key("t"));
1208 scope.remove_source("t");
1209 assert!(!scope.sources.contains_key("t"));
1210 }
1211
1212 #[test]
1213 fn test_walk_in_scope() {
1214 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1215 let expr = &ast[0];
1216
1217 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1219 assert!(!walked.is_empty());
1220
1221 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1223 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1225 }
1226
1227 #[test]
1228 fn test_find_in_scope() {
1229 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1230 let expr = &ast[0];
1231
1232 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1234 assert!(found.is_some());
1235 assert!(matches!(found.unwrap(), Expression::Column(_)));
1236 }
1237
1238 #[test]
1239 fn test_find_all_in_scope() {
1240 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1241 let expr = &ast[0];
1242
1243 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1245 assert_eq!(found.len(), 3);
1246 }
1247
1248 #[test]
1249 fn test_traverse_scope() {
1250 let ast =
1251 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1252 let expr = &ast[0];
1253
1254 let scopes = traverse_scope(expr);
1255 assert!(!scopes.is_empty());
1258 assert!(scopes.iter().any(|s| s.is_root()));
1260 }
1261
1262 #[test]
1263 fn test_branch_with_options() {
1264 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1265 let scope = build_scope(&ast[0]);
1266
1267 let child = scope.branch_with_options(
1268 ast[0].clone(),
1269 ScopeType::Subquery, None,
1271 None,
1272 Some(vec!["col1".to_string(), "col2".to_string()]),
1273 );
1274
1275 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1276 assert!(child.can_be_correlated); }
1278
1279 #[test]
1280 fn test_is_udtf() {
1281 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1282 let scope = Scope::new(ast[0].clone());
1283 assert!(!scope.is_udtf());
1284
1285 let root = build_scope(&ast[0]);
1286 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1287 assert!(udtf_scope.is_udtf());
1288 }
1289
1290 #[test]
1291 fn test_is_union() {
1292 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1293
1294 assert!(scope.is_root());
1295 assert_eq!(scope.union_scopes.len(), 2);
1296 assert!(scope.union_scopes[0].is_union());
1298 assert!(scope.union_scopes[1].is_union());
1299 }
1300
1301 #[test]
1302 fn test_clear_cache() {
1303 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1304
1305 let _ = scope.columns();
1307 assert!(scope.columns_cache.is_some());
1308
1309 scope.clear_cache();
1311 assert!(scope.columns_cache.is_none());
1312 assert!(scope.external_columns_cache.is_none());
1313 }
1314
1315 #[test]
1316 fn test_scope_traverse() {
1317 let scope = parse_and_build_scope(
1318 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1319 );
1320
1321 let traversed = scope.traverse();
1322 assert!(traversed.len() >= 3);
1324 }
1325}