use crate::config::global::get_date_notation;
use crate::data::data_view::DataView;
use crate::data::datatable::{DataTable, DataValue};
use crate::data::value_comparisons::compare_with_op;
use crate::sql::aggregate_functions::AggregateFunctionRegistry; use crate::sql::aggregates::AggregateRegistry; use crate::sql::functions::FunctionRegistry;
use crate::sql::parser::ast::{ColumnRef, WindowSpec};
use crate::sql::recursive_parser::SqlExpression;
use crate::sql::window_context::WindowContext;
use crate::sql::window_functions::{ExpressionEvaluator, WindowFunctionRegistry};
use anyhow::{anyhow, Result};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use std::time::Instant;
use tracing::{debug, info};
pub struct ArithmeticEvaluator<'a> {
table: &'a DataTable,
_date_notation: String,
function_registry: Arc<FunctionRegistry>,
aggregate_registry: Arc<AggregateRegistry>, new_aggregate_registry: Arc<AggregateFunctionRegistry>, window_function_registry: Arc<WindowFunctionRegistry>,
visible_rows: Option<Vec<usize>>, window_contexts: HashMap<u64, Arc<WindowContext>>, table_aliases: HashMap<String, String>, }
impl<'a> ArithmeticEvaluator<'a> {
#[must_use]
pub fn new(table: &'a DataTable) -> Self {
Self {
table,
_date_notation: get_date_notation(),
function_registry: Arc::new(FunctionRegistry::new()),
aggregate_registry: Arc::new(AggregateRegistry::new()),
new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
window_function_registry: Arc::new(WindowFunctionRegistry::new()),
visible_rows: None,
window_contexts: HashMap::new(),
table_aliases: HashMap::new(),
}
}
#[must_use]
pub fn with_date_notation(table: &'a DataTable, date_notation: String) -> Self {
Self {
table,
_date_notation: date_notation,
function_registry: Arc::new(FunctionRegistry::new()),
aggregate_registry: Arc::new(AggregateRegistry::new()),
new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
window_function_registry: Arc::new(WindowFunctionRegistry::new()),
visible_rows: None,
window_contexts: HashMap::new(),
table_aliases: HashMap::new(),
}
}
#[must_use]
pub fn with_visible_rows(mut self, rows: Vec<usize>) -> Self {
self.visible_rows = Some(rows);
self
}
#[must_use]
pub fn with_table_aliases(mut self, aliases: HashMap<String, String>) -> Self {
self.table_aliases = aliases;
self
}
#[must_use]
pub fn with_date_notation_and_registry(
table: &'a DataTable,
date_notation: String,
function_registry: Arc<FunctionRegistry>,
) -> Self {
Self {
table,
_date_notation: date_notation,
function_registry,
aggregate_registry: Arc::new(AggregateRegistry::new()),
new_aggregate_registry: Arc::new(AggregateFunctionRegistry::new()),
window_function_registry: Arc::new(WindowFunctionRegistry::new()),
visible_rows: None,
window_contexts: HashMap::new(),
table_aliases: HashMap::new(),
}
}
fn find_similar_column(&self, name: &str) -> Option<String> {
let columns = self.table.column_names();
let mut best_match: Option<(String, usize)> = None;
for col in columns {
let distance = self.edit_distance(&col.to_lowercase(), &name.to_lowercase());
let max_distance = if name.len() > 10 { 3 } else { 2 };
if distance <= max_distance {
match &best_match {
None => best_match = Some((col, distance)),
Some((_, best_dist)) if distance < *best_dist => {
best_match = Some((col, distance));
}
_ => {}
}
}
}
best_match.map(|(name, _)| name)
}
fn edit_distance(&self, s1: &str, s2: &str) -> usize {
crate::sql::functions::string_methods::EditDistanceFunction::calculate_edit_distance(s1, s2)
}
pub fn evaluate(&mut self, expr: &SqlExpression, row_index: usize) -> Result<DataValue> {
debug!(
"ArithmeticEvaluator: evaluating {:?} for row {}",
expr, row_index
);
match expr {
SqlExpression::Column(column_ref) => self.evaluate_column_ref(column_ref, row_index),
SqlExpression::StringLiteral(s) => Ok(DataValue::String(s.clone())),
SqlExpression::BooleanLiteral(b) => Ok(DataValue::Boolean(*b)),
SqlExpression::NumberLiteral(n) => self.evaluate_number_literal(n),
SqlExpression::Null => Ok(DataValue::Null),
SqlExpression::BinaryOp { left, op, right } => {
self.evaluate_binary_op(left, op, right, row_index)
}
SqlExpression::FunctionCall {
name,
args,
distinct,
} => self.evaluate_function_with_distinct(name, args, *distinct, row_index),
SqlExpression::WindowFunction {
name,
args,
window_spec,
} => self.evaluate_window_function(name, args, window_spec, row_index),
SqlExpression::MethodCall {
object,
method,
args,
} => self.evaluate_method_call(object, method, args, row_index),
SqlExpression::ChainedMethodCall { base, method, args } => {
let base_value = self.evaluate(base, row_index)?;
self.evaluate_method_on_value(&base_value, method, args, row_index)
}
SqlExpression::CaseExpression {
when_branches,
else_branch,
} => self.evaluate_case_expression(when_branches, else_branch, row_index),
SqlExpression::SimpleCaseExpression {
expr,
when_branches,
else_branch,
} => self.evaluate_simple_case_expression(expr, when_branches, else_branch, row_index),
SqlExpression::DateTimeConstructor {
year,
month,
day,
hour,
minute,
second,
} => self.evaluate_datetime_constructor(*year, *month, *day, *hour, *minute, *second),
SqlExpression::DateTimeToday {
hour,
minute,
second,
} => self.evaluate_datetime_today(*hour, *minute, *second),
_ => Err(anyhow!(
"Unsupported expression type for arithmetic evaluation: {:?}",
expr
)),
}
}
fn evaluate_column_ref(&self, column_ref: &ColumnRef, row_index: usize) -> Result<DataValue> {
if let Some(table_prefix) = &column_ref.table_prefix {
let actual_table = self
.table_aliases
.get(table_prefix)
.map(|s| s.as_str())
.unwrap_or(table_prefix);
let qualified_name = format!("{}.{}", actual_table, column_ref.name);
if let Some(col_idx) = self.table.find_column_by_qualified_name(&qualified_name) {
debug!(
"Resolved {}.{} -> '{}' at index {}",
table_prefix, column_ref.name, qualified_name, col_idx
);
return self
.table
.get_value(row_index, col_idx)
.ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
.map(|v| v.clone());
}
if let Some(col_idx) = self.table.get_column_index(&column_ref.name) {
debug!(
"Resolved {}.{} -> unqualified '{}' at index {}",
table_prefix, column_ref.name, column_ref.name, col_idx
);
return self
.table
.get_value(row_index, col_idx)
.ok_or_else(|| anyhow!("Row {} out of bounds", row_index))
.map(|v| v.clone());
}
Err(anyhow!(
"Column '{}' not found. Table '{}' may not support qualified column names",
qualified_name,
actual_table
))
} else {
self.evaluate_column(&column_ref.name, row_index)
}
}
fn evaluate_column(&self, column_name: &str, row_index: usize) -> Result<DataValue> {
let resolved_column = if column_name.contains('.') {
if let Some(dot_pos) = column_name.rfind('.') {
let _table_or_alias = &column_name[..dot_pos];
let col_name = &column_name[dot_pos + 1..];
debug!(
"Resolving qualified column: {} -> {}",
column_name, col_name
);
col_name.to_string()
} else {
column_name.to_string()
}
} else {
column_name.to_string()
};
let col_index = if let Some(idx) = self.table.get_column_index(&resolved_column) {
idx
} else if resolved_column != column_name {
if let Some(idx) = self.table.get_column_index(column_name) {
idx
} else {
let suggestion = self.find_similar_column(&resolved_column);
return Err(match suggestion {
Some(similar) => anyhow!(
"Column '{}' not found. Did you mean '{}'?",
column_name,
similar
),
None => anyhow!("Column '{}' not found", column_name),
});
}
} else {
let suggestion = self.find_similar_column(&resolved_column);
return Err(match suggestion {
Some(similar) => anyhow!(
"Column '{}' not found. Did you mean '{}'?",
column_name,
similar
),
None => anyhow!("Column '{}' not found", column_name),
});
};
if row_index >= self.table.row_count() {
return Err(anyhow!("Row index {} out of bounds", row_index));
}
let row = self
.table
.get_row(row_index)
.ok_or_else(|| anyhow!("Row {} not found", row_index))?;
let value = row
.get(col_index)
.ok_or_else(|| anyhow!("Column index {} out of bounds for row", col_index))?;
Ok(value.clone())
}
fn evaluate_number_literal(&self, number_str: &str) -> Result<DataValue> {
if let Ok(int_val) = number_str.parse::<i64>() {
return Ok(DataValue::Integer(int_val));
}
if let Ok(float_val) = number_str.parse::<f64>() {
return Ok(DataValue::Float(float_val));
}
Err(anyhow!("Invalid number literal: {}", number_str))
}
fn evaluate_binary_op(
&mut self,
left: &SqlExpression,
op: &str,
right: &SqlExpression,
row_index: usize,
) -> Result<DataValue> {
let left_val = self.evaluate(left, row_index)?;
let right_val = self.evaluate(right, row_index)?;
debug!(
"ArithmeticEvaluator: {} {} {}",
self.format_value(&left_val),
op,
self.format_value(&right_val)
);
match op {
"+" => self.add_values(&left_val, &right_val),
"-" => self.subtract_values(&left_val, &right_val),
"*" => self.multiply_values(&left_val, &right_val),
"/" => self.divide_values(&left_val, &right_val),
"%" => {
let args = vec![left.clone(), right.clone()];
self.evaluate_function("MOD", &args, row_index)
}
">" | "<" | ">=" | "<=" | "=" | "!=" | "<>" => {
let result = compare_with_op(&left_val, &right_val, op, false);
Ok(DataValue::Boolean(result))
}
"IS NULL" => Ok(DataValue::Boolean(matches!(left_val, DataValue::Null))),
"IS NOT NULL" => Ok(DataValue::Boolean(!matches!(left_val, DataValue::Null))),
"AND" => {
let left_bool = self.to_bool(&left_val)?;
let right_bool = self.to_bool(&right_val)?;
Ok(DataValue::Boolean(left_bool && right_bool))
}
"OR" => {
let left_bool = self.to_bool(&left_val)?;
let right_bool = self.to_bool(&right_val)?;
Ok(DataValue::Boolean(left_bool || right_bool))
}
"LIKE" => {
let text = self.value_to_string(&left_val);
let pattern = self.value_to_string(&right_val);
let matches = self.sql_like_match(&text, &pattern);
Ok(DataValue::Boolean(matches))
}
_ => Err(anyhow!("Unsupported arithmetic operator: {}", op)),
}
}
fn add_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
return Ok(DataValue::Null);
}
match (left, right) {
(DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a + b)),
(DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 + b)),
(DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a + *b as f64)),
(DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a + b)),
_ => Err(anyhow!("Cannot add {:?} and {:?}", left, right)),
}
}
fn subtract_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
return Ok(DataValue::Null);
}
match (left, right) {
(DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a - b)),
(DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 - b)),
(DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a - *b as f64)),
(DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a - b)),
_ => Err(anyhow!("Cannot subtract {:?} and {:?}", left, right)),
}
}
fn multiply_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
return Ok(DataValue::Null);
}
match (left, right) {
(DataValue::Integer(a), DataValue::Integer(b)) => Ok(DataValue::Integer(a * b)),
(DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 * b)),
(DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a * *b as f64)),
(DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a * b)),
_ => Err(anyhow!("Cannot multiply {:?} and {:?}", left, right)),
}
}
fn divide_values(&self, left: &DataValue, right: &DataValue) -> Result<DataValue> {
if matches!(left, DataValue::Null) || matches!(right, DataValue::Null) {
return Ok(DataValue::Null);
}
let is_zero = match right {
DataValue::Integer(0) => true,
DataValue::Float(f) if *f == 0.0 => true, _ => false,
};
if is_zero {
return Err(anyhow!("Division by zero"));
}
match (left, right) {
(DataValue::Integer(a), DataValue::Integer(b)) => {
if a % b == 0 {
Ok(DataValue::Integer(a / b))
} else {
Ok(DataValue::Float(*a as f64 / *b as f64))
}
}
(DataValue::Integer(a), DataValue::Float(b)) => Ok(DataValue::Float(*a as f64 / b)),
(DataValue::Float(a), DataValue::Integer(b)) => Ok(DataValue::Float(a / *b as f64)),
(DataValue::Float(a), DataValue::Float(b)) => Ok(DataValue::Float(a / b)),
_ => Err(anyhow!("Cannot divide {:?} and {:?}", left, right)),
}
}
fn format_value(&self, value: &DataValue) -> String {
match value {
DataValue::Integer(i) => i.to_string(),
DataValue::Float(f) => f.to_string(),
DataValue::String(s) => format!("'{s}'"),
_ => format!("{value:?}"),
}
}
fn to_bool(&self, value: &DataValue) -> Result<bool> {
match value {
DataValue::Boolean(b) => Ok(*b),
DataValue::Integer(i) => Ok(*i != 0),
DataValue::Float(f) => Ok(*f != 0.0),
DataValue::Null => Ok(false),
_ => Err(anyhow!("Cannot convert {:?} to boolean", value)),
}
}
fn value_to_string(&self, value: &DataValue) -> String {
match value {
DataValue::String(s) => s.clone(),
DataValue::InternedString(s) => s.to_string(),
DataValue::Integer(i) => i.to_string(),
DataValue::Float(f) => f.to_string(),
DataValue::Boolean(b) => b.to_string(),
DataValue::DateTime(dt) => dt.to_string(),
DataValue::Vector(v) => {
let components: Vec<String> = v.iter().map(|f| f.to_string()).collect();
format!("[{}]", components.join(","))
}
DataValue::Null => String::new(),
}
}
fn sql_like_match(&self, text: &str, pattern: &str) -> bool {
let pattern_chars: Vec<char> = pattern.chars().collect();
let text_chars: Vec<char> = text.chars().collect();
self.like_match_recursive(&text_chars, 0, &pattern_chars, 0)
}
fn like_match_recursive(
&self,
text: &[char],
text_pos: usize,
pattern: &[char],
pattern_pos: usize,
) -> bool {
if pattern_pos >= pattern.len() {
return text_pos >= text.len();
}
if pattern[pattern_pos] == '%' {
if self.like_match_recursive(text, text_pos, pattern, pattern_pos + 1) {
return true;
}
if text_pos < text.len() {
return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos);
}
return false;
}
if text_pos >= text.len() {
return false;
}
if pattern[pattern_pos] == '_' {
return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
}
if text[text_pos] == pattern[pattern_pos] {
return self.like_match_recursive(text, text_pos + 1, pattern, pattern_pos + 1);
}
false
}
fn evaluate_function_with_distinct(
&mut self,
name: &str,
args: &[SqlExpression],
distinct: bool,
row_index: usize,
) -> Result<DataValue> {
if distinct {
let name_upper = name.to_uppercase();
if self.aggregate_registry.is_aggregate(&name_upper)
|| self.new_aggregate_registry.contains(&name_upper)
{
return self.evaluate_aggregate_with_distinct(&name_upper, args, row_index);
} else {
return Err(anyhow!(
"DISTINCT can only be used with aggregate functions"
));
}
}
self.evaluate_function(name, args, row_index)
}
fn evaluate_aggregate_with_distinct(
&mut self,
name: &str,
args: &[SqlExpression],
_row_index: usize,
) -> Result<DataValue> {
let name_upper = name.to_uppercase();
if self.new_aggregate_registry.get(&name_upper).is_some() {
let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
visible.clone()
} else {
(0..self.table.rows.len()).collect()
};
let mut vals = Vec::new();
for &row_idx in &rows_to_process {
if !args.is_empty() {
let value = self.evaluate(&args[0], row_idx)?;
vals.push(value);
}
}
let mut seen = HashSet::new();
let unique_values: Vec<_> = vals
.into_iter()
.filter(|v| {
let key = format!("{:?}", v);
seen.insert(key)
})
.collect();
let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
let mut state = agg_func.create_state();
for value in &unique_values {
state.accumulate(value)?;
}
return Ok(state.finalize());
}
if self.aggregate_registry.get(&name_upper).is_some() {
let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
visible.clone()
} else {
(0..self.table.rows.len()).collect()
};
if name_upper == "STRING_AGG" && args.len() >= 2 {
let mut state = crate::sql::aggregates::AggregateState::StringAgg(
if args.len() >= 2 {
let separator = self.evaluate(&args[1], 0)?; match separator {
DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
DataValue::InternedString(s) => {
crate::sql::aggregates::StringAggState::new(&s)
}
_ => crate::sql::aggregates::StringAggState::new(","), }
} else {
crate::sql::aggregates::StringAggState::new(",")
},
);
let mut seen_values = HashSet::new();
for &row_idx in &rows_to_process {
let value = self.evaluate(&args[0], row_idx)?;
if !seen_values.insert(value.clone()) {
continue; }
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
agg_func.accumulate(&mut state, &value)?;
}
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
return Ok(agg_func.finalize(state));
}
let mut vals = Vec::new();
for &row_idx in &rows_to_process {
if !args.is_empty() {
let value = self.evaluate(&args[0], row_idx)?;
vals.push(value);
}
}
let mut seen = HashSet::new();
let mut unique_values = Vec::new();
for value in vals {
if seen.insert(value.clone()) {
unique_values.push(value);
}
}
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
let mut state = agg_func.init();
for value in &unique_values {
agg_func.accumulate(&mut state, value)?;
}
return Ok(agg_func.finalize(state));
}
Err(anyhow!("Unknown aggregate function: {}", name))
}
fn evaluate_function(
&mut self,
name: &str,
args: &[SqlExpression],
row_index: usize,
) -> Result<DataValue> {
let name_upper = name.to_uppercase();
if self.new_aggregate_registry.get(&name_upper).is_some() {
let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
visible.clone()
} else {
(0..self.table.rows.len()).collect()
};
let agg_func = self.new_aggregate_registry.get(&name_upper).unwrap();
let mut state = agg_func.create_state();
if name_upper == "COUNT" || name_upper == "COUNT_STAR" {
if args.is_empty()
|| (args.len() == 1
&& matches!(&args[0], SqlExpression::Column(col) if col.name == "*"))
|| (args.len() == 1
&& matches!(&args[0], SqlExpression::StringLiteral(s) if s == "*"))
{
for _ in &rows_to_process {
state.accumulate(&DataValue::Integer(1))?;
}
} else {
for &row_idx in &rows_to_process {
let value = self.evaluate(&args[0], row_idx)?;
state.accumulate(&value)?;
}
}
} else {
if !args.is_empty() {
for &row_idx in &rows_to_process {
let value = self.evaluate(&args[0], row_idx)?;
state.accumulate(&value)?;
}
}
}
return Ok(state.finalize());
}
if self.aggregate_registry.get(&name_upper).is_some() {
let rows_to_process: Vec<usize> = if let Some(ref visible) = self.visible_rows {
visible.clone()
} else {
(0..self.table.rows.len()).collect()
};
if name_upper == "STRING_AGG" && args.len() >= 2 {
let mut state = crate::sql::aggregates::AggregateState::StringAgg(
if args.len() >= 2 {
let separator = self.evaluate(&args[1], 0)?; match separator {
DataValue::String(s) => crate::sql::aggregates::StringAggState::new(&s),
DataValue::InternedString(s) => {
crate::sql::aggregates::StringAggState::new(&s)
}
_ => crate::sql::aggregates::StringAggState::new(","), }
} else {
crate::sql::aggregates::StringAggState::new(",")
},
);
for &row_idx in &rows_to_process {
let value = self.evaluate(&args[0], row_idx)?;
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
agg_func.accumulate(&mut state, &value)?;
}
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
return Ok(agg_func.finalize(state));
}
let values = if !args.is_empty()
&& !(args.len() == 1
&& matches!(&args[0], SqlExpression::Column(c) if c.name == "*"))
{
let mut vals = Vec::new();
for &row_idx in &rows_to_process {
let value = self.evaluate(&args[0], row_idx)?;
vals.push(value);
}
Some(vals)
} else {
None
};
let agg_func = self.aggregate_registry.get(&name_upper).unwrap();
let mut state = agg_func.init();
if let Some(values) = values {
for value in &values {
agg_func.accumulate(&mut state, value)?;
}
} else {
for _ in &rows_to_process {
agg_func.accumulate(&mut state, &DataValue::Integer(1))?;
}
}
return Ok(agg_func.finalize(state));
}
if self.function_registry.get(name).is_some() {
let mut evaluated_args = Vec::new();
for arg in args {
evaluated_args.push(self.evaluate(arg, row_index)?);
}
let func = self.function_registry.get(name).unwrap();
return func.evaluate(&evaluated_args);
}
Err(anyhow!("Unknown function: {}", name))
}
pub fn get_or_create_window_context(
&mut self,
spec: &WindowSpec,
) -> Result<Arc<WindowContext>> {
let overall_start = Instant::now();
let key = spec.compute_hash();
if let Some(context) = self.window_contexts.get(&key) {
info!(
"WindowContext cache hit for spec (lookup: {:.2}μs)",
overall_start.elapsed().as_micros()
);
return Ok(Arc::clone(context));
}
info!("WindowContext cache miss - creating new context");
let dataview_start = Instant::now();
let data_view = if let Some(ref _visible_rows) = self.visible_rows {
let view = DataView::new(Arc::new(self.table.clone()));
view
} else {
DataView::new(Arc::new(self.table.clone()))
};
info!(
"DataView creation took {:.2}μs",
dataview_start.elapsed().as_micros()
);
let context_start = Instant::now();
let context = WindowContext::new_with_spec(Arc::new(data_view), spec.clone())?;
info!(
"WindowContext::new_with_spec took {:.2}ms (rows: {})",
context_start.elapsed().as_secs_f64() * 1000.0,
self.table.row_count()
);
let context = Arc::new(context);
self.window_contexts.insert(key, Arc::clone(&context));
info!(
"Total WindowContext creation (cache miss) took {:.2}ms",
overall_start.elapsed().as_secs_f64() * 1000.0
);
Ok(context)
}
fn evaluate_window_function(
&mut self,
name: &str,
args: &[SqlExpression],
spec: &WindowSpec,
row_index: usize,
) -> Result<DataValue> {
let func_start = Instant::now();
let name_upper = name.to_uppercase();
debug!("Looking for window function {} in registry", name_upper);
if let Some(window_fn_arc) = self.window_function_registry.get(&name_upper) {
debug!("Found window function {} in registry", name_upper);
let window_fn = window_fn_arc.as_ref();
window_fn.validate_args(args)?;
let transformed_spec = window_fn.transform_window_spec(spec, args)?;
let context = self.get_or_create_window_context(&transformed_spec)?;
struct EvaluatorAdapter<'a, 'b> {
evaluator: &'a mut ArithmeticEvaluator<'b>,
row_index: usize,
}
impl<'a, 'b> ExpressionEvaluator for EvaluatorAdapter<'a, 'b> {
fn evaluate(
&mut self,
expr: &SqlExpression,
row_index: usize,
) -> Result<DataValue> {
self.evaluator.evaluate(expr, row_index)
}
}
let mut adapter = EvaluatorAdapter {
evaluator: self,
row_index,
};
let compute_start = Instant::now();
let result = window_fn.compute(&context, row_index, args, &mut adapter);
info!(
"{} (registry) evaluation: total={:.2}μs, compute={:.2}μs",
name_upper,
func_start.elapsed().as_micros(),
compute_start.elapsed().as_micros()
);
return result;
}
let context_start = Instant::now();
let context = self.get_or_create_window_context(spec)?;
let context_time = context_start.elapsed();
let eval_start = Instant::now();
let result = match name_upper.as_str() {
"LAG" => {
if args.is_empty() {
return Err(anyhow!("LAG requires at least 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("LAG first argument must be a column")),
};
let offset = if args.len() > 1 {
match self.evaluate(&args[1], row_index)? {
DataValue::Integer(i) => i as i32,
_ => return Err(anyhow!("LAG offset must be an integer")),
}
} else {
1
};
let offset_start = Instant::now();
let value = context
.get_offset_value(row_index, -offset, &column.name)
.unwrap_or(DataValue::Null);
debug!(
"LAG offset access took {:.2}μs (offset={})",
offset_start.elapsed().as_micros(),
offset
);
Ok(value)
}
"LEAD" => {
if args.is_empty() {
return Err(anyhow!("LEAD requires at least 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("LEAD first argument must be a column")),
};
let offset = if args.len() > 1 {
match self.evaluate(&args[1], row_index)? {
DataValue::Integer(i) => i as i32,
_ => return Err(anyhow!("LEAD offset must be an integer")),
}
} else {
1
};
let offset_start = Instant::now();
let value = context
.get_offset_value(row_index, offset, &column.name)
.unwrap_or(DataValue::Null);
debug!(
"LEAD offset access took {:.2}μs (offset={})",
offset_start.elapsed().as_micros(),
offset
);
Ok(value)
}
"ROW_NUMBER" => {
Ok(DataValue::Integer(context.get_row_number(row_index) as i64))
}
"RANK" => {
Ok(DataValue::Integer(context.get_rank(row_index)))
}
"DENSE_RANK" => {
Ok(DataValue::Integer(context.get_dense_rank(row_index)))
}
"FIRST_VALUE" => {
if args.is_empty() {
return Err(anyhow!("FIRST_VALUE requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("FIRST_VALUE argument must be a column")),
};
if context.has_frame() {
Ok(context
.get_frame_first_value(row_index, &column.name)
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_first_value(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
}
"LAST_VALUE" => {
if args.is_empty() {
return Err(anyhow!("LAST_VALUE requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("LAST_VALUE argument must be a column")),
};
if context.has_frame() {
Ok(context
.get_frame_last_value(row_index, &column.name)
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_last_value(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
}
"SUM" => {
if args.is_empty() {
return Err(anyhow!("SUM requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("SUM argument must be a column")),
};
if context.has_frame() {
Ok(context
.get_frame_sum(row_index, &column.name)
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_partition_sum(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
}
"AVG" => {
if args.is_empty() {
return Err(anyhow!("AVG requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("AVG argument must be a column")),
};
if context.has_frame() {
Ok(context
.get_frame_avg(row_index, &column.name)
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_partition_avg(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
}
"STDDEV" | "STDEV" => {
if args.is_empty() {
return Err(anyhow!("STDDEV requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("STDDEV argument must be a column")),
};
Ok(context
.get_frame_stddev(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
"VARIANCE" | "VAR" => {
if args.is_empty() {
return Err(anyhow!("VARIANCE requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("VARIANCE argument must be a column")),
};
Ok(context
.get_frame_variance(row_index, &column.name)
.unwrap_or(DataValue::Null))
}
"MIN" => {
if args.is_empty() {
return Err(anyhow!("MIN requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("MIN argument must be a column")),
};
let frame_rows = context.get_frame_rows(row_index);
if frame_rows.is_empty() {
return Ok(DataValue::Null);
}
let source_table = context.source();
let col_idx = source_table
.get_column_index(&column.name)
.ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
let mut min_value: Option<DataValue> = None;
for &row_idx in &frame_rows {
if let Some(value) = source_table.get_value(row_idx, col_idx) {
if !matches!(value, DataValue::Null) {
match &min_value {
None => min_value = Some(value.clone()),
Some(current_min) => {
if value < current_min {
min_value = Some(value.clone());
}
}
}
}
}
}
Ok(min_value.unwrap_or(DataValue::Null))
}
"MAX" => {
if args.is_empty() {
return Err(anyhow!("MAX requires 1 argument"));
}
let column = match &args[0] {
SqlExpression::Column(col) => col.clone(),
_ => return Err(anyhow!("MAX argument must be a column")),
};
let frame_rows = context.get_frame_rows(row_index);
if frame_rows.is_empty() {
return Ok(DataValue::Null);
}
let source_table = context.source();
let col_idx = source_table
.get_column_index(&column.name)
.ok_or_else(|| anyhow!("Column '{}' not found", column.name))?;
let mut max_value: Option<DataValue> = None;
for &row_idx in &frame_rows {
if let Some(value) = source_table.get_value(row_idx, col_idx) {
if !matches!(value, DataValue::Null) {
match &max_value {
None => max_value = Some(value.clone()),
Some(current_max) => {
if value > current_max {
max_value = Some(value.clone());
}
}
}
}
}
}
Ok(max_value.unwrap_or(DataValue::Null))
}
"COUNT" => {
if args.is_empty() {
if context.has_frame() {
Ok(context
.get_frame_count(row_index, None)
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_partition_count(row_index, None)
.unwrap_or(DataValue::Null))
}
} else {
let column = match &args[0] {
SqlExpression::Column(col) => {
if col.name == "*" {
if context.has_frame() {
return Ok(context
.get_frame_count(row_index, None)
.unwrap_or(DataValue::Null));
} else {
return Ok(context
.get_partition_count(row_index, None)
.unwrap_or(DataValue::Null));
}
}
col.clone()
}
SqlExpression::StringLiteral(s) if s == "*" => {
if context.has_frame() {
return Ok(context
.get_frame_count(row_index, None)
.unwrap_or(DataValue::Null));
} else {
return Ok(context
.get_partition_count(row_index, None)
.unwrap_or(DataValue::Null));
}
}
_ => return Err(anyhow!("COUNT argument must be a column or *")),
};
if context.has_frame() {
Ok(context
.get_frame_count(row_index, Some(&column.name))
.unwrap_or(DataValue::Null))
} else {
Ok(context
.get_partition_count(row_index, Some(&column.name))
.unwrap_or(DataValue::Null))
}
}
}
_ => Err(anyhow!("Unknown window function: {}", name)),
};
let eval_time = eval_start.elapsed();
info!(
"{} (built-in) evaluation: total={:.2}μs, context={:.2}μs, eval={:.2}μs",
name_upper,
func_start.elapsed().as_micros(),
context_time.as_micros(),
eval_time.as_micros()
);
result
}
fn evaluate_method_call(
&mut self,
object: &str,
method: &str,
args: &[SqlExpression],
row_index: usize,
) -> Result<DataValue> {
let col_index = self.table.get_column_index(object).ok_or_else(|| {
let suggestion = self.find_similar_column(object);
match suggestion {
Some(similar) => {
anyhow!("Column '{}' not found. Did you mean '{}'?", object, similar)
}
None => anyhow!("Column '{}' not found", object),
}
})?;
let cell_value = self.table.get_value(row_index, col_index).cloned();
self.evaluate_method_on_value(
&cell_value.unwrap_or(DataValue::Null),
method,
args,
row_index,
)
}
fn evaluate_method_on_value(
&mut self,
value: &DataValue,
method: &str,
args: &[SqlExpression],
row_index: usize,
) -> Result<DataValue> {
let function_name = match method.to_lowercase().as_str() {
"trim" => "TRIM",
"trimstart" | "trimbegin" => "TRIMSTART",
"trimend" => "TRIMEND",
"length" | "len" => "LENGTH",
"contains" => "CONTAINS",
"startswith" => "STARTSWITH",
"endswith" => "ENDSWITH",
"indexof" => "INDEXOF",
_ => method, };
if self.function_registry.get(function_name).is_some() {
debug!(
"Proxying method '{}' through function registry as '{}'",
method, function_name
);
let mut func_args = vec![value.clone()];
for arg in args {
func_args.push(self.evaluate(arg, row_index)?);
}
let func = self.function_registry.get(function_name).unwrap();
return func.evaluate(&func_args);
}
Err(anyhow!(
"Method '{}' not found. It should be registered in the function registry.",
method
))
}
fn evaluate_case_expression(
&mut self,
when_branches: &[crate::sql::recursive_parser::WhenBranch],
else_branch: &Option<Box<SqlExpression>>,
row_index: usize,
) -> Result<DataValue> {
debug!(
"ArithmeticEvaluator: evaluating CASE expression for row {}",
row_index
);
for branch in when_branches {
let condition_result = self.evaluate_condition_as_bool(&branch.condition, row_index)?;
if condition_result {
debug!("CASE: WHEN condition matched, evaluating result expression");
return self.evaluate(&branch.result, row_index);
}
}
if let Some(else_expr) = else_branch {
debug!("CASE: No WHEN matched, evaluating ELSE expression");
self.evaluate(else_expr, row_index)
} else {
debug!("CASE: No WHEN matched and no ELSE, returning NULL");
Ok(DataValue::Null)
}
}
fn evaluate_simple_case_expression(
&mut self,
expr: &Box<SqlExpression>,
when_branches: &[crate::sql::parser::ast::SimpleWhenBranch],
else_branch: &Option<Box<SqlExpression>>,
row_index: usize,
) -> Result<DataValue> {
debug!(
"ArithmeticEvaluator: evaluating simple CASE expression for row {}",
row_index
);
let case_value = self.evaluate(expr, row_index)?;
debug!("Simple CASE: evaluated expression to {:?}", case_value);
for branch in when_branches {
let when_value = self.evaluate(&branch.value, row_index)?;
if self.values_equal(&case_value, &when_value)? {
debug!("Simple CASE: WHEN value matched, evaluating result expression");
return self.evaluate(&branch.result, row_index);
}
}
if let Some(else_expr) = else_branch {
debug!("Simple CASE: No WHEN matched, evaluating ELSE expression");
self.evaluate(else_expr, row_index)
} else {
debug!("Simple CASE: No WHEN matched and no ELSE, returning NULL");
Ok(DataValue::Null)
}
}
fn values_equal(&self, left: &DataValue, right: &DataValue) -> Result<bool> {
match (left, right) {
(DataValue::Null, DataValue::Null) => Ok(true),
(DataValue::Null, _) | (_, DataValue::Null) => Ok(false),
(DataValue::Integer(a), DataValue::Integer(b)) => Ok(a == b),
(DataValue::Float(a), DataValue::Float(b)) => Ok((a - b).abs() < f64::EPSILON),
(DataValue::String(a), DataValue::String(b)) => Ok(a == b),
(DataValue::Boolean(a), DataValue::Boolean(b)) => Ok(a == b),
(DataValue::DateTime(a), DataValue::DateTime(b)) => Ok(a == b),
(DataValue::Integer(a), DataValue::Float(b)) => {
Ok((*a as f64 - b).abs() < f64::EPSILON)
}
(DataValue::Float(a), DataValue::Integer(b)) => {
Ok((a - *b as f64).abs() < f64::EPSILON)
}
_ => Ok(false),
}
}
fn evaluate_condition_as_bool(
&mut self,
expr: &SqlExpression,
row_index: usize,
) -> Result<bool> {
let value = self.evaluate(expr, row_index)?;
match value {
DataValue::Boolean(b) => Ok(b),
DataValue::Integer(i) => Ok(i != 0),
DataValue::Float(f) => Ok(f != 0.0),
DataValue::Null => Ok(false),
DataValue::String(s) => Ok(!s.is_empty()),
DataValue::InternedString(s) => Ok(!s.is_empty()),
_ => Ok(true), }
}
fn evaluate_datetime_constructor(
&self,
year: i32,
month: u32,
day: u32,
hour: Option<u32>,
minute: Option<u32>,
second: Option<u32>,
) -> Result<DataValue> {
use chrono::{NaiveDate, TimeZone, Utc};
let date = NaiveDate::from_ymd_opt(year, month, day)
.ok_or_else(|| anyhow!("Invalid date: {}-{}-{}", year, month, day))?;
let hour = hour.unwrap_or(0);
let minute = minute.unwrap_or(0);
let second = second.unwrap_or(0);
let naive_datetime = date
.and_hms_opt(hour, minute, second)
.ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
let datetime = Utc.from_utc_datetime(&naive_datetime);
let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
Ok(DataValue::String(datetime_str))
}
fn evaluate_datetime_today(
&self,
hour: Option<u32>,
minute: Option<u32>,
second: Option<u32>,
) -> Result<DataValue> {
use chrono::{TimeZone, Utc};
let today = Utc::now().date_naive();
let hour = hour.unwrap_or(0);
let minute = minute.unwrap_or(0);
let second = second.unwrap_or(0);
let naive_datetime = today
.and_hms_opt(hour, minute, second)
.ok_or_else(|| anyhow!("Invalid time: {}:{}:{}", hour, minute, second))?;
let datetime = Utc.from_utc_datetime(&naive_datetime);
let datetime_str = datetime.format("%Y-%m-%d %H:%M:%S%.3f").to_string();
Ok(DataValue::String(datetime_str))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::data::datatable::{DataColumn, DataRow};
fn create_test_table() -> DataTable {
let mut table = DataTable::new("test");
table.add_column(DataColumn::new("a"));
table.add_column(DataColumn::new("b"));
table.add_column(DataColumn::new("c"));
table
.add_row(DataRow::new(vec![
DataValue::Integer(10),
DataValue::Float(2.5),
DataValue::Integer(4),
]))
.unwrap();
table
}
#[test]
fn test_evaluate_column() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let expr = SqlExpression::Column(ColumnRef::unquoted("a".to_string()));
let result = evaluator.evaluate(&expr, 0).unwrap();
assert_eq!(result, DataValue::Integer(10));
}
#[test]
fn test_evaluate_number_literal() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let expr = SqlExpression::NumberLiteral("42".to_string());
let result = evaluator.evaluate(&expr, 0).unwrap();
assert_eq!(result, DataValue::Integer(42));
let expr = SqlExpression::NumberLiteral("3.14".to_string());
let result = evaluator.evaluate(&expr, 0).unwrap();
assert_eq!(result, DataValue::Float(3.14));
}
#[test]
fn test_add_values() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let result = evaluator
.add_values(&DataValue::Integer(5), &DataValue::Integer(3))
.unwrap();
assert_eq!(result, DataValue::Integer(8));
let result = evaluator
.add_values(&DataValue::Integer(5), &DataValue::Float(2.5))
.unwrap();
assert_eq!(result, DataValue::Float(7.5));
}
#[test]
fn test_multiply_values() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let result = evaluator
.multiply_values(&DataValue::Integer(4), &DataValue::Float(2.5))
.unwrap();
assert_eq!(result, DataValue::Float(10.0));
}
#[test]
fn test_divide_values() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let result = evaluator
.divide_values(&DataValue::Integer(10), &DataValue::Integer(2))
.unwrap();
assert_eq!(result, DataValue::Integer(5));
let result = evaluator
.divide_values(&DataValue::Integer(10), &DataValue::Integer(3))
.unwrap();
assert_eq!(result, DataValue::Float(10.0 / 3.0));
}
#[test]
fn test_division_by_zero() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let result = evaluator.divide_values(&DataValue::Integer(10), &DataValue::Integer(0));
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Division by zero"));
}
#[test]
fn test_binary_op_expression() {
let table = create_test_table();
let mut evaluator = ArithmeticEvaluator::new(&table);
let expr = SqlExpression::BinaryOp {
left: Box::new(SqlExpression::Column(ColumnRef::unquoted("a".to_string()))),
op: "*".to_string(),
right: Box::new(SqlExpression::Column(ColumnRef::unquoted("b".to_string()))),
};
let result = evaluator.evaluate(&expr, 0).unwrap();
assert_eq!(result, DataValue::Float(25.0));
}
}