use super::eval::{compare_values, eval_expr};
use super::node::PlanNode;
use super::types::{ColumnMeta, Row, Schema};
use crate::soch_ql::SochValue;
use crate::sql::ast::Expr;
use sochdb_core::Result;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AggFunc {
Count,
CountDistinct,
Sum,
Avg,
Min,
Max,
}
#[derive(Debug, Clone)]
pub struct AggDef {
pub func: AggFunc,
pub expr: Option<Expr>,
pub alias: String,
}
struct Accumulator {
func: AggFunc,
count: u64,
sum_int: i64,
sum_float: f64,
is_float: bool,
min_val: Option<SochValue>,
max_val: Option<SochValue>,
distinct_set: Option<Vec<SochValue>>,
}
impl Accumulator {
fn new(func: &AggFunc) -> Self {
Self {
func: func.clone(),
count: 0,
sum_int: 0,
sum_float: 0.0,
is_float: false,
min_val: None,
max_val: None,
distinct_set: if matches!(func, AggFunc::CountDistinct) {
Some(Vec::new())
} else {
None
},
}
}
fn accumulate(&mut self, val: &SochValue) {
if matches!(val, SochValue::Null) {
if matches!(self.func, AggFunc::Count) {
return;
}
return;
}
match self.func {
AggFunc::Count => {
self.count += 1;
}
AggFunc::CountDistinct => {
if let Some(set) = &mut self.distinct_set {
let already = set
.iter()
.any(|v| compare_values(v, val) == Some(std::cmp::Ordering::Equal));
if !already {
set.push(val.clone());
}
}
}
AggFunc::Sum => {
match val {
SochValue::Int(i) => {
if self.is_float {
self.sum_float += *i as f64;
} else {
self.sum_int += i;
}
}
SochValue::UInt(u) => {
if self.is_float {
self.sum_float += *u as f64;
} else {
self.sum_int += *u as i64;
}
}
SochValue::Float(f) => {
if !self.is_float {
self.sum_float = self.sum_int as f64;
self.is_float = true;
}
self.sum_float += f;
}
_ => {}
}
self.count += 1;
}
AggFunc::Avg => {
match val {
SochValue::Int(i) => self.sum_float += *i as f64,
SochValue::UInt(u) => self.sum_float += *u as f64,
SochValue::Float(f) => self.sum_float += f,
_ => {}
}
self.count += 1;
}
AggFunc::Min => {
let update = match &self.min_val {
None => true,
Some(current) => compare_values(val, current) == Some(std::cmp::Ordering::Less),
};
if update {
self.min_val = Some(val.clone());
}
}
AggFunc::Max => {
let update = match &self.max_val {
None => true,
Some(current) => {
compare_values(val, current) == Some(std::cmp::Ordering::Greater)
}
};
if update {
self.max_val = Some(val.clone());
}
}
}
}
fn finalize(&self) -> SochValue {
match self.func {
AggFunc::Count => SochValue::Int(self.count as i64),
AggFunc::CountDistinct => {
SochValue::Int(self.distinct_set.as_ref().map_or(0, |s| s.len()) as i64)
}
AggFunc::Sum => {
if self.count == 0 {
SochValue::Null
} else if self.is_float {
SochValue::Float(self.sum_float)
} else {
SochValue::Int(self.sum_int)
}
}
AggFunc::Avg => {
if self.count == 0 {
SochValue::Null
} else {
SochValue::Float(self.sum_float / self.count as f64)
}
}
AggFunc::Min => self.min_val.clone().unwrap_or(SochValue::Null),
AggFunc::Max => self.max_val.clone().unwrap_or(SochValue::Null),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct GroupKey(Vec<GroupVal>);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum GroupVal {
Null,
Bool(bool),
Int(i64),
UInt(u64),
Text(String),
Other(String),
}
impl From<&SochValue> for GroupVal {
fn from(v: &SochValue) -> Self {
match v {
SochValue::Null => GroupVal::Null,
SochValue::Bool(b) => GroupVal::Bool(*b),
SochValue::Int(i) => GroupVal::Int(*i),
SochValue::UInt(u) => GroupVal::UInt(*u),
SochValue::Text(s) => GroupVal::Text(s.clone()),
other => GroupVal::Other(format!("{:?}", other)),
}
}
}
struct GroupState {
key_values: Vec<SochValue>,
accumulators: Vec<Accumulator>,
}
pub struct HashAggregateNode {
input: Box<dyn PlanNode>,
group_by_exprs: Vec<Expr>,
agg_defs: Vec<AggDef>,
output_schema: Schema,
groups: Option<Vec<Row>>,
pos: usize,
is_global: bool,
}
impl HashAggregateNode {
pub fn new(input: Box<dyn PlanNode>, group_by_exprs: Vec<Expr>, agg_defs: Vec<AggDef>) -> Self {
let is_global = group_by_exprs.is_empty();
let mut cols: Vec<ColumnMeta> = group_by_exprs
.iter()
.map(|e| {
let name = match e {
Expr::Column(c) => c.column.clone(),
_ => format!("{:?}", e),
};
ColumnMeta::new(name)
})
.collect();
for ad in &agg_defs {
cols.push(ColumnMeta::new(ad.alias.clone()));
}
let output_schema = Schema::new(cols);
Self {
input,
group_by_exprs,
agg_defs,
output_schema,
groups: None,
pos: 0,
is_global,
}
}
fn materialize(&mut self) -> Result<()> {
if self.groups.is_some() {
return Ok(());
}
let input_schema = self.input.schema().clone();
let mut group_map: HashMap<GroupKey, GroupState> = HashMap::new();
let mut group_order: Vec<GroupKey> = Vec::new();
let has_count_star: Vec<bool> = self
.agg_defs
.iter()
.map(|ad| matches!(ad.func, AggFunc::Count) && ad.expr.is_none())
.collect();
while let Some(row) = self.input.next()? {
let key_values: Vec<SochValue> = self
.group_by_exprs
.iter()
.map(|e| eval_expr(e, &row, &input_schema).unwrap_or(SochValue::Null))
.collect();
let group_key = GroupKey(key_values.iter().map(GroupVal::from).collect());
let state = group_map.entry(group_key.clone()).or_insert_with(|| {
group_order.push(group_key.clone());
GroupState {
key_values: key_values.clone(),
accumulators: self
.agg_defs
.iter()
.map(|ad| Accumulator::new(&ad.func))
.collect(),
}
});
for (i, ad) in self.agg_defs.iter().enumerate() {
if has_count_star[i] {
state.accumulators[i].count += 1;
} else if let Some(expr) = &ad.expr {
let val = eval_expr(expr, &row, &input_schema)?;
state.accumulators[i].accumulate(&val);
}
}
}
if self.is_global && group_map.is_empty() {
let mut row: Row = Vec::new();
for ad in &self.agg_defs {
let acc = Accumulator::new(&ad.func);
row.push(acc.finalize());
}
self.groups = Some(vec![row]);
return Ok(());
}
let mut result = Vec::with_capacity(group_order.len());
for gk in &group_order {
if let Some(state) = group_map.get(gk) {
let mut row: Row = state.key_values.clone();
for acc in &state.accumulators {
row.push(acc.finalize());
}
result.push(row);
}
}
self.groups = Some(result);
Ok(())
}
}
impl PlanNode for HashAggregateNode {
fn schema(&self) -> &Schema {
&self.output_schema
}
fn next(&mut self) -> Result<Option<Row>> {
self.materialize()?;
if let Some(groups) = &self.groups {
if self.pos < groups.len() {
let row = groups[self.pos].clone();
self.pos += 1;
Ok(Some(row))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
fn reset(&mut self) -> Result<()> {
self.groups = None;
self.pos = 0;
self.input.reset()
}
}