1use crate::expressions::Expression;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11#[cfg(feature = "bindings")]
12use ts_rs::TS;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
16#[cfg_attr(feature = "bindings", derive(TS))]
17#[cfg_attr(feature = "bindings", ts(export))]
18pub enum ScopeType {
19 Root,
21 Subquery,
23 DerivedTable,
25 Cte,
27 SetOperation,
29 Udtf,
31}
32
33#[derive(Debug, Clone)]
35pub struct SourceInfo {
36 pub expression: Expression,
38 pub is_scope: bool,
40}
41
42#[derive(Debug, Clone, PartialEq, Eq, Hash)]
44pub struct ColumnRef {
45 pub table: Option<String>,
47 pub name: String,
49}
50
51#[derive(Debug, Clone)]
56pub struct Scope {
57 pub expression: Expression,
59
60 pub scope_type: ScopeType,
62
63 pub sources: HashMap<String, SourceInfo>,
65
66 pub lateral_sources: HashMap<String, SourceInfo>,
68
69 pub cte_sources: HashMap<String, SourceInfo>,
71
72 pub outer_columns: Vec<String>,
75
76 pub can_be_correlated: bool,
79
80 pub subquery_scopes: Vec<Scope>,
82
83 pub derived_table_scopes: Vec<Scope>,
85
86 pub cte_scopes: Vec<Scope>,
88
89 pub udtf_scopes: Vec<Scope>,
91
92 pub table_scopes: Vec<Scope>,
94
95 pub union_scopes: Vec<Scope>,
97
98 columns_cache: Option<Vec<ColumnRef>>,
100
101 external_columns_cache: Option<Vec<ColumnRef>>,
103}
104
105impl Scope {
106 pub fn new(expression: Expression) -> Self {
108 Self {
109 expression,
110 scope_type: ScopeType::Root,
111 sources: HashMap::new(),
112 lateral_sources: HashMap::new(),
113 cte_sources: HashMap::new(),
114 outer_columns: Vec::new(),
115 can_be_correlated: false,
116 subquery_scopes: Vec::new(),
117 derived_table_scopes: Vec::new(),
118 cte_scopes: Vec::new(),
119 udtf_scopes: Vec::new(),
120 table_scopes: Vec::new(),
121 union_scopes: Vec::new(),
122 columns_cache: None,
123 external_columns_cache: None,
124 }
125 }
126
127 pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
129 self.branch_with_options(expression, scope_type, None, None, None)
130 }
131
132 pub fn branch_with_options(
134 &self,
135 expression: Expression,
136 scope_type: ScopeType,
137 sources: Option<HashMap<String, SourceInfo>>,
138 lateral_sources: Option<HashMap<String, SourceInfo>>,
139 outer_columns: Option<Vec<String>>,
140 ) -> Self {
141 let can_be_correlated = self.can_be_correlated
142 || scope_type == ScopeType::Subquery
143 || scope_type == ScopeType::Udtf;
144
145 Self {
146 expression,
147 scope_type,
148 sources: sources.unwrap_or_default(),
149 lateral_sources: lateral_sources.unwrap_or_default(),
150 cte_sources: self.cte_sources.clone(),
151 outer_columns: outer_columns.unwrap_or_default(),
152 can_be_correlated,
153 subquery_scopes: Vec::new(),
154 derived_table_scopes: Vec::new(),
155 cte_scopes: Vec::new(),
156 udtf_scopes: Vec::new(),
157 table_scopes: Vec::new(),
158 union_scopes: Vec::new(),
159 columns_cache: None,
160 external_columns_cache: None,
161 }
162 }
163
164 pub fn clear_cache(&mut self) {
166 self.columns_cache = None;
167 self.external_columns_cache = None;
168 }
169
170 pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
172 self.sources.insert(
173 name,
174 SourceInfo {
175 expression,
176 is_scope,
177 },
178 );
179 self.clear_cache();
180 }
181
182 pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
184 self.lateral_sources.insert(
185 name.clone(),
186 SourceInfo {
187 expression: expression.clone(),
188 is_scope,
189 },
190 );
191 self.sources.insert(
192 name,
193 SourceInfo {
194 expression,
195 is_scope,
196 },
197 );
198 self.clear_cache();
199 }
200
201 pub fn add_cte_source(&mut self, name: String, expression: Expression) {
203 self.cte_sources.insert(
204 name.clone(),
205 SourceInfo {
206 expression: expression.clone(),
207 is_scope: true,
208 },
209 );
210 self.sources.insert(
211 name,
212 SourceInfo {
213 expression,
214 is_scope: true,
215 },
216 );
217 self.clear_cache();
218 }
219
220 pub fn rename_source(&mut self, old_name: &str, new_name: String) {
222 if let Some(source) = self.sources.remove(old_name) {
223 self.sources.insert(new_name, source);
224 }
225 self.clear_cache();
226 }
227
228 pub fn remove_source(&mut self, name: &str) {
230 self.sources.remove(name);
231 self.clear_cache();
232 }
233
234 pub fn columns(&mut self) -> &[ColumnRef] {
236 if self.columns_cache.is_none() {
237 let mut columns = Vec::new();
238 collect_columns(&self.expression, &mut columns);
239 self.columns_cache = Some(columns);
240 }
241 self.columns_cache.as_ref().unwrap()
242 }
243
244 pub fn source_names(&self) -> HashSet<String> {
246 let mut names: HashSet<String> = self.sources.keys().cloned().collect();
247 names.extend(self.cte_sources.keys().cloned());
248 names
249 }
250
251 pub fn external_columns(&mut self) -> Vec<ColumnRef> {
253 if self.external_columns_cache.is_some() {
254 return self.external_columns_cache.clone().unwrap();
255 }
256
257 let source_names = self.source_names();
258 let columns = self.columns().to_vec();
259
260 let external: Vec<ColumnRef> = columns
261 .into_iter()
262 .filter(|col| {
263 match &col.table {
265 Some(table) => !source_names.contains(table),
266 None => false, }
268 })
269 .collect();
270
271 self.external_columns_cache = Some(external.clone());
272 external
273 }
274
275 pub fn local_columns(&mut self) -> Vec<ColumnRef> {
277 let external_set: HashSet<_> = self.external_columns().into_iter().collect();
278 let columns = self.columns().to_vec();
279
280 columns
281 .into_iter()
282 .filter(|col| !external_set.contains(col))
283 .collect()
284 }
285
286 pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
288 self.columns()
289 .iter()
290 .filter(|c| c.table.is_none())
291 .cloned()
292 .collect()
293 }
294
295 pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
297 self.columns()
298 .iter()
299 .filter(|col| col.table.as_deref() == Some(source_name))
300 .cloned()
301 .collect()
302 }
303
304 pub fn is_correlated_subquery(&mut self) -> bool {
310 self.can_be_correlated && !self.external_columns().is_empty()
311 }
312
313 pub fn is_subquery(&self) -> bool {
315 self.scope_type == ScopeType::Subquery
316 }
317
318 pub fn is_derived_table(&self) -> bool {
320 self.scope_type == ScopeType::DerivedTable
321 }
322
323 pub fn is_cte(&self) -> bool {
325 self.scope_type == ScopeType::Cte
326 }
327
328 pub fn is_root(&self) -> bool {
330 self.scope_type == ScopeType::Root
331 }
332
333 pub fn is_udtf(&self) -> bool {
335 self.scope_type == ScopeType::Udtf
336 }
337
338 pub fn is_union(&self) -> bool {
340 self.scope_type == ScopeType::SetOperation
341 }
342
343 pub fn traverse(&self) -> Vec<&Scope> {
345 let mut result = Vec::new();
346 self.traverse_impl(&mut result);
347 result
348 }
349
350 fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
351 for scope in &self.cte_scopes {
353 scope.traverse_impl(result);
354 }
355 for scope in &self.union_scopes {
356 scope.traverse_impl(result);
357 }
358 for scope in &self.table_scopes {
359 scope.traverse_impl(result);
360 }
361 for scope in &self.subquery_scopes {
362 scope.traverse_impl(result);
363 }
364 result.push(self);
366 }
367
368 pub fn ref_count(&self) -> HashMap<usize, usize> {
370 let mut counts: HashMap<usize, usize> = HashMap::new();
371
372 for scope in self.traverse() {
373 for (_, source_info) in scope.sources.iter() {
374 if source_info.is_scope {
375 let id = &source_info.expression as *const _ as usize;
376 *counts.entry(id).or_insert(0) += 1;
377 }
378 }
379 }
380
381 counts
382 }
383}
384
385fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
387 match expr {
388 Expression::Column(col) => {
389 columns.push(ColumnRef {
390 table: col.table.as_ref().map(|t| t.name.clone()),
391 name: col.name.name.clone(),
392 });
393 }
394 Expression::Select(select) => {
395 for e in &select.expressions {
397 collect_columns(e, columns);
398 }
399 for join in &select.joins {
401 if let Some(on) = &join.on {
402 collect_columns(on, columns);
403 }
404 if let Some(match_condition) = &join.match_condition {
405 collect_columns(match_condition, columns);
406 }
407 }
408 if let Some(where_clause) = &select.where_clause {
410 collect_columns(&where_clause.this, columns);
411 }
412 if let Some(having) = &select.having {
414 collect_columns(&having.this, columns);
415 }
416 if let Some(order_by) = &select.order_by {
418 for ord in &order_by.expressions {
419 collect_columns(&ord.this, columns);
420 }
421 }
422 if let Some(group_by) = &select.group_by {
424 for e in &group_by.expressions {
425 collect_columns(e, columns);
426 }
427 }
428 }
431 Expression::And(bin)
433 | Expression::Or(bin)
434 | Expression::Add(bin)
435 | Expression::Sub(bin)
436 | Expression::Mul(bin)
437 | Expression::Div(bin)
438 | Expression::Mod(bin)
439 | Expression::Eq(bin)
440 | Expression::Neq(bin)
441 | Expression::Lt(bin)
442 | Expression::Lte(bin)
443 | Expression::Gt(bin)
444 | Expression::Gte(bin)
445 | Expression::BitwiseAnd(bin)
446 | Expression::BitwiseOr(bin)
447 | Expression::BitwiseXor(bin)
448 | Expression::Concat(bin) => {
449 collect_columns(&bin.left, columns);
450 collect_columns(&bin.right, columns);
451 }
452 Expression::Like(like) | Expression::ILike(like) => {
454 collect_columns(&like.left, columns);
455 collect_columns(&like.right, columns);
456 if let Some(escape) = &like.escape {
457 collect_columns(escape, columns);
458 }
459 }
460 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
462 collect_columns(&un.this, columns);
463 }
464 Expression::Function(func) => {
465 for arg in &func.args {
466 collect_columns(arg, columns);
467 }
468 }
469 Expression::AggregateFunction(agg) => {
470 for arg in &agg.args {
471 collect_columns(arg, columns);
472 }
473 }
474 Expression::WindowFunction(wf) => {
475 collect_columns(&wf.this, columns);
476 for e in &wf.over.partition_by {
477 collect_columns(e, columns);
478 }
479 for e in &wf.over.order_by {
480 collect_columns(&e.this, columns);
481 }
482 }
483 Expression::Alias(alias) => {
484 collect_columns(&alias.this, columns);
485 }
486 Expression::Case(case) => {
487 if let Some(operand) = &case.operand {
488 collect_columns(operand, columns);
489 }
490 for (when_expr, then_expr) in &case.whens {
491 collect_columns(when_expr, columns);
492 collect_columns(then_expr, columns);
493 }
494 if let Some(else_clause) = &case.else_ {
495 collect_columns(else_clause, columns);
496 }
497 }
498 Expression::Paren(paren) => {
499 collect_columns(&paren.this, columns);
500 }
501 Expression::Ordered(ord) => {
502 collect_columns(&ord.this, columns);
503 }
504 Expression::In(in_expr) => {
505 collect_columns(&in_expr.this, columns);
506 for e in &in_expr.expressions {
507 collect_columns(e, columns);
508 }
509 }
511 Expression::Between(between) => {
512 collect_columns(&between.this, columns);
513 collect_columns(&between.low, columns);
514 collect_columns(&between.high, columns);
515 }
516 Expression::IsNull(is_null) => {
517 collect_columns(&is_null.this, columns);
518 }
519 Expression::Cast(cast) => {
520 collect_columns(&cast.this, columns);
521 }
522 Expression::Extract(extract) => {
523 collect_columns(&extract.this, columns);
524 }
525 Expression::Exists(_) | Expression::Subquery(_) => {
526 }
528 _ => {
529 }
531 }
532}
533
534pub fn build_scope(expression: &Expression) -> Scope {
539 let mut root = Scope::new(expression.clone());
540 build_scope_impl(expression, &mut root);
541 root
542}
543
544fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
545 match expression {
546 Expression::Select(select) => {
547 if let Some(with) = &select.with {
549 for cte in &with.ctes {
550 let cte_name = cte.alias.name.clone();
551 let mut cte_scope = current_scope
552 .branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
553 build_scope_impl(&cte.this, &mut cte_scope);
554 current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
555 current_scope.cte_scopes.push(cte_scope);
556 }
557 }
558
559 if let Some(from) = &select.from {
561 for table in &from.expressions {
562 add_table_to_scope(table, current_scope);
563 }
564 }
565
566 for join in &select.joins {
568 add_table_to_scope(&join.this, current_scope);
569 }
570
571 collect_subqueries(expression, current_scope);
573 }
574 Expression::Union(union) => {
575 let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
576 build_scope_impl(&union.left, &mut left_scope);
577
578 let mut right_scope =
579 current_scope.branch(union.right.clone(), ScopeType::SetOperation);
580 build_scope_impl(&union.right, &mut right_scope);
581
582 current_scope.union_scopes.push(left_scope);
583 current_scope.union_scopes.push(right_scope);
584 }
585 Expression::Intersect(intersect) => {
586 let mut left_scope =
587 current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
588 build_scope_impl(&intersect.left, &mut left_scope);
589
590 let mut right_scope =
591 current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
592 build_scope_impl(&intersect.right, &mut right_scope);
593
594 current_scope.union_scopes.push(left_scope);
595 current_scope.union_scopes.push(right_scope);
596 }
597 Expression::Except(except) => {
598 let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
599 build_scope_impl(&except.left, &mut left_scope);
600
601 let mut right_scope =
602 current_scope.branch(except.right.clone(), ScopeType::SetOperation);
603 build_scope_impl(&except.right, &mut right_scope);
604
605 current_scope.union_scopes.push(left_scope);
606 current_scope.union_scopes.push(right_scope);
607 }
608 _ => {}
609 }
610}
611
612fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
613 match expr {
614 Expression::Table(table) => {
615 let name = table
616 .alias
617 .as_ref()
618 .map(|a| a.name.clone())
619 .unwrap_or_else(|| table.name.name.clone());
620 let cte_source = if table.schema.is_none() && table.catalog.is_none() {
621 scope
622 .cte_sources
623 .get(&table.name.name)
624 .or_else(|| {
625 scope
626 .cte_sources
627 .iter()
628 .find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
629 .map(|(_, source)| source)
630 })
631 } else {
632 None
633 };
634
635 if let Some(source) = cte_source {
636 scope.add_source(name, source.expression.clone(), true);
637 } else {
638 scope.add_source(name, expr.clone(), false);
639 }
640 }
641 Expression::Subquery(subquery) => {
642 let name = subquery
643 .alias
644 .as_ref()
645 .map(|a| a.name.clone())
646 .unwrap_or_default();
647
648 let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
649 build_scope_impl(&subquery.this, &mut derived_scope);
650
651 scope.add_source(name.clone(), expr.clone(), true);
652 scope.derived_table_scopes.push(derived_scope);
653 }
654 Expression::Paren(paren) => {
655 add_table_to_scope(&paren.this, scope);
656 }
657 _ => {}
658 }
659}
660
661fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
662 match expr {
663 Expression::Select(select) => {
664 if let Some(where_clause) = &select.where_clause {
666 collect_subqueries_in_expr(&where_clause.this, parent_scope);
667 }
668 for e in &select.expressions {
670 collect_subqueries_in_expr(e, parent_scope);
671 }
672 if let Some(having) = &select.having {
674 collect_subqueries_in_expr(&having.this, parent_scope);
675 }
676 }
677 _ => {}
678 }
679}
680
681fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
682 match expr {
683 Expression::Subquery(subquery) if subquery.alias.is_none() => {
684 let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
686 build_scope_impl(&subquery.this, &mut sub_scope);
687 parent_scope.subquery_scopes.push(sub_scope);
688 }
689 Expression::In(in_expr) => {
690 collect_subqueries_in_expr(&in_expr.this, parent_scope);
691 if let Some(query) = &in_expr.query {
692 let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
693 build_scope_impl(query, &mut sub_scope);
694 parent_scope.subquery_scopes.push(sub_scope);
695 }
696 }
697 Expression::Exists(exists) => {
698 let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
699 build_scope_impl(&exists.this, &mut sub_scope);
700 parent_scope.subquery_scopes.push(sub_scope);
701 }
702 Expression::And(bin)
704 | Expression::Or(bin)
705 | Expression::Add(bin)
706 | Expression::Sub(bin)
707 | Expression::Mul(bin)
708 | Expression::Div(bin)
709 | Expression::Mod(bin)
710 | Expression::Eq(bin)
711 | Expression::Neq(bin)
712 | Expression::Lt(bin)
713 | Expression::Lte(bin)
714 | Expression::Gt(bin)
715 | Expression::Gte(bin)
716 | Expression::BitwiseAnd(bin)
717 | Expression::BitwiseOr(bin)
718 | Expression::BitwiseXor(bin)
719 | Expression::Concat(bin) => {
720 collect_subqueries_in_expr(&bin.left, parent_scope);
721 collect_subqueries_in_expr(&bin.right, parent_scope);
722 }
723 Expression::Like(like) | Expression::ILike(like) => {
725 collect_subqueries_in_expr(&like.left, parent_scope);
726 collect_subqueries_in_expr(&like.right, parent_scope);
727 if let Some(escape) = &like.escape {
728 collect_subqueries_in_expr(escape, parent_scope);
729 }
730 }
731 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
733 collect_subqueries_in_expr(&un.this, parent_scope);
734 }
735 Expression::Function(func) => {
736 for arg in &func.args {
737 collect_subqueries_in_expr(arg, parent_scope);
738 }
739 }
740 Expression::Case(case) => {
741 if let Some(operand) = &case.operand {
742 collect_subqueries_in_expr(operand, parent_scope);
743 }
744 for (when_expr, then_expr) in &case.whens {
745 collect_subqueries_in_expr(when_expr, parent_scope);
746 collect_subqueries_in_expr(then_expr, parent_scope);
747 }
748 if let Some(else_clause) = &case.else_ {
749 collect_subqueries_in_expr(else_clause, parent_scope);
750 }
751 }
752 Expression::Paren(paren) => {
753 collect_subqueries_in_expr(&paren.this, parent_scope);
754 }
755 Expression::Alias(alias) => {
756 collect_subqueries_in_expr(&alias.this, parent_scope);
757 }
758 _ => {}
759 }
760}
761
762pub fn walk_in_scope<'a>(
774 expression: &'a Expression,
775 bfs: bool,
776) -> impl Iterator<Item = &'a Expression> {
777 WalkInScopeIter::new(expression, bfs)
778}
779
780struct WalkInScopeIter<'a> {
782 queue: VecDeque<&'a Expression>,
783 bfs: bool,
784}
785
786impl<'a> WalkInScopeIter<'a> {
787 fn new(expression: &'a Expression, bfs: bool) -> Self {
788 let mut queue = VecDeque::new();
789 queue.push_back(expression);
790 Self { queue, bfs }
791 }
792
793 fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
794 if is_root {
795 return false;
796 }
797
798 if matches!(expr, Expression::Cte(_)) {
800 return true;
801 }
802
803 if let Expression::Subquery(subquery) = expr {
805 if subquery.alias.is_some() {
806 return true;
807 }
808 }
809
810 if matches!(
812 expr,
813 Expression::Select(_)
814 | Expression::Union(_)
815 | Expression::Intersect(_)
816 | Expression::Except(_)
817 ) {
818 return true;
819 }
820
821 false
822 }
823
824 fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
825 let mut children = Vec::new();
826
827 match expr {
828 Expression::Select(select) => {
829 for e in &select.expressions {
831 children.push(e);
832 }
833 if let Some(from) = &select.from {
835 for table in &from.expressions {
836 if !self.should_stop_at(table, false) {
837 children.push(table);
838 }
839 }
840 }
841 for join in &select.joins {
843 if let Some(on) = &join.on {
844 children.push(on);
845 }
846 }
848 if let Some(where_clause) = &select.where_clause {
850 children.push(&where_clause.this);
851 }
852 if let Some(group_by) = &select.group_by {
854 for e in &group_by.expressions {
855 children.push(e);
856 }
857 }
858 if let Some(having) = &select.having {
860 children.push(&having.this);
861 }
862 if let Some(order_by) = &select.order_by {
864 for ord in &order_by.expressions {
865 children.push(&ord.this);
866 }
867 }
868 if let Some(limit) = &select.limit {
870 children.push(&limit.this);
871 }
872 if let Some(offset) = &select.offset {
874 children.push(&offset.this);
875 }
876 }
877 Expression::And(bin)
878 | Expression::Or(bin)
879 | Expression::Add(bin)
880 | Expression::Sub(bin)
881 | Expression::Mul(bin)
882 | Expression::Div(bin)
883 | Expression::Mod(bin)
884 | Expression::Eq(bin)
885 | Expression::Neq(bin)
886 | Expression::Lt(bin)
887 | Expression::Lte(bin)
888 | Expression::Gt(bin)
889 | Expression::Gte(bin)
890 | Expression::BitwiseAnd(bin)
891 | Expression::BitwiseOr(bin)
892 | Expression::BitwiseXor(bin)
893 | Expression::Concat(bin) => {
894 children.push(&bin.left);
895 children.push(&bin.right);
896 }
897 Expression::Like(like) | Expression::ILike(like) => {
898 children.push(&like.left);
899 children.push(&like.right);
900 if let Some(escape) = &like.escape {
901 children.push(escape);
902 }
903 }
904 Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
905 children.push(&un.this);
906 }
907 Expression::Function(func) => {
908 for arg in &func.args {
909 children.push(arg);
910 }
911 }
912 Expression::AggregateFunction(agg) => {
913 for arg in &agg.args {
914 children.push(arg);
915 }
916 }
917 Expression::WindowFunction(wf) => {
918 children.push(&wf.this);
919 for e in &wf.over.partition_by {
920 children.push(e);
921 }
922 for e in &wf.over.order_by {
923 children.push(&e.this);
924 }
925 }
926 Expression::Alias(alias) => {
927 children.push(&alias.this);
928 }
929 Expression::Case(case) => {
930 if let Some(operand) = &case.operand {
931 children.push(operand);
932 }
933 for (when_expr, then_expr) in &case.whens {
934 children.push(when_expr);
935 children.push(then_expr);
936 }
937 if let Some(else_clause) = &case.else_ {
938 children.push(else_clause);
939 }
940 }
941 Expression::Paren(paren) => {
942 children.push(&paren.this);
943 }
944 Expression::Ordered(ord) => {
945 children.push(&ord.this);
946 }
947 Expression::In(in_expr) => {
948 children.push(&in_expr.this);
949 for e in &in_expr.expressions {
950 children.push(e);
951 }
952 }
954 Expression::Between(between) => {
955 children.push(&between.this);
956 children.push(&between.low);
957 children.push(&between.high);
958 }
959 Expression::IsNull(is_null) => {
960 children.push(&is_null.this);
961 }
962 Expression::Cast(cast) => {
963 children.push(&cast.this);
964 }
965 Expression::Extract(extract) => {
966 children.push(&extract.this);
967 }
968 Expression::Coalesce(coalesce) => {
969 for e in &coalesce.expressions {
970 children.push(e);
971 }
972 }
973 Expression::NullIf(nullif) => {
974 children.push(&nullif.this);
975 children.push(&nullif.expression);
976 }
977 Expression::Table(_table) => {
978 }
981 Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
982 }
984 Expression::Subquery(_) | Expression::Exists(_) => {}
986 _ => {
987 }
989 }
990
991 children
992 }
993}
994
995impl<'a> Iterator for WalkInScopeIter<'a> {
996 type Item = &'a Expression;
997
998 fn next(&mut self) -> Option<Self::Item> {
999 let expr = if self.bfs {
1000 self.queue.pop_front()?
1001 } else {
1002 self.queue.pop_back()?
1003 };
1004
1005 let children = self.get_children(expr);
1007
1008 if self.bfs {
1009 for child in children {
1010 if !self.should_stop_at(child, false) {
1011 self.queue.push_back(child);
1012 }
1013 }
1014 } else {
1015 for child in children.into_iter().rev() {
1016 if !self.should_stop_at(child, false) {
1017 self.queue.push_back(child);
1018 }
1019 }
1020 }
1021
1022 Some(expr)
1023 }
1024}
1025
1026pub fn find_in_scope<'a, F>(
1038 expression: &'a Expression,
1039 predicate: F,
1040 bfs: bool,
1041) -> Option<&'a Expression>
1042where
1043 F: Fn(&Expression) -> bool,
1044{
1045 walk_in_scope(expression, bfs).find(|e| predicate(e))
1046}
1047
1048pub fn find_all_in_scope<'a, F>(
1060 expression: &'a Expression,
1061 predicate: F,
1062 bfs: bool,
1063) -> Vec<&'a Expression>
1064where
1065 F: Fn(&Expression) -> bool,
1066{
1067 walk_in_scope(expression, bfs)
1068 .filter(|e| predicate(e))
1069 .collect()
1070}
1071
1072pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
1082 match expression {
1083 Expression::Select(_)
1084 | Expression::Union(_)
1085 | Expression::Intersect(_)
1086 | Expression::Except(_) => {
1087 let root = build_scope(expression);
1088 root.traverse().into_iter().cloned().collect()
1089 }
1090 _ => Vec::new(),
1091 }
1092}
1093
1094#[cfg(test)]
1095mod tests {
1096 use super::*;
1097 use crate::parser::Parser;
1098
1099 fn parse_and_build_scope(sql: &str) -> Scope {
1100 let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
1101 build_scope(&ast[0])
1102 }
1103
1104 #[test]
1105 fn test_simple_select_scope() {
1106 let mut scope = parse_and_build_scope("SELECT a, b FROM t");
1107
1108 assert!(scope.is_root());
1109 assert!(!scope.can_be_correlated);
1110 assert!(scope.sources.contains_key("t"));
1111
1112 let columns = scope.columns();
1113 assert_eq!(columns.len(), 2);
1114 }
1115
1116 #[test]
1117 fn test_derived_table_scope() {
1118 let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
1119
1120 assert!(scope.sources.contains_key("x"));
1121 assert_eq!(scope.derived_table_scopes.len(), 1);
1122
1123 let derived = &mut scope.derived_table_scopes[0];
1124 assert!(derived.is_derived_table());
1125 assert!(derived.sources.contains_key("t"));
1126 }
1127
1128 #[test]
1129 fn test_non_correlated_subquery() {
1130 let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
1131
1132 assert_eq!(scope.subquery_scopes.len(), 1);
1133
1134 let subquery = &mut scope.subquery_scopes[0];
1135 assert!(subquery.is_subquery());
1136 assert!(subquery.can_be_correlated);
1137
1138 assert!(subquery.sources.contains_key("s"));
1140 assert!(!subquery.is_correlated_subquery());
1141 }
1142
1143 #[test]
1144 fn test_correlated_subquery() {
1145 let mut scope =
1146 parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
1147
1148 assert_eq!(scope.subquery_scopes.len(), 1);
1149
1150 let subquery = &mut scope.subquery_scopes[0];
1151 assert!(subquery.is_subquery());
1152 assert!(subquery.can_be_correlated);
1153
1154 let external = subquery.external_columns();
1156 assert!(!external.is_empty());
1157 assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
1158 assert!(subquery.is_correlated_subquery());
1159 }
1160
1161 #[test]
1162 fn test_cte_scope() {
1163 let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
1164
1165 assert_eq!(scope.cte_scopes.len(), 1);
1166 assert!(scope.cte_sources.contains_key("cte"));
1167
1168 let cte = &scope.cte_scopes[0];
1169 assert!(cte.is_cte());
1170 }
1171
1172 #[test]
1173 fn test_multiple_sources() {
1174 let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
1175
1176 assert!(scope.sources.contains_key("t"));
1177 assert!(scope.sources.contains_key("s"));
1178 assert_eq!(scope.sources.len(), 2);
1179 }
1180
1181 #[test]
1182 fn test_aliased_table() {
1183 let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
1184
1185 assert!(scope.sources.contains_key("x"));
1187 assert!(!scope.sources.contains_key("t"));
1188 }
1189
1190 #[test]
1191 fn test_local_columns() {
1192 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1193
1194 let local = scope.local_columns();
1195 assert_eq!(local.len(), 5);
1198 assert!(local.iter().all(|c| c.table.is_some()));
1199 }
1200
1201 #[test]
1202 fn test_columns_include_join_on_clause_references() {
1203 let mut scope = parse_and_build_scope(
1204 "SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
1205 );
1206
1207 let cols: Vec<String> = scope
1208 .columns()
1209 .iter()
1210 .map(|c| match &c.table {
1211 Some(t) => format!("{}.{}", t, c.name),
1212 None => c.name.clone(),
1213 })
1214 .collect();
1215
1216 assert!(cols.contains(&"o.total".to_string()));
1217 assert!(cols.contains(&"c.id".to_string()));
1218 assert!(cols.contains(&"o.customer_id".to_string()));
1219 }
1220
1221 #[test]
1222 fn test_unqualified_columns() {
1223 let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
1224
1225 let unqualified = scope.unqualified_columns();
1226 assert_eq!(unqualified.len(), 2);
1228 assert!(unqualified.iter().all(|c| c.table.is_none()));
1229 }
1230
1231 #[test]
1232 fn test_source_columns() {
1233 let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
1234
1235 let t_cols = scope.source_columns("t");
1236 assert!(t_cols.len() >= 2);
1238 assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
1239
1240 let s_cols = scope.source_columns("s");
1241 assert!(s_cols.len() >= 1);
1243 assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
1244 }
1245
1246 #[test]
1247 fn test_rename_source() {
1248 let mut scope = parse_and_build_scope("SELECT a FROM t");
1249
1250 assert!(scope.sources.contains_key("t"));
1251 scope.rename_source("t", "new_name".to_string());
1252 assert!(!scope.sources.contains_key("t"));
1253 assert!(scope.sources.contains_key("new_name"));
1254 }
1255
1256 #[test]
1257 fn test_remove_source() {
1258 let mut scope = parse_and_build_scope("SELECT a FROM t");
1259
1260 assert!(scope.sources.contains_key("t"));
1261 scope.remove_source("t");
1262 assert!(!scope.sources.contains_key("t"));
1263 }
1264
1265 #[test]
1266 fn test_walk_in_scope() {
1267 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1268 let expr = &ast[0];
1269
1270 let walked: Vec<_> = walk_in_scope(expr, true).collect();
1272 assert!(!walked.is_empty());
1273
1274 assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
1276 assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
1278 }
1279
1280 #[test]
1281 fn test_find_in_scope() {
1282 let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
1283 let expr = &ast[0];
1284
1285 let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1287 assert!(found.is_some());
1288 assert!(matches!(found.unwrap(), Expression::Column(_)));
1289 }
1290
1291 #[test]
1292 fn test_find_all_in_scope() {
1293 let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
1294 let expr = &ast[0];
1295
1296 let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
1298 assert_eq!(found.len(), 3);
1299 }
1300
1301 #[test]
1302 fn test_traverse_scope() {
1303 let ast =
1304 Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
1305 let expr = &ast[0];
1306
1307 let scopes = traverse_scope(expr);
1308 assert!(!scopes.is_empty());
1311 assert!(scopes.iter().any(|s| s.is_root()));
1313 }
1314
1315 #[test]
1316 fn test_branch_with_options() {
1317 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1318 let scope = build_scope(&ast[0]);
1319
1320 let child = scope.branch_with_options(
1321 ast[0].clone(),
1322 ScopeType::Subquery, None,
1324 None,
1325 Some(vec!["col1".to_string(), "col2".to_string()]),
1326 );
1327
1328 assert_eq!(child.outer_columns, vec!["col1", "col2"]);
1329 assert!(child.can_be_correlated); }
1331
1332 #[test]
1333 fn test_is_udtf() {
1334 let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
1335 let scope = Scope::new(ast[0].clone());
1336 assert!(!scope.is_udtf());
1337
1338 let root = build_scope(&ast[0]);
1339 let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
1340 assert!(udtf_scope.is_udtf());
1341 }
1342
1343 #[test]
1344 fn test_is_union() {
1345 let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
1346
1347 assert!(scope.is_root());
1348 assert_eq!(scope.union_scopes.len(), 2);
1349 assert!(scope.union_scopes[0].is_union());
1351 assert!(scope.union_scopes[1].is_union());
1352 }
1353
1354 #[test]
1355 fn test_clear_cache() {
1356 let mut scope = parse_and_build_scope("SELECT t.a FROM t");
1357
1358 let _ = scope.columns();
1360 assert!(scope.columns_cache.is_some());
1361
1362 scope.clear_cache();
1364 assert!(scope.columns_cache.is_none());
1365 assert!(scope.external_columns_cache.is_none());
1366 }
1367
1368 #[test]
1369 fn test_scope_traverse() {
1370 let scope = parse_and_build_scope(
1371 "WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
1372 );
1373
1374 let traversed = scope.traverse();
1375 assert!(traversed.len() >= 3);
1377 }
1378}