use std::sync::atomic::Ordering;
use super::accumulator::MATERIALIZED_COUNT;
use super::ast::{AggregateExpr, AggregateOp, AggregateQueryAst, PlanError};
use super::planner::AggregateQueryPlanner;
use super::scan::{ScanIterator, ScanRow};
use crate::storage::schema::Value;
struct VecScan(std::vec::IntoIter<ScanRow>);
impl VecScan {
fn new(rows: Vec<ScanRow>) -> Self {
Self(rows.into_iter())
}
}
impl ScanIterator for VecScan {
fn next_row(&mut self) -> Option<ScanRow> {
self.0.next()
}
}
fn agg(op: AggregateOp, input_index: usize, name: &str) -> AggregateExpr {
AggregateExpr {
op,
input_index,
output_name: name.to_string(),
}
}
fn ast(group_name: &str, aggregates: Vec<AggregateExpr>) -> AggregateQueryAst {
AggregateQueryAst {
group_by_output_name: group_name.to_string(),
aggregates,
}
}
#[test]
fn empty_aggregates_rejected() {
let plan = ast("g", vec![]);
let scan = VecScan::new(vec![]);
assert_eq!(
AggregateQueryPlanner::plan(&plan, scan).unwrap_err(),
PlanError::NoAggregates
);
}
#[test]
fn duplicate_output_names_rejected() {
let plan = ast(
"g",
vec![
agg(AggregateOp::CountStar, 0, "n"),
agg(AggregateOp::Sum, 0, "n"),
],
);
let scan = VecScan::new(vec![]);
let err = AggregateQueryPlanner::plan(&plan, scan).unwrap_err();
assert!(matches!(err, PlanError::DuplicateOutputName(name) if name == "n"));
}
#[test]
fn count_star_per_group() {
let plan = ast("dept", vec![agg(AggregateOp::CountStar, 0, "n")]);
let rows = vec![
ScanRow {
group_key: Value::Text("eng".into()),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Text("eng".into()),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Text("ops".into()),
agg_inputs: vec![Value::Null],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let mut got: Vec<(String, i64)> = stream
.into_iter()
.map(
|r| match (r.group_key, r.aggregate_values.into_iter().next()) {
(Value::Text(k), Some(Value::Integer(n))) => (k.to_string(), n),
other => panic!("unexpected row shape: {other:?}"),
},
)
.collect();
got.sort();
assert_eq!(got, vec![("eng".to_string(), 2), ("ops".to_string(), 1)]);
}
#[test]
fn count_column_skips_nulls() {
let plan = ast("k", vec![agg(AggregateOp::CountColumn, 0, "n")]);
let rows = vec![
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Integer(10)],
},
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Integer(20)],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let rows = stream.into_vec();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].aggregate_values, vec![Value::Integer(2)]);
}
#[test]
fn sum_avg_min_max_basic() {
let plan = ast(
"k",
vec![
agg(AggregateOp::Sum, 0, "s"),
agg(AggregateOp::Avg, 0, "a"),
agg(AggregateOp::Min, 0, "lo"),
agg(AggregateOp::Max, 0, "hi"),
],
);
let rows = vec![
ScanRow {
group_key: Value::Text("g".into()),
agg_inputs: vec![Value::Integer(2)],
},
ScanRow {
group_key: Value::Text("g".into()),
agg_inputs: vec![Value::Integer(4)],
},
ScanRow {
group_key: Value::Text("g".into()),
agg_inputs: vec![Value::Integer(6)],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let row = stream.into_vec().pop().unwrap();
assert_eq!(row.aggregate_values[0], Value::Float(12.0));
assert_eq!(row.aggregate_values[1], Value::Float(4.0));
assert_eq!(row.aggregate_values[2], Value::Integer(2));
assert_eq!(row.aggregate_values[3], Value::Integer(6));
}
#[test]
fn all_null_group_returns_null_for_sum_avg() {
let plan = ast(
"k",
vec![agg(AggregateOp::Sum, 0, "s"), agg(AggregateOp::Avg, 0, "a")],
);
let rows = vec![
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Null],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let row = stream.into_vec().pop().unwrap();
assert_eq!(row.aggregate_values[0], Value::Null);
assert_eq!(row.aggregate_values[1], Value::Null);
}
#[test]
fn empty_scan_yields_empty_stream() {
let plan = ast("k", vec![agg(AggregateOp::CountStar, 0, "n")]);
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(vec![])).unwrap();
assert!(stream.is_empty());
}
#[test]
fn single_row_group() {
let plan = ast("k", vec![agg(AggregateOp::Sum, 0, "s")]);
let rows = vec![ScanRow {
group_key: Value::Boolean(true),
agg_inputs: vec![Value::Integer(42)],
}];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let row = stream.into_vec().pop().unwrap();
assert_eq!(row.group_key, Value::Boolean(true));
assert_eq!(row.aggregate_values[0], Value::Float(42.0));
}
#[test]
fn null_grouping_key_collapses_into_one_group() {
let plan = ast("k", vec![agg(AggregateOp::CountStar, 0, "n")]);
let rows = vec![
ScanRow {
group_key: Value::Null,
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Null,
agg_inputs: vec![Value::Null],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
let rows = stream.into_vec();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].group_key, Value::Null);
assert_eq!(rows[0].aggregate_values[0], Value::Integer(2));
}
#[test]
fn hashable_keys_collide_across_int_text_bool() {
let plan = ast("k", vec![agg(AggregateOp::CountStar, 0, "n")]);
let rows = vec![
ScanRow {
group_key: Value::Integer(1),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Text("1".into()),
agg_inputs: vec![Value::Null],
},
ScanRow {
group_key: Value::Boolean(true),
agg_inputs: vec![Value::Null],
},
];
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
assert_eq!(stream.len(), 3);
}
#[test]
fn materializes_one_row_per_group_not_per_input() {
let before = MATERIALIZED_COUNT.load(Ordering::Relaxed);
let plan = ast(
"bucket",
vec![
agg(AggregateOp::CountStar, 0, "n"),
agg(AggregateOp::Sum, 0, "s"),
],
);
let groups = 50usize;
let total_rows = 10_000usize;
let mut rows = Vec::with_capacity(total_rows);
for i in 0..total_rows {
rows.push(ScanRow {
group_key: Value::Integer((i % groups) as i64),
agg_inputs: vec![Value::Integer(i as i64)],
});
}
let stream = AggregateQueryPlanner::plan(&plan, VecScan::new(rows)).unwrap();
assert_eq!(
stream.len(),
groups,
"planner emitted {} rows; expected exactly one per group ({})",
stream.len(),
groups,
);
let mut by_group: Vec<(i64, i64, f64)> = stream
.into_iter()
.map(|r| match (r.group_key, &r.aggregate_values[..]) {
(Value::Integer(k), [Value::Integer(n), Value::Float(s)]) => (k, *n, *s),
other => panic!("unexpected planner row: {other:?}"),
})
.collect();
by_group.sort_by_key(|(k, _, _)| *k);
for (k, n, _s) in &by_group {
assert_eq!(*n, (total_rows / groups) as i64, "group {k} count");
}
let after = MATERIALIZED_COUNT.load(Ordering::Relaxed);
let delta = after - before;
assert!(
delta < total_rows / 2,
"regression: push-down materialised {delta} rows across all parallel callers; \
expected far fewer than the {total_rows}-row legacy baseline",
);
}