use std::collections::HashMap;
use std::sync::Arc;
use manifoldb_core::Value;
use crate::error::ParseError;
use crate::exec::context::ExecutionContext;
use crate::exec::operator::{BoxedOperator, Operator, OperatorBase, OperatorResult, OperatorState};
use crate::exec::operators::filter::evaluate_expr;
use crate::exec::row::{Row, Schema};
use crate::plan::logical::{AggregateFunction, LogicalExpr};
pub struct HashAggregateOp {
base: OperatorBase,
group_by: Vec<LogicalExpr>,
aggregates: Vec<LogicalExpr>,
having: Option<LogicalExpr>,
input: BoxedOperator,
groups: HashMap<Vec<u8>, GroupState>,
results_iter: std::vec::IntoIter<Row>,
aggregated: bool,
key_buffer: Vec<u8>,
max_rows_in_memory: usize,
}
impl HashAggregateOp {
#[must_use]
pub fn new(
group_by: Vec<LogicalExpr>,
aggregates: Vec<LogicalExpr>,
having: Option<LogicalExpr>,
input: BoxedOperator,
) -> Self {
let mut columns = Vec::with_capacity(group_by.len() + aggregates.len());
for (i, expr) in group_by.iter().enumerate() {
columns.push(expr_to_name(expr, i));
}
for (i, expr) in aggregates.iter().enumerate() {
columns.push(expr_to_name(expr, group_by.len() + i));
}
let schema = Arc::new(Schema::new(columns));
const INITIAL_GROUPS_CAPACITY: usize = 1000;
Self {
base: OperatorBase::new(schema),
group_by,
aggregates,
having,
input,
groups: HashMap::with_capacity(INITIAL_GROUPS_CAPACITY),
results_iter: Vec::new().into_iter(),
aggregated: false,
key_buffer: Vec::with_capacity(64), max_rows_in_memory: 0, }
}
fn compute_group_values(&self, row: &Row) -> OperatorResult<Vec<Value>> {
self.group_by.iter().map(|expr| evaluate_expr(expr, row)).collect()
}
fn aggregate_all(&mut self) -> OperatorResult<()> {
let mut key_buffer = std::mem::take(&mut self.key_buffer);
while let Some(row) = self.input.next()? {
key_buffer.clear();
for expr in &self.group_by {
let value = evaluate_expr(expr, &row)?;
encode_value(&value, &mut key_buffer);
}
let is_new_group = !self.groups.contains_key(&key_buffer);
if is_new_group
&& self.max_rows_in_memory > 0
&& self.groups.len() >= self.max_rows_in_memory
{
self.key_buffer = key_buffer;
return Err(ParseError::QueryTooLarge {
actual: self.groups.len() + 1,
limit: self.max_rows_in_memory,
});
}
let state = if let Some(state) = self.groups.get_mut(&key_buffer) {
state
} else {
let group_values = self.compute_group_values(&row)?;
self.groups
.entry(key_buffer.clone())
.or_insert_with(|| GroupState::new(group_values, self.aggregates.len()))
};
for (i, agg_expr) in self.aggregates.iter().enumerate() {
if let LogicalExpr::AggregateFunction { func, arg, distinct: _ } = agg_expr {
let is_wildcard = matches!(arg.as_ref(), LogicalExpr::Wildcard);
let arg_value = evaluate_expr(arg, &row)?;
state.accumulators[i].update(func, &arg_value, is_wildcard);
}
}
}
self.key_buffer = key_buffer;
let schema = self.base.schema();
let mut results = Vec::with_capacity(self.groups.len());
for state in self.groups.values() {
let mut values = state.group_values.clone();
for acc in &state.accumulators {
values.push(acc.result());
}
let row = Row::new(Arc::clone(&schema), values);
if let Some(having) = &self.having {
let result = evaluate_expr(having, &row)?;
if !matches!(result, Value::Bool(true)) {
continue;
}
}
results.push(row);
}
if self.group_by.is_empty() && self.groups.is_empty() {
let mut values = Vec::new();
for agg_expr in &self.aggregates {
if let LogicalExpr::AggregateFunction { func, .. } = agg_expr {
values.push(Accumulator::new().default_for(func));
} else {
values.push(Value::Null);
}
}
let row = Row::new(Arc::clone(&schema), values);
results.push(row);
}
self.results_iter = results.into_iter();
self.aggregated = true;
Ok(())
}
}
impl Operator for HashAggregateOp {
fn open(&mut self, ctx: &ExecutionContext) -> OperatorResult<()> {
self.input.open(ctx)?;
self.groups.clear();
self.results_iter = Vec::new().into_iter();
self.aggregated = false;
self.max_rows_in_memory = ctx.max_rows_in_memory();
self.base.set_open();
Ok(())
}
fn next(&mut self) -> OperatorResult<Option<Row>> {
if !self.aggregated {
self.aggregate_all()?;
}
match self.results_iter.next() {
Some(row) => {
self.base.inc_rows_produced();
Ok(Some(row))
}
None => {
self.base.set_finished();
Ok(None)
}
}
}
fn close(&mut self) -> OperatorResult<()> {
self.input.close()?;
self.groups.clear();
self.results_iter = Vec::new().into_iter();
self.base.set_closed();
Ok(())
}
fn schema(&self) -> Arc<Schema> {
self.base.schema()
}
fn state(&self) -> OperatorState {
self.base.state()
}
fn name(&self) -> &'static str {
"HashAggregate"
}
}
pub struct SortMergeAggregateOp {
base: OperatorBase,
group_by: Vec<LogicalExpr>,
aggregates: Vec<LogicalExpr>,
having: Option<LogicalExpr>,
input: BoxedOperator,
current_key: Option<Vec<u8>>,
current_values: Vec<Value>,
accumulators: Vec<Accumulator>,
pending_row: Option<Row>,
finished: bool,
key_buffer: Vec<u8>,
}
impl SortMergeAggregateOp {
#[must_use]
pub fn new(
group_by: Vec<LogicalExpr>,
aggregates: Vec<LogicalExpr>,
having: Option<LogicalExpr>,
input: BoxedOperator,
) -> Self {
let mut columns = Vec::with_capacity(group_by.len() + aggregates.len());
for (i, expr) in group_by.iter().enumerate() {
columns.push(expr_to_name(expr, i));
}
for (i, expr) in aggregates.iter().enumerate() {
columns.push(expr_to_name(expr, group_by.len() + i));
}
let schema = Arc::new(Schema::new(columns));
Self {
base: OperatorBase::new(schema),
group_by,
aggregates,
having,
input,
current_key: None,
current_values: Vec::with_capacity(8), accumulators: Vec::new(),
pending_row: None,
finished: false,
key_buffer: Vec::with_capacity(64), }
}
fn compute_group_key_into(&self, row: &Row, buf: &mut Vec<u8>) -> OperatorResult<()> {
buf.clear();
for expr in &self.group_by {
let value = evaluate_expr(expr, row)?;
encode_value(&value, buf);
}
Ok(())
}
fn compute_group_values(&self, row: &Row) -> OperatorResult<Vec<Value>> {
self.group_by.iter().map(|expr| evaluate_expr(expr, row)).collect()
}
fn init_accumulators(&mut self) {
self.accumulators = (0..self.aggregates.len()).map(|_| Accumulator::new()).collect();
}
fn update_accumulators(&mut self, row: &Row) -> OperatorResult<()> {
for (i, agg_expr) in self.aggregates.iter().enumerate() {
if let LogicalExpr::AggregateFunction { func, arg, .. } = agg_expr {
let is_wildcard = matches!(arg.as_ref(), LogicalExpr::Wildcard);
let arg_value = evaluate_expr(arg, row)?;
self.accumulators[i].update(func, &arg_value, is_wildcard);
}
}
Ok(())
}
fn build_result(&self) -> Row {
let mut values = self.current_values.clone();
for acc in &self.accumulators {
values.push(acc.result());
}
Row::new(self.base.schema(), values)
}
}
impl Operator for SortMergeAggregateOp {
fn open(&mut self, ctx: &ExecutionContext) -> OperatorResult<()> {
self.input.open(ctx)?;
self.current_key = None;
self.current_values.clear();
self.accumulators.clear();
self.pending_row = None;
self.finished = false;
self.base.set_open();
Ok(())
}
fn next(&mut self) -> OperatorResult<Option<Row>> {
if self.finished {
return Ok(None);
}
let mut key_buffer = std::mem::take(&mut self.key_buffer);
let result = self.next_inner(&mut key_buffer);
self.key_buffer = key_buffer;
result
}
fn close(&mut self) -> OperatorResult<()> {
self.input.close()?;
self.base.set_closed();
Ok(())
}
fn schema(&self) -> Arc<Schema> {
self.base.schema()
}
fn state(&self) -> OperatorState {
self.base.state()
}
fn name(&self) -> &'static str {
"SortMergeAggregate"
}
}
impl SortMergeAggregateOp {
fn next_inner(&mut self, key_buffer: &mut Vec<u8>) -> OperatorResult<Option<Row>> {
loop {
let row =
if let Some(r) = self.pending_row.take() { Some(r) } else { self.input.next()? };
match row {
Some(row) => {
self.compute_group_key_into(&row, key_buffer)?;
if self.current_key.as_deref() == Some(key_buffer.as_slice()) {
self.update_accumulators(&row)?;
} else if self.current_key.is_some() {
self.pending_row = Some(row.clone());
let result = self.build_result();
self.current_key = Some(key_buffer.clone());
self.current_values = self.compute_group_values(&row)?;
self.init_accumulators();
self.update_accumulators(&row)?;
if let Some(having) = &self.having {
let check = evaluate_expr(having, &result)?;
if !matches!(check, Value::Bool(true)) {
continue;
}
}
self.base.inc_rows_produced();
return Ok(Some(result));
} else {
self.current_key = Some(key_buffer.clone());
self.current_values = self.compute_group_values(&row)?;
self.init_accumulators();
self.update_accumulators(&row)?;
}
}
None => {
self.finished = true;
if self.current_key.is_some() {
let result = self.build_result();
if let Some(having) = &self.having {
let check = evaluate_expr(having, &result)?;
if !matches!(check, Value::Bool(true)) {
self.base.set_finished();
return Ok(None);
}
}
self.base.inc_rows_produced();
return Ok(Some(result));
}
self.base.set_finished();
return Ok(None);
}
}
}
}
}
#[derive(Debug)]
struct GroupState {
group_values: Vec<Value>,
accumulators: Vec<Accumulator>,
}
impl GroupState {
fn new(group_values: Vec<Value>, num_aggregates: usize) -> Self {
Self {
group_values,
accumulators: (0..num_aggregates).map(|_| Accumulator::new()).collect(),
}
}
}
#[derive(Debug, Default)]
struct Accumulator {
func: Option<AggregateFunction>,
count: i64,
sum: f64,
min: Option<Value>,
max: Option<Value>,
}
impl Accumulator {
fn new() -> Self {
Self::default()
}
fn update(&mut self, func: &AggregateFunction, value: &Value, is_wildcard: bool) {
if self.func.is_none() {
self.func = Some(*func);
}
if matches!(func, AggregateFunction::Count) && is_wildcard {
self.count += 1;
return;
}
if matches!(value, Value::Null) {
return;
}
self.count += 1;
match func {
AggregateFunction::Count => {
}
AggregateFunction::Sum | AggregateFunction::Avg => {
self.sum += value_to_f64(value);
}
AggregateFunction::Min => {
self.min = Some(match &self.min {
None => value.clone(),
Some(m) => {
if compare_values(value, m) < 0 {
value.clone()
} else {
m.clone()
}
}
});
}
AggregateFunction::Max => {
self.max = Some(match &self.max {
None => value.clone(),
Some(m) => {
if compare_values(value, m) > 0 {
value.clone()
} else {
m.clone()
}
}
});
}
_ => {}
}
}
fn result(&self) -> Value {
match &self.func {
Some(AggregateFunction::Count) => Value::Int(self.count),
Some(AggregateFunction::Sum) => {
if self.count > 0 {
Value::Float(self.sum)
} else {
Value::Null
}
}
Some(AggregateFunction::Avg) => {
if self.count > 0 {
Value::Float(self.sum / self.count as f64)
} else {
Value::Null
}
}
Some(AggregateFunction::Min) => self.min.clone().unwrap_or(Value::Null),
Some(AggregateFunction::Max) => self.max.clone().unwrap_or(Value::Null),
_ => {
if self.min.is_some() {
return self.min.clone().unwrap_or(Value::Null);
}
if self.max.is_some() {
return self.max.clone().unwrap_or(Value::Null);
}
Value::Int(self.count)
}
}
}
fn default_for(&self, func: &AggregateFunction) -> Value {
match func {
AggregateFunction::Count => Value::Int(0),
AggregateFunction::Sum => Value::Null,
AggregateFunction::Avg => Value::Null,
AggregateFunction::Min | AggregateFunction::Max => Value::Null,
_ => Value::Null,
}
}
}
fn encode_value(value: &Value, buf: &mut Vec<u8>) {
match value {
Value::Null => buf.push(0),
Value::Bool(b) => {
buf.push(1);
buf.push(u8::from(*b));
}
Value::Int(i) => {
buf.push(2);
buf.extend_from_slice(&i.to_le_bytes());
}
Value::Float(f) => {
buf.push(3);
buf.extend_from_slice(&f.to_le_bytes());
}
Value::String(s) => {
buf.push(4);
buf.extend_from_slice(s.as_bytes());
buf.push(0);
}
_ => buf.push(0),
}
}
fn value_to_f64(value: &Value) -> f64 {
match value {
Value::Int(i) => *i as f64,
Value::Float(f) => *f,
_ => 0.0,
}
}
fn compare_values(a: &Value, b: &Value) -> i32 {
match (a, b) {
(Value::Int(a), Value::Int(b)) => a.cmp(b) as i32,
(Value::Float(a), Value::Float(b)) => {
if a < b {
-1
} else if a > b {
1
} else {
0
}
}
(Value::String(a), Value::String(b)) => a.cmp(b) as i32,
_ => 0,
}
}
fn expr_to_name(expr: &LogicalExpr, index: usize) -> String {
match expr {
LogicalExpr::Column { name, .. } => name.clone(),
LogicalExpr::Alias { alias, .. } => alias.clone(),
LogicalExpr::AggregateFunction { func, .. } => format!("{func}"),
_ => format!("col_{index}"),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exec::operators::values::ValuesOp;
fn make_input() -> BoxedOperator {
Box::new(ValuesOp::with_columns(
vec!["dept".to_string(), "salary".to_string()],
vec![
vec![Value::from("A"), Value::Int(100)],
vec![Value::from("A"), Value::Int(150)],
vec![Value::from("B"), Value::Int(200)],
vec![Value::from("A"), Value::Int(125)],
vec![Value::from("B"), Value::Int(180)],
],
))
}
#[test]
fn hash_aggregate_count() {
let group_by = vec![LogicalExpr::column("dept")];
let aggregates = vec![LogicalExpr::count(LogicalExpr::wildcard(), false)];
let mut op = HashAggregateOp::new(group_by, aggregates, None, make_input());
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let mut rows = Vec::new();
while let Some(row) = op.next().unwrap() {
rows.push(row);
}
assert_eq!(rows.len(), 2);
for row in &rows {
let dept = row.get(0).unwrap();
let count = row.get(1).unwrap();
if dept == &Value::from("A") {
assert_eq!(count, &Value::Int(3));
} else if dept == &Value::from("B") {
assert_eq!(count, &Value::Int(2));
}
}
op.close().unwrap();
}
#[test]
fn hash_aggregate_sum() {
let group_by = vec![LogicalExpr::column("dept")];
let aggregates = vec![LogicalExpr::sum(LogicalExpr::column("salary"), false)];
let mut op = HashAggregateOp::new(group_by, aggregates, None, make_input());
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let mut rows = Vec::new();
while let Some(row) = op.next().unwrap() {
rows.push(row);
}
assert_eq!(rows.len(), 2);
for row in &rows {
let dept = row.get(0).unwrap();
let sum = row.get(1).unwrap();
if dept == &Value::from("A") {
assert_eq!(sum, &Value::Float(375.0));
} else if dept == &Value::from("B") {
assert_eq!(sum, &Value::Float(380.0));
}
}
op.close().unwrap();
}
#[test]
fn hash_aggregate_min_max() {
let group_by = vec![LogicalExpr::column("dept")];
let aggregates = vec![
LogicalExpr::min(LogicalExpr::column("salary")),
LogicalExpr::max(LogicalExpr::column("salary")),
];
let mut op = HashAggregateOp::new(group_by, aggregates, None, make_input());
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let mut found_a = false;
while let Some(row) = op.next().unwrap() {
if row.get(0) == Some(&Value::from("A")) {
assert_eq!(row.get(1), Some(&Value::Int(100))); assert_eq!(row.get(2), Some(&Value::Int(150))); found_a = true;
}
}
assert!(found_a);
op.close().unwrap();
}
#[test]
fn hash_aggregate_no_groups() {
let input: BoxedOperator = Box::new(ValuesOp::with_columns(
vec!["n".to_string()],
vec![vec![Value::Int(1)], vec![Value::Int(2)], vec![Value::Int(3)]],
));
let group_by = vec![];
let aggregates = vec![
LogicalExpr::count(LogicalExpr::wildcard(), false),
LogicalExpr::sum(LogicalExpr::column("n"), false),
];
let mut op = HashAggregateOp::new(group_by, aggregates, None, input);
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let row = op.next().unwrap().unwrap();
assert_eq!(row.get(0), Some(&Value::Int(3))); assert_eq!(row.get(1), Some(&Value::Float(6.0)));
assert!(op.next().unwrap().is_none());
op.close().unwrap();
}
#[test]
fn hash_aggregate_avg_with_nulls() {
let input: BoxedOperator = Box::new(ValuesOp::with_columns(
vec!["val".to_string()],
vec![vec![Value::Int(10)], vec![Value::Null], vec![Value::Int(20)]],
));
let group_by = vec![];
let aggregates = vec![LogicalExpr::avg(LogicalExpr::column("val"), false)];
let mut op = HashAggregateOp::new(group_by, aggregates, None, input);
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let row = op.next().unwrap().unwrap();
assert_eq!(row.get(0), Some(&Value::Float(15.0)));
assert!(op.next().unwrap().is_none());
op.close().unwrap();
}
#[test]
fn hash_aggregate_avg_vs_sum() {
let input: BoxedOperator = Box::new(ValuesOp::with_columns(
vec!["n".to_string()],
vec![vec![Value::Int(10)], vec![Value::Int(20)], vec![Value::Int(30)]],
));
let group_by = vec![];
let aggregates = vec![
LogicalExpr::sum(LogicalExpr::column("n"), false),
LogicalExpr::avg(LogicalExpr::column("n"), false),
];
let mut op = HashAggregateOp::new(group_by, aggregates, None, input);
let ctx = ExecutionContext::new();
op.open(&ctx).unwrap();
let row = op.next().unwrap().unwrap();
assert_eq!(row.get(0), Some(&Value::Float(60.0))); assert_eq!(row.get(1), Some(&Value::Float(20.0)));
assert!(op.next().unwrap().is_none());
op.close().unwrap();
}
}