1use indexmap::IndexMap;
8use std::collections::HashSet;
9
10use grafeo_common::types::{LogicalType, Value};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum HashableValue {
15 Null,
16 Bool(bool),
17 Int64(i64),
18 Float64Bits(u64),
19 String(String),
20 Other(String),
21}
22
23impl From<&Value> for HashableValue {
24 fn from(v: &Value) -> Self {
25 match v {
26 Value::Null => HashableValue::Null,
27 Value::Bool(b) => HashableValue::Bool(*b),
28 Value::Int64(i) => HashableValue::Int64(*i),
29 Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
30 Value::String(s) => HashableValue::String(s.to_string()),
31 other => HashableValue::Other(format!("{other:?}")),
32 }
33 }
34}
35
36impl From<Value> for HashableValue {
37 fn from(v: Value) -> Self {
38 Self::from(&v)
39 }
40}
41
42use super::{Operator, OperatorError, OperatorResult};
43use crate::execution::DataChunk;
44use crate::execution::chunk::DataChunkBuilder;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AggregateFunction {
49 Count,
51 CountNonNull,
53 Sum,
55 Avg,
57 Min,
59 Max,
61 First,
63 Last,
65 Collect,
67 StdDev,
69 StdDevPop,
71 PercentileDisc,
73 PercentileCont,
75}
76
77#[derive(Debug, Clone)]
79pub struct AggregateExpr {
80 pub function: AggregateFunction,
82 pub column: Option<usize>,
84 pub distinct: bool,
86 pub alias: Option<String>,
88 pub percentile: Option<f64>,
90}
91
92impl AggregateExpr {
93 pub fn count_star() -> Self {
95 Self {
96 function: AggregateFunction::Count,
97 column: None,
98 distinct: false,
99 alias: None,
100 percentile: None,
101 }
102 }
103
104 pub fn count(column: usize) -> Self {
106 Self {
107 function: AggregateFunction::CountNonNull,
108 column: Some(column),
109 distinct: false,
110 alias: None,
111 percentile: None,
112 }
113 }
114
115 pub fn sum(column: usize) -> Self {
117 Self {
118 function: AggregateFunction::Sum,
119 column: Some(column),
120 distinct: false,
121 alias: None,
122 percentile: None,
123 }
124 }
125
126 pub fn avg(column: usize) -> Self {
128 Self {
129 function: AggregateFunction::Avg,
130 column: Some(column),
131 distinct: false,
132 alias: None,
133 percentile: None,
134 }
135 }
136
137 pub fn min(column: usize) -> Self {
139 Self {
140 function: AggregateFunction::Min,
141 column: Some(column),
142 distinct: false,
143 alias: None,
144 percentile: None,
145 }
146 }
147
148 pub fn max(column: usize) -> Self {
150 Self {
151 function: AggregateFunction::Max,
152 column: Some(column),
153 distinct: false,
154 alias: None,
155 percentile: None,
156 }
157 }
158
159 pub fn first(column: usize) -> Self {
161 Self {
162 function: AggregateFunction::First,
163 column: Some(column),
164 distinct: false,
165 alias: None,
166 percentile: None,
167 }
168 }
169
170 pub fn last(column: usize) -> Self {
172 Self {
173 function: AggregateFunction::Last,
174 column: Some(column),
175 distinct: false,
176 alias: None,
177 percentile: None,
178 }
179 }
180
181 pub fn collect(column: usize) -> Self {
183 Self {
184 function: AggregateFunction::Collect,
185 column: Some(column),
186 distinct: false,
187 alias: None,
188 percentile: None,
189 }
190 }
191
192 pub fn stdev(column: usize) -> Self {
194 Self {
195 function: AggregateFunction::StdDev,
196 column: Some(column),
197 distinct: false,
198 alias: None,
199 percentile: None,
200 }
201 }
202
203 pub fn stdev_pop(column: usize) -> Self {
205 Self {
206 function: AggregateFunction::StdDevPop,
207 column: Some(column),
208 distinct: false,
209 alias: None,
210 percentile: None,
211 }
212 }
213
214 pub fn percentile_disc(column: usize, percentile: f64) -> Self {
220 Self {
221 function: AggregateFunction::PercentileDisc,
222 column: Some(column),
223 distinct: false,
224 alias: None,
225 percentile: Some(percentile.clamp(0.0, 1.0)),
226 }
227 }
228
229 pub fn percentile_cont(column: usize, percentile: f64) -> Self {
235 Self {
236 function: AggregateFunction::PercentileCont,
237 column: Some(column),
238 distinct: false,
239 alias: None,
240 percentile: Some(percentile.clamp(0.0, 1.0)),
241 }
242 }
243
244 pub fn with_distinct(mut self) -> Self {
246 self.distinct = true;
247 self
248 }
249
250 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
252 self.alias = Some(alias.into());
253 self
254 }
255}
256
257#[derive(Debug, Clone)]
259enum AggregateState {
260 Count(i64),
262 CountDistinct(i64, HashSet<HashableValue>),
264 SumInt(i64),
266 SumIntDistinct(i64, HashSet<HashableValue>),
268 SumFloat(f64),
270 SumFloatDistinct(f64, HashSet<HashableValue>),
272 Avg(f64, i64),
274 AvgDistinct(f64, i64, HashSet<HashableValue>),
276 Min(Option<Value>),
278 Max(Option<Value>),
280 First(Option<Value>),
282 Last(Option<Value>),
284 Collect(Vec<Value>),
286 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
288 StdDev { count: i64, mean: f64, m2: f64 },
290 StdDevPop { count: i64, mean: f64, m2: f64 },
292 PercentileDisc { values: Vec<f64>, percentile: f64 },
294 PercentileCont { values: Vec<f64>, percentile: f64 },
296}
297
298impl AggregateState {
299 fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
301 match (function, distinct) {
302 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
303 AggregateState::Count(0)
304 }
305 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
306 AggregateState::CountDistinct(0, HashSet::new())
307 }
308 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
309 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
310 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
311 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
312 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
314 (AggregateFunction::First, _) => AggregateState::First(None),
315 (AggregateFunction::Last, _) => AggregateState::Last(None),
316 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
317 (AggregateFunction::Collect, true) => {
318 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
319 }
320 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
322 count: 0,
323 mean: 0.0,
324 m2: 0.0,
325 },
326 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
327 count: 0,
328 mean: 0.0,
329 m2: 0.0,
330 },
331 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
332 values: Vec::new(),
333 percentile: percentile.unwrap_or(0.5),
334 },
335 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
336 values: Vec::new(),
337 percentile: percentile.unwrap_or(0.5),
338 },
339 }
340 }
341
342 fn update(&mut self, value: Option<Value>) {
344 match self {
345 AggregateState::Count(count) => {
346 *count += 1;
347 }
348 AggregateState::CountDistinct(count, seen) => {
349 if let Some(ref v) = value {
350 let hashable = HashableValue::from(v);
351 if seen.insert(hashable) {
352 *count += 1;
353 }
354 }
355 }
356 AggregateState::SumInt(sum) => {
357 if let Some(Value::Int64(v)) = value {
358 *sum += v;
359 } else if let Some(Value::Float64(v)) = value {
360 *self = AggregateState::SumFloat(*sum as f64 + v);
362 } else if let Some(ref v) = value {
363 if let Some(num) = value_to_f64(v) {
365 *self = AggregateState::SumFloat(*sum as f64 + num);
366 }
367 }
368 }
369 AggregateState::SumIntDistinct(sum, seen) => {
370 if let Some(ref v) = value {
371 let hashable = HashableValue::from(v);
372 if seen.insert(hashable) {
373 if let Value::Int64(i) = v {
374 *sum += i;
375 } else if let Value::Float64(f) = v {
376 let moved_seen = std::mem::take(seen);
378 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
379 } else if let Some(num) = value_to_f64(v) {
380 let moved_seen = std::mem::take(seen);
382 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
383 }
384 }
385 }
386 }
387 AggregateState::SumFloat(sum) => {
388 if let Some(ref v) = value {
389 if let Some(num) = value_to_f64(v) {
391 *sum += num;
392 }
393 }
394 }
395 AggregateState::SumFloatDistinct(sum, seen) => {
396 if let Some(ref v) = value {
397 let hashable = HashableValue::from(v);
398 if seen.insert(hashable)
399 && let Some(num) = value_to_f64(v)
400 {
401 *sum += num;
402 }
403 }
404 }
405 AggregateState::Avg(sum, count) => {
406 if let Some(ref v) = value
407 && let Some(num) = value_to_f64(v)
408 {
409 *sum += num;
410 *count += 1;
411 }
412 }
413 AggregateState::AvgDistinct(sum, count, seen) => {
414 if let Some(ref v) = value {
415 let hashable = HashableValue::from(v);
416 if seen.insert(hashable)
417 && let Some(num) = value_to_f64(v)
418 {
419 *sum += num;
420 *count += 1;
421 }
422 }
423 }
424 AggregateState::Min(min) => {
425 if let Some(v) = value {
426 match min {
427 None => *min = Some(v),
428 Some(current) => {
429 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
430 *min = Some(v);
431 }
432 }
433 }
434 }
435 }
436 AggregateState::Max(max) => {
437 if let Some(v) = value {
438 match max {
439 None => *max = Some(v),
440 Some(current) => {
441 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
442 *max = Some(v);
443 }
444 }
445 }
446 }
447 }
448 AggregateState::First(first) => {
449 if first.is_none() {
450 *first = value;
451 }
452 }
453 AggregateState::Last(last) => {
454 if value.is_some() {
455 *last = value;
456 }
457 }
458 AggregateState::Collect(list) => {
459 if let Some(v) = value {
460 list.push(v);
461 }
462 }
463 AggregateState::CollectDistinct(list, seen) => {
464 if let Some(v) = value {
465 let hashable = HashableValue::from(&v);
466 if seen.insert(hashable) {
467 list.push(v);
468 }
469 }
470 }
471 AggregateState::StdDev { count, mean, m2 }
473 | AggregateState::StdDevPop { count, mean, m2 } => {
474 if let Some(ref v) = value
475 && let Some(x) = value_to_f64(v)
476 {
477 *count += 1;
478 let delta = x - *mean;
479 *mean += delta / *count as f64;
480 let delta2 = x - *mean;
481 *m2 += delta * delta2;
482 }
483 }
484 AggregateState::PercentileDisc { values, .. }
485 | AggregateState::PercentileCont { values, .. } => {
486 if let Some(ref v) = value
487 && let Some(x) = value_to_f64(v)
488 {
489 values.push(x);
490 }
491 }
492 }
493 }
494
495 fn finalize(&self) -> Value {
497 match self {
498 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
499 Value::Int64(*count)
500 }
501 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
502 Value::Int64(*sum)
503 }
504 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
505 Value::Float64(*sum)
506 }
507 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
508 if *count == 0 {
509 Value::Null
510 } else {
511 Value::Float64(*sum / *count as f64)
512 }
513 }
514 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
515 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
516 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
517 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
518 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
519 Value::List(list.clone().into())
520 }
521 AggregateState::StdDev { count, m2, .. } => {
523 if *count < 2 {
524 Value::Null
525 } else {
526 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
527 }
528 }
529 AggregateState::StdDevPop { count, m2, .. } => {
531 if *count == 0 {
532 Value::Null
533 } else {
534 Value::Float64((*m2 / *count as f64).sqrt())
535 }
536 }
537 AggregateState::PercentileDisc { values, percentile } => {
539 if values.is_empty() {
540 Value::Null
541 } else {
542 let mut sorted = values.clone();
543 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
546 Value::Float64(sorted[index])
547 }
548 }
549 AggregateState::PercentileCont { values, percentile } => {
551 if values.is_empty() {
552 Value::Null
553 } else {
554 let mut sorted = values.clone();
555 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556 let rank = percentile * (sorted.len() - 1) as f64;
558 let lower_idx = rank.floor() as usize;
559 let upper_idx = rank.ceil() as usize;
560 if lower_idx == upper_idx {
561 Value::Float64(sorted[lower_idx])
562 } else {
563 let fraction = rank - lower_idx as f64;
564 let result =
565 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
566 Value::Float64(result)
567 }
568 }
569 }
570 }
571 }
572}
573
574fn value_to_f64(value: &Value) -> Option<f64> {
577 match value {
578 Value::Int64(i) => Some(*i as f64),
579 Value::Float64(f) => Some(*f),
580 Value::String(s) => s.parse::<f64>().ok(),
582 _ => None,
583 }
584}
585
586fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
589 match (a, b) {
590 (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
591 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
592 (Value::String(a), Value::String(b)) => {
593 if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
595 a_num.partial_cmp(&b_num)
596 } else {
597 Some(a.cmp(b))
598 }
599 }
600 (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
601 (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
602 (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
603 (Value::String(s), Value::Int64(i)) => s.parse::<f64>().ok()?.partial_cmp(&(*i as f64)),
605 (Value::String(s), Value::Float64(f)) => s.parse::<f64>().ok()?.partial_cmp(f),
606 (Value::Int64(i), Value::String(s)) => (*i as f64).partial_cmp(&s.parse::<f64>().ok()?),
607 (Value::Float64(f), Value::String(s)) => f.partial_cmp(&s.parse::<f64>().ok()?),
608 _ => None,
609 }
610}
611
612#[derive(Debug, Clone, PartialEq, Eq, Hash)]
614pub struct GroupKey(Vec<GroupKeyPart>);
615
616#[derive(Debug, Clone, PartialEq, Eq, Hash)]
617enum GroupKeyPart {
618 Null,
619 Bool(bool),
620 Int64(i64),
621 String(String),
622}
623
624impl GroupKey {
625 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
627 let parts: Vec<GroupKeyPart> = group_columns
628 .iter()
629 .map(|&col_idx| {
630 chunk
631 .column(col_idx)
632 .and_then(|col| col.get_value(row))
633 .map_or(GroupKeyPart::Null, |v| match v {
634 Value::Null => GroupKeyPart::Null,
635 Value::Bool(b) => GroupKeyPart::Bool(b),
636 Value::Int64(i) => GroupKeyPart::Int64(i),
637 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
638 Value::String(s) => GroupKeyPart::String(s.to_string()),
639 _ => GroupKeyPart::String(format!("{v:?}")),
640 })
641 })
642 .collect();
643 GroupKey(parts)
644 }
645
646 fn to_values(&self) -> Vec<Value> {
648 self.0
649 .iter()
650 .map(|part| match part {
651 GroupKeyPart::Null => Value::Null,
652 GroupKeyPart::Bool(b) => Value::Bool(*b),
653 GroupKeyPart::Int64(i) => Value::Int64(*i),
654 GroupKeyPart::String(s) => Value::String(s.clone().into()),
655 })
656 .collect()
657 }
658}
659
660pub struct HashAggregateOperator {
664 child: Box<dyn Operator>,
666 group_columns: Vec<usize>,
668 aggregates: Vec<AggregateExpr>,
670 output_schema: Vec<LogicalType>,
672 groups: IndexMap<GroupKey, Vec<AggregateState>>,
674 aggregation_complete: bool,
676 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
678}
679
680impl HashAggregateOperator {
681 pub fn new(
689 child: Box<dyn Operator>,
690 group_columns: Vec<usize>,
691 aggregates: Vec<AggregateExpr>,
692 output_schema: Vec<LogicalType>,
693 ) -> Self {
694 Self {
695 child,
696 group_columns,
697 aggregates,
698 output_schema,
699 groups: IndexMap::new(),
700 aggregation_complete: false,
701 results: None,
702 }
703 }
704
705 fn aggregate(&mut self) -> Result<(), OperatorError> {
707 while let Some(chunk) = self.child.next()? {
708 for row in chunk.selected_indices() {
709 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
710
711 let states = self.groups.entry(key).or_insert_with(|| {
713 self.aggregates
714 .iter()
715 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
716 .collect()
717 });
718
719 for (i, agg) in self.aggregates.iter().enumerate() {
721 let value = match (agg.function, agg.distinct) {
722 (AggregateFunction::Count, false) => None,
724 (AggregateFunction::Count, true) => agg
726 .column
727 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
728 _ => agg
729 .column
730 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
731 };
732
733 match (agg.function, agg.distinct) {
735 (AggregateFunction::Count, false) => states[i].update(None),
736 (AggregateFunction::Count, true) => {
737 if value.is_some() && !matches!(value, Some(Value::Null)) {
739 states[i].update(value);
740 }
741 }
742 (AggregateFunction::CountNonNull, _) => {
743 if value.is_some() && !matches!(value, Some(Value::Null)) {
744 states[i].update(value);
745 }
746 }
747 _ => {
748 if value.is_some() && !matches!(value, Some(Value::Null)) {
749 states[i].update(value);
750 }
751 }
752 }
753 }
754 }
755 }
756
757 self.aggregation_complete = true;
758
759 let results: Vec<_> = self.groups.drain(..).collect();
761 self.results = Some(results.into_iter());
762
763 Ok(())
764 }
765}
766
767impl Operator for HashAggregateOperator {
768 fn next(&mut self) -> OperatorResult {
769 if !self.aggregation_complete {
771 self.aggregate()?;
772 }
773
774 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
776 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
778
779 for agg in &self.aggregates {
780 let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
781 let value = state.finalize();
782 if let Some(col) = builder.column_mut(self.group_columns.len()) {
783 col.push_value(value);
784 }
785 }
786 builder.advance_row();
787
788 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
790 }
791
792 let Some(results) = &mut self.results else {
793 return Ok(None);
794 };
795
796 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
797
798 for (key, states) in results.by_ref() {
799 let key_values = key.to_values();
801 for (i, value) in key_values.into_iter().enumerate() {
802 if let Some(col) = builder.column_mut(i) {
803 col.push_value(value);
804 }
805 }
806
807 for (i, state) in states.iter().enumerate() {
809 let col_idx = self.group_columns.len() + i;
810 if let Some(col) = builder.column_mut(col_idx) {
811 col.push_value(state.finalize());
812 }
813 }
814
815 builder.advance_row();
816
817 if builder.is_full() {
818 return Ok(Some(builder.finish()));
819 }
820 }
821
822 if builder.row_count() > 0 {
823 Ok(Some(builder.finish()))
824 } else {
825 Ok(None)
826 }
827 }
828
829 fn reset(&mut self) {
830 self.child.reset();
831 self.groups.clear();
832 self.aggregation_complete = false;
833 self.results = None;
834 }
835
836 fn name(&self) -> &'static str {
837 "HashAggregate"
838 }
839}
840
841pub struct SimpleAggregateOperator {
845 child: Box<dyn Operator>,
847 aggregates: Vec<AggregateExpr>,
849 output_schema: Vec<LogicalType>,
851 states: Vec<AggregateState>,
853 done: bool,
855}
856
857impl SimpleAggregateOperator {
858 pub fn new(
860 child: Box<dyn Operator>,
861 aggregates: Vec<AggregateExpr>,
862 output_schema: Vec<LogicalType>,
863 ) -> Self {
864 let states = aggregates
865 .iter()
866 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
867 .collect();
868
869 Self {
870 child,
871 aggregates,
872 output_schema,
873 states,
874 done: false,
875 }
876 }
877}
878
879impl Operator for SimpleAggregateOperator {
880 fn next(&mut self) -> OperatorResult {
881 if self.done {
882 return Ok(None);
883 }
884
885 while let Some(chunk) = self.child.next()? {
887 for row in chunk.selected_indices() {
888 for (i, agg) in self.aggregates.iter().enumerate() {
889 let value = match (agg.function, agg.distinct) {
890 (AggregateFunction::Count, false) => None,
892 (AggregateFunction::Count, true) => agg
894 .column
895 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
896 _ => agg
897 .column
898 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
899 };
900
901 match (agg.function, agg.distinct) {
902 (AggregateFunction::Count, false) => self.states[i].update(None),
903 (AggregateFunction::Count, true) => {
904 if value.is_some() && !matches!(value, Some(Value::Null)) {
906 self.states[i].update(value);
907 }
908 }
909 (AggregateFunction::CountNonNull, _) => {
910 if value.is_some() && !matches!(value, Some(Value::Null)) {
911 self.states[i].update(value);
912 }
913 }
914 _ => {
915 if value.is_some() && !matches!(value, Some(Value::Null)) {
916 self.states[i].update(value);
917 }
918 }
919 }
920 }
921 }
922 }
923
924 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
926
927 for (i, state) in self.states.iter().enumerate() {
928 if let Some(col) = builder.column_mut(i) {
929 col.push_value(state.finalize());
930 }
931 }
932 builder.advance_row();
933
934 self.done = true;
935 Ok(Some(builder.finish()))
936 }
937
938 fn reset(&mut self) {
939 self.child.reset();
940 self.states = self
941 .aggregates
942 .iter()
943 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
944 .collect();
945 self.done = false;
946 }
947
948 fn name(&self) -> &'static str {
949 "SimpleAggregate"
950 }
951}
952
953#[cfg(test)]
954mod tests {
955 use super::*;
956 use crate::execution::chunk::DataChunkBuilder;
957
958 struct MockOperator {
959 chunks: Vec<DataChunk>,
960 position: usize,
961 }
962
963 impl MockOperator {
964 fn new(chunks: Vec<DataChunk>) -> Self {
965 Self {
966 chunks,
967 position: 0,
968 }
969 }
970 }
971
972 impl Operator for MockOperator {
973 fn next(&mut self) -> OperatorResult {
974 if self.position < self.chunks.len() {
975 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
976 self.position += 1;
977 Ok(Some(chunk))
978 } else {
979 Ok(None)
980 }
981 }
982
983 fn reset(&mut self) {
984 self.position = 0;
985 }
986
987 fn name(&self) -> &'static str {
988 "Mock"
989 }
990 }
991
992 fn create_test_chunk() -> DataChunk {
993 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
995
996 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
997 for (group, value) in data {
998 builder.column_mut(0).unwrap().push_int64(group);
999 builder.column_mut(1).unwrap().push_int64(value);
1000 builder.advance_row();
1001 }
1002
1003 builder.finish()
1004 }
1005
1006 #[test]
1007 fn test_simple_count() {
1008 let mock = MockOperator::new(vec![create_test_chunk()]);
1009
1010 let mut agg = SimpleAggregateOperator::new(
1011 Box::new(mock),
1012 vec![AggregateExpr::count_star()],
1013 vec![LogicalType::Int64],
1014 );
1015
1016 let result = agg.next().unwrap().unwrap();
1017 assert_eq!(result.row_count(), 1);
1018 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1019
1020 assert!(agg.next().unwrap().is_none());
1022 }
1023
1024 #[test]
1025 fn test_simple_sum() {
1026 let mock = MockOperator::new(vec![create_test_chunk()]);
1027
1028 let mut agg = SimpleAggregateOperator::new(
1029 Box::new(mock),
1030 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1032 );
1033
1034 let result = agg.next().unwrap().unwrap();
1035 assert_eq!(result.row_count(), 1);
1036 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1038 }
1039
1040 #[test]
1041 fn test_simple_avg() {
1042 let mock = MockOperator::new(vec![create_test_chunk()]);
1043
1044 let mut agg = SimpleAggregateOperator::new(
1045 Box::new(mock),
1046 vec![AggregateExpr::avg(1)],
1047 vec![LogicalType::Float64],
1048 );
1049
1050 let result = agg.next().unwrap().unwrap();
1051 assert_eq!(result.row_count(), 1);
1052 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1054 assert!((avg - 30.0).abs() < 0.001);
1055 }
1056
1057 #[test]
1058 fn test_simple_min_max() {
1059 let mock = MockOperator::new(vec![create_test_chunk()]);
1060
1061 let mut agg = SimpleAggregateOperator::new(
1062 Box::new(mock),
1063 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1064 vec![LogicalType::Int64, LogicalType::Int64],
1065 );
1066
1067 let result = agg.next().unwrap().unwrap();
1068 assert_eq!(result.row_count(), 1);
1069 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1072
1073 #[test]
1074 fn test_sum_with_string_values() {
1075 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1077 builder.column_mut(0).unwrap().push_string("30");
1078 builder.advance_row();
1079 builder.column_mut(0).unwrap().push_string("25");
1080 builder.advance_row();
1081 builder.column_mut(0).unwrap().push_string("35");
1082 builder.advance_row();
1083 let chunk = builder.finish();
1084
1085 let mock = MockOperator::new(vec![chunk]);
1086 let mut agg = SimpleAggregateOperator::new(
1087 Box::new(mock),
1088 vec![AggregateExpr::sum(0)],
1089 vec![LogicalType::Float64],
1090 );
1091
1092 let result = agg.next().unwrap().unwrap();
1093 assert_eq!(result.row_count(), 1);
1094 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1096 assert!(
1097 (sum_val - 90.0).abs() < 0.001,
1098 "Expected 90.0, got {}",
1099 sum_val
1100 );
1101 }
1102
1103 #[test]
1104 fn test_grouped_aggregation() {
1105 let mock = MockOperator::new(vec![create_test_chunk()]);
1106
1107 let mut agg = HashAggregateOperator::new(
1109 Box::new(mock),
1110 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1113 );
1114
1115 let mut results: Vec<(i64, i64)> = Vec::new();
1116 while let Some(chunk) = agg.next().unwrap() {
1117 for row in chunk.selected_indices() {
1118 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1119 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1120 results.push((group, sum));
1121 }
1122 }
1123
1124 results.sort_by_key(|(g, _)| *g);
1125 assert_eq!(results.len(), 2);
1126 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1129
1130 #[test]
1131 fn test_grouped_count() {
1132 let mock = MockOperator::new(vec![create_test_chunk()]);
1133
1134 let mut agg = HashAggregateOperator::new(
1136 Box::new(mock),
1137 vec![0],
1138 vec![AggregateExpr::count_star()],
1139 vec![LogicalType::Int64, LogicalType::Int64],
1140 );
1141
1142 let mut results: Vec<(i64, i64)> = Vec::new();
1143 while let Some(chunk) = agg.next().unwrap() {
1144 for row in chunk.selected_indices() {
1145 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1146 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1147 results.push((group, count));
1148 }
1149 }
1150
1151 results.sort_by_key(|(g, _)| *g);
1152 assert_eq!(results.len(), 2);
1153 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1156
1157 #[test]
1158 fn test_multiple_aggregates() {
1159 let mock = MockOperator::new(vec![create_test_chunk()]);
1160
1161 let mut agg = HashAggregateOperator::new(
1163 Box::new(mock),
1164 vec![0],
1165 vec![
1166 AggregateExpr::count_star(),
1167 AggregateExpr::sum(1),
1168 AggregateExpr::avg(1),
1169 ],
1170 vec![
1171 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1176 );
1177
1178 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1179 while let Some(chunk) = agg.next().unwrap() {
1180 for row in chunk.selected_indices() {
1181 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1182 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1183 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1184 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1185 results.push((group, count, sum, avg));
1186 }
1187 }
1188
1189 results.sort_by_key(|(g, _, _, _)| *g);
1190 assert_eq!(results.len(), 2);
1191
1192 assert_eq!(results[0].0, 1);
1194 assert_eq!(results[0].1, 2);
1195 assert_eq!(results[0].2, 30);
1196 assert!((results[0].3 - 15.0).abs() < 0.001);
1197
1198 assert_eq!(results[1].0, 2);
1200 assert_eq!(results[1].1, 3);
1201 assert_eq!(results[1].2, 120);
1202 assert!((results[1].3 - 40.0).abs() < 0.001);
1203 }
1204
1205 fn create_test_chunk_with_duplicates() -> DataChunk {
1206 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1211
1212 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1213 for (group, value) in data {
1214 builder.column_mut(0).unwrap().push_int64(group);
1215 builder.column_mut(1).unwrap().push_int64(value);
1216 builder.advance_row();
1217 }
1218
1219 builder.finish()
1220 }
1221
1222 #[test]
1223 fn test_count_distinct() {
1224 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1225
1226 let mut agg = SimpleAggregateOperator::new(
1228 Box::new(mock),
1229 vec![AggregateExpr::count(1).with_distinct()],
1230 vec![LogicalType::Int64],
1231 );
1232
1233 let result = agg.next().unwrap().unwrap();
1234 assert_eq!(result.row_count(), 1);
1235 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1237 }
1238
1239 #[test]
1240 fn test_grouped_count_distinct() {
1241 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1242
1243 let mut agg = HashAggregateOperator::new(
1245 Box::new(mock),
1246 vec![0],
1247 vec![AggregateExpr::count(1).with_distinct()],
1248 vec![LogicalType::Int64, LogicalType::Int64],
1249 );
1250
1251 let mut results: Vec<(i64, i64)> = Vec::new();
1252 while let Some(chunk) = agg.next().unwrap() {
1253 for row in chunk.selected_indices() {
1254 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1255 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1256 results.push((group, count));
1257 }
1258 }
1259
1260 results.sort_by_key(|(g, _)| *g);
1261 assert_eq!(results.len(), 2);
1262 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1265
1266 #[test]
1267 fn test_sum_distinct() {
1268 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1269
1270 let mut agg = SimpleAggregateOperator::new(
1272 Box::new(mock),
1273 vec![AggregateExpr::sum(1).with_distinct()],
1274 vec![LogicalType::Int64],
1275 );
1276
1277 let result = agg.next().unwrap().unwrap();
1278 assert_eq!(result.row_count(), 1);
1279 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1281 }
1282
1283 #[test]
1284 fn test_avg_distinct() {
1285 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1286
1287 let mut agg = SimpleAggregateOperator::new(
1289 Box::new(mock),
1290 vec![AggregateExpr::avg(1).with_distinct()],
1291 vec![LogicalType::Float64],
1292 );
1293
1294 let result = agg.next().unwrap().unwrap();
1295 assert_eq!(result.row_count(), 1);
1296 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1298 assert!((avg - 20.0).abs() < 0.001);
1299 }
1300
1301 fn create_statistical_test_chunk() -> DataChunk {
1302 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1305
1306 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1307 builder.column_mut(0).unwrap().push_int64(value);
1308 builder.advance_row();
1309 }
1310
1311 builder.finish()
1312 }
1313
1314 #[test]
1315 fn test_stdev_sample() {
1316 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1317
1318 let mut agg = SimpleAggregateOperator::new(
1319 Box::new(mock),
1320 vec![AggregateExpr::stdev(0)],
1321 vec![LogicalType::Float64],
1322 );
1323
1324 let result = agg.next().unwrap().unwrap();
1325 assert_eq!(result.row_count(), 1);
1326 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1329 assert!((stdev - 2.138).abs() < 0.01);
1330 }
1331
1332 #[test]
1333 fn test_stdev_population() {
1334 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1335
1336 let mut agg = SimpleAggregateOperator::new(
1337 Box::new(mock),
1338 vec![AggregateExpr::stdev_pop(0)],
1339 vec![LogicalType::Float64],
1340 );
1341
1342 let result = agg.next().unwrap().unwrap();
1343 assert_eq!(result.row_count(), 1);
1344 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1347 assert!((stdev - 2.0).abs() < 0.01);
1348 }
1349
1350 #[test]
1351 fn test_percentile_disc() {
1352 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1353
1354 let mut agg = SimpleAggregateOperator::new(
1356 Box::new(mock),
1357 vec![AggregateExpr::percentile_disc(0, 0.5)],
1358 vec![LogicalType::Float64],
1359 );
1360
1361 let result = agg.next().unwrap().unwrap();
1362 assert_eq!(result.row_count(), 1);
1363 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1365 assert!((percentile - 4.0).abs() < 0.01);
1366 }
1367
1368 #[test]
1369 fn test_percentile_cont() {
1370 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1371
1372 let mut agg = SimpleAggregateOperator::new(
1374 Box::new(mock),
1375 vec![AggregateExpr::percentile_cont(0, 0.5)],
1376 vec![LogicalType::Float64],
1377 );
1378
1379 let result = agg.next().unwrap().unwrap();
1380 assert_eq!(result.row_count(), 1);
1381 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1384 assert!((percentile - 4.5).abs() < 0.01);
1385 }
1386
1387 #[test]
1388 fn test_percentile_extremes() {
1389 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1391
1392 let mut agg = SimpleAggregateOperator::new(
1393 Box::new(mock),
1394 vec![
1395 AggregateExpr::percentile_disc(0, 0.0),
1396 AggregateExpr::percentile_disc(0, 1.0),
1397 ],
1398 vec![LogicalType::Float64, LogicalType::Float64],
1399 );
1400
1401 let result = agg.next().unwrap().unwrap();
1402 assert_eq!(result.row_count(), 1);
1403 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1405 assert!((p0 - 2.0).abs() < 0.01);
1406 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1408 assert!((p100 - 9.0).abs() < 0.01);
1409 }
1410
1411 #[test]
1412 fn test_stdev_single_value() {
1413 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1415 builder.column_mut(0).unwrap().push_int64(42);
1416 builder.advance_row();
1417 let chunk = builder.finish();
1418
1419 let mock = MockOperator::new(vec![chunk]);
1420
1421 let mut agg = SimpleAggregateOperator::new(
1422 Box::new(mock),
1423 vec![AggregateExpr::stdev(0)],
1424 vec![LogicalType::Float64],
1425 );
1426
1427 let result = agg.next().unwrap().unwrap();
1428 assert_eq!(result.row_count(), 1);
1429 assert!(matches!(
1431 result.column(0).unwrap().get_value(0),
1432 Some(Value::Null)
1433 ));
1434 }
1435
1436 #[test]
1437 fn test_stdev_pop_single_value() {
1438 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1440 builder.column_mut(0).unwrap().push_int64(42);
1441 builder.advance_row();
1442 let chunk = builder.finish();
1443
1444 let mock = MockOperator::new(vec![chunk]);
1445
1446 let mut agg = SimpleAggregateOperator::new(
1447 Box::new(mock),
1448 vec![AggregateExpr::stdev_pop(0)],
1449 vec![LogicalType::Float64],
1450 );
1451
1452 let result = agg.next().unwrap().unwrap();
1453 assert_eq!(result.row_count(), 1);
1454 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1456 assert!((stdev - 0.0).abs() < 0.01);
1457 }
1458}