use crate::sql::parser::ast::{
CTEType, Condition, SelectItem, SelectStatement, SqlExpression, WhereClause, CTE,
};
use std::collections::{HashMap, HashSet};
pub struct CTEHoister {
hoisted_ctes: Vec<CTE>,
_cte_counter: usize,
dependency_graph: HashMap<String, HashSet<String>>,
}
impl CTEHoister {
pub fn new() -> Self {
Self {
hoisted_ctes: Vec::new(),
_cte_counter: 0,
dependency_graph: HashMap::new(),
}
}
pub fn hoist_ctes(mut statement: SelectStatement) -> SelectStatement {
let mut hoister = CTEHoister::new();
for cte in statement.ctes.drain(..) {
hoister.add_cte(cte);
}
let rewritten = hoister.hoist_from_statement(statement);
SelectStatement {
ctes: hoister.get_ordered_ctes(),
..rewritten
}
}
fn hoist_from_statement(&mut self, mut statement: SelectStatement) -> SelectStatement {
if let Some(subquery) = statement.from_subquery.take() {
let rewritten_sub = self.hoist_from_statement(*subquery);
for cte in rewritten_sub.ctes.clone() {
self.add_cte(cte);
}
statement.from_subquery = Some(Box::new(SelectStatement {
ctes: Vec::new(),
..rewritten_sub
}));
}
let local_ctes = statement.ctes.drain(..).collect::<Vec<_>>();
for mut cte in local_ctes {
if let CTEType::Standard(query) = cte.cte_type {
let hoisted_query = self.hoist_from_statement(query);
cte.cte_type = CTEType::Standard(hoisted_query);
}
self.add_cte(cte);
}
statement.select_items = statement
.select_items
.into_iter()
.map(|item| self.hoist_from_select_item(item))
.collect();
if let Some(where_clause) = &mut statement.where_clause {
self.hoist_from_where_clause(where_clause);
}
SelectStatement {
ctes: Vec::new(),
..statement
}
}
fn hoist_from_select_item(&mut self, item: SelectItem) -> SelectItem {
match item {
SelectItem::Expression {
expr,
alias,
leading_comments,
trailing_comment,
} => SelectItem::Expression {
expr: self.hoist_from_expression(expr),
alias,
leading_comments,
trailing_comment,
},
other => other,
}
}
fn hoist_from_expression(&mut self, expr: SqlExpression) -> SqlExpression {
match expr {
SqlExpression::ScalarSubquery { query } => {
let rewritten = self.hoist_from_statement(*query);
SqlExpression::ScalarSubquery {
query: Box::new(rewritten),
}
}
SqlExpression::BinaryOp { left, op, right } => SqlExpression::BinaryOp {
left: Box::new(self.hoist_from_expression(*left)),
op,
right: Box::new(self.hoist_from_expression(*right)),
},
SqlExpression::FunctionCall {
name,
args,
distinct,
} => SqlExpression::FunctionCall {
name,
args: args
.into_iter()
.map(|arg| self.hoist_from_expression(arg))
.collect(),
distinct,
},
SqlExpression::CaseExpression {
when_branches,
else_branch,
} => SqlExpression::CaseExpression {
when_branches: when_branches
.into_iter()
.map(|branch| crate::sql::parser::ast::WhenBranch {
condition: Box::new(self.hoist_from_expression(*branch.condition)),
result: Box::new(self.hoist_from_expression(*branch.result)),
})
.collect(),
else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
},
SqlExpression::InList { expr, values } => SqlExpression::InList {
expr: Box::new(self.hoist_from_expression(*expr)),
values: values
.into_iter()
.map(|e| self.hoist_from_expression(e))
.collect(),
},
SqlExpression::NotInList { expr, values } => SqlExpression::NotInList {
expr: Box::new(self.hoist_from_expression(*expr)),
values: values
.into_iter()
.map(|e| self.hoist_from_expression(e))
.collect(),
},
SqlExpression::InSubquery { expr, subquery } => {
let rewritten = self.hoist_from_statement(*subquery);
SqlExpression::InSubquery {
expr: Box::new(self.hoist_from_expression(*expr)),
subquery: Box::new(rewritten),
}
}
SqlExpression::NotInSubquery { expr, subquery } => {
let rewritten = self.hoist_from_statement(*subquery);
SqlExpression::NotInSubquery {
expr: Box::new(self.hoist_from_expression(*expr)),
subquery: Box::new(rewritten),
}
}
SqlExpression::Between { expr, lower, upper } => SqlExpression::Between {
expr: Box::new(self.hoist_from_expression(*expr)),
lower: Box::new(self.hoist_from_expression(*lower)),
upper: Box::new(self.hoist_from_expression(*upper)),
},
SqlExpression::Not { expr } => SqlExpression::Not {
expr: Box::new(self.hoist_from_expression(*expr)),
},
SqlExpression::SimpleCaseExpression {
expr,
when_branches,
else_branch,
} => SqlExpression::SimpleCaseExpression {
expr: Box::new(self.hoist_from_expression(*expr)),
when_branches: when_branches
.into_iter()
.map(|branch| crate::sql::parser::ast::SimpleWhenBranch {
value: Box::new(self.hoist_from_expression(*branch.value)),
result: Box::new(self.hoist_from_expression(*branch.result)),
})
.collect(),
else_branch: else_branch.map(|e| Box::new(self.hoist_from_expression(*e))),
},
other => other,
}
}
fn hoist_from_where_clause(&mut self, where_clause: &mut WhereClause) {
for condition in &mut where_clause.conditions {
condition.expr = self.hoist_from_expression(condition.expr.clone());
}
}
fn hoist_from_condition(&mut self, condition: &mut Condition) {
condition.expr = self.hoist_from_expression(condition.expr.clone());
}
fn add_cte(&mut self, cte: CTE) {
self.analyze_cte_dependencies(&cte);
self.hoisted_ctes.push(cte);
}
fn analyze_cte_dependencies(&mut self, cte: &CTE) {
let mut deps = HashSet::new();
if let CTEType::Standard(query) = &cte.cte_type {
self.find_cte_references(query, &mut deps);
}
self.dependency_graph.insert(cte.name.clone(), deps);
}
fn find_cte_references(&self, statement: &SelectStatement, deps: &mut HashSet<String>) {
if let Some(table) = &statement.from_table {
for cte in &self.hoisted_ctes {
if cte.name == *table {
deps.insert(table.clone());
}
}
}
if let Some(subquery) = &statement.from_subquery {
self.find_cte_references(subquery, deps);
}
for join in &statement.joins {
if let crate::sql::parser::ast::TableSource::Table(table_name) = &join.table {
for cte in &self.hoisted_ctes {
if cte.name == *table_name {
deps.insert(table_name.clone());
}
}
}
}
for item in &statement.select_items {
if let SelectItem::Expression { expr, .. } = item {
self.find_cte_refs_in_expression(expr, deps);
}
}
if let Some(where_clause) = &statement.where_clause {
for condition in &where_clause.conditions {
self.find_cte_refs_in_expression(&condition.expr, deps);
}
}
}
fn find_cte_refs_in_expression(&self, expr: &SqlExpression, deps: &mut HashSet<String>) {
match expr {
SqlExpression::ScalarSubquery { query } => {
self.find_cte_references(query, deps);
}
SqlExpression::InSubquery { subquery, .. } => {
self.find_cte_references(subquery, deps);
}
SqlExpression::NotInSubquery { subquery, .. } => {
self.find_cte_references(subquery, deps);
}
SqlExpression::FunctionCall { args, .. } => {
for arg in args {
self.find_cte_refs_in_expression(arg, deps);
}
}
SqlExpression::BinaryOp { left, right, .. } => {
self.find_cte_refs_in_expression(left, deps);
self.find_cte_refs_in_expression(right, deps);
}
SqlExpression::CaseExpression {
when_branches,
else_branch,
} => {
for branch in when_branches {
self.find_cte_refs_in_expression(&branch.condition, deps);
self.find_cte_refs_in_expression(&branch.result, deps);
}
if let Some(else_expr) = else_branch {
self.find_cte_refs_in_expression(else_expr, deps);
}
}
SqlExpression::SimpleCaseExpression {
expr,
when_branches,
else_branch,
} => {
self.find_cte_refs_in_expression(expr, deps);
for branch in when_branches {
self.find_cte_refs_in_expression(&branch.value, deps);
self.find_cte_refs_in_expression(&branch.result, deps);
}
if let Some(else_expr) = else_branch {
self.find_cte_refs_in_expression(else_expr, deps);
}
}
SqlExpression::InList { expr, values } | SqlExpression::NotInList { expr, values } => {
self.find_cte_refs_in_expression(expr, deps);
for value in values {
self.find_cte_refs_in_expression(value, deps);
}
}
SqlExpression::Between { expr, lower, upper } => {
self.find_cte_refs_in_expression(expr, deps);
self.find_cte_refs_in_expression(lower, deps);
self.find_cte_refs_in_expression(upper, deps);
}
SqlExpression::Not { expr } => {
self.find_cte_refs_in_expression(expr, deps);
}
_ => {}
}
}
fn get_ordered_ctes(self) -> Vec<CTE> {
let mut result = Vec::new();
let mut visited = HashSet::new();
let mut temp_mark = HashSet::new();
fn visit(
name: &str,
graph: &HashMap<String, HashSet<String>>,
ctes: &[CTE],
visited: &mut HashSet<String>,
temp_mark: &mut HashSet<String>,
result: &mut Vec<CTE>,
) {
if visited.contains(name) {
return;
}
if temp_mark.contains(name) {
return;
}
temp_mark.insert(name.to_string());
if let Some(deps) = graph.get(name) {
for dep in deps {
visit(dep, graph, ctes, visited, temp_mark, result);
}
}
temp_mark.remove(name);
visited.insert(name.to_string());
if let Some(cte) = ctes.iter().find(|c| c.name == name) {
result.push(cte.clone());
}
}
for cte in &self.hoisted_ctes {
visit(
&cte.name,
&self.dependency_graph,
&self.hoisted_ctes,
&mut visited,
&mut temp_mark,
&mut result,
);
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_cte_hoisting() {
let inner_query = SelectStatement {
distinct: false,
columns: vec!["col1".to_string()],
select_items: vec![],
from_source: None,
#[allow(deprecated)]
from_table: Some("table1".to_string()),
#[allow(deprecated)]
from_subquery: None,
#[allow(deprecated)]
from_function: None,
#[allow(deprecated)]
from_alias: None,
joins: vec![],
where_clause: None,
order_by: None,
group_by: None,
having: None,
qualify: None,
limit: None,
offset: None,
ctes: vec![],
into_table: None,
set_operations: vec![],
leading_comments: vec![],
trailing_comment: None,
};
let nested_query = SelectStatement {
distinct: false,
columns: vec![],
select_items: vec![],
from_source: None,
#[allow(deprecated)]
from_subquery: Some(Box::new(SelectStatement {
distinct: false,
columns: vec![],
select_items: vec![],
ctes: vec![CTE {
name: "inner".to_string(),
column_list: None,
cte_type: CTEType::Standard(inner_query),
}],
from_source: None,
#[allow(deprecated)]
from_table: Some("inner".to_string()),
#[allow(deprecated)]
from_subquery: None,
#[allow(deprecated)]
from_function: None,
#[allow(deprecated)]
from_alias: None,
joins: vec![],
where_clause: None,
order_by: None,
group_by: None,
having: None,
qualify: None,
limit: None,
offset: None,
into_table: None,
set_operations: vec![],
leading_comments: vec![],
trailing_comment: None,
})),
#[allow(deprecated)]
from_table: None,
#[allow(deprecated)]
from_function: None,
#[allow(deprecated)]
from_alias: None,
joins: vec![],
where_clause: None,
order_by: None,
group_by: None,
having: None,
qualify: None,
limit: None,
offset: None,
ctes: vec![],
into_table: None,
set_operations: vec![],
leading_comments: vec![],
trailing_comment: None,
};
let result = CTEHoister::hoist_ctes(nested_query);
assert_eq!(result.ctes.len(), 1);
assert_eq!(result.ctes[0].name, "inner");
assert!(result.from_subquery.as_ref().unwrap().ctes.is_empty());
}
}