1use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19 Root,
21 Subquery,
23 DerivedTable,
25 Cte,
27 SetOperation,
29 Udtf,
31}
32
33#[derive(Debug, Clone)]
35pub struct SourceInfo {
36 pub expression: Expression,
38 pub is_scope: bool,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45 pub table: Option<String>,
47 pub name: String,
49}
50
51#[derive(Debug, Clone)]
56pub struct Scope {
57 pub expression: Expression,
59
60 pub scope_type: ScopeType,
62
63 pub sources: HashMap<String, SourceInfo>,
65
66 pub lateral_sources: HashMap<String, SourceInfo>,
68
69 pub cte_sources: HashMap<String, SourceInfo>,
71
72 pub outer_columns: Vec<String>,
75
76 pub can_be_correlated: bool,
79
80 pub subquery_scopes: Vec<Scope>,
82
83 pub derived_table_scopes: Vec<Scope>,
85
86 pub cte_scopes: Vec<Scope>,
88
89 pub udtf_scopes: Vec<Scope>,
91
92 pub table_scopes: Vec<Scope>,
94
95 pub union_scopes: Vec<Scope>,
97
98 columns_cache: Option<Vec<ColumnRef>>,
100
101 external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106 pub fn new(expression: Expression) -> Self {
108 Self {
109 expression,
110 scope_type: ScopeType::Root,
111 sources: HashMap::new(),
112 lateral_sources: HashMap::new(),
113 cte_sources: HashMap::new(),
114 outer_columns: Vec::new(),
115 can_be_correlated: false,
116 subquery_scopes: Vec::new(),
117 derived_table_scopes: Vec::new(),
118 cte_scopes: Vec::new(),
119 udtf_scopes: Vec::new(),
120 table_scopes: Vec::new(),
121 union_scopes: Vec::new(),
122 columns_cache: None,
123 external_columns_cache: None,
124 }
125 }
126
127 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129 self.branch_with_options(expression, scope_type, None, None, None)
130 }
131
132 pub fn branch_with_options(
134 &self,
135 expression: Expression,
136 scope_type: ScopeType,
137 sources: Option<HashMap<String, SourceInfo>>,
138 lateral_sources: Option<HashMap<String, SourceInfo>>,
139 outer_columns: Option<Vec<String>>,
140 ) -> Self {
141 let can_be_correlated = self.can_be_correlated
142 || scope_type == ScopeType::Subquery
143 || scope_type == ScopeType::Udtf;
144
145 Self {
146 expression,
147 scope_type,
148 sources: sources.unwrap_or_default(),
149 lateral_sources: lateral_sources.unwrap_or_default(),
150 cte_sources: self.cte_sources.clone(),
151 outer_columns: outer_columns.unwrap_or_default(),
152 can_be_correlated,
153 subquery_scopes: Vec::new(),
154 derived_table_scopes: Vec::new(),
155 cte_scopes: Vec::new(),
156 udtf_scopes: Vec::new(),
157 table_scopes: Vec::new(),
158 union_scopes: Vec::new(),
159 columns_cache: None,
160 external_columns_cache: None,
161 }
162 }
163
164 pub fn clear_cache(&mut self) {
166 self.columns_cache = None;
167 self.external_columns_cache = None;
168 }
169
170 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172 self.sources.insert(
173 name,
174 SourceInfo {
175 expression,
176 is_scope,
177 },
178 );
179 self.clear_cache();
180 }
181
182 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184 self.lateral_sources.insert(
185 name.clone(),
186 SourceInfo {
187 expression: expression.clone(),
188 is_scope,
189 },
190 );
191 self.sources.insert(
192 name,
193 SourceInfo {
194 expression,
195 is_scope,
196 },
197 );
198 self.clear_cache();
199 }
200
201 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203 self.cte_sources.insert(
204 name.clone(),
205 SourceInfo {
206 expression: expression.clone(),
207 is_scope: true,
208 },
209 );
210 self.sources.insert(
211 name,
212 SourceInfo {
213 expression,
214 is_scope: true,
215 },
216 );
217 self.clear_cache();
218 }
219
220 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222 if let Some(source) = self.sources.remove(old_name) {
223 self.sources.insert(new_name, source);
224 }
225 self.clear_cache();
226 }
227
228 pub fn remove_source(&mut self, name: &str) {
230 self.sources.remove(name);
231 self.clear_cache();
232 }
233
234 pub fn columns(&mut self) -> &[ColumnRef] {
236 if self.columns_cache.is_none() {
237 let mut columns = Vec::new();
238 collect_columns(&self.expression, &mut columns);
239 self.columns_cache = Some(columns);
240 }
241 self.columns_cache.as_ref().unwrap()
242 }
243
244 pub fn source_names(&self) -> HashSet<String> {
246 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
247 names.extend(self.cte_sources.keys().cloned());
248 names
249 }
250
251 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
253 if self.external_columns_cache.is_some() {
254 return self.external_columns_cache.clone().unwrap();
255 }
256
257 let source_names = self.source_names();
258 let columns = self.columns().to_vec();
259
260 let external: Vec<ColumnRef> = columns
261 .into_iter()
262 .filter(|col| {
263 match &col.table {
265 Some(table) => !source_names.contains(table),
266 None => false, }
268 })
269 .collect();
270
271 self.external_columns_cache = Some(external.clone());
272 external
273 }
274
275 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
277 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
278 let columns = self.columns().to_vec();
279
280 columns
281 .into_iter()
282 .filter(|col| !external_set.contains(col))
283 .collect()
284 }
285
286 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
288 self.columns()
289 .iter()
290 .filter(|c| c.table.is_none())
291 .cloned()
292 .collect()
293 }
294
295 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
297 self.columns()
298 .iter()
299 .filter(|col| col.table.as_deref() == Some(source_name))
300 .cloned()
301 .collect()
302 }
303
304 pub fn is_correlated_subquery(&mut self) -> bool {
310 self.can_be_correlated && !self.external_columns().is_empty()
311 }
312
313 pub fn is_subquery(&self) -> bool {
315 self.scope_type == ScopeType::Subquery
316 }
317
318 pub fn is_derived_table(&self) -> bool {
320 self.scope_type == ScopeType::DerivedTable
321 }
322
323 pub fn is_cte(&self) -> bool {
325 self.scope_type == ScopeType::Cte
326 }
327
328 pub fn is_root(&self) -> bool {
330 self.scope_type == ScopeType::Root
331 }
332
333 pub fn is_udtf(&self) -> bool {
335 self.scope_type == ScopeType::Udtf
336 }
337
338 pub fn is_union(&self) -> bool {
340 self.scope_type == ScopeType::SetOperation
341 }
342
343 pub fn traverse(&self) -> Vec<&Scope> {
345 let mut result = Vec::new();
346 self.traverse_impl(&mut result);
347 result
348 }
349
350 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
351 for scope in &self.cte_scopes {
353 scope.traverse_impl(result);
354 }
355 for scope in &self.union_scopes {
356 scope.traverse_impl(result);
357 }
358 for scope in &self.table_scopes {
359 scope.traverse_impl(result);
360 }
361 for scope in &self.subquery_scopes {
362 scope.traverse_impl(result);
363 }
364 result.push(self);
366 }
367
368 pub fn ref_count(&self) -> HashMap<usize, usize> {
370 let mut counts: HashMap<usize, usize> = HashMap::new();
371
372 for scope in self.traverse() {
373 for (_, source_info) in scope.sources.iter() {
374 if source_info.is_scope {
375 let id = &source_info.expression as *const _ as usize;
376 *counts.entry(id).or_insert(0) += 1;
377 }
378 }
379 }
380
381 counts
382 }
383}
384
385fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
387 match expr {
388 Expression::Column(col) => {
389 columns.push(ColumnRef {
390 table: col.table.as_ref().map(|t| t.name.clone()),
391 name: col.name.name.clone(),
392 });
393 }
394 Expression::Select(select) => {
395 for e in &select.expressions {
397 collect_columns(e, columns);
398 }
399 if let Some(where_clause) = &select.where_clause {
401 collect_columns(&where_clause.this, columns);
402 }
403 if let Some(having) = &select.having {
405 collect_columns(&having.this, columns);
406 }
407 if let Some(order_by) = &select.order_by {
409 for ord in &order_by.expressions {
410 collect_columns(&ord.this, columns);
411 }
412 }
413 if let Some(group_by) = &select.group_by {
415 for e in &group_by.expressions {
416 collect_columns(e, columns);
417 }
418 }
419 }
422 Expression::And(bin)
424 | Expression::Or(bin)
425 | Expression::Add(bin)
426 | Expression::Sub(bin)
427 | Expression::Mul(bin)
428 | Expression::Div(bin)
429 | Expression::Mod(bin)
430 | Expression::Eq(bin)
431 | Expression::Neq(bin)
432 | Expression::Lt(bin)
433 | Expression::Lte(bin)
434 | Expression::Gt(bin)
435 | Expression::Gte(bin)
436 | Expression::BitwiseAnd(bin)
437 | Expression::BitwiseOr(bin)
438 | Expression::BitwiseXor(bin)
439 | Expression::Concat(bin) => {
440 collect_columns(&bin.left, columns);
441 collect_columns(&bin.right, columns);
442 }
443 Expression::Like(like) | Expression::ILike(like) => {
445 collect_columns(&like.left, columns);
446 collect_columns(&like.right, columns);
447 if let Some(escape) = &like.escape {
448 collect_columns(escape, columns);
449 }
450 }
451 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
453 collect_columns(&un.this, columns);
454 }
455 Expression::Function(func) => {
456 for arg in &func.args {
457 collect_columns(arg, columns);
458 }
459 }
460 Expression::AggregateFunction(agg) => {
461 for arg in &agg.args {
462 collect_columns(arg, columns);
463 }
464 }
465 Expression::WindowFunction(wf) => {
466 collect_columns(&wf.this, columns);
467 for e in &wf.over.partition_by {
468 collect_columns(e, columns);
469 }
470 for e in &wf.over.order_by {
471 collect_columns(&e.this, columns);
472 }
473 }
474 Expression::Alias(alias) => {
475 collect_columns(&alias.this, columns);
476 }
477 Expression::Case(case) => {
478 if let Some(operand) = &case.operand {
479 collect_columns(operand, columns);
480 }
481 for (when_expr, then_expr) in &case.whens {
482 collect_columns(when_expr, columns);
483 collect_columns(then_expr, columns);
484 }
485 if let Some(else_clause) = &case.else_ {
486 collect_columns(else_clause, columns);
487 }
488 }
489 Expression::Paren(paren) => {
490 collect_columns(&paren.this, columns);
491 }
492 Expression::Ordered(ord) => {
493 collect_columns(&ord.this, columns);
494 }
495 Expression::In(in_expr) => {
496 collect_columns(&in_expr.this, columns);
497 for e in &in_expr.expressions {
498 collect_columns(e, columns);
499 }
500 }
502 Expression::Between(between) => {
503 collect_columns(&between.this, columns);
504 collect_columns(&between.low, columns);
505 collect_columns(&between.high, columns);
506 }
507 Expression::IsNull(is_null) => {
508 collect_columns(&is_null.this, columns);
509 }
510 Expression::Cast(cast) => {
511 collect_columns(&cast.this, columns);
512 }
513 Expression::Extract(extract) => {
514 collect_columns(&extract.this, columns);
515 }
516 Expression::Exists(_) | Expression::Subquery(_) => {
517 }
519 _ => {
520 }
522 }
523}
524
525pub fn build_scope(expression: &Expression) -> Scope {
530 let mut root = Scope::new(expression.clone());
531 build_scope_impl(expression, &mut root);
532 root
533}
534
535fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
536 match expression {
537 Expression::Select(select) => {
538 if let Some(with) = &select.with {
540 for cte in &with.ctes {
541 let cte_name = cte.alias.name.clone();
542 let mut cte_scope = current_scope
543 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
544 build_scope_impl(&cte.this, &mut cte_scope);
545 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
546 current_scope.cte_scopes.push(cte_scope);
547 }
548 }
549
550 if let Some(from) = &select.from {
552 for table in &from.expressions {
553 add_table_to_scope(table, current_scope);
554 }
555 }
556
557 for join in &select.joins {
559 add_table_to_scope(&join.this, current_scope);
560 }
561
562 collect_subqueries(expression, current_scope);
564 }
565 Expression::Union(union) => {
566 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
567 build_scope_impl(&union.left, &mut left_scope);
568
569 let mut right_scope =
570 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
571 build_scope_impl(&union.right, &mut right_scope);
572
573 current_scope.union_scopes.push(left_scope);
574 current_scope.union_scopes.push(right_scope);
575 }
576 Expression::Intersect(intersect) => {
577 let mut left_scope =
578 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
579 build_scope_impl(&intersect.left, &mut left_scope);
580
581 let mut right_scope =
582 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
583 build_scope_impl(&intersect.right, &mut right_scope);
584
585 current_scope.union_scopes.push(left_scope);
586 current_scope.union_scopes.push(right_scope);
587 }
588 Expression::Except(except) => {
589 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
590 build_scope_impl(&except.left, &mut left_scope);
591
592 let mut right_scope =
593 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
594 build_scope_impl(&except.right, &mut right_scope);
595
596 current_scope.union_scopes.push(left_scope);
597 current_scope.union_scopes.push(right_scope);
598 }
599 _ => {}
600 }
601}
602
603fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
604 match expr {
605 Expression::Table(table) => {
606 let name = table
607 .alias
608 .as_ref()
609 .map(|a| a.name.clone())
610 .unwrap_or_else(|| table.name.name.clone());
611 scope.add_source(name, expr.clone(), false);
612 }
613 Expression::Subquery(subquery) => {
614 let name = subquery
615 .alias
616 .as_ref()
617 .map(|a| a.name.clone())
618 .unwrap_or_default();
619
620 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
621 build_scope_impl(&subquery.this, &mut derived_scope);
622
623 scope.add_source(name.clone(), expr.clone(), true);
624 scope.derived_table_scopes.push(derived_scope);
625 }
626 Expression::Paren(paren) => {
627 add_table_to_scope(&paren.this, scope);
628 }
629 _ => {}
630 }
631}
632
633fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
634 match expr {
635 Expression::Select(select) => {
636 if let Some(where_clause) = &select.where_clause {
638 collect_subqueries_in_expr(&where_clause.this, parent_scope);
639 }
640 for e in &select.expressions {
642 collect_subqueries_in_expr(e, parent_scope);
643 }
644 if let Some(having) = &select.having {
646 collect_subqueries_in_expr(&having.this, parent_scope);
647 }
648 }
649 _ => {}
650 }
651}
652
653fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
654 match expr {
655 Expression::Subquery(subquery) if subquery.alias.is_none() => {
656 let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
658 build_scope_impl(&subquery.this, &mut sub_scope);
659 parent_scope.subquery_scopes.push(sub_scope);
660 }
661 Expression::In(in_expr) => {
662 collect_subqueries_in_expr(&in_expr.this, parent_scope);
663 if let Some(query) = &in_expr.query {
664 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
665 build_scope_impl(query, &mut sub_scope);
666 parent_scope.subquery_scopes.push(sub_scope);
667 }
668 }
669 Expression::Exists(exists) => {
670 let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
671 build_scope_impl(&exists.this, &mut sub_scope);
672 parent_scope.subquery_scopes.push(sub_scope);
673 }
674 Expression::And(bin)
676 | Expression::Or(bin)
677 | Expression::Add(bin)
678 | Expression::Sub(bin)
679 | Expression::Mul(bin)
680 | Expression::Div(bin)
681 | Expression::Mod(bin)
682 | Expression::Eq(bin)
683 | Expression::Neq(bin)
684 | Expression::Lt(bin)
685 | Expression::Lte(bin)
686 | Expression::Gt(bin)
687 | Expression::Gte(bin)
688 | Expression::BitwiseAnd(bin)
689 | Expression::BitwiseOr(bin)
690 | Expression::BitwiseXor(bin)
691 | Expression::Concat(bin) => {
692 collect_subqueries_in_expr(&bin.left, parent_scope);
693 collect_subqueries_in_expr(&bin.right, parent_scope);
694 }
695 Expression::Like(like) | Expression::ILike(like) => {
697 collect_subqueries_in_expr(&like.left, parent_scope);
698 collect_subqueries_in_expr(&like.right, parent_scope);
699 if let Some(escape) = &like.escape {
700 collect_subqueries_in_expr(escape, parent_scope);
701 }
702 }
703 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
705 collect_subqueries_in_expr(&un.this, parent_scope);
706 }
707 Expression::Function(func) => {
708 for arg in &func.args {
709 collect_subqueries_in_expr(arg, parent_scope);
710 }
711 }
712 Expression::Case(case) => {
713 if let Some(operand) = &case.operand {
714 collect_subqueries_in_expr(operand, parent_scope);
715 }
716 for (when_expr, then_expr) in &case.whens {
717 collect_subqueries_in_expr(when_expr, parent_scope);
718 collect_subqueries_in_expr(then_expr, parent_scope);
719 }
720 if let Some(else_clause) = &case.else_ {
721 collect_subqueries_in_expr(else_clause, parent_scope);
722 }
723 }
724 Expression::Paren(paren) => {
725 collect_subqueries_in_expr(&paren.this, parent_scope);
726 }
727 Expression::Alias(alias) => {
728 collect_subqueries_in_expr(&alias.this, parent_scope);
729 }
730 _ => {}
731 }
732}
733
734pub fn walk_in_scope<'a>(
746 expression: &'a Expression,
747 bfs: bool,
748) -> impl Iterator<Item = &'a Expression> {
749 WalkInScopeIter::new(expression, bfs)
750}
751
752struct WalkInScopeIter<'a> {
754 queue: VecDeque<&'a Expression>,
755 bfs: bool,
756}
757
758impl<'a> WalkInScopeIter<'a> {
759 fn new(expression: &'a Expression, bfs: bool) -> Self {
760 let mut queue = VecDeque::new();
761 queue.push_back(expression);
762 Self { queue, bfs }
763 }
764
765 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
766 if is_root {
767 return false;
768 }
769
770 if matches!(expr, Expression::Cte(_)) {
772 return true;
773 }
774
775 if let Expression::Subquery(subquery) = expr {
777 if subquery.alias.is_some() {
778 return true;
779 }
780 }
781
782 if matches!(
784 expr,
785 Expression::Select(_)
786 | Expression::Union(_)
787 | Expression::Intersect(_)
788 | Expression::Except(_)
789 ) {
790 return true;
791 }
792
793 false
794 }
795
796 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
797 let mut children = Vec::new();
798
799 match expr {
800 Expression::Select(select) => {
801 for e in &select.expressions {
803 children.push(e);
804 }
805 if let Some(from) = &select.from {
807 for table in &from.expressions {
808 if !self.should_stop_at(table, false) {
809 children.push(table);
810 }
811 }
812 }
813 for join in &select.joins {
815 if let Some(on) = &join.on {
816 children.push(on);
817 }
818 }
820 if let Some(where_clause) = &select.where_clause {
822 children.push(&where_clause.this);
823 }
824 if let Some(group_by) = &select.group_by {
826 for e in &group_by.expressions {
827 children.push(e);
828 }
829 }
830 if let Some(having) = &select.having {
832 children.push(&having.this);
833 }
834 if let Some(order_by) = &select.order_by {
836 for ord in &order_by.expressions {
837 children.push(&ord.this);
838 }
839 }
840 if let Some(limit) = &select.limit {
842 children.push(&limit.this);
843 }
844 if let Some(offset) = &select.offset {
846 children.push(&offset.this);
847 }
848 }
849 Expression::And(bin)
850 | Expression::Or(bin)
851 | Expression::Add(bin)
852 | Expression::Sub(bin)
853 | Expression::Mul(bin)
854 | Expression::Div(bin)
855 | Expression::Mod(bin)
856 | Expression::Eq(bin)
857 | Expression::Neq(bin)
858 | Expression::Lt(bin)
859 | Expression::Lte(bin)
860 | Expression::Gt(bin)
861 | Expression::Gte(bin)
862 | Expression::BitwiseAnd(bin)
863 | Expression::BitwiseOr(bin)
864 | Expression::BitwiseXor(bin)
865 | Expression::Concat(bin) => {
866 children.push(&bin.left);
867 children.push(&bin.right);
868 }
869 Expression::Like(like) | Expression::ILike(like) => {
870 children.push(&like.left);
871 children.push(&like.right);
872 if let Some(escape) = &like.escape {
873 children.push(escape);
874 }
875 }
876 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
877 children.push(&un.this);
878 }
879 Expression::Function(func) => {
880 for arg in &func.args {
881 children.push(arg);
882 }
883 }
884 Expression::AggregateFunction(agg) => {
885 for arg in &agg.args {
886 children.push(arg);
887 }
888 }
889 Expression::WindowFunction(wf) => {
890 children.push(&wf.this);
891 for e in &wf.over.partition_by {
892 children.push(e);
893 }
894 for e in &wf.over.order_by {
895 children.push(&e.this);
896 }
897 }
898 Expression::Alias(alias) => {
899 children.push(&alias.this);
900 }
901 Expression::Case(case) => {
902 if let Some(operand) = &case.operand {
903 children.push(operand);
904 }
905 for (when_expr, then_expr) in &case.whens {
906 children.push(when_expr);
907 children.push(then_expr);
908 }
909 if let Some(else_clause) = &case.else_ {
910 children.push(else_clause);
911 }
912 }
913 Expression::Paren(paren) => {
914 children.push(&paren.this);
915 }
916 Expression::Ordered(ord) => {
917 children.push(&ord.this);
918 }
919 Expression::In(in_expr) => {
920 children.push(&in_expr.this);
921 for e in &in_expr.expressions {
922 children.push(e);
923 }
924 }
926 Expression::Between(between) => {
927 children.push(&between.this);
928 children.push(&between.low);
929 children.push(&between.high);
930 }
931 Expression::IsNull(is_null) => {
932 children.push(&is_null.this);
933 }
934 Expression::Cast(cast) => {
935 children.push(&cast.this);
936 }
937 Expression::Extract(extract) => {
938 children.push(&extract.this);
939 }
940 Expression::Coalesce(coalesce) => {
941 for e in &coalesce.expressions {
942 children.push(e);
943 }
944 }
945 Expression::NullIf(nullif) => {
946 children.push(&nullif.this);
947 children.push(&nullif.expression);
948 }
949 Expression::Table(_table) => {
950 }
953 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
954 }
956 Expression::Subquery(_) | Expression::Exists(_) => {}
958 _ => {
959 }
961 }
962
963 children
964 }
965}
966
967impl<'a> Iterator for WalkInScopeIter<'a> {
968 type Item = &'a Expression;
969
970 fn next(&mut self) -> Option<Self::Item> {
971 let expr = if self.bfs {
972 self.queue.pop_front()?
973 } else {
974 self.queue.pop_back()?
975 };
976
977 let children = self.get_children(expr);
979
980 if self.bfs {
981 for child in children {
982 if !self.should_stop_at(child, false) {
983 self.queue.push_back(child);
984 }
985 }
986 } else {
987 for child in children.into_iter().rev() {
988 if !self.should_stop_at(child, false) {
989 self.queue.push_back(child);
990 }
991 }
992 }
993
994 Some(expr)
995 }
996}
997
998pub fn find_in_scope<'a, F>(
1010 expression: &'a Expression,
1011 predicate: F,
1012 bfs: bool,
1013) -> Option<&'a Expression>
1014where
1015 F: Fn(&Expression) -> bool,
1016{
1017 walk_in_scope(expression, bfs).find(|e| predicate(e))
1018}
1019
1020pub fn find_all_in_scope<'a, F>(
1032 expression: &'a Expression,
1033 predicate: F,
1034 bfs: bool,
1035) -> Vec<&'a Expression>
1036where
1037 F: Fn(&Expression) -> bool,
1038{
1039 walk_in_scope(expression, bfs)
1040 .filter(|e| predicate(e))
1041 .collect()
1042}
1043
1044pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1054 match expression {
1055 Expression::Select(_)
1056 | Expression::Union(_)
1057 | Expression::Intersect(_)
1058 | Expression::Except(_) => {
1059 let root = build_scope(expression);
1060 root.traverse().into_iter().cloned().collect()
1061 }
1062 _ => Vec::new(),
1063 }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::*;
1069 use crate::parser::Parser;
1070
1071 fn parse_and_build_scope(sql: &str) -> Scope {
1072 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1073 build_scope(&ast[0])
1074 }
1075
1076 #[test]
1077 fn test_simple_select_scope() {
1078 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1079
1080 assert!(scope.is_root());
1081 assert!(!scope.can_be_correlated);
1082 assert!(scope.sources.contains_key("t"));
1083
1084 let columns = scope.columns();
1085 assert_eq!(columns.len(), 2);
1086 }
1087
1088 #[test]
1089 fn test_derived_table_scope() {
1090 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1091
1092 assert!(scope.sources.contains_key("x"));
1093 assert_eq!(scope.derived_table_scopes.len(), 1);
1094
1095 let derived = &mut scope.derived_table_scopes[0];
1096 assert!(derived.is_derived_table());
1097 assert!(derived.sources.contains_key("t"));
1098 }
1099
1100 #[test]
1101 fn test_non_correlated_subquery() {
1102 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1103
1104 assert_eq!(scope.subquery_scopes.len(), 1);
1105
1106 let subquery = &mut scope.subquery_scopes[0];
1107 assert!(subquery.is_subquery());
1108 assert!(subquery.can_be_correlated);
1109
1110 assert!(subquery.sources.contains_key("s"));
1112 assert!(!subquery.is_correlated_subquery());
1113 }
1114
1115 #[test]
1116 fn test_correlated_subquery() {
1117 let mut scope =
1118 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1119
1120 assert_eq!(scope.subquery_scopes.len(), 1);
1121
1122 let subquery = &mut scope.subquery_scopes[0];
1123 assert!(subquery.is_subquery());
1124 assert!(subquery.can_be_correlated);
1125
1126 let external = subquery.external_columns();
1128 assert!(!external.is_empty());
1129 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1130 assert!(subquery.is_correlated_subquery());
1131 }
1132
1133 #[test]
1134 fn test_cte_scope() {
1135 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1136
1137 assert_eq!(scope.cte_scopes.len(), 1);
1138 assert!(scope.cte_sources.contains_key("cte"));
1139
1140 let cte = &scope.cte_scopes[0];
1141 assert!(cte.is_cte());
1142 }
1143
1144 #[test]
1145 fn test_multiple_sources() {
1146 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1147
1148 assert!(scope.sources.contains_key("t"));
1149 assert!(scope.sources.contains_key("s"));
1150 assert_eq!(scope.sources.len(), 2);
1151 }
1152
1153 #[test]
1154 fn test_aliased_table() {
1155 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1156
1157 assert!(scope.sources.contains_key("x"));
1159 assert!(!scope.sources.contains_key("t"));
1160 }
1161
1162 #[test]
1163 fn test_local_columns() {
1164 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1165
1166 let local = scope.local_columns();
1167 assert_eq!(local.len(), 3);
1169 assert!(local.iter().all(|c| c.table.is_some()));
1170 }
1171
1172 #[test]
1173 fn test_unqualified_columns() {
1174 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1175
1176 let unqualified = scope.unqualified_columns();
1177 assert_eq!(unqualified.len(), 2);
1179 assert!(unqualified.iter().all(|c| c.table.is_none()));
1180 }
1181
1182 #[test]
1183 fn test_source_columns() {
1184 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1185
1186 let t_cols = scope.source_columns("t");
1187 assert!(t_cols.len() >= 2);
1189 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1190
1191 let s_cols = scope.source_columns("s");
1192 assert!(s_cols.len() >= 1);
1194 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1195 }
1196
1197 #[test]
1198 fn test_rename_source() {
1199 let mut scope = parse_and_build_scope("SELECT a FROM t");
1200
1201 assert!(scope.sources.contains_key("t"));
1202 scope.rename_source("t", "new_name".to_string());
1203 assert!(!scope.sources.contains_key("t"));
1204 assert!(scope.sources.contains_key("new_name"));
1205 }
1206
1207 #[test]
1208 fn test_remove_source() {
1209 let mut scope = parse_and_build_scope("SELECT a FROM t");
1210
1211 assert!(scope.sources.contains_key("t"));
1212 scope.remove_source("t");
1213 assert!(!scope.sources.contains_key("t"));
1214 }
1215
1216 #[test]
1217 fn test_walk_in_scope() {
1218 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1219 let expr = &ast[0];
1220
1221 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1223 assert!(!walked.is_empty());
1224
1225 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1227 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1229 }
1230
1231 #[test]
1232 fn test_find_in_scope() {
1233 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1234 let expr = &ast[0];
1235
1236 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1238 assert!(found.is_some());
1239 assert!(matches!(found.unwrap(), Expression::Column(_)));
1240 }
1241
1242 #[test]
1243 fn test_find_all_in_scope() {
1244 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1245 let expr = &ast[0];
1246
1247 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1249 assert_eq!(found.len(), 3);
1250 }
1251
1252 #[test]
1253 fn test_traverse_scope() {
1254 let ast =
1255 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1256 let expr = &ast[0];
1257
1258 let scopes = traverse_scope(expr);
1259 assert!(!scopes.is_empty());
1262 assert!(scopes.iter().any(|s| s.is_root()));
1264 }
1265
1266 #[test]
1267 fn test_branch_with_options() {
1268 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1269 let scope = build_scope(&ast[0]);
1270
1271 let child = scope.branch_with_options(
1272 ast[0].clone(),
1273 ScopeType::Subquery, None,
1275 None,
1276 Some(vec!["col1".to_string(), "col2".to_string()]),
1277 );
1278
1279 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1280 assert!(child.can_be_correlated); }
1282
1283 #[test]
1284 fn test_is_udtf() {
1285 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1286 let scope = Scope::new(ast[0].clone());
1287 assert!(!scope.is_udtf());
1288
1289 let root = build_scope(&ast[0]);
1290 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1291 assert!(udtf_scope.is_udtf());
1292 }
1293
1294 #[test]
1295 fn test_is_union() {
1296 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1297
1298 assert!(scope.is_root());
1299 assert_eq!(scope.union_scopes.len(), 2);
1300 assert!(scope.union_scopes[0].is_union());
1302 assert!(scope.union_scopes[1].is_union());
1303 }
1304
1305 #[test]
1306 fn test_clear_cache() {
1307 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1308
1309 let _ = scope.columns();
1311 assert!(scope.columns_cache.is_some());
1312
1313 scope.clear_cache();
1315 assert!(scope.columns_cache.is_none());
1316 assert!(scope.external_columns_cache.is_none());
1317 }
1318
1319 #[test]
1320 fn test_scope_traverse() {
1321 let scope = parse_and_build_scope(
1322 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1323 );
1324
1325 let traversed = scope.traverse();
1326 assert!(traversed.len() >= 3);
1328 }
1329}