use crate::expressions::Expression;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
#[cfg(feature = "bindings")]
use ts_rs::TS;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "bindings", derive(TS))]
#[cfg_attr(feature = "bindings", ts(export))]
pub enum ScopeType {
Root,
Subquery,
DerivedTable,
Cte,
SetOperation,
Udtf,
}
#[derive(Debug, Clone)]
pub struct SourceInfo {
pub expression: Expression,
pub is_scope: bool,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ColumnRef {
pub table: Option<String>,
pub name: String,
}
#[derive(Debug, Clone)]
pub struct Scope {
pub expression: Expression,
pub scope_type: ScopeType,
pub sources: HashMap<String, SourceInfo>,
pub lateral_sources: HashMap<String, SourceInfo>,
pub cte_sources: HashMap<String, SourceInfo>,
pub outer_columns: Vec<String>,
pub can_be_correlated: bool,
pub subquery_scopes: Vec<Scope>,
pub derived_table_scopes: Vec<Scope>,
pub cte_scopes: Vec<Scope>,
pub udtf_scopes: Vec<Scope>,
pub table_scopes: Vec<Scope>,
pub union_scopes: Vec<Scope>,
columns_cache: Option<Vec<ColumnRef>>,
external_columns_cache: Option<Vec<ColumnRef>>,
}
impl Scope {
pub fn new(expression: Expression) -> Self {
Self {
expression,
scope_type: ScopeType::Root,
sources: HashMap::new(),
lateral_sources: HashMap::new(),
cte_sources: HashMap::new(),
outer_columns: Vec::new(),
can_be_correlated: false,
subquery_scopes: Vec::new(),
derived_table_scopes: Vec::new(),
cte_scopes: Vec::new(),
udtf_scopes: Vec::new(),
table_scopes: Vec::new(),
union_scopes: Vec::new(),
columns_cache: None,
external_columns_cache: None,
}
}
pub fn branch(&self, expression: Expression, scope_type: ScopeType) -> Self {
self.branch_with_options(expression, scope_type, None, None, None)
}
pub fn branch_with_options(
&self,
expression: Expression,
scope_type: ScopeType,
sources: Option<HashMap<String, SourceInfo>>,
lateral_sources: Option<HashMap<String, SourceInfo>>,
outer_columns: Option<Vec<String>>,
) -> Self {
let can_be_correlated = self.can_be_correlated
|| scope_type == ScopeType::Subquery
|| scope_type == ScopeType::Udtf;
Self {
expression,
scope_type,
sources: sources.unwrap_or_default(),
lateral_sources: lateral_sources.unwrap_or_default(),
cte_sources: self.cte_sources.clone(),
outer_columns: outer_columns.unwrap_or_default(),
can_be_correlated,
subquery_scopes: Vec::new(),
derived_table_scopes: Vec::new(),
cte_scopes: Vec::new(),
udtf_scopes: Vec::new(),
table_scopes: Vec::new(),
union_scopes: Vec::new(),
columns_cache: None,
external_columns_cache: None,
}
}
pub fn clear_cache(&mut self) {
self.columns_cache = None;
self.external_columns_cache = None;
}
pub fn add_source(&mut self, name: String, expression: Expression, is_scope: bool) {
self.sources.insert(
name,
SourceInfo {
expression,
is_scope,
},
);
self.clear_cache();
}
pub fn add_lateral_source(&mut self, name: String, expression: Expression, is_scope: bool) {
self.lateral_sources.insert(
name.clone(),
SourceInfo {
expression: expression.clone(),
is_scope,
},
);
self.sources.insert(
name,
SourceInfo {
expression,
is_scope,
},
);
self.clear_cache();
}
pub fn add_cte_source(&mut self, name: String, expression: Expression) {
self.cte_sources.insert(
name.clone(),
SourceInfo {
expression: expression.clone(),
is_scope: true,
},
);
self.sources.insert(
name,
SourceInfo {
expression,
is_scope: true,
},
);
self.clear_cache();
}
pub fn rename_source(&mut self, old_name: &str, new_name: String) {
if let Some(source) = self.sources.remove(old_name) {
self.sources.insert(new_name, source);
}
self.clear_cache();
}
pub fn remove_source(&mut self, name: &str) {
self.sources.remove(name);
self.clear_cache();
}
pub fn columns(&mut self) -> &[ColumnRef] {
if self.columns_cache.is_none() {
let mut columns = Vec::new();
collect_columns(&self.expression, &mut columns);
self.columns_cache = Some(columns);
}
self.columns_cache.as_ref().unwrap()
}
pub fn output_columns(&self) -> Vec<String> {
crate::ast_transforms::get_output_column_names(&self.expression)
}
pub fn source_names(&self) -> HashSet<String> {
let mut names: HashSet<String> = self.sources.keys().cloned().collect();
names.extend(self.cte_sources.keys().cloned());
names
}
pub fn external_columns(&mut self) -> Vec<ColumnRef> {
if self.external_columns_cache.is_some() {
return self.external_columns_cache.clone().unwrap();
}
let source_names = self.source_names();
let columns = self.columns().to_vec();
let external: Vec<ColumnRef> = columns
.into_iter()
.filter(|col| {
match &col.table {
Some(table) => !source_names.contains(table),
None => false, }
})
.collect();
self.external_columns_cache = Some(external.clone());
external
}
pub fn local_columns(&mut self) -> Vec<ColumnRef> {
let external_set: HashSet<_> = self.external_columns().into_iter().collect();
let columns = self.columns().to_vec();
columns
.into_iter()
.filter(|col| !external_set.contains(col))
.collect()
}
pub fn unqualified_columns(&mut self) -> Vec<ColumnRef> {
self.columns()
.iter()
.filter(|c| c.table.is_none())
.cloned()
.collect()
}
pub fn source_columns(&mut self, source_name: &str) -> Vec<ColumnRef> {
self.columns()
.iter()
.filter(|col| col.table.as_deref() == Some(source_name))
.cloned()
.collect()
}
pub fn is_correlated_subquery(&mut self) -> bool {
self.can_be_correlated && !self.external_columns().is_empty()
}
pub fn is_subquery(&self) -> bool {
self.scope_type == ScopeType::Subquery
}
pub fn is_derived_table(&self) -> bool {
self.scope_type == ScopeType::DerivedTable
}
pub fn is_cte(&self) -> bool {
self.scope_type == ScopeType::Cte
}
pub fn is_root(&self) -> bool {
self.scope_type == ScopeType::Root
}
pub fn is_udtf(&self) -> bool {
self.scope_type == ScopeType::Udtf
}
pub fn is_union(&self) -> bool {
self.scope_type == ScopeType::SetOperation
}
pub fn traverse(&self) -> Vec<&Scope> {
let mut result = Vec::new();
self.traverse_impl(&mut result);
result
}
fn traverse_impl<'a>(&'a self, result: &mut Vec<&'a Scope>) {
for scope in &self.cte_scopes {
scope.traverse_impl(result);
}
for scope in &self.union_scopes {
scope.traverse_impl(result);
}
for scope in &self.table_scopes {
scope.traverse_impl(result);
}
for scope in &self.subquery_scopes {
scope.traverse_impl(result);
}
result.push(self);
}
pub fn ref_count(&self) -> HashMap<usize, usize> {
let mut counts: HashMap<usize, usize> = HashMap::new();
for scope in self.traverse() {
for (_, source_info) in scope.sources.iter() {
if source_info.is_scope {
let id = &source_info.expression as *const _ as usize;
*counts.entry(id).or_insert(0) += 1;
}
}
}
counts
}
}
fn collect_columns(expr: &Expression, columns: &mut Vec<ColumnRef>) {
match expr {
Expression::Column(col) => {
columns.push(ColumnRef {
table: col.table.as_ref().map(|t| t.name.clone()),
name: col.name.name.clone(),
});
}
Expression::Select(select) => {
for e in &select.expressions {
collect_columns(e, columns);
}
for join in &select.joins {
if let Some(on) = &join.on {
collect_columns(on, columns);
}
if let Some(match_condition) = &join.match_condition {
collect_columns(match_condition, columns);
}
}
if let Some(where_clause) = &select.where_clause {
collect_columns(&where_clause.this, columns);
}
if let Some(having) = &select.having {
collect_columns(&having.this, columns);
}
if let Some(order_by) = &select.order_by {
for ord in &order_by.expressions {
collect_columns(&ord.this, columns);
}
}
if let Some(group_by) = &select.group_by {
for e in &group_by.expressions {
collect_columns(e, columns);
}
}
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin)
| Expression::Mod(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::BitwiseAnd(bin)
| Expression::BitwiseOr(bin)
| Expression::BitwiseXor(bin)
| Expression::Concat(bin) => {
collect_columns(&bin.left, columns);
collect_columns(&bin.right, columns);
}
Expression::Like(like) | Expression::ILike(like) => {
collect_columns(&like.left, columns);
collect_columns(&like.right, columns);
if let Some(escape) = &like.escape {
collect_columns(escape, columns);
}
}
Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
collect_columns(&un.this, columns);
}
Expression::Function(func) => {
for arg in &func.args {
collect_columns(arg, columns);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
collect_columns(arg, columns);
}
}
Expression::WindowFunction(wf) => {
collect_columns(&wf.this, columns);
for e in &wf.over.partition_by {
collect_columns(e, columns);
}
for e in &wf.over.order_by {
collect_columns(&e.this, columns);
}
}
Expression::Alias(alias) => {
collect_columns(&alias.this, columns);
}
Expression::Case(case) => {
if let Some(operand) = &case.operand {
collect_columns(operand, columns);
}
for (when_expr, then_expr) in &case.whens {
collect_columns(when_expr, columns);
collect_columns(then_expr, columns);
}
if let Some(else_clause) = &case.else_ {
collect_columns(else_clause, columns);
}
}
Expression::Paren(paren) => {
collect_columns(&paren.this, columns);
}
Expression::Ordered(ord) => {
collect_columns(&ord.this, columns);
}
Expression::In(in_expr) => {
collect_columns(&in_expr.this, columns);
for e in &in_expr.expressions {
collect_columns(e, columns);
}
}
Expression::Between(between) => {
collect_columns(&between.this, columns);
collect_columns(&between.low, columns);
collect_columns(&between.high, columns);
}
Expression::IsNull(is_null) => {
collect_columns(&is_null.this, columns);
}
Expression::Cast(cast) => {
collect_columns(&cast.this, columns);
}
Expression::Extract(extract) => {
collect_columns(&extract.this, columns);
}
Expression::Exists(_) | Expression::Subquery(_) => {
}
_ => {
}
}
}
pub fn build_scope(expression: &Expression) -> Scope {
let mut root = Scope::new(expression.clone());
build_scope_impl(expression, &mut root);
root
}
fn build_scope_impl(expression: &Expression, current_scope: &mut Scope) {
match expression {
Expression::Select(select) => {
if let Some(with) = &select.with {
for cte in &with.ctes {
let cte_name = cte.alias.name.clone();
let mut cte_scope = current_scope
.branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
build_scope_impl(&cte.this, &mut cte_scope);
current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
current_scope.cte_scopes.push(cte_scope);
}
}
if let Some(from) = &select.from {
for table in &from.expressions {
add_table_to_scope(table, current_scope);
}
}
for join in &select.joins {
add_table_to_scope(&join.this, current_scope);
}
collect_subqueries(expression, current_scope);
}
Expression::Union(union) => {
let mut left_scope = current_scope.branch(union.left.clone(), ScopeType::SetOperation);
build_scope_impl(&union.left, &mut left_scope);
let mut right_scope =
current_scope.branch(union.right.clone(), ScopeType::SetOperation);
build_scope_impl(&union.right, &mut right_scope);
current_scope.union_scopes.push(left_scope);
current_scope.union_scopes.push(right_scope);
}
Expression::Intersect(intersect) => {
let mut left_scope =
current_scope.branch(intersect.left.clone(), ScopeType::SetOperation);
build_scope_impl(&intersect.left, &mut left_scope);
let mut right_scope =
current_scope.branch(intersect.right.clone(), ScopeType::SetOperation);
build_scope_impl(&intersect.right, &mut right_scope);
current_scope.union_scopes.push(left_scope);
current_scope.union_scopes.push(right_scope);
}
Expression::Except(except) => {
let mut left_scope = current_scope.branch(except.left.clone(), ScopeType::SetOperation);
build_scope_impl(&except.left, &mut left_scope);
let mut right_scope =
current_scope.branch(except.right.clone(), ScopeType::SetOperation);
build_scope_impl(&except.right, &mut right_scope);
current_scope.union_scopes.push(left_scope);
current_scope.union_scopes.push(right_scope);
}
Expression::CreateTable(create) => {
if let Some(with) = &create.with_cte {
for cte in &with.ctes {
let cte_name = cte.alias.name.clone();
let mut cte_scope = current_scope
.branch(Expression::Cte(Box::new(cte.clone())), ScopeType::Cte);
build_scope_impl(&cte.this, &mut cte_scope);
current_scope.add_cte_source(cte_name, Expression::Cte(Box::new(cte.clone())));
current_scope.cte_scopes.push(cte_scope);
}
}
if let Some(as_select) = &create.as_select {
build_scope_impl(as_select, current_scope);
}
}
_ => {}
}
}
fn add_table_to_scope(expr: &Expression, scope: &mut Scope) {
match expr {
Expression::Table(table) => {
let name = table
.alias
.as_ref()
.map(|a| a.name.clone())
.unwrap_or_else(|| table.name.name.clone());
let cte_source = if table.schema.is_none() && table.catalog.is_none() {
scope.cte_sources.get(&table.name.name).or_else(|| {
scope
.cte_sources
.iter()
.find(|(cte_name, _)| cte_name.eq_ignore_ascii_case(&table.name.name))
.map(|(_, source)| source)
})
} else {
None
};
if let Some(source) = cte_source {
scope.add_source(name, source.expression.clone(), true);
} else {
scope.add_source(name, expr.clone(), false);
}
}
Expression::Subquery(subquery) => {
let name = subquery
.alias
.as_ref()
.map(|a| a.name.clone())
.unwrap_or_default();
let mut derived_scope = scope.branch(subquery.this.clone(), ScopeType::DerivedTable);
build_scope_impl(&subquery.this, &mut derived_scope);
scope.add_source(name.clone(), expr.clone(), true);
scope.derived_table_scopes.push(derived_scope);
}
Expression::Paren(paren) => {
add_table_to_scope(&paren.this, scope);
}
_ => {}
}
}
fn collect_subqueries(expr: &Expression, parent_scope: &mut Scope) {
match expr {
Expression::Select(select) => {
if let Some(where_clause) = &select.where_clause {
collect_subqueries_in_expr(&where_clause.this, parent_scope);
}
for e in &select.expressions {
collect_subqueries_in_expr(e, parent_scope);
}
if let Some(having) = &select.having {
collect_subqueries_in_expr(&having.this, parent_scope);
}
}
_ => {}
}
}
fn collect_subqueries_in_expr(expr: &Expression, parent_scope: &mut Scope) {
match expr {
Expression::Subquery(subquery) if subquery.alias.is_none() => {
let mut sub_scope = parent_scope.branch(subquery.this.clone(), ScopeType::Subquery);
build_scope_impl(&subquery.this, &mut sub_scope);
parent_scope.subquery_scopes.push(sub_scope);
}
Expression::In(in_expr) => {
collect_subqueries_in_expr(&in_expr.this, parent_scope);
if let Some(query) = &in_expr.query {
let mut sub_scope = parent_scope.branch(query.clone(), ScopeType::Subquery);
build_scope_impl(query, &mut sub_scope);
parent_scope.subquery_scopes.push(sub_scope);
}
}
Expression::Exists(exists) => {
let mut sub_scope = parent_scope.branch(exists.this.clone(), ScopeType::Subquery);
build_scope_impl(&exists.this, &mut sub_scope);
parent_scope.subquery_scopes.push(sub_scope);
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin)
| Expression::Mod(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::BitwiseAnd(bin)
| Expression::BitwiseOr(bin)
| Expression::BitwiseXor(bin)
| Expression::Concat(bin) => {
collect_subqueries_in_expr(&bin.left, parent_scope);
collect_subqueries_in_expr(&bin.right, parent_scope);
}
Expression::Like(like) | Expression::ILike(like) => {
collect_subqueries_in_expr(&like.left, parent_scope);
collect_subqueries_in_expr(&like.right, parent_scope);
if let Some(escape) = &like.escape {
collect_subqueries_in_expr(escape, parent_scope);
}
}
Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
collect_subqueries_in_expr(&un.this, parent_scope);
}
Expression::Function(func) => {
for arg in &func.args {
collect_subqueries_in_expr(arg, parent_scope);
}
}
Expression::Case(case) => {
if let Some(operand) = &case.operand {
collect_subqueries_in_expr(operand, parent_scope);
}
for (when_expr, then_expr) in &case.whens {
collect_subqueries_in_expr(when_expr, parent_scope);
collect_subqueries_in_expr(then_expr, parent_scope);
}
if let Some(else_clause) = &case.else_ {
collect_subqueries_in_expr(else_clause, parent_scope);
}
}
Expression::Paren(paren) => {
collect_subqueries_in_expr(&paren.this, parent_scope);
}
Expression::Alias(alias) => {
collect_subqueries_in_expr(&alias.this, parent_scope);
}
_ => {}
}
}
pub fn walk_in_scope<'a>(
expression: &'a Expression,
bfs: bool,
) -> impl Iterator<Item = &'a Expression> {
WalkInScopeIter::new(expression, bfs)
}
struct WalkInScopeIter<'a> {
queue: VecDeque<&'a Expression>,
bfs: bool,
}
impl<'a> WalkInScopeIter<'a> {
fn new(expression: &'a Expression, bfs: bool) -> Self {
let mut queue = VecDeque::new();
queue.push_back(expression);
Self { queue, bfs }
}
fn should_stop_at(&self, expr: &Expression, is_root: bool) -> bool {
if is_root {
return false;
}
if matches!(expr, Expression::Cte(_)) {
return true;
}
if let Expression::Subquery(subquery) = expr {
if subquery.alias.is_some() {
return true;
}
}
if matches!(
expr,
Expression::Select(_)
| Expression::Union(_)
| Expression::Intersect(_)
| Expression::Except(_)
) {
return true;
}
false
}
fn get_children(&self, expr: &'a Expression) -> Vec<&'a Expression> {
let mut children = Vec::new();
match expr {
Expression::Select(select) => {
for e in &select.expressions {
children.push(e);
}
if let Some(from) = &select.from {
for table in &from.expressions {
if !self.should_stop_at(table, false) {
children.push(table);
}
}
}
for join in &select.joins {
if let Some(on) = &join.on {
children.push(on);
}
}
if let Some(where_clause) = &select.where_clause {
children.push(&where_clause.this);
}
if let Some(group_by) = &select.group_by {
for e in &group_by.expressions {
children.push(e);
}
}
if let Some(having) = &select.having {
children.push(&having.this);
}
if let Some(order_by) = &select.order_by {
for ord in &order_by.expressions {
children.push(&ord.this);
}
}
if let Some(limit) = &select.limit {
children.push(&limit.this);
}
if let Some(offset) = &select.offset {
children.push(&offset.this);
}
}
Expression::And(bin)
| Expression::Or(bin)
| Expression::Add(bin)
| Expression::Sub(bin)
| Expression::Mul(bin)
| Expression::Div(bin)
| Expression::Mod(bin)
| Expression::Eq(bin)
| Expression::Neq(bin)
| Expression::Lt(bin)
| Expression::Lte(bin)
| Expression::Gt(bin)
| Expression::Gte(bin)
| Expression::BitwiseAnd(bin)
| Expression::BitwiseOr(bin)
| Expression::BitwiseXor(bin)
| Expression::Concat(bin) => {
children.push(&bin.left);
children.push(&bin.right);
}
Expression::Like(like) | Expression::ILike(like) => {
children.push(&like.left);
children.push(&like.right);
if let Some(escape) = &like.escape {
children.push(escape);
}
}
Expression::Not(un) | Expression::Neg(un) | Expression::BitwiseNot(un) => {
children.push(&un.this);
}
Expression::Function(func) => {
for arg in &func.args {
children.push(arg);
}
}
Expression::AggregateFunction(agg) => {
for arg in &agg.args {
children.push(arg);
}
}
Expression::WindowFunction(wf) => {
children.push(&wf.this);
for e in &wf.over.partition_by {
children.push(e);
}
for e in &wf.over.order_by {
children.push(&e.this);
}
}
Expression::Alias(alias) => {
children.push(&alias.this);
}
Expression::Case(case) => {
if let Some(operand) = &case.operand {
children.push(operand);
}
for (when_expr, then_expr) in &case.whens {
children.push(when_expr);
children.push(then_expr);
}
if let Some(else_clause) = &case.else_ {
children.push(else_clause);
}
}
Expression::Paren(paren) => {
children.push(&paren.this);
}
Expression::Ordered(ord) => {
children.push(&ord.this);
}
Expression::In(in_expr) => {
children.push(&in_expr.this);
for e in &in_expr.expressions {
children.push(e);
}
}
Expression::Between(between) => {
children.push(&between.this);
children.push(&between.low);
children.push(&between.high);
}
Expression::IsNull(is_null) => {
children.push(&is_null.this);
}
Expression::Cast(cast) => {
children.push(&cast.this);
}
Expression::Extract(extract) => {
children.push(&extract.this);
}
Expression::Coalesce(coalesce) => {
for e in &coalesce.expressions {
children.push(e);
}
}
Expression::NullIf(nullif) => {
children.push(&nullif.this);
children.push(&nullif.expression);
}
Expression::Table(_table) => {
}
Expression::Column(_) | Expression::Literal(_) | Expression::Identifier(_) => {
}
Expression::Subquery(_) | Expression::Exists(_) => {}
_ => {
}
}
children
}
}
impl<'a> Iterator for WalkInScopeIter<'a> {
type Item = &'a Expression;
fn next(&mut self) -> Option<Self::Item> {
let expr = if self.bfs {
self.queue.pop_front()?
} else {
self.queue.pop_back()?
};
let children = self.get_children(expr);
if self.bfs {
for child in children {
if !self.should_stop_at(child, false) {
self.queue.push_back(child);
}
}
} else {
for child in children.into_iter().rev() {
if !self.should_stop_at(child, false) {
self.queue.push_back(child);
}
}
}
Some(expr)
}
}
pub fn find_in_scope<'a, F>(
expression: &'a Expression,
predicate: F,
bfs: bool,
) -> Option<&'a Expression>
where
F: Fn(&Expression) -> bool,
{
walk_in_scope(expression, bfs).find(|e| predicate(e))
}
pub fn find_all_in_scope<'a, F>(
expression: &'a Expression,
predicate: F,
bfs: bool,
) -> Vec<&'a Expression>
where
F: Fn(&Expression) -> bool,
{
walk_in_scope(expression, bfs)
.filter(|e| predicate(e))
.collect()
}
pub fn traverse_scope(expression: &Expression) -> Vec<Scope> {
match expression {
Expression::Select(_)
| Expression::Union(_)
| Expression::Intersect(_)
| Expression::Except(_)
| Expression::CreateTable(_) => {
let root = build_scope(expression);
root.traverse().into_iter().cloned().collect()
}
_ => Vec::new(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::Parser;
fn parse_and_build_scope(sql: &str) -> Scope {
let ast = Parser::parse_sql(sql).expect("Failed to parse SQL");
build_scope(&ast[0])
}
#[test]
fn test_simple_select_scope() {
let mut scope = parse_and_build_scope("SELECT a, b FROM t");
assert!(scope.is_root());
assert!(!scope.can_be_correlated);
assert!(scope.sources.contains_key("t"));
let columns = scope.columns();
assert_eq!(columns.len(), 2);
}
#[test]
fn test_derived_table_scope() {
let mut scope = parse_and_build_scope("SELECT x.a FROM (SELECT a FROM t) AS x");
assert!(scope.sources.contains_key("x"));
assert_eq!(scope.derived_table_scopes.len(), 1);
let derived = &mut scope.derived_table_scopes[0];
assert!(derived.is_derived_table());
assert!(derived.sources.contains_key("t"));
}
#[test]
fn test_non_correlated_subquery() {
let mut scope = parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s)");
assert_eq!(scope.subquery_scopes.len(), 1);
let subquery = &mut scope.subquery_scopes[0];
assert!(subquery.is_subquery());
assert!(subquery.can_be_correlated);
assert!(subquery.sources.contains_key("s"));
assert!(!subquery.is_correlated_subquery());
}
#[test]
fn test_correlated_subquery() {
let mut scope =
parse_and_build_scope("SELECT * FROM t WHERE EXISTS (SELECT b FROM s WHERE s.x = t.y)");
assert_eq!(scope.subquery_scopes.len(), 1);
let subquery = &mut scope.subquery_scopes[0];
assert!(subquery.is_subquery());
assert!(subquery.can_be_correlated);
let external = subquery.external_columns();
assert!(!external.is_empty());
assert!(external.iter().any(|c| c.table.as_deref() == Some("t")));
assert!(subquery.is_correlated_subquery());
}
#[test]
fn test_cte_scope() {
let scope = parse_and_build_scope("WITH cte AS (SELECT a FROM t) SELECT * FROM cte");
assert_eq!(scope.cte_scopes.len(), 1);
assert!(scope.cte_sources.contains_key("cte"));
let cte = &scope.cte_scopes[0];
assert!(cte.is_cte());
}
#[test]
fn test_multiple_sources() {
let scope = parse_and_build_scope("SELECT t.a, s.b FROM t JOIN s ON t.id = s.id");
assert!(scope.sources.contains_key("t"));
assert!(scope.sources.contains_key("s"));
assert_eq!(scope.sources.len(), 2);
}
#[test]
fn test_aliased_table() {
let scope = parse_and_build_scope("SELECT x.a FROM t AS x");
assert!(scope.sources.contains_key("x"));
assert!(!scope.sources.contains_key("t"));
}
#[test]
fn test_local_columns() {
let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
let local = scope.local_columns();
assert_eq!(local.len(), 5);
assert!(local.iter().all(|c| c.table.is_some()));
}
#[test]
fn test_columns_include_join_on_clause_references() {
let mut scope = parse_and_build_scope(
"SELECT o.total FROM orders o JOIN customers c ON c.id = o.customer_id",
);
let cols: Vec<String> = scope
.columns()
.iter()
.map(|c| match &c.table {
Some(t) => format!("{}.{}", t, c.name),
None => c.name.clone(),
})
.collect();
assert!(cols.contains(&"o.total".to_string()));
assert!(cols.contains(&"c.id".to_string()));
assert!(cols.contains(&"o.customer_id".to_string()));
}
#[test]
fn test_unqualified_columns() {
let mut scope = parse_and_build_scope("SELECT a, b, t.c FROM t");
let unqualified = scope.unqualified_columns();
assert_eq!(unqualified.len(), 2);
assert!(unqualified.iter().all(|c| c.table.is_none()));
}
#[test]
fn test_source_columns() {
let mut scope = parse_and_build_scope("SELECT t.a, t.b, s.c FROM t JOIN s ON t.id = s.id");
let t_cols = scope.source_columns("t");
assert!(t_cols.len() >= 2);
assert!(t_cols.iter().all(|c| c.table.as_deref() == Some("t")));
let s_cols = scope.source_columns("s");
assert!(s_cols.len() >= 1);
assert!(s_cols.iter().all(|c| c.table.as_deref() == Some("s")));
}
#[test]
fn test_rename_source() {
let mut scope = parse_and_build_scope("SELECT a FROM t");
assert!(scope.sources.contains_key("t"));
scope.rename_source("t", "new_name".to_string());
assert!(!scope.sources.contains_key("t"));
assert!(scope.sources.contains_key("new_name"));
}
#[test]
fn test_remove_source() {
let mut scope = parse_and_build_scope("SELECT a FROM t");
assert!(scope.sources.contains_key("t"));
scope.remove_source("t");
assert!(!scope.sources.contains_key("t"));
}
#[test]
fn test_walk_in_scope() {
let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
let expr = &ast[0];
let walked: Vec<_> = walk_in_scope(expr, true).collect();
assert!(!walked.is_empty());
assert!(walked.iter().any(|e| matches!(e, Expression::Select(_))));
assert!(walked.iter().any(|e| matches!(e, Expression::Column(_))));
}
#[test]
fn test_find_in_scope() {
let ast = Parser::parse_sql("SELECT a, b FROM t WHERE a > 1").expect("Failed to parse");
let expr = &ast[0];
let found = find_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
assert!(found.is_some());
assert!(matches!(found.unwrap(), Expression::Column(_)));
}
#[test]
fn test_find_all_in_scope() {
let ast = Parser::parse_sql("SELECT a, b, c FROM t").expect("Failed to parse");
let expr = &ast[0];
let found = find_all_in_scope(expr, |e| matches!(e, Expression::Column(_)), true);
assert_eq!(found.len(), 3);
}
#[test]
fn test_traverse_scope() {
let ast =
Parser::parse_sql("SELECT a FROM (SELECT b FROM t) AS x").expect("Failed to parse");
let expr = &ast[0];
let scopes = traverse_scope(expr);
assert!(!scopes.is_empty());
assert!(scopes.iter().any(|s| s.is_root()));
}
#[test]
fn test_branch_with_options() {
let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
let scope = build_scope(&ast[0]);
let child = scope.branch_with_options(
ast[0].clone(),
ScopeType::Subquery, None,
None,
Some(vec!["col1".to_string(), "col2".to_string()]),
);
assert_eq!(child.outer_columns, vec!["col1", "col2"]);
assert!(child.can_be_correlated); }
#[test]
fn test_is_udtf() {
let ast = Parser::parse_sql("SELECT a FROM t").expect("Failed to parse");
let scope = Scope::new(ast[0].clone());
assert!(!scope.is_udtf());
let root = build_scope(&ast[0]);
let udtf_scope = root.branch(ast[0].clone(), ScopeType::Udtf);
assert!(udtf_scope.is_udtf());
}
#[test]
fn test_is_union() {
let scope = parse_and_build_scope("SELECT a FROM t UNION SELECT b FROM s");
assert!(scope.is_root());
assert_eq!(scope.union_scopes.len(), 2);
assert!(scope.union_scopes[0].is_union());
assert!(scope.union_scopes[1].is_union());
}
#[test]
fn test_union_output_columns() {
let scope = parse_and_build_scope(
"SELECT id, name FROM customers UNION ALL SELECT id, name FROM employees",
);
assert_eq!(scope.output_columns(), vec!["id", "name"]);
}
#[test]
fn test_clear_cache() {
let mut scope = parse_and_build_scope("SELECT t.a FROM t");
let _ = scope.columns();
assert!(scope.columns_cache.is_some());
scope.clear_cache();
assert!(scope.columns_cache.is_none());
assert!(scope.external_columns_cache.is_none());
}
#[test]
fn test_scope_traverse() {
let scope = parse_and_build_scope(
"WITH cte AS (SELECT a FROM t) SELECT * FROM cte WHERE EXISTS (SELECT b FROM s)",
);
let traversed = scope.traverse();
assert!(traversed.len() >= 3);
}
#[test]
fn test_create_table_as_select_scope() {
let scope = parse_and_build_scope("CREATE TABLE out_table AS SELECT 1 AS id FROM src");
assert!(
scope.sources.contains_key("src"),
"CTAS scope should contain the FROM table"
);
assert!(
!scope.sources.contains_key("out_table"),
"CTAS target table should not be treated as a source"
);
let scope = parse_and_build_scope(
"CREATE TABLE out_table AS SELECT a.id FROM foo AS a JOIN bar AS b ON a.id = b.id",
);
assert!(scope.sources.contains_key("a"));
assert!(scope.sources.contains_key("b"));
assert!(
!scope.sources.contains_key("out_table"),
"CTAS target table should not be treated as a source"
);
let scope = parse_and_build_scope(
"CREATE TABLE out_table AS WITH cte AS (SELECT 1 AS id FROM src) SELECT * FROM cte",
);
assert!(
scope.sources.contains_key("cte"),
"CTAS with CTE should resolve CTE as source"
);
assert!(
!scope.sources.contains_key("out_table"),
"CTAS target table should not be treated as a source"
);
assert_eq!(scope.cte_scopes.len(), 1);
}
#[test]
fn test_create_table_as_select_traverse() {
let ast = Parser::parse_sql("CREATE TABLE t AS SELECT a FROM src").unwrap();
let scopes = traverse_scope(&ast[0]);
assert!(
!scopes.is_empty(),
"traverse_scope should return scopes for CTAS"
);
}
}