use indexmap::IndexMap;
use std::collections::HashSet;
use std::sync::Arc;
use arcstr::ArcStr;
use grafeo_common::types::{LogicalType, PropertyKey, Value};
use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
use super::{Operator, OperatorError, OperatorResult};
use crate::execution::DataChunk;
use crate::execution::chunk::DataChunkBuilder;
#[derive(Debug, Clone)]
#[allow(missing_docs)]
pub enum AggregateState {
Count(i64),
CountDistinct(i64, HashSet<HashableValue>),
SumInt(i64, i64),
SumIntDistinct(i64, i64, HashSet<HashableValue>),
SumFloat(f64, f64, i64),
SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
Avg(f64, i64),
AvgDistinct(f64, i64, HashSet<HashableValue>),
Min(Option<Value>),
Max(Option<Value>),
First(Option<Value>),
Last(Option<Value>),
Collect(Vec<Value>),
CollectDistinct(Vec<Value>, HashSet<HashableValue>),
StdDev { count: i64, mean: f64, m2: f64 },
StdDevPop { count: i64, mean: f64, m2: f64 },
PercentileDisc { values: Vec<f64>, percentile: f64 },
PercentileCont { values: Vec<f64>, percentile: f64 },
GroupConcat(Vec<String>, String),
GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
Sample(Option<Value>),
Variance { count: i64, mean: f64, m2: f64 },
VariancePop { count: i64, mean: f64, m2: f64 },
Bivariate {
kind: AggregateFunction,
count: i64,
mean_x: f64,
mean_y: f64,
m2_x: f64,
m2_y: f64,
c_xy: f64,
},
Frozen(Value),
}
impl AggregateState {
pub fn new(
function: AggregateFunction,
distinct: bool,
percentile: Option<f64>,
separator: Option<&str>,
) -> Self {
match (function, distinct) {
(AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
AggregateState::Count(0)
}
(AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
AggregateState::CountDistinct(0, HashSet::new())
}
(AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
(AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
(AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
(AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
(AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
(AggregateFunction::First, _) => AggregateState::First(None),
(AggregateFunction::Last, _) => AggregateState::Last(None),
(AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
(AggregateFunction::Collect, true) => {
AggregateState::CollectDistinct(Vec::new(), HashSet::new())
}
(AggregateFunction::StdDev, _) => AggregateState::StdDev {
count: 0,
mean: 0.0,
m2: 0.0,
},
(AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
count: 0,
mean: 0.0,
m2: 0.0,
},
(AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
values: Vec::new(),
percentile: percentile.unwrap_or(0.5),
},
(AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
values: Vec::new(),
percentile: percentile.unwrap_or(0.5),
},
(AggregateFunction::GroupConcat, false) => {
AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
}
(AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
Vec::new(),
separator.unwrap_or(" ").to_string(),
HashSet::new(),
),
(AggregateFunction::Sample, _) => AggregateState::Sample(None),
(
AggregateFunction::CovarSamp
| AggregateFunction::CovarPop
| AggregateFunction::Corr
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrR2
| AggregateFunction::RegrCount
| AggregateFunction::RegrSxx
| AggregateFunction::RegrSyy
| AggregateFunction::RegrSxy
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy,
_,
) => AggregateState::Bivariate {
kind: function,
count: 0,
mean_x: 0.0,
mean_y: 0.0,
m2_x: 0.0,
m2_y: 0.0,
c_xy: 0.0,
},
(AggregateFunction::Variance, _) => AggregateState::Variance {
count: 0,
mean: 0.0,
m2: 0.0,
},
(AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
count: 0,
mean: 0.0,
m2: 0.0,
},
}
}
pub fn update(&mut self, value: Option<Value>) {
match self {
AggregateState::Count(count) => {
*count += 1;
}
AggregateState::CountDistinct(count, seen) => {
if let Some(ref v) = value {
let hashable = HashableValue::from(v);
if seen.insert(hashable) {
*count += 1;
}
}
}
AggregateState::SumInt(sum, count) => {
if let Some(Value::Int64(v)) = value {
*sum += v;
*count += 1;
} else if let Some(Value::Float64(v)) = value {
*self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
} else if let Some(ref v) = value {
if let Some(num) = value_to_f64(v) {
*self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
}
}
}
AggregateState::SumIntDistinct(sum, count, seen) => {
if let Some(ref v) = value {
let hashable = HashableValue::from(v);
if seen.insert(hashable) {
if let Value::Int64(i) = v {
*sum += i;
*count += 1;
} else if let Value::Float64(f) = v {
let moved_seen = std::mem::take(seen);
*self = AggregateState::SumFloatDistinct(
*sum as f64 + f,
0.0,
*count + 1,
moved_seen,
);
} else if let Some(num) = value_to_f64(v) {
let moved_seen = std::mem::take(seen);
*self = AggregateState::SumFloatDistinct(
*sum as f64 + num,
0.0,
*count + 1,
moved_seen,
);
}
}
}
}
AggregateState::SumFloat(sum, comp, count) => {
if let Some(ref v) = value {
if let Some(num) = value_to_f64(v) {
let y = num - *comp;
let t = *sum + y;
*comp = (t - *sum) - y;
*sum = t;
*count += 1;
}
}
}
AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
if let Some(ref v) = value {
let hashable = HashableValue::from(v);
if seen.insert(hashable)
&& let Some(num) = value_to_f64(v)
{
let y = num - *comp;
let t = *sum + y;
*comp = (t - *sum) - y;
*sum = t;
*count += 1;
}
}
}
AggregateState::Avg(sum, count) => {
if let Some(ref v) = value
&& let Some(num) = value_to_f64(v)
{
*sum += num;
*count += 1;
}
}
AggregateState::AvgDistinct(sum, count, seen) => {
if let Some(ref v) = value {
let hashable = HashableValue::from(v);
if seen.insert(hashable)
&& let Some(num) = value_to_f64(v)
{
*sum += num;
*count += 1;
}
}
}
AggregateState::Min(min) => {
if let Some(v) = value {
match min {
None => *min = Some(v),
Some(current) => {
if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
*min = Some(v);
}
}
}
}
}
AggregateState::Max(max) => {
if let Some(v) = value {
match max {
None => *max = Some(v),
Some(current) => {
if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
*max = Some(v);
}
}
}
}
}
AggregateState::First(first) => {
if first.is_none() {
*first = value;
}
}
AggregateState::Last(last) => {
if value.is_some() {
*last = value;
}
}
AggregateState::Collect(list) => {
if let Some(v) = value {
list.push(v);
}
}
AggregateState::CollectDistinct(list, seen) => {
if let Some(v) = value {
let hashable = HashableValue::from(&v);
if seen.insert(hashable) {
list.push(v);
}
}
}
AggregateState::StdDev { count, mean, m2 }
| AggregateState::StdDevPop { count, mean, m2 }
| AggregateState::Variance { count, mean, m2 }
| AggregateState::VariancePop { count, mean, m2 } => {
if let Some(ref v) = value
&& let Some(x) = value_to_f64(v)
{
*count += 1;
let delta = x - *mean;
*mean += delta / *count as f64;
let delta2 = x - *mean;
*m2 += delta * delta2;
}
}
AggregateState::PercentileDisc { values, .. }
| AggregateState::PercentileCont { values, .. } => {
if let Some(ref v) = value
&& let Some(x) = value_to_f64(v)
{
values.push(x);
}
}
AggregateState::GroupConcat(list, _sep) => {
if let Some(v) = value {
list.push(agg_value_to_string(&v));
}
}
AggregateState::GroupConcatDistinct(list, _sep, seen) => {
if let Some(v) = value {
let hashable = HashableValue::from(&v);
if seen.insert(hashable) {
list.push(agg_value_to_string(&v));
}
}
}
AggregateState::Sample(sample) => {
if sample.is_none() {
*sample = value;
}
}
AggregateState::Bivariate { .. } => {
}
AggregateState::Frozen(_) => {}
}
}
pub fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
if let AggregateState::Bivariate {
count,
mean_x,
mean_y,
m2_x,
m2_y,
c_xy,
..
} = self
{
if let (Some(y), Some(x)) = (&y_val, &x_val)
&& let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
{
*count += 1;
let n = *count as f64;
let dx = x_f - *mean_x;
let dy = y_f - *mean_y;
*mean_x += dx / n;
*mean_y += dy / n;
let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
*m2_y += dy * dy2;
*c_xy += dx * dy2;
}
}
}
pub fn finalize(&self) -> Value {
match self {
AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
Value::Int64(*count)
}
AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
if *count == 0 {
Value::Null
} else {
Value::Int64(*sum)
}
}
AggregateState::SumFloat(sum, _, count)
| AggregateState::SumFloatDistinct(sum, _, count, _) => {
if *count == 0 {
Value::Null
} else {
Value::Float64(*sum)
}
}
AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
if *count == 0 {
Value::Null
} else {
Value::Float64(*sum / *count as f64)
}
}
AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
Value::List(list.clone().into())
}
AggregateState::StdDev { count, m2, .. } => {
if *count < 2 {
Value::Null
} else {
Value::Float64((*m2 / (*count - 1) as f64).sqrt())
}
}
AggregateState::StdDevPop { count, m2, .. } => {
if *count == 0 {
Value::Null
} else {
Value::Float64((*m2 / *count as f64).sqrt())
}
}
AggregateState::Variance { count, m2, .. } => {
if *count < 2 {
Value::Null
} else {
Value::Float64(*m2 / (*count - 1) as f64)
}
}
AggregateState::VariancePop { count, m2, .. } => {
if *count == 0 {
Value::Null
} else {
Value::Float64(*m2 / *count as f64)
}
}
AggregateState::PercentileDisc { values, percentile } => {
if values.is_empty() {
Value::Null
} else {
let mut sorted = values.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
Value::Float64(sorted[index])
}
}
AggregateState::PercentileCont { values, percentile } => {
if values.is_empty() {
Value::Null
} else {
let mut sorted = values.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let rank = percentile * (sorted.len() - 1) as f64;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let lower_idx = rank.floor() as usize;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let upper_idx = rank.ceil() as usize;
if lower_idx == upper_idx {
Value::Float64(sorted[lower_idx])
} else {
let fraction = rank - lower_idx as f64;
let result =
sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
Value::Float64(result)
}
}
}
AggregateState::GroupConcat(list, sep)
| AggregateState::GroupConcatDistinct(list, sep, _) => {
Value::String(list.join(sep).into())
}
AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
AggregateState::Frozen(val) => val.clone(),
AggregateState::Bivariate {
kind,
count,
mean_x,
mean_y,
m2_x,
m2_y,
c_xy,
} => {
let n = *count;
match kind {
AggregateFunction::CovarSamp => {
if n < 2 {
Value::Null
} else {
Value::Float64(*c_xy / (n - 1) as f64)
}
}
AggregateFunction::CovarPop => {
if n == 0 {
Value::Null
} else {
Value::Float64(*c_xy / n as f64)
}
}
AggregateFunction::Corr => {
if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
Value::Null
} else {
Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
}
}
AggregateFunction::RegrSlope => {
if n == 0 || *m2_x == 0.0 {
Value::Null
} else {
Value::Float64(*c_xy / *m2_x)
}
}
AggregateFunction::RegrIntercept => {
if n == 0 || *m2_x == 0.0 {
Value::Null
} else {
let slope = *c_xy / *m2_x;
Value::Float64(*mean_y - slope * *mean_x)
}
}
AggregateFunction::RegrR2 => {
if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
Value::Null
} else {
Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
}
}
AggregateFunction::RegrCount => Value::Int64(n),
AggregateFunction::RegrSxx => {
if n == 0 {
Value::Null
} else {
Value::Float64(*m2_x)
}
}
AggregateFunction::RegrSyy => {
if n == 0 {
Value::Null
} else {
Value::Float64(*m2_y)
}
}
AggregateFunction::RegrSxy => {
if n == 0 {
Value::Null
} else {
Value::Float64(*c_xy)
}
}
AggregateFunction::RegrAvgx => {
if n == 0 {
Value::Null
} else {
Value::Float64(*mean_x)
}
}
AggregateFunction::RegrAvgy => {
if n == 0 {
Value::Null
} else {
Value::Float64(*mean_y)
}
}
_ => Value::Null, }
}
}
}
}
use super::value_utils::{compare_values, value_to_f64};
fn agg_value_to_string(val: &Value) -> String {
match val {
Value::String(s) => s.to_string(),
Value::Int64(i) => i.to_string(),
Value::Float64(f) => f.to_string(),
Value::Bool(b) => b.to_string(),
Value::Null => String::new(),
other => format!("{other:?}"),
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GroupKey(Vec<GroupKeyPart>);
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
enum GroupKeyPart {
Null,
Bool(bool),
Int64(i64),
String(ArcStr),
Bytes(Arc<[u8]>),
Date(grafeo_common::types::Date),
Time(grafeo_common::types::Time),
Timestamp(grafeo_common::types::Timestamp),
Duration(grafeo_common::types::Duration),
ZonedDatetime(grafeo_common::types::ZonedDatetime),
List(Vec<GroupKeyPart>),
Map(Vec<(ArcStr, GroupKeyPart)>),
}
impl GroupKeyPart {
fn from_value(v: Value) -> Self {
match v {
Value::Null => Self::Null,
Value::Bool(b) => Self::Bool(b),
Value::Int64(i) => Self::Int64(i),
#[allow(clippy::cast_possible_wrap)]
Value::Float64(f) => Self::Int64(f.to_bits() as i64),
Value::String(s) => Self::String(s.clone()),
Value::Bytes(b) => Self::Bytes(b),
Value::Date(d) => Self::Date(d),
Value::Time(t) => Self::Time(t),
Value::Timestamp(ts) => Self::Timestamp(ts),
Value::Duration(d) => Self::Duration(d),
Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
Value::Map(map) => {
let entries: Vec<(ArcStr, GroupKeyPart)> = map
.iter()
.map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
.collect();
Self::Map(entries)
}
other => Self::String(ArcStr::from(format!("{other:?}"))),
}
}
fn to_value(&self) -> Value {
match self {
Self::Null => Value::Null,
Self::Bool(b) => Value::Bool(*b),
Self::Int64(i) => Value::Int64(*i),
Self::String(s) => Value::String(s.clone()),
Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
Self::Date(d) => Value::Date(*d),
Self::Time(t) => Value::Time(*t),
Self::Timestamp(ts) => Value::Timestamp(*ts),
Self::Duration(d) => Value::Duration(*d),
Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
Self::List(parts) => {
let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
Value::List(Arc::from(values.into_boxed_slice()))
}
Self::Map(entries) => {
let map: std::collections::BTreeMap<PropertyKey, Value> = entries
.iter()
.map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
.collect();
Value::Map(Arc::new(map))
}
}
}
}
impl GroupKey {
fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
let parts: Vec<GroupKeyPart> = group_columns
.iter()
.map(|&col_idx| {
chunk
.column(col_idx)
.and_then(|col| col.get_value(row))
.map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
})
.collect();
GroupKey(parts)
}
fn to_values(&self) -> Vec<Value> {
self.0.iter().map(GroupKeyPart::to_value).collect()
}
}
pub struct HashAggregateOperator {
child: Box<dyn Operator>,
group_columns: Vec<usize>,
aggregates: Vec<AggregateExpr>,
output_schema: Vec<LogicalType>,
groups: IndexMap<GroupKey, Vec<AggregateState>>,
aggregation_complete: bool,
results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
}
impl HashAggregateOperator {
pub fn new(
child: Box<dyn Operator>,
group_columns: Vec<usize>,
aggregates: Vec<AggregateExpr>,
output_schema: Vec<LogicalType>,
) -> Self {
Self {
child,
group_columns,
aggregates,
output_schema,
groups: IndexMap::new(),
aggregation_complete: false,
results: None,
}
}
pub fn into_parts(self) -> (Box<dyn Operator>, Vec<usize>, Vec<AggregateExpr>) {
(self.child, self.group_columns, self.aggregates)
}
fn aggregate(&mut self) -> Result<(), OperatorError> {
while let Some(chunk) = self.child.next()? {
for row in chunk.selected_indices() {
let key = GroupKey::from_row(&chunk, row, &self.group_columns);
let states = self.groups.entry(key).or_insert_with(|| {
self.aggregates
.iter()
.map(|agg| {
AggregateState::new(
agg.function,
agg.distinct,
agg.percentile,
agg.separator.as_deref(),
)
})
.collect()
});
for (i, agg) in self.aggregates.iter().enumerate() {
if agg.column2.is_some() {
let y_val = agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
let x_val = agg
.column2
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
states[i].update_bivariate(y_val, x_val);
continue;
}
let value = match (agg.function, agg.distinct) {
(AggregateFunction::Count, false) => None,
(AggregateFunction::Count, true) => agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
_ => agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
};
match (agg.function, agg.distinct) {
(AggregateFunction::Count, false) => states[i].update(None),
(AggregateFunction::Count, true) => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
states[i].update(value);
}
}
(AggregateFunction::CountNonNull, _) => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
states[i].update(value);
}
}
_ => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
states[i].update(value);
}
}
}
}
}
}
self.aggregation_complete = true;
let results: Vec<_> = self.groups.drain(..).collect();
self.results = Some(results.into_iter());
Ok(())
}
}
impl Operator for HashAggregateOperator {
fn next(&mut self) -> OperatorResult {
if !self.aggregation_complete {
self.aggregate()?;
}
if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
for agg in &self.aggregates {
let state = AggregateState::new(
agg.function,
agg.distinct,
agg.percentile,
agg.separator.as_deref(),
);
let value = state.finalize();
if let Some(col) = builder.column_mut(self.group_columns.len()) {
col.push_value(value);
}
}
builder.advance_row();
self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
}
let Some(results) = &mut self.results else {
return Ok(None);
};
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
for (key, states) in results.by_ref() {
let key_values = key.to_values();
for (i, value) in key_values.into_iter().enumerate() {
if let Some(col) = builder.column_mut(i) {
col.push_value(value);
}
}
for (i, state) in states.iter().enumerate() {
let col_idx = self.group_columns.len() + i;
if let Some(col) = builder.column_mut(col_idx) {
col.push_value(state.finalize());
}
}
builder.advance_row();
if builder.is_full() {
return Ok(Some(builder.finish()));
}
}
if builder.row_count() > 0 {
Ok(Some(builder.finish()))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.child.reset();
self.groups.clear();
self.aggregation_complete = false;
self.results = None;
}
fn name(&self) -> &'static str {
"HashAggregate"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
pub struct SimpleAggregateOperator {
child: Box<dyn Operator>,
aggregates: Vec<AggregateExpr>,
output_schema: Vec<LogicalType>,
states: Vec<AggregateState>,
done: bool,
}
impl SimpleAggregateOperator {
pub fn new(
child: Box<dyn Operator>,
aggregates: Vec<AggregateExpr>,
output_schema: Vec<LogicalType>,
) -> Self {
let states = aggregates
.iter()
.map(|agg| {
AggregateState::new(
agg.function,
agg.distinct,
agg.percentile,
agg.separator.as_deref(),
)
})
.collect();
Self {
child,
aggregates,
output_schema,
states,
done: false,
}
}
}
impl Operator for SimpleAggregateOperator {
fn next(&mut self) -> OperatorResult {
if self.done {
return Ok(None);
}
while let Some(chunk) = self.child.next()? {
for row in chunk.selected_indices() {
for (i, agg) in self.aggregates.iter().enumerate() {
if agg.column2.is_some() {
let y_val = agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
let x_val = agg
.column2
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
self.states[i].update_bivariate(y_val, x_val);
continue;
}
let value = match (agg.function, agg.distinct) {
(AggregateFunction::Count, false) => None,
(AggregateFunction::Count, true) => agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
_ => agg
.column
.and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
};
match (agg.function, agg.distinct) {
(AggregateFunction::Count, false) => self.states[i].update(None),
(AggregateFunction::Count, true) => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
self.states[i].update(value);
}
}
(AggregateFunction::CountNonNull, _) => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
self.states[i].update(value);
}
}
_ => {
if value.is_some() && !matches!(value, Some(Value::Null)) {
self.states[i].update(value);
}
}
}
}
}
}
let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
for (i, state) in self.states.iter().enumerate() {
if let Some(col) = builder.column_mut(i) {
col.push_value(state.finalize());
}
}
builder.advance_row();
self.done = true;
Ok(Some(builder.finish()))
}
fn reset(&mut self) {
self.child.reset();
self.states = self
.aggregates
.iter()
.map(|agg| {
AggregateState::new(
agg.function,
agg.distinct,
agg.percentile,
agg.separator.as_deref(),
)
})
.collect();
self.done = false;
}
fn name(&self) -> &'static str {
"SimpleAggregate"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::chunk::DataChunkBuilder;
struct MockOperator {
chunks: Vec<DataChunk>,
position: usize,
}
impl MockOperator {
fn new(chunks: Vec<DataChunk>) -> Self {
Self {
chunks,
position: 0,
}
}
}
impl Operator for MockOperator {
fn next(&mut self) -> OperatorResult {
if self.position < self.chunks.len() {
let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
self.position += 1;
Ok(Some(chunk))
} else {
Ok(None)
}
}
fn reset(&mut self) {
self.position = 0;
}
fn name(&self) -> &'static str {
"Mock"
}
fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
self
}
}
fn create_test_chunk() -> DataChunk {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
for (group, value) in data {
builder.column_mut(0).unwrap().push_int64(group);
builder.column_mut(1).unwrap().push_int64(value);
builder.advance_row();
}
builder.finish()
}
#[test]
fn test_simple_count() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::count_star()],
vec![LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
assert!(agg.next().unwrap().is_none());
}
#[test]
fn test_simple_sum() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
}
#[test]
fn test_simple_avg() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::avg(1)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let avg = result.column(0).unwrap().get_float64(0).unwrap();
assert!((avg - 30.0).abs() < 0.001);
}
#[test]
fn test_simple_min_max() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::min(1), AggregateExpr::max(1)],
vec![LogicalType::Int64, LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
#[test]
fn test_sum_with_string_values() {
let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
builder.column_mut(0).unwrap().push_string("30");
builder.advance_row();
builder.column_mut(0).unwrap().push_string("25");
builder.advance_row();
builder.column_mut(0).unwrap().push_string("35");
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::sum(0)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
assert!(
(sum_val - 90.0).abs() < 0.001,
"Expected 90.0, got {}",
sum_val
);
}
#[test]
fn test_grouped_aggregation() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = HashAggregateOperator::new(
Box::new(mock),
vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
);
let mut results: Vec<(i64, i64)> = Vec::new();
while let Some(chunk) = agg.next().unwrap() {
for row in chunk.selected_indices() {
let group = chunk.column(0).unwrap().get_int64(row).unwrap();
let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
results.push((group, sum));
}
}
results.sort_by_key(|(g, _)| *g);
assert_eq!(results.len(), 2);
assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
#[test]
fn test_grouped_count() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = HashAggregateOperator::new(
Box::new(mock),
vec![0],
vec![AggregateExpr::count_star()],
vec![LogicalType::Int64, LogicalType::Int64],
);
let mut results: Vec<(i64, i64)> = Vec::new();
while let Some(chunk) = agg.next().unwrap() {
for row in chunk.selected_indices() {
let group = chunk.column(0).unwrap().get_int64(row).unwrap();
let count = chunk.column(1).unwrap().get_int64(row).unwrap();
results.push((group, count));
}
}
results.sort_by_key(|(g, _)| *g);
assert_eq!(results.len(), 2);
assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
#[test]
fn test_multiple_aggregates() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = HashAggregateOperator::new(
Box::new(mock),
vec![0],
vec![
AggregateExpr::count_star(),
AggregateExpr::sum(1),
AggregateExpr::avg(1),
],
vec![
LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
);
let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
while let Some(chunk) = agg.next().unwrap() {
for row in chunk.selected_indices() {
let group = chunk.column(0).unwrap().get_int64(row).unwrap();
let count = chunk.column(1).unwrap().get_int64(row).unwrap();
let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
results.push((group, count, sum, avg));
}
}
results.sort_by_key(|(g, _, _, _)| *g);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, 1);
assert_eq!(results[0].1, 2);
assert_eq!(results[0].2, 30);
assert!((results[0].3 - 15.0).abs() < 0.001);
assert_eq!(results[1].0, 2);
assert_eq!(results[1].1, 3);
assert_eq!(results[1].2, 120);
assert!((results[1].3 - 40.0).abs() < 0.001);
}
fn create_test_chunk_with_duplicates() -> DataChunk {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
for (group, value) in data {
builder.column_mut(0).unwrap().push_int64(group);
builder.column_mut(1).unwrap().push_int64(value);
builder.advance_row();
}
builder.finish()
}
#[test]
fn test_count_distinct() {
let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::count(1).with_distinct()],
vec![LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
}
#[test]
fn test_grouped_count_distinct() {
let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
let mut agg = HashAggregateOperator::new(
Box::new(mock),
vec![0],
vec![AggregateExpr::count(1).with_distinct()],
vec![LogicalType::Int64, LogicalType::Int64],
);
let mut results: Vec<(i64, i64)> = Vec::new();
while let Some(chunk) = agg.next().unwrap() {
for row in chunk.selected_indices() {
let group = chunk.column(0).unwrap().get_int64(row).unwrap();
let count = chunk.column(1).unwrap().get_int64(row).unwrap();
results.push((group, count));
}
}
results.sort_by_key(|(g, _)| *g);
assert_eq!(results.len(), 2);
assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
#[test]
fn test_sum_distinct() {
let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::sum(1).with_distinct()],
vec![LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
}
#[test]
fn test_avg_distinct() {
let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::avg(1).with_distinct()],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let avg = result.column(0).unwrap().get_float64(0).unwrap();
assert!((avg - 20.0).abs() < 0.001);
}
fn create_statistical_test_chunk() -> DataChunk {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
builder.column_mut(0).unwrap().push_int64(value);
builder.advance_row();
}
builder.finish()
}
#[test]
fn test_stdev_sample() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::stdev(0)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let stdev = result.column(0).unwrap().get_float64(0).unwrap();
assert!((stdev - 2.138).abs() < 0.01);
}
#[test]
fn test_stdev_population() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::stdev_pop(0)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let stdev = result.column(0).unwrap().get_float64(0).unwrap();
assert!((stdev - 2.0).abs() < 0.01);
}
#[test]
fn test_percentile_disc() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::percentile_disc(0, 0.5)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let percentile = result.column(0).unwrap().get_float64(0).unwrap();
assert!((percentile - 4.0).abs() < 0.01);
}
#[test]
fn test_percentile_cont() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::percentile_cont(0, 0.5)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let percentile = result.column(0).unwrap().get_float64(0).unwrap();
assert!((percentile - 4.5).abs() < 0.01);
}
#[test]
fn test_percentile_extremes() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![
AggregateExpr::percentile_disc(0, 0.0),
AggregateExpr::percentile_disc(0, 1.0),
],
vec![LogicalType::Float64, LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let p0 = result.column(0).unwrap().get_float64(0).unwrap();
assert!((p0 - 2.0).abs() < 0.01);
let p100 = result.column(1).unwrap().get_float64(0).unwrap();
assert!((p100 - 9.0).abs() < 0.01);
}
#[test]
fn test_stdev_single_value() {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
builder.column_mut(0).unwrap().push_int64(42);
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::stdev(0)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert!(matches!(
result.column(0).unwrap().get_value(0),
Some(Value::Null)
));
}
#[test]
fn test_first_and_last() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::first(1), AggregateExpr::last(1)],
vec![LogicalType::Int64, LogicalType::Int64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
}
#[test]
fn test_collect() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::collect(1)],
vec![LogicalType::Any],
);
let result = agg.next().unwrap().unwrap();
let val = result.column(0).unwrap().get_value(0).unwrap();
if let Value::List(items) = val {
assert_eq!(items.len(), 5);
} else {
panic!("Expected List value");
}
}
#[test]
fn test_collect_distinct() {
let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::collect(1).with_distinct()],
vec![LogicalType::Any],
);
let result = agg.next().unwrap().unwrap();
let val = result.column(0).unwrap().get_value(0).unwrap();
if let Value::List(items) = val {
assert_eq!(items.len(), 3);
} else {
panic!("Expected List value");
}
}
#[test]
fn test_group_concat() {
let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
for s in ["hello", "world", "foo"] {
builder.column_mut(0).unwrap().push_string(s);
builder.advance_row();
}
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let agg_expr = AggregateExpr {
function: AggregateFunction::GroupConcat,
column: Some(0),
column2: None,
distinct: false,
alias: None,
percentile: None,
separator: None,
};
let mut agg =
SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
let result = agg.next().unwrap().unwrap();
let val = result.column(0).unwrap().get_value(0).unwrap();
assert_eq!(val, Value::String("hello world foo".into()));
}
#[test]
fn test_sample() {
let mock = MockOperator::new(vec![create_test_chunk()]);
let agg_expr = AggregateExpr {
function: AggregateFunction::Sample,
column: Some(1),
column2: None,
distinct: false,
alias: None,
percentile: None,
separator: None,
};
let mut agg =
SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
}
#[test]
fn test_variance_sample() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let agg_expr = AggregateExpr {
function: AggregateFunction::Variance,
column: Some(0),
column2: None,
distinct: false,
alias: None,
percentile: None,
separator: None,
};
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![agg_expr],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
let variance = result.column(0).unwrap().get_float64(0).unwrap();
assert!((variance - 32.0 / 7.0).abs() < 0.01);
}
#[test]
fn test_variance_population() {
let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
let agg_expr = AggregateExpr {
function: AggregateFunction::VariancePop,
column: Some(0),
column2: None,
distinct: false,
alias: None,
percentile: None,
separator: None,
};
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![agg_expr],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
let variance = result.column(0).unwrap().get_float64(0).unwrap();
assert!((variance - 4.0).abs() < 0.01);
}
#[test]
fn test_variance_single_value() {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
builder.column_mut(0).unwrap().push_int64(42);
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let agg_expr = AggregateExpr {
function: AggregateFunction::Variance,
column: Some(0),
column2: None,
distinct: false,
alias: None,
percentile: None,
separator: None,
};
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![agg_expr],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert!(matches!(
result.column(0).unwrap().get_value(0),
Some(Value::Null)
));
}
#[test]
fn test_empty_aggregation() {
let mock = MockOperator::new(vec![]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![
AggregateExpr::count_star(),
AggregateExpr::sum(0),
AggregateExpr::avg(0),
AggregateExpr::min(0),
AggregateExpr::max(0),
],
vec![
LogicalType::Int64,
LogicalType::Int64,
LogicalType::Float64,
LogicalType::Int64,
LogicalType::Int64,
],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
result.column(1).unwrap().get_value(0),
Some(Value::Null)
)); assert!(matches!(
result.column(2).unwrap().get_value(0),
Some(Value::Null)
)); assert!(matches!(
result.column(3).unwrap().get_value(0),
Some(Value::Null)
)); assert!(matches!(
result.column(4).unwrap().get_value(0),
Some(Value::Null)
)); }
#[test]
fn test_stdev_pop_single_value() {
let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
builder.column_mut(0).unwrap().push_int64(42);
builder.advance_row();
let chunk = builder.finish();
let mock = MockOperator::new(vec![chunk]);
let mut agg = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::stdev_pop(0)],
vec![LogicalType::Float64],
);
let result = agg.next().unwrap().unwrap();
assert_eq!(result.row_count(), 1);
let stdev = result.column(0).unwrap().get_float64(0).unwrap();
assert!((stdev - 0.0).abs() < 0.01);
}
#[test]
fn test_hash_aggregate_into_any() {
let mock = MockOperator::new(vec![]);
let op = HashAggregateOperator::new(
Box::new(mock),
vec![0],
vec![AggregateExpr::count_star()],
vec![LogicalType::Int64, LogicalType::Int64],
);
let any = Box::new(op).into_any();
assert!(any.downcast::<HashAggregateOperator>().is_ok());
}
#[test]
fn test_simple_aggregate_into_any() {
let mock = MockOperator::new(vec![]);
let op = SimpleAggregateOperator::new(
Box::new(mock),
vec![AggregateExpr::count_star()],
vec![LogicalType::Int64],
);
let any = Box::new(op).into_any();
assert!(any.downcast::<SimpleAggregateOperator>().is_ok());
}
#[test]
fn test_hash_aggregate_into_parts() {
let mock = MockOperator::new(vec![]);
let op = HashAggregateOperator::new(
Box::new(mock),
vec![0, 2],
vec![AggregateExpr::sum(1), AggregateExpr::count_star()],
vec![LogicalType::Int64, LogicalType::Int64, LogicalType::Int64],
);
let (mut child, group_columns, aggregates) = op.into_parts();
assert_eq!(group_columns, vec![0, 2]);
assert_eq!(aggregates.len(), 2);
assert!(child.next().unwrap().is_none());
}
}