use anyhow::{anyhow, Result};
use fxhash::FxHashMap;
use std::sync::Arc;
use std::time::Duration;
use crate::data::arithmetic_evaluator::ArithmeticEvaluator;
use crate::data::data_view::DataView;
use crate::data::datatable::{DataColumn, DataRow, DataTable, DataValue};
use crate::data::query_engine::QueryEngine;
use crate::sql::aggregates::contains_aggregate;
use crate::sql::parser::ast::{SelectItem, SqlExpression};
use tracing::debug;
#[derive(Debug, Clone)]
pub struct GroupByPhaseInfo {
pub total_rows: usize,
pub num_groups: usize,
pub num_expressions: usize,
pub phase1_cardinality_estimation: Duration,
pub phase2_key_building: Duration,
pub phase2_expression_evaluation: Duration,
pub phase3_dataview_creation: Duration,
pub phase4_aggregation: Duration,
pub phase4_having_evaluation: Duration,
pub groups_filtered_by_having: usize,
pub total_time: Duration,
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GroupKey(pub Vec<DataValue>);
pub trait GroupByExpressions {
fn group_by_expressions(
&self,
view: DataView,
group_by_exprs: &[SqlExpression],
) -> Result<FxHashMap<GroupKey, DataView>>;
fn apply_group_by_expressions(
&self,
view: DataView,
group_by_exprs: &[SqlExpression],
select_items: &[SelectItem],
having: Option<&SqlExpression>,
_case_insensitive: bool,
date_notation: String,
) -> Result<(DataView, GroupByPhaseInfo)>;
}
impl GroupByExpressions for QueryEngine {
fn group_by_expressions(
&self,
view: DataView,
group_by_exprs: &[SqlExpression],
) -> Result<FxHashMap<GroupKey, DataView>> {
use std::time::Instant;
let start = Instant::now();
let phase1_start = Instant::now();
let estimated_groups = self.estimate_group_cardinality(&view, group_by_exprs);
let mut groups = FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
let mut group_rows: FxHashMap<GroupKey, Vec<usize>> =
FxHashMap::with_capacity_and_hasher(estimated_groups, Default::default());
let phase1_time = phase1_start.elapsed();
debug!(
"GROUP BY Phase 1 (cardinality estimation): {:?}, estimated {} groups",
phase1_time, estimated_groups
);
let phase2_start = Instant::now();
let visible_rows = view.get_visible_rows();
let total_rows = visible_rows.len();
debug!("GROUP BY Phase 2 starting: processing {} rows", total_rows);
let mut evaluator = ArithmeticEvaluator::new(view.source());
let mut key_values = Vec::with_capacity(group_by_exprs.len());
for row_idx in visible_rows.iter().copied() {
key_values.clear();
for expr in group_by_exprs {
let value = evaluator.evaluate(expr, row_idx).unwrap_or(DataValue::Null);
key_values.push(value);
}
let key = GroupKey(key_values.clone()); group_rows.entry(key).or_default().push(row_idx);
}
let phase2_time = phase2_start.elapsed();
debug!(
"GROUP BY Phase 2 (expression evaluation & key building): {:?}, created {} unique keys",
phase2_time,
group_rows.len()
);
let phase3_start = Instant::now();
for (key, rows) in group_rows {
let mut group_view = DataView::new(view.source_arc());
group_view = group_view.with_rows(rows);
groups.insert(key, group_view);
}
let phase3_time = phase3_start.elapsed();
debug!("GROUP BY Phase 3 (DataView creation): {:?}", phase3_time);
let total_time = start.elapsed();
debug!(
"GROUP BY Total time: {:?} (P1: {:?}, P2: {:?}, P3: {:?})",
total_time, phase1_time, phase2_time, phase3_time
);
Ok(groups)
}
fn apply_group_by_expressions(
&self,
view: DataView,
group_by_exprs: &[SqlExpression],
select_items: &[SelectItem],
having: Option<&SqlExpression>,
_case_insensitive: bool,
date_notation: String,
) -> Result<(DataView, GroupByPhaseInfo)> {
use std::time::Instant;
let start = Instant::now();
debug!(
"apply_group_by_expressions - grouping by {} expressions, {} select items",
group_by_exprs.len(),
select_items.len()
);
let phase1_start = Instant::now();
let groups = self.group_by_expressions(view.clone(), group_by_exprs)?;
let phase1_time = phase1_start.elapsed();
debug!(
"apply_group_by_expressions Phase 1 (group building): {:?}, created {} groups",
phase1_time,
groups.len()
);
let mut result_table = DataTable::new("grouped_result");
let mut aggregate_columns = Vec::new();
let mut non_aggregate_exprs = Vec::new();
let mut group_by_aliases = Vec::new();
for (i, group_expr) in group_by_exprs.iter().enumerate() {
let mut found_alias = None;
for item in select_items {
if let SelectItem::Expression { expr, alias, .. } = item {
if !contains_aggregate(expr) && expressions_match(expr, group_expr) {
found_alias = Some(alias.clone());
break;
}
}
}
let alias = found_alias.unwrap_or_else(|| match group_expr {
SqlExpression::Column(column_ref) => column_ref.name.clone(),
_ => format!("group_expr_{}", i + 1),
});
result_table.add_column(DataColumn::new(&alias));
group_by_aliases.push(alias);
}
for item in select_items {
match item {
SelectItem::Expression { expr, alias, .. } => {
if contains_aggregate(expr) {
result_table.add_column(DataColumn::new(alias));
aggregate_columns.push((expr.clone(), alias.clone()));
} else {
let mut found = false;
for group_expr in group_by_exprs {
if expressions_match(expr, group_expr) {
found = true;
non_aggregate_exprs.push((expr.clone(), alias.clone()));
break;
}
}
if !found {
if let SqlExpression::Column(col) = expr {
let referenced = group_by_exprs
.iter()
.any(|ge| expression_references_column(ge, &col.name));
if !referenced {
return Err(anyhow!(
"Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
alias
));
}
} else {
return Err(anyhow!(
"Expression '{}' must appear in GROUP BY clause or be used in an aggregate function",
alias
));
}
}
}
}
SelectItem::Column {
column: col_ref, ..
} => {
let in_group_by = group_by_exprs.iter().any(
|expr| matches!(expr, SqlExpression::Column(name) if name.name == col_ref.name),
);
if !in_group_by {
return Err(anyhow!(
"Column '{}' must appear in GROUP BY clause or be used in an aggregate function",
col_ref.name
));
}
}
SelectItem::Star { .. } => {
}
SelectItem::StarExclude { .. } => {
}
}
}
let phase2_start = Instant::now();
let mut aggregation_time = std::time::Duration::ZERO;
let mut having_time = std::time::Duration::ZERO;
let mut groups_processed = 0;
let mut groups_filtered = 0;
for (group_key, group_view) in groups {
let mut row_values = Vec::new();
for value in &group_key.0 {
row_values.push(value.clone());
}
let agg_start = Instant::now();
for (expr, _col_name) in &aggregate_columns {
let group_rows = group_view.get_visible_rows();
let mut evaluator = ArithmeticEvaluator::with_date_notation(
group_view.source(),
date_notation.clone(),
)
.with_visible_rows(group_rows.clone());
let value = if group_view.row_count() > 0 && !group_rows.is_empty() {
evaluator
.evaluate(expr, group_rows[0])
.unwrap_or(DataValue::Null)
} else {
DataValue::Null
};
row_values.push(value);
}
aggregation_time += agg_start.elapsed();
let having_start = Instant::now();
if let Some(having_expr) = having {
let mut temp_table = DataTable::new("having_eval");
for alias in &group_by_aliases {
temp_table.add_column(DataColumn::new(alias));
}
for (_, alias) in &aggregate_columns {
temp_table.add_column(DataColumn::new(alias));
}
temp_table
.add_row(DataRow::new(row_values.clone()))
.map_err(|e| anyhow!("Failed to create temp table for HAVING: {}", e))?;
let mut evaluator =
ArithmeticEvaluator::with_date_notation(&temp_table, date_notation.clone());
let having_result = evaluator.evaluate(having_expr, 0)?;
if !is_truthy(&having_result) {
groups_filtered += 1;
having_time += having_start.elapsed();
continue;
}
}
having_time += having_start.elapsed();
groups_processed += 1;
result_table
.add_row(DataRow::new(row_values))
.map_err(|e| anyhow!("Failed to add grouped row: {}", e))?;
}
let phase2_time = phase2_start.elapsed();
let total_time = start.elapsed();
debug!(
"apply_group_by_expressions Phase 2 (aggregation): {:?}",
phase2_time
);
debug!(" - Aggregation time: {:?}", aggregation_time);
debug!(" - HAVING evaluation time: {:?}", having_time);
debug!(
" - Groups processed: {}, filtered by HAVING: {}",
groups_processed, groups_filtered
);
debug!(
"apply_group_by_expressions Total time: {:?} (P1: {:?}, P2: {:?})",
total_time, phase1_time, phase2_time
);
let phase_info = GroupByPhaseInfo {
total_rows: view.row_count(),
num_groups: groups_processed,
num_expressions: group_by_exprs.len(),
phase1_cardinality_estimation: Duration::ZERO, phase2_key_building: phase1_time, phase2_expression_evaluation: Duration::ZERO, phase3_dataview_creation: Duration::ZERO, phase4_aggregation: aggregation_time,
phase4_having_evaluation: having_time,
groups_filtered_by_having: groups_filtered,
total_time,
};
Ok((DataView::new(Arc::new(result_table)), phase_info))
}
}
fn expressions_match(expr1: &SqlExpression, expr2: &SqlExpression) -> bool {
format!("{:?}", expr1) == format!("{:?}", expr2)
}
fn expression_references_column(expr: &SqlExpression, column: &str) -> bool {
match expr {
SqlExpression::Column(name) => name == column,
SqlExpression::BinaryOp { left, right, .. } => {
expression_references_column(left, column)
|| expression_references_column(right, column)
}
SqlExpression::FunctionCall { args, .. } => args
.iter()
.any(|arg| expression_references_column(arg, column)),
SqlExpression::Between { expr, lower, upper } => {
expression_references_column(expr, column)
|| expression_references_column(lower, column)
|| expression_references_column(upper, column)
}
_ => false,
}
}
fn is_truthy(value: &DataValue) -> bool {
match value {
DataValue::Boolean(b) => *b,
DataValue::Integer(i) => *i != 0,
DataValue::Float(f) => *f != 0.0 && !f.is_nan(),
DataValue::Null => false,
_ => true,
}
}