use std::collections::HashMap;
use std::collections::HashSet;
use super::accumulator::GroupAccumulator;
use super::ast::{AggregateOp, AggregateQueryAst, PlanError};
use super::scan::{AggregateRow, AggregateRowStream, ScanIterator};
use crate::storage::schema::{value_to_canonical_key, CanonicalKey, Value};
pub struct AggregateQueryPlanner;
impl AggregateQueryPlanner {
pub fn plan<S: ScanIterator>(
ast: &AggregateQueryAst,
mut scan: S,
) -> Result<AggregateRowStream, PlanError> {
validate_ast(ast)?;
let mut groups: HashMap<GroupKey, (Value, GroupAccumulator)> = HashMap::new();
while let Some(row) = scan.next_row() {
let key = canonical_group_key(&row.group_key);
let entry = groups.entry(key).or_insert_with(|| {
(
row.group_key.clone(),
GroupAccumulator::new(&ast.aggregates),
)
});
entry.1.accumulate(&ast.aggregates, &row.agg_inputs);
}
let mut emitted = Vec::with_capacity(groups.len());
for (_, (group_value, acc)) in groups {
let aggregate_values = acc.finalize();
emitted.push(AggregateRow {
group_key: group_value,
aggregate_values,
});
}
Ok(AggregateRowStream::from_rows(emitted))
}
}
fn validate_ast(ast: &AggregateQueryAst) -> Result<(), PlanError> {
if ast.aggregates.is_empty() {
return Err(PlanError::NoAggregates);
}
let mut seen = HashSet::with_capacity(ast.aggregates.len());
for agg in &ast.aggregates {
if !seen.insert(agg.output_name.as_str()) {
return Err(PlanError::DuplicateOutputName(agg.output_name.clone()));
}
}
Ok(())
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum GroupKey {
Canonical(CanonicalKey),
Fallback(String),
}
fn canonical_group_key(value: &Value) -> GroupKey {
match value_to_canonical_key(value) {
Some(k) => GroupKey::Canonical(k),
None => GroupKey::Fallback(format!("{:?}", value)),
}
}
pub(crate) fn op_is_supported(op: AggregateOp) -> bool {
matches!(
op,
AggregateOp::CountStar
| AggregateOp::CountColumn
| AggregateOp::Sum
| AggregateOp::Avg
| AggregateOp::Min
| AggregateOp::Max
)
}