use crate::sql::parser::ast::{SelectStatement, SqlExpression, WhereClause};
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq)]
pub enum SubqueryLocation {
FromClause,
WhereClause,
SelectList,
HavingClause,
JoinCondition,
}
#[derive(Debug, Clone, PartialEq)]
pub enum SubqueryType {
Scalar,
InList { negated: bool },
Exists { negated: bool },
DerivedTable,
}
#[derive(Debug, Clone)]
pub struct SubqueryInfo {
pub location: SubqueryLocation,
pub subquery_type: SubqueryType,
pub is_correlated: bool,
pub outer_references: Vec<String>,
pub statement: SelectStatement,
}
#[derive(Debug, Default)]
pub struct CorrelationAnalysis {
pub subqueries: Vec<SubqueryInfo>,
pub total_count: usize,
pub correlated_count: usize,
pub non_correlated_count: usize,
}
impl CorrelationAnalysis {
pub fn report(&self) -> String {
let mut report = String::new();
report.push_str(&format!("=== Subquery Analysis ===\n"));
report.push_str(&format!("Total subqueries: {}\n", self.total_count));
report.push_str(&format!(" Correlated: {}\n", self.correlated_count));
report.push_str(&format!(
" Non-correlated: {}\n",
self.non_correlated_count
));
report.push_str("\n");
if self.subqueries.is_empty() {
report.push_str("No subqueries detected.\n");
return report;
}
for (idx, info) in self.subqueries.iter().enumerate() {
report.push_str(&format!("Subquery #{}: ", idx + 1));
report.push_str(&format!("{:?} - ", info.location));
report.push_str(&format!("{:?}", info.subquery_type));
if info.is_correlated {
report.push_str(" [CORRELATED]\n");
report.push_str(&format!(
" Outer references: {:?}\n",
info.outer_references
));
} else {
report.push_str(" [NON-CORRELATED]\n");
}
}
report
}
}
pub struct CorrelatedSubqueryAnalyzer {
scope_stack: Vec<HashSet<String>>,
}
impl CorrelatedSubqueryAnalyzer {
pub fn new() -> Self {
Self {
scope_stack: vec![HashSet::new()],
}
}
pub fn analyze(&mut self, stmt: &SelectStatement) -> CorrelationAnalysis {
let mut analysis = CorrelationAnalysis::default();
let mut current_scope = HashSet::new();
if let Some(ref table) = stmt.from_table {
current_scope.insert(table.clone());
}
if let Some(ref alias) = stmt.from_alias {
current_scope.insert(alias.clone());
}
self.scope_stack.push(current_scope);
self.analyze_from_clause(stmt, &mut analysis);
self.analyze_select_list(stmt, &mut analysis);
self.analyze_where_clause(stmt, &mut analysis);
self.analyze_having_clause(stmt, &mut analysis);
self.scope_stack.pop();
analysis.total_count = analysis.subqueries.len();
analysis.correlated_count = analysis
.subqueries
.iter()
.filter(|s| s.is_correlated)
.count();
analysis.non_correlated_count = analysis.total_count - analysis.correlated_count;
analysis
}
fn analyze_from_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
if let Some(ref subquery) = stmt.from_subquery {
let outer_refs = self.find_outer_references(subquery);
analysis.subqueries.push(SubqueryInfo {
location: SubqueryLocation::FromClause,
subquery_type: SubqueryType::DerivedTable,
is_correlated: !outer_refs.is_empty(),
outer_references: outer_refs,
statement: (**subquery).clone(),
});
}
}
fn analyze_select_list(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
for item in &stmt.select_items {
if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
self.analyze_expression_for_subqueries(
expr,
SubqueryLocation::SelectList,
analysis,
);
}
}
}
fn analyze_where_clause(&mut self, stmt: &SelectStatement, analysis: &mut CorrelationAnalysis) {
if let Some(ref where_clause) = stmt.where_clause {
for condition in &where_clause.conditions {
self.analyze_expression_for_subqueries(
&condition.expr,
SubqueryLocation::WhereClause,
analysis,
);
}
}
}
fn analyze_having_clause(
&mut self,
stmt: &SelectStatement,
analysis: &mut CorrelationAnalysis,
) {
if let Some(ref having_expr) = stmt.having {
self.analyze_expression_for_subqueries(
having_expr,
SubqueryLocation::HavingClause,
analysis,
);
}
}
fn analyze_expression_for_subqueries(
&mut self,
expr: &SqlExpression,
location: SubqueryLocation,
analysis: &mut CorrelationAnalysis,
) {
match expr {
SqlExpression::ScalarSubquery { query } => {
let outer_refs = self.find_outer_references(query);
analysis.subqueries.push(SubqueryInfo {
location: location.clone(),
subquery_type: SubqueryType::Scalar,
is_correlated: !outer_refs.is_empty(),
outer_references: outer_refs,
statement: (**query).clone(),
});
}
SqlExpression::InSubquery { expr: _, subquery } => {
let outer_refs = self.find_outer_references(subquery);
analysis.subqueries.push(SubqueryInfo {
location: location.clone(),
subquery_type: SubqueryType::InList { negated: false },
is_correlated: !outer_refs.is_empty(),
outer_references: outer_refs,
statement: (**subquery).clone(),
});
}
SqlExpression::NotInSubquery { expr: _, subquery } => {
let outer_refs = self.find_outer_references(subquery);
analysis.subqueries.push(SubqueryInfo {
location: location.clone(),
subquery_type: SubqueryType::InList { negated: true },
is_correlated: !outer_refs.is_empty(),
outer_references: outer_refs,
statement: (**subquery).clone(),
});
}
SqlExpression::BinaryOp { left, right, .. } => {
self.analyze_expression_for_subqueries(left, location.clone(), analysis);
self.analyze_expression_for_subqueries(right, location, analysis);
}
SqlExpression::Not { expr } => {
self.analyze_expression_for_subqueries(expr, location, analysis);
}
SqlExpression::FunctionCall { args, .. } => {
for arg in args {
self.analyze_expression_for_subqueries(arg, location.clone(), analysis);
}
}
_ => {
}
}
}
fn find_outer_references(&self, subquery: &SelectStatement) -> Vec<String> {
let mut outer_refs = Vec::new();
let mut referenced_tables = HashSet::new();
self.collect_table_references(subquery, &mut referenced_tables);
for table in &referenced_tables {
for scope in self.scope_stack.iter().rev().skip(1) {
if scope.contains(table) {
outer_refs.push(table.clone());
break;
}
}
}
outer_refs.sort();
outer_refs.dedup();
outer_refs
}
fn collect_table_references(&self, stmt: &SelectStatement, refs: &mut HashSet<String>) {
if let Some(ref where_clause) = stmt.where_clause {
self.collect_references_from_where(where_clause, refs);
}
for item in &stmt.select_items {
if let crate::sql::parser::ast::SelectItem::Expression { expr, .. } = item {
self.collect_references_from_expr(expr, refs);
}
}
if let Some(ref having) = stmt.having {
self.collect_references_from_expr(having, refs);
}
}
fn collect_references_from_where(
&self,
where_clause: &WhereClause,
refs: &mut HashSet<String>,
) {
for condition in &where_clause.conditions {
self.collect_references_from_expr(&condition.expr, refs);
}
}
fn collect_references_from_expr(&self, expr: &SqlExpression, refs: &mut HashSet<String>) {
match expr {
SqlExpression::Column(col_ref) => {
if let Some(ref table) = col_ref.table_prefix {
refs.insert(table.clone());
}
}
SqlExpression::BinaryOp { left, right, .. } => {
self.collect_references_from_expr(left, refs);
self.collect_references_from_expr(right, refs);
}
SqlExpression::Not { expr } => {
self.collect_references_from_expr(expr, refs);
}
SqlExpression::FunctionCall { args, .. } => {
for arg in args {
self.collect_references_from_expr(arg, refs);
}
}
SqlExpression::InList { expr, values } => {
self.collect_references_from_expr(expr, refs);
for val in values {
self.collect_references_from_expr(val, refs);
}
}
SqlExpression::NotInList { expr, values } => {
self.collect_references_from_expr(expr, refs);
for val in values {
self.collect_references_from_expr(val, refs);
}
}
_ => {
}
}
}
}
impl Default for CorrelatedSubqueryAnalyzer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sql::parser::ast::{Condition, QuoteStyle};
#[test]
fn test_non_correlated_scalar_subquery() {
let mut analyzer = CorrelatedSubqueryAnalyzer::new();
let main_stmt = SelectStatement {
from_table: Some("trades".to_string()),
..Default::default()
};
let analysis = analyzer.analyze(&main_stmt);
assert_eq!(analysis.total_count, 0);
}
#[test]
fn test_from_clause_subquery() {
let mut analyzer = CorrelatedSubqueryAnalyzer::new();
let subquery = SelectStatement {
from_table: Some("inner_table".to_string()),
..Default::default()
};
let main_stmt = SelectStatement {
from_subquery: Some(Box::new(subquery)),
from_alias: Some("sub".to_string()),
..Default::default()
};
let analysis = analyzer.analyze(&main_stmt);
assert_eq!(analysis.total_count, 1);
assert_eq!(
analysis.subqueries[0].location,
SubqueryLocation::FromClause
);
assert_eq!(
analysis.subqueries[0].subquery_type,
SubqueryType::DerivedTable
);
assert!(!analysis.subqueries[0].is_correlated);
}
#[test]
fn test_analysis_report_format() {
let analysis = CorrelationAnalysis {
subqueries: vec![],
total_count: 0,
correlated_count: 0,
non_correlated_count: 0,
};
let report = analysis.report();
assert!(report.contains("Subquery Analysis"));
assert!(report.contains("No subqueries detected"));
}
}