use anyhow::Result;
use std::sync::Arc;
use std::time::Instant;
use crate::data::data_view::DataView;
use crate::data::datatable::DataTable;
use crate::data::query_engine::QueryEngine;
use crate::query_plan::{create_pipeline_with_config, IntoClauseRemover};
use crate::sql::parser::ast::SelectStatement;
use super::config::ExecutionConfig;
use super::context::ExecutionContext;
#[derive(Debug)]
pub struct ExecutionResult {
pub dataview: DataView,
pub stats: ExecutionStats,
pub transformed_ast: Option<SelectStatement>,
}
#[derive(Debug, Clone)]
pub struct ExecutionStats {
pub preprocessing_time_ms: f64,
pub execution_time_ms: f64,
pub total_time_ms: f64,
pub row_count: usize,
pub column_count: usize,
pub preprocessing_applied: bool,
}
impl ExecutionStats {
fn new() -> Self {
Self {
preprocessing_time_ms: 0.0,
execution_time_ms: 0.0,
total_time_ms: 0.0,
row_count: 0,
column_count: 0,
preprocessing_applied: false,
}
}
}
pub struct StatementExecutor {
config: ExecutionConfig,
}
impl StatementExecutor {
pub fn new() -> Self {
Self {
config: ExecutionConfig::default(),
}
}
pub fn with_config(config: ExecutionConfig) -> Self {
Self { config }
}
pub fn execute(
&self,
stmt: SelectStatement,
context: &mut ExecutionContext,
) -> Result<ExecutionResult> {
let total_start = Instant::now();
let mut stats = ExecutionStats::new();
let into_table_name = stmt.into_table.as_ref().map(|it| it.name.clone());
let source_table = if let Some(ref from_source) = stmt.from_source {
match from_source {
crate::sql::parser::ast::TableSource::Table(table_name) => {
context.resolve_table(table_name)
}
crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
Self::extract_base_table(&**query, context)
}
crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
Self::extract_base_table_from_source(source, context, &stmt)
}
}
} else {
#[allow(deprecated)]
if let Some(ref from_table) = stmt.from_table {
context.resolve_table(from_table)
} else {
Arc::new(DataTable::dual())
}
};
let preprocess_start = Instant::now();
let (transformed_stmt, preprocessing_applied) = self.apply_preprocessing(stmt)?;
stats.preprocessing_time_ms = preprocess_start.elapsed().as_secs_f64() * 1000.0;
stats.preprocessing_applied = preprocessing_applied;
let final_source_table = if !transformed_stmt.ctes.is_empty() {
Self::extract_base_table(&transformed_stmt, context)
} else {
source_table
};
let exec_start = Instant::now();
let result_view =
self.execute_ast(transformed_stmt.clone(), final_source_table, context)?;
stats.execution_time_ms = exec_start.elapsed().as_secs_f64() * 1000.0;
if let Some(table_name) = into_table_name {
let engine = QueryEngine::with_case_insensitive(self.config.case_insensitive);
let temp_table = engine.materialize_view(result_view.clone())?;
context.store_temp_table(table_name.clone(), Arc::new(temp_table))?;
tracing::debug!("Stored temp table: {}", table_name);
}
stats.total_time_ms = total_start.elapsed().as_secs_f64() * 1000.0;
stats.row_count = result_view.row_count();
stats.column_count = result_view.column_count();
Ok(ExecutionResult {
dataview: result_view,
stats,
transformed_ast: Some(transformed_stmt),
})
}
fn apply_preprocessing(&self, mut stmt: SelectStatement) -> Result<(SelectStatement, bool)> {
let has_from_clause = if stmt.from_source.is_some() {
true
} else {
#[allow(deprecated)]
{
stmt.from_table.is_some()
|| stmt.from_subquery.is_some()
|| stmt.from_function.is_some()
}
};
if !has_from_clause {
return Ok((stmt, false));
}
let mut pipeline = create_pipeline_with_config(
self.config.show_preprocessing,
self.config.show_sql_transformations,
self.config.transformer_config.clone(),
);
match pipeline.process(stmt.clone()) {
Ok(transformed) => {
let final_stmt = if transformed.into_table.is_some() {
IntoClauseRemover::remove_into_clause(transformed)
} else {
transformed
};
Ok((final_stmt, true))
}
Err(e) => {
tracing::debug!("Preprocessing failed: {}, using original statement", e);
let fallback = if stmt.into_table.is_some() {
IntoClauseRemover::remove_into_clause(stmt)
} else {
stmt
};
Ok((fallback, false))
}
}
}
fn execute_ast(
&self,
stmt: SelectStatement,
source_table: Arc<DataTable>,
context: &ExecutionContext,
) -> Result<DataView> {
let engine = QueryEngine::with_case_insensitive(self.config.case_insensitive);
engine.execute_statement_with_temp_tables(source_table, stmt, Some(&context.temp_tables))
}
pub fn config(&self) -> &ExecutionConfig {
&self.config
}
pub fn set_config(&mut self, config: ExecutionConfig) {
self.config = config;
}
fn extract_base_table(stmt: &SelectStatement, context: &ExecutionContext) -> Arc<DataTable> {
if !stmt.ctes.is_empty() {
return Self::extract_base_table_from_ctes(stmt, context);
}
if let Some(ref from_source) = stmt.from_source {
Self::extract_base_table_from_source(from_source, context, stmt)
} else {
#[allow(deprecated)]
if let Some(ref from_table) = stmt.from_table {
context.resolve_table(from_table)
} else {
Arc::new(DataTable::dual())
}
}
}
fn extract_base_table_from_ctes(
stmt: &SelectStatement,
context: &ExecutionContext,
) -> Arc<DataTable> {
use crate::sql::parser::ast::CTEType;
if let Some(ref from_source) = stmt.from_source {
match from_source {
crate::sql::parser::ast::TableSource::Table(table_name) => {
for cte in &stmt.ctes {
if &cte.name == table_name {
if let CTEType::Standard(cte_query) = &cte.cte_type {
return Self::extract_base_table(cte_query, context);
}
}
}
context.resolve_table(table_name)
}
crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
Self::extract_base_table(&**query, context)
}
crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
Self::extract_base_table_from_source(&**source, context, stmt)
}
}
} else {
Arc::new(DataTable::dual())
}
}
fn extract_base_table_from_source(
source: &crate::sql::parser::ast::TableSource,
context: &ExecutionContext,
stmt: &SelectStatement,
) -> Arc<DataTable> {
match source {
crate::sql::parser::ast::TableSource::Table(table_name) => {
for cte in &stmt.ctes {
if &cte.name == table_name {
use crate::sql::parser::ast::CTEType;
if let CTEType::Standard(cte_query) = &cte.cte_type {
return Self::extract_base_table(cte_query, context);
}
}
}
context.resolve_table(table_name)
}
crate::sql::parser::ast::TableSource::DerivedTable { query, .. } => {
Self::extract_base_table(&**query, context)
}
crate::sql::parser::ast::TableSource::Pivot { source, .. } => {
Self::extract_base_table_from_source(&**source, context, stmt)
}
}
}
}
impl Default for StatementExecutor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::datatable::{DataColumn, DataRow, DataType, DataValue};
use crate::sql::recursive_parser::Parser;
fn create_test_table(name: &str, rows: usize) -> DataTable {
let mut table = DataTable::new(name);
table.add_column(DataColumn::new("id").with_type(DataType::Integer));
table.add_column(DataColumn::new("name").with_type(DataType::String));
for i in 0..rows {
let _ = table.add_row(DataRow {
values: vec![
DataValue::Integer(i as i64),
DataValue::String(format!("name_{}", i)),
],
});
}
table
}
#[test]
fn test_new_executor() {
let executor = StatementExecutor::new();
assert!(!executor.config().case_insensitive);
assert!(!executor.config().show_preprocessing);
}
#[test]
fn test_executor_with_config() {
let config = ExecutionConfig::new()
.with_case_insensitive(true)
.with_show_preprocessing(true);
let executor = StatementExecutor::with_config(config);
assert!(executor.config().case_insensitive);
assert!(executor.config().show_preprocessing);
}
#[test]
fn test_execute_simple_select() {
let table = create_test_table("test", 10);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT id, name FROM test WHERE id < 5");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert_eq!(result.dataview.row_count(), 5);
assert_eq!(result.dataview.column_count(), 2);
assert!(result.stats.total_time_ms >= 0.0);
}
#[test]
fn test_execute_select_star() {
let table = create_test_table("test", 5);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT * FROM test");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert_eq!(result.dataview.row_count(), 5);
assert_eq!(result.dataview.column_count(), 2);
}
#[test]
fn test_execute_with_dual() {
let table = create_test_table("test", 5);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT 1+1 as result");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert_eq!(result.dataview.row_count(), 1);
assert_eq!(result.dataview.column_count(), 1);
}
#[test]
fn test_execute_with_temp_table() {
let base_table = create_test_table("base", 10);
let mut context = ExecutionContext::new(Arc::new(base_table));
let executor = StatementExecutor::new();
let temp_table = create_test_table("#temp", 3);
context
.store_temp_table("#temp".to_string(), Arc::new(temp_table))
.unwrap();
let mut parser = Parser::new("SELECT * FROM #temp");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert_eq!(result.dataview.row_count(), 3);
}
#[test]
fn test_preprocessing_applied_with_from() {
let table = create_test_table("test", 10);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT id FROM test WHERE id > 0");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert!(result.stats.preprocessing_time_ms >= 0.0);
}
#[test]
fn test_no_preprocessing_without_from() {
let table = create_test_table("test", 10);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT 42 as answer");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
assert!(!result.stats.preprocessing_applied);
}
#[test]
fn test_execution_stats() {
let table = create_test_table("test", 100);
let mut context = ExecutionContext::new(Arc::new(table));
let executor = StatementExecutor::new();
let mut parser = Parser::new("SELECT * FROM test WHERE id < 50");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context).unwrap();
let stats = result.stats;
assert_eq!(stats.row_count, 50);
assert_eq!(stats.column_count, 2);
assert!(stats.total_time_ms >= 0.0);
assert!(stats.total_time_ms >= stats.preprocessing_time_ms);
assert!(stats.total_time_ms >= stats.execution_time_ms);
}
#[test]
fn test_case_insensitive_execution() {
let table = create_test_table("test", 10);
let mut context = ExecutionContext::new(Arc::new(table));
let config = ExecutionConfig::new().with_case_insensitive(true);
let executor = StatementExecutor::with_config(config);
let mut parser = Parser::new("SELECT ID FROM test");
let stmt = parser.parse().unwrap();
let result = executor.execute(stmt, &mut context);
assert!(result.is_ok());
}
}