use anyhow::Result;
use std::collections::HashMap;
use std::sync::Arc;
use crate::data::datatable::DataTable;
use crate::data::temp_table_registry::TempTableRegistry;
#[derive(Clone)]
pub struct ExecutionContext {
pub source_table: Arc<DataTable>,
pub temp_tables: TempTableRegistry,
pub variables: HashMap<String, String>,
}
impl ExecutionContext {
pub fn new(source_table: Arc<DataTable>) -> Self {
Self {
source_table,
temp_tables: TempTableRegistry::new(),
variables: HashMap::new(),
}
}
pub fn with_dual() -> Self {
Self::new(Arc::new(DataTable::dual()))
}
pub fn resolve_table(&self, name: &str) -> Arc<DataTable> {
if name.starts_with('#') {
self.temp_tables
.get(name)
.unwrap_or_else(|| self.source_table.clone())
} else if name.eq_ignore_ascii_case("DUAL") {
Arc::new(DataTable::dual())
} else {
self.source_table.clone()
}
}
pub fn resolve_table_strict(&self, name: &str) -> Result<Arc<DataTable>> {
if name.starts_with('#') {
self.temp_tables
.get(name)
.ok_or_else(|| anyhow::anyhow!("Temporary table '{}' not found", name))
} else if name.eq_ignore_ascii_case("DUAL") {
Ok(Arc::new(DataTable::dual()))
} else {
Ok(self.source_table.clone())
}
}
pub fn store_temp_table(&mut self, name: String, table: Arc<DataTable>) -> Result<()> {
self.temp_tables.insert(name, table)
}
pub fn has_temp_table(&self, name: &str) -> bool {
self.temp_tables.contains(name)
}
pub fn temp_table_names(&self) -> Vec<String> {
self.temp_tables.list_tables()
}
pub fn set_variable(&mut self, name: String, value: String) {
self.variables.insert(name, value);
}
pub fn get_variable(&self, name: &str) -> Option<&String> {
self.variables.get(name)
}
pub fn clear_temp_tables(&mut self) {
self.temp_tables = TempTableRegistry::new();
}
pub fn clear_variables(&mut self) {
self.variables.clear();
}
pub fn source_table_info(&self) -> (String, usize, usize) {
(
self.source_table.name.clone(),
self.source_table.row_count(),
self.source_table.column_count(),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_table(name: &str, rows: usize) -> DataTable {
let mut table = DataTable::new(name);
table.add_column(
crate::data::datatable::DataColumn::new("id")
.with_type(crate::data::datatable::DataType::Integer),
);
for i in 0..rows {
let _ = table.add_row(crate::data::datatable::DataRow {
values: vec![crate::data::datatable::DataValue::Integer(i as i64)],
});
}
table
}
#[test]
fn test_new_context() {
let table = create_test_table("test", 10);
let ctx = ExecutionContext::new(Arc::new(table));
assert_eq!(ctx.source_table.name, "test");
assert_eq!(ctx.source_table.row_count(), 10);
assert_eq!(ctx.temp_tables.list_tables().len(), 0);
}
#[test]
fn test_dual_context() {
let ctx = ExecutionContext::with_dual();
assert_eq!(ctx.source_table.name, "DUAL");
assert_eq!(ctx.source_table.row_count(), 1);
}
#[test]
fn test_resolve_source_table() {
let table = create_test_table("customers", 5);
let ctx = ExecutionContext::new(Arc::new(table));
let resolved = ctx.resolve_table("customers");
assert_eq!(resolved.name, "customers");
assert_eq!(resolved.row_count(), 5);
}
#[test]
fn test_resolve_dual_table() {
let table = create_test_table("test", 10);
let ctx = ExecutionContext::new(Arc::new(table));
let resolved = ctx.resolve_table("DUAL");
assert_eq!(resolved.name, "DUAL");
assert_eq!(resolved.row_count(), 1);
}
#[test]
fn test_store_and_resolve_temp_table() {
let base_table = create_test_table("base", 10);
let mut ctx = ExecutionContext::new(Arc::new(base_table));
let temp_table = create_test_table("#temp1", 5);
ctx.store_temp_table("#temp1".to_string(), Arc::new(temp_table))
.unwrap();
assert!(ctx.has_temp_table("#temp1"));
assert_eq!(ctx.temp_table_names(), vec!["#temp1"]);
let resolved = ctx.resolve_table("#temp1");
assert_eq!(resolved.name, "#temp1");
assert_eq!(resolved.row_count(), 5);
}
#[test]
fn test_resolve_missing_temp_table_fallback() {
let base_table = create_test_table("base", 10);
let ctx = ExecutionContext::new(Arc::new(base_table));
let resolved = ctx.resolve_table("#nonexistent");
assert_eq!(resolved.name, "base");
}
#[test]
fn test_resolve_missing_temp_table_strict() {
let base_table = create_test_table("base", 10);
let ctx = ExecutionContext::new(Arc::new(base_table));
let result = ctx.resolve_table_strict("#nonexistent");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not found"));
}
#[test]
fn test_variables() {
let table = create_test_table("test", 5);
let mut ctx = ExecutionContext::new(Arc::new(table));
ctx.set_variable("user_id".to_string(), "123".to_string());
ctx.set_variable("dept".to_string(), "sales".to_string());
assert_eq!(ctx.get_variable("user_id"), Some(&"123".to_string()));
assert_eq!(ctx.get_variable("dept"), Some(&"sales".to_string()));
assert_eq!(ctx.get_variable("nonexistent"), None);
ctx.clear_variables();
assert_eq!(ctx.get_variable("user_id"), None);
}
#[test]
fn test_clear_temp_tables() {
let base_table = create_test_table("base", 10);
let mut ctx = ExecutionContext::new(Arc::new(base_table));
ctx.store_temp_table(
"#temp1".to_string(),
Arc::new(create_test_table("#temp1", 5)),
)
.unwrap();
ctx.store_temp_table(
"#temp2".to_string(),
Arc::new(create_test_table("#temp2", 3)),
)
.unwrap();
assert_eq!(ctx.temp_table_names().len(), 2);
ctx.clear_temp_tables();
assert_eq!(ctx.temp_table_names().len(), 0);
}
#[test]
fn test_source_table_info() {
let table = create_test_table("sales", 100);
let ctx = ExecutionContext::new(Arc::new(table));
let (name, rows, cols) = ctx.source_table_info();
assert_eq!(name, "sales");
assert_eq!(rows, 100);
assert_eq!(cols, 1); }
}