use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet, VecDeque};
use tracing::{debug, info};
use crate::sql::recursive_parser::{Parser, SelectStatement, SqlExpression, TableSource};
#[derive(Debug, Clone, PartialEq)]
pub struct TableReferences {
pub reads: Vec<String>,
pub writes: Vec<String>,
}
impl TableReferences {
fn new() -> Self {
Self {
reads: Vec::new(),
writes: Vec::new(),
}
}
fn add_read(&mut self, table: String) {
if !self.reads.contains(&table) {
self.reads.push(table);
}
}
fn add_write(&mut self, table: String) {
if !self.writes.contains(&table) {
self.writes.push(table);
}
}
}
#[derive(Debug, Clone)]
pub struct DependencyStatement {
pub number: usize,
pub sql: String,
pub references: TableReferences,
pub creates_temp_table: bool,
}
#[derive(Debug, Clone)]
pub struct ExecutionPlan {
pub statements_to_execute: Vec<usize>,
pub statements_to_skip: Vec<usize>,
pub target_statement: usize,
pub dependency_graph: HashMap<usize, Vec<usize>>,
}
impl ExecutionPlan {
pub fn format_debug_trace(&self, statements: &[DependencyStatement]) -> String {
let mut output = Vec::new();
output.push("=== Execution Plan Debug Trace ===\n".to_string());
output.push(format!("Target Statement: #{}\n", self.target_statement));
output.push(format!(
"Statements to Execute: {:?}\n",
self.statements_to_execute
));
output.push(format!(
"Statements to Skip: {:?}\n\n",
self.statements_to_skip
));
output.push("--- Dependency Graph ---\n".to_string());
for stmt_num in &self.statements_to_execute {
if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
output.push(format!("\nStatement #{}: ", stmt_num));
if stmt.creates_temp_table {
output.push("[TEMP TABLE] ".to_string());
}
output.push(format!("\n Reads: {:?}", stmt.references.reads));
output.push(format!("\n Writes: {:?}", stmt.references.writes));
if let Some(deps) = self.dependency_graph.get(stmt_num) {
if !deps.is_empty() {
output.push(format!("\n Depends on: {:?}", deps));
}
}
output.push("\n SQL: ".to_string());
output.push(
stmt.sql
.lines()
.map(|line| format!(" {}", line))
.collect::<Vec<_>>()
.join("\n"),
);
}
}
output.push("\n\n--- Skipped Statements ---\n".to_string());
for stmt_num in &self.statements_to_skip {
if let Some(stmt) = statements.iter().find(|s| s.number == *stmt_num) {
output.push(format!("\nStatement #{}: [SKIPPED]\n", stmt_num));
output.push(format!(" Reads: {:?}\n", stmt.references.reads));
output.push(format!(" Writes: {:?}\n", stmt.references.writes));
}
}
output.join("")
}
}
pub struct DependencyAnalyzer;
impl DependencyAnalyzer {
pub fn analyze_statements(statements: &[String]) -> Result<Vec<DependencyStatement>> {
let mut analyzed = Vec::new();
for (idx, sql) in statements.iter().enumerate() {
let number = idx + 1;
let mut parser = Parser::new(sql);
let ast = parser
.parse()
.map_err(|e| anyhow!("Failed to parse statement {}: {}", number, e))?;
let creates_temp_table = ast.into_table.is_some() || Self::is_create_temp_table(sql);
let references = Self::extract_table_references(&ast)?;
analyzed.push(DependencyStatement {
number,
sql: sql.clone(),
references,
creates_temp_table,
});
}
Ok(analyzed)
}
fn extract_table_references(ast: &SelectStatement) -> Result<TableReferences> {
let mut refs = TableReferences::new();
if let Some(ref into_table) = ast.into_table {
refs.add_write(into_table.name.clone());
}
if let Some(table) = &ast.from_table {
refs.add_read(table.clone());
}
if let Some(subquery) = &ast.from_subquery {
let subquery_refs = Self::extract_table_references(subquery)?;
for table in subquery_refs.reads {
refs.add_read(table);
}
}
if let Some(_function) = &ast.from_function {
}
for join in &ast.joins {
Self::extract_from_table_source(&join.table, &mut refs)?;
}
for cte in &ast.ctes {
match &cte.cte_type {
crate::sql::parser::ast::CTEType::Standard(stmt) => {
let cte_refs = Self::extract_table_references(stmt)?;
for table in cte_refs.reads {
refs.add_read(table);
}
}
_ => {} }
}
if let Some(where_clause) = &ast.where_clause {
for condition in &where_clause.conditions {
Self::extract_from_expression(&condition.expr, &mut refs)?;
}
}
Ok(refs)
}
fn extract_from_table_source(
table_source: &TableSource,
refs: &mut TableReferences,
) -> Result<()> {
match table_source {
TableSource::Table(name) => {
refs.add_read(name.clone());
}
TableSource::DerivedTable { query, .. } => {
let subquery_refs = Self::extract_table_references(query)?;
for table in subquery_refs.reads {
refs.add_read(table);
}
}
TableSource::Pivot { source, .. } => {
Self::extract_from_table_source(source, refs)?;
}
}
Ok(())
}
fn extract_from_expression(expr: &SqlExpression, refs: &mut TableReferences) -> Result<()> {
match expr {
SqlExpression::ScalarSubquery { query } => {
let subquery_refs = Self::extract_table_references(query)?;
for table in subquery_refs.reads {
refs.add_read(table);
}
}
SqlExpression::InSubquery {
expr: inner_expr,
subquery,
} => {
Self::extract_from_expression(inner_expr, refs)?;
let subquery_refs = Self::extract_table_references(subquery)?;
for table in subquery_refs.reads {
refs.add_read(table);
}
}
SqlExpression::NotInSubquery {
expr: inner_expr,
subquery,
} => {
Self::extract_from_expression(inner_expr, refs)?;
let subquery_refs = Self::extract_table_references(subquery)?;
for table in subquery_refs.reads {
refs.add_read(table);
}
}
SqlExpression::BinaryOp { left, right, .. } => {
Self::extract_from_expression(left, refs)?;
Self::extract_from_expression(right, refs)?;
}
SqlExpression::FunctionCall { args, .. } => {
for arg in args {
Self::extract_from_expression(arg, refs)?;
}
}
SqlExpression::WindowFunction { args, .. } => {
for arg in args {
Self::extract_from_expression(arg, refs)?;
}
}
SqlExpression::MethodCall { args, .. } => {
for arg in args {
Self::extract_from_expression(arg, refs)?;
}
}
SqlExpression::ChainedMethodCall { base, args, .. } => {
Self::extract_from_expression(base, refs)?;
for arg in args {
Self::extract_from_expression(arg, refs)?;
}
}
_ => {} }
Ok(())
}
fn is_create_temp_table(sql: &str) -> bool {
let sql_lower = sql.to_lowercase();
sql_lower.contains("create temp table") || sql_lower.contains("create temporary table")
}
pub fn compute_execution_plan(
statements: &[DependencyStatement],
target_statement_number: usize,
) -> Result<ExecutionPlan> {
if target_statement_number == 0 || target_statement_number > statements.len() {
return Err(anyhow!(
"Invalid target statement number: {}. Must be 1-{}",
target_statement_number,
statements.len()
));
}
info!(
"Computing execution plan for statement #{}",
target_statement_number
);
let mut dependency_graph: HashMap<usize, Vec<usize>> = HashMap::new();
let mut table_creators: HashMap<String, usize> = HashMap::new();
for stmt in statements {
for table in &stmt.references.writes {
table_creators.insert(table.clone(), stmt.number);
}
let mut depends_on = Vec::new();
for table in &stmt.references.reads {
for candidate in statements {
if candidate.number >= stmt.number {
break; }
if candidate.references.writes.contains(table) {
if !depends_on.contains(&candidate.number) {
depends_on.push(candidate.number);
}
}
}
}
if !depends_on.is_empty() {
dependency_graph.insert(stmt.number, depends_on);
}
}
debug!("Dependency graph: {:?}", dependency_graph);
let mut to_execute = HashSet::new();
let mut queue = VecDeque::new();
queue.push_back(target_statement_number);
while let Some(stmt_num) = queue.pop_front() {
if to_execute.insert(stmt_num) {
if let Some(deps) = dependency_graph.get(&stmt_num) {
for &dep in deps {
queue.push_back(dep);
}
}
}
}
let mut statements_to_execute: Vec<usize> = to_execute.into_iter().collect();
statements_to_execute.sort_unstable();
let statements_to_skip: Vec<usize> = (1..=statements.len())
.filter(|n| !statements_to_execute.contains(n))
.collect();
info!(
"Execution plan: execute {:?}, skip {:?}",
statements_to_execute, statements_to_skip
);
Ok(ExecutionPlan {
statements_to_execute,
statements_to_skip,
target_statement: target_statement_number,
dependency_graph,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_dependency() {
let statements = vec![
"SELECT * FROM sales INTO #raw_data".to_string(),
"SELECT COUNT(*) FROM customers".to_string(),
"SELECT * FROM #raw_data WHERE amount > 100".to_string(),
];
let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
assert_eq!(analyzed.len(), 3);
assert_eq!(analyzed[0].references.writes, vec!["#raw_data"]);
assert_eq!(analyzed[0].references.reads, vec!["sales"]);
assert_eq!(analyzed[1].references.reads, vec!["customers"]);
assert!(analyzed[1].references.writes.is_empty());
assert_eq!(analyzed[2].references.reads, vec!["#raw_data"]);
}
#[test]
fn test_execution_plan() {
let statements = vec![
"SELECT * FROM sales INTO #raw_data".to_string(),
"SELECT COUNT(*) FROM customers".to_string(),
"SELECT * FROM #raw_data WHERE amount > 100".to_string(),
];
let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 3).unwrap();
assert_eq!(plan.statements_to_execute, vec![1, 3]);
assert_eq!(plan.statements_to_skip, vec![2]);
assert_eq!(plan.target_statement, 3);
}
#[test]
fn test_transitive_dependencies() {
let statements = vec![
"SELECT * FROM base INTO #t1".to_string(),
"SELECT * FROM #t1 INTO #t2".to_string(),
"SELECT * FROM #t2 INTO #t3".to_string(),
"SELECT * FROM unrelated".to_string(),
"SELECT * FROM #t3".to_string(),
];
let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
let plan = DependencyAnalyzer::compute_execution_plan(&analyzed, 5).unwrap();
assert_eq!(plan.statements_to_execute, vec![1, 2, 3, 5]);
assert_eq!(plan.statements_to_skip, vec![4]);
}
#[test]
fn test_invalid_statement_number() {
let statements = vec!["SELECT 1".to_string()];
let analyzed = DependencyAnalyzer::analyze_statements(&statements).unwrap();
assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 0).is_err());
assert!(DependencyAnalyzer::compute_execution_plan(&analyzed, 5).is_err());
}
}