1use indexmap::IndexMap;
8use std::collections::HashSet;
9
10use grafeo_common::types::{LogicalType, Value};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum HashableValue {
15 Null,
16 Bool(bool),
17 Int64(i64),
18 Float64Bits(u64),
19 String(String),
20 Other(String),
21}
22
23impl From<&Value> for HashableValue {
24 fn from(v: &Value) -> Self {
25 match v {
26 Value::Null => HashableValue::Null,
27 Value::Bool(b) => HashableValue::Bool(*b),
28 Value::Int64(i) => HashableValue::Int64(*i),
29 Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
30 Value::String(s) => HashableValue::String(s.to_string()),
31 other => HashableValue::Other(format!("{other:?}")),
32 }
33 }
34}
35
36impl From<Value> for HashableValue {
37 fn from(v: Value) -> Self {
38 Self::from(&v)
39 }
40}
41
42use super::{Operator, OperatorError, OperatorResult};
43use crate::execution::DataChunk;
44use crate::execution::chunk::DataChunkBuilder;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AggregateFunction {
49 Count,
51 CountNonNull,
53 Sum,
55 Avg,
57 Min,
59 Max,
61 First,
63 Last,
65 Collect,
67 StdDev,
69 StdDevPop,
71 PercentileDisc,
73 PercentileCont,
75}
76
77#[derive(Debug, Clone)]
79pub struct AggregateExpr {
80 pub function: AggregateFunction,
82 pub column: Option<usize>,
84 pub distinct: bool,
86 pub alias: Option<String>,
88 pub percentile: Option<f64>,
90}
91
92impl AggregateExpr {
93 pub fn count_star() -> Self {
95 Self {
96 function: AggregateFunction::Count,
97 column: None,
98 distinct: false,
99 alias: None,
100 percentile: None,
101 }
102 }
103
104 pub fn count(column: usize) -> Self {
106 Self {
107 function: AggregateFunction::CountNonNull,
108 column: Some(column),
109 distinct: false,
110 alias: None,
111 percentile: None,
112 }
113 }
114
115 pub fn sum(column: usize) -> Self {
117 Self {
118 function: AggregateFunction::Sum,
119 column: Some(column),
120 distinct: false,
121 alias: None,
122 percentile: None,
123 }
124 }
125
126 pub fn avg(column: usize) -> Self {
128 Self {
129 function: AggregateFunction::Avg,
130 column: Some(column),
131 distinct: false,
132 alias: None,
133 percentile: None,
134 }
135 }
136
137 pub fn min(column: usize) -> Self {
139 Self {
140 function: AggregateFunction::Min,
141 column: Some(column),
142 distinct: false,
143 alias: None,
144 percentile: None,
145 }
146 }
147
148 pub fn max(column: usize) -> Self {
150 Self {
151 function: AggregateFunction::Max,
152 column: Some(column),
153 distinct: false,
154 alias: None,
155 percentile: None,
156 }
157 }
158
159 pub fn first(column: usize) -> Self {
161 Self {
162 function: AggregateFunction::First,
163 column: Some(column),
164 distinct: false,
165 alias: None,
166 percentile: None,
167 }
168 }
169
170 pub fn last(column: usize) -> Self {
172 Self {
173 function: AggregateFunction::Last,
174 column: Some(column),
175 distinct: false,
176 alias: None,
177 percentile: None,
178 }
179 }
180
181 pub fn collect(column: usize) -> Self {
183 Self {
184 function: AggregateFunction::Collect,
185 column: Some(column),
186 distinct: false,
187 alias: None,
188 percentile: None,
189 }
190 }
191
192 pub fn stdev(column: usize) -> Self {
194 Self {
195 function: AggregateFunction::StdDev,
196 column: Some(column),
197 distinct: false,
198 alias: None,
199 percentile: None,
200 }
201 }
202
203 pub fn stdev_pop(column: usize) -> Self {
205 Self {
206 function: AggregateFunction::StdDevPop,
207 column: Some(column),
208 distinct: false,
209 alias: None,
210 percentile: None,
211 }
212 }
213
214 pub fn percentile_disc(column: usize, percentile: f64) -> Self {
220 Self {
221 function: AggregateFunction::PercentileDisc,
222 column: Some(column),
223 distinct: false,
224 alias: None,
225 percentile: Some(percentile.clamp(0.0, 1.0)),
226 }
227 }
228
229 pub fn percentile_cont(column: usize, percentile: f64) -> Self {
235 Self {
236 function: AggregateFunction::PercentileCont,
237 column: Some(column),
238 distinct: false,
239 alias: None,
240 percentile: Some(percentile.clamp(0.0, 1.0)),
241 }
242 }
243
244 pub fn with_distinct(mut self) -> Self {
246 self.distinct = true;
247 self
248 }
249
250 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
252 self.alias = Some(alias.into());
253 self
254 }
255}
256
257#[derive(Debug, Clone)]
259enum AggregateState {
260 Count(i64),
262 CountDistinct(i64, HashSet<HashableValue>),
264 SumInt(i64),
266 SumIntDistinct(i64, HashSet<HashableValue>),
268 SumFloat(f64),
270 SumFloatDistinct(f64, HashSet<HashableValue>),
272 Avg(f64, i64),
274 AvgDistinct(f64, i64, HashSet<HashableValue>),
276 Min(Option<Value>),
278 Max(Option<Value>),
280 First(Option<Value>),
282 Last(Option<Value>),
284 Collect(Vec<Value>),
286 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
288 StdDev { count: i64, mean: f64, m2: f64 },
290 StdDevPop { count: i64, mean: f64, m2: f64 },
292 PercentileDisc { values: Vec<f64>, percentile: f64 },
294 PercentileCont { values: Vec<f64>, percentile: f64 },
296}
297
298impl AggregateState {
299 fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
301 match (function, distinct) {
302 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
303 AggregateState::Count(0)
304 }
305 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
306 AggregateState::CountDistinct(0, HashSet::new())
307 }
308 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
309 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
310 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
311 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
312 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
314 (AggregateFunction::First, _) => AggregateState::First(None),
315 (AggregateFunction::Last, _) => AggregateState::Last(None),
316 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
317 (AggregateFunction::Collect, true) => {
318 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
319 }
320 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
322 count: 0,
323 mean: 0.0,
324 m2: 0.0,
325 },
326 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
327 count: 0,
328 mean: 0.0,
329 m2: 0.0,
330 },
331 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
332 values: Vec::new(),
333 percentile: percentile.unwrap_or(0.5),
334 },
335 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
336 values: Vec::new(),
337 percentile: percentile.unwrap_or(0.5),
338 },
339 }
340 }
341
342 fn update(&mut self, value: Option<Value>) {
344 match self {
345 AggregateState::Count(count) => {
346 *count += 1;
347 }
348 AggregateState::CountDistinct(count, seen) => {
349 if let Some(ref v) = value {
350 let hashable = HashableValue::from(v);
351 if seen.insert(hashable) {
352 *count += 1;
353 }
354 }
355 }
356 AggregateState::SumInt(sum) => {
357 if let Some(Value::Int64(v)) = value {
358 *sum += v;
359 } else if let Some(Value::Float64(v)) = value {
360 *self = AggregateState::SumFloat(*sum as f64 + v);
362 } else if let Some(ref v) = value {
363 if let Some(num) = value_to_f64(v) {
365 *self = AggregateState::SumFloat(*sum as f64 + num);
366 }
367 }
368 }
369 AggregateState::SumIntDistinct(sum, seen) => {
370 if let Some(ref v) = value {
371 let hashable = HashableValue::from(v);
372 if seen.insert(hashable) {
373 if let Value::Int64(i) = v {
374 *sum += i;
375 } else if let Value::Float64(f) = v {
376 let seen_clone = seen.clone();
378 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, seen_clone);
379 } else if let Some(num) = value_to_f64(v) {
380 let seen_clone = seen.clone();
382 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, seen_clone);
383 }
384 }
385 }
386 }
387 AggregateState::SumFloat(sum) => {
388 if let Some(ref v) = value {
389 if let Some(num) = value_to_f64(v) {
391 *sum += num;
392 }
393 }
394 }
395 AggregateState::SumFloatDistinct(sum, seen) => {
396 if let Some(ref v) = value {
397 let hashable = HashableValue::from(v);
398 if seen.insert(hashable) {
399 if let Some(num) = value_to_f64(v) {
400 *sum += num;
401 }
402 }
403 }
404 }
405 AggregateState::Avg(sum, count) => {
406 if let Some(ref v) = value {
407 if let Some(num) = value_to_f64(v) {
408 *sum += num;
409 *count += 1;
410 }
411 }
412 }
413 AggregateState::AvgDistinct(sum, count, seen) => {
414 if let Some(ref v) = value {
415 let hashable = HashableValue::from(v);
416 if seen.insert(hashable) {
417 if let Some(num) = value_to_f64(v) {
418 *sum += num;
419 *count += 1;
420 }
421 }
422 }
423 }
424 AggregateState::Min(min) => {
425 if let Some(v) = value {
426 match min {
427 None => *min = Some(v),
428 Some(current) => {
429 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
430 *min = Some(v);
431 }
432 }
433 }
434 }
435 }
436 AggregateState::Max(max) => {
437 if let Some(v) = value {
438 match max {
439 None => *max = Some(v),
440 Some(current) => {
441 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
442 *max = Some(v);
443 }
444 }
445 }
446 }
447 }
448 AggregateState::First(first) => {
449 if first.is_none() {
450 *first = value;
451 }
452 }
453 AggregateState::Last(last) => {
454 if value.is_some() {
455 *last = value;
456 }
457 }
458 AggregateState::Collect(list) => {
459 if let Some(v) = value {
460 list.push(v);
461 }
462 }
463 AggregateState::CollectDistinct(list, seen) => {
464 if let Some(v) = value {
465 let hashable = HashableValue::from(&v);
466 if seen.insert(hashable) {
467 list.push(v);
468 }
469 }
470 }
471 AggregateState::StdDev { count, mean, m2 }
473 | AggregateState::StdDevPop { count, mean, m2 } => {
474 if let Some(ref v) = value {
475 if let Some(x) = value_to_f64(v) {
476 *count += 1;
477 let delta = x - *mean;
478 *mean += delta / *count as f64;
479 let delta2 = x - *mean;
480 *m2 += delta * delta2;
481 }
482 }
483 }
484 AggregateState::PercentileDisc { values, .. }
485 | AggregateState::PercentileCont { values, .. } => {
486 if let Some(ref v) = value {
487 if let Some(x) = value_to_f64(v) {
488 values.push(x);
489 }
490 }
491 }
492 }
493 }
494
495 fn finalize(&self) -> Value {
497 match self {
498 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
499 Value::Int64(*count)
500 }
501 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
502 Value::Int64(*sum)
503 }
504 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
505 Value::Float64(*sum)
506 }
507 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
508 if *count == 0 {
509 Value::Null
510 } else {
511 Value::Float64(*sum / *count as f64)
512 }
513 }
514 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
515 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
516 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
517 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
518 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
519 Value::List(list.clone().into())
520 }
521 AggregateState::StdDev { count, m2, .. } => {
523 if *count < 2 {
524 Value::Null
525 } else {
526 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
527 }
528 }
529 AggregateState::StdDevPop { count, m2, .. } => {
531 if *count == 0 {
532 Value::Null
533 } else {
534 Value::Float64((*m2 / *count as f64).sqrt())
535 }
536 }
537 AggregateState::PercentileDisc { values, percentile } => {
539 if values.is_empty() {
540 Value::Null
541 } else {
542 let mut sorted = values.clone();
543 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
546 Value::Float64(sorted[index])
547 }
548 }
549 AggregateState::PercentileCont { values, percentile } => {
551 if values.is_empty() {
552 Value::Null
553 } else {
554 let mut sorted = values.clone();
555 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556 let rank = percentile * (sorted.len() - 1) as f64;
558 let lower_idx = rank.floor() as usize;
559 let upper_idx = rank.ceil() as usize;
560 if lower_idx == upper_idx {
561 Value::Float64(sorted[lower_idx])
562 } else {
563 let fraction = rank - lower_idx as f64;
564 let result =
565 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
566 Value::Float64(result)
567 }
568 }
569 }
570 }
571 }
572}
573
574fn value_to_f64(value: &Value) -> Option<f64> {
577 match value {
578 Value::Int64(i) => Some(*i as f64),
579 Value::Float64(f) => Some(*f),
580 Value::String(s) => s.parse::<f64>().ok(),
582 _ => None,
583 }
584}
585
586fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
589 match (a, b) {
590 (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
591 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
592 (Value::String(a), Value::String(b)) => {
593 if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
595 a_num.partial_cmp(&b_num)
596 } else {
597 Some(a.cmp(b))
598 }
599 }
600 (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
601 (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
602 (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
603 (Value::String(s), Value::Int64(i)) => s.parse::<f64>().ok()?.partial_cmp(&(*i as f64)),
605 (Value::String(s), Value::Float64(f)) => s.parse::<f64>().ok()?.partial_cmp(f),
606 (Value::Int64(i), Value::String(s)) => (*i as f64).partial_cmp(&s.parse::<f64>().ok()?),
607 (Value::Float64(f), Value::String(s)) => f.partial_cmp(&s.parse::<f64>().ok()?),
608 _ => None,
609 }
610}
611
612#[derive(Debug, Clone, PartialEq, Eq, Hash)]
614pub struct GroupKey(Vec<GroupKeyPart>);
615
616#[derive(Debug, Clone, PartialEq, Eq, Hash)]
617enum GroupKeyPart {
618 Null,
619 Bool(bool),
620 Int64(i64),
621 String(String),
622}
623
624impl GroupKey {
625 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
627 let parts: Vec<GroupKeyPart> = group_columns
628 .iter()
629 .map(|&col_idx| {
630 chunk
631 .column(col_idx)
632 .and_then(|col| col.get_value(row))
633 .map(|v| match v {
634 Value::Null => GroupKeyPart::Null,
635 Value::Bool(b) => GroupKeyPart::Bool(b),
636 Value::Int64(i) => GroupKeyPart::Int64(i),
637 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
638 Value::String(s) => GroupKeyPart::String(s.to_string()),
639 _ => GroupKeyPart::String(format!("{v:?}")),
640 })
641 .unwrap_or(GroupKeyPart::Null)
642 })
643 .collect();
644 GroupKey(parts)
645 }
646
647 fn to_values(&self) -> Vec<Value> {
649 self.0
650 .iter()
651 .map(|part| match part {
652 GroupKeyPart::Null => Value::Null,
653 GroupKeyPart::Bool(b) => Value::Bool(*b),
654 GroupKeyPart::Int64(i) => Value::Int64(*i),
655 GroupKeyPart::String(s) => Value::String(s.clone().into()),
656 })
657 .collect()
658 }
659}
660
661pub struct HashAggregateOperator {
665 child: Box<dyn Operator>,
667 group_columns: Vec<usize>,
669 aggregates: Vec<AggregateExpr>,
671 output_schema: Vec<LogicalType>,
673 groups: IndexMap<GroupKey, Vec<AggregateState>>,
675 aggregation_complete: bool,
677 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
679}
680
681impl HashAggregateOperator {
682 pub fn new(
690 child: Box<dyn Operator>,
691 group_columns: Vec<usize>,
692 aggregates: Vec<AggregateExpr>,
693 output_schema: Vec<LogicalType>,
694 ) -> Self {
695 Self {
696 child,
697 group_columns,
698 aggregates,
699 output_schema,
700 groups: IndexMap::new(),
701 aggregation_complete: false,
702 results: None,
703 }
704 }
705
706 fn aggregate(&mut self) -> Result<(), OperatorError> {
708 while let Some(chunk) = self.child.next()? {
709 for row in chunk.selected_indices() {
710 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
711
712 let states = self.groups.entry(key).or_insert_with(|| {
714 self.aggregates
715 .iter()
716 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
717 .collect()
718 });
719
720 for (i, agg) in self.aggregates.iter().enumerate() {
722 let value = match (agg.function, agg.distinct) {
723 (AggregateFunction::Count, false) => None,
725 (AggregateFunction::Count, true) => agg
727 .column
728 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
729 _ => agg
730 .column
731 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
732 };
733
734 match (agg.function, agg.distinct) {
736 (AggregateFunction::Count, false) => states[i].update(None),
737 (AggregateFunction::Count, true) => {
738 if value.is_some() && !matches!(value, Some(Value::Null)) {
740 states[i].update(value);
741 }
742 }
743 (AggregateFunction::CountNonNull, _) => {
744 if value.is_some() && !matches!(value, Some(Value::Null)) {
745 states[i].update(value);
746 }
747 }
748 _ => {
749 if value.is_some() && !matches!(value, Some(Value::Null)) {
750 states[i].update(value);
751 }
752 }
753 }
754 }
755 }
756 }
757
758 self.aggregation_complete = true;
759
760 let results: Vec<_> = self.groups.drain(..).collect();
762 self.results = Some(results.into_iter());
763
764 Ok(())
765 }
766}
767
768impl Operator for HashAggregateOperator {
769 fn next(&mut self) -> OperatorResult {
770 if !self.aggregation_complete {
772 self.aggregate()?;
773 }
774
775 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
777 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
779
780 for agg in &self.aggregates {
781 let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
782 let value = state.finalize();
783 if let Some(col) = builder.column_mut(self.group_columns.len()) {
784 col.push_value(value);
785 }
786 }
787 builder.advance_row();
788
789 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
791 }
792
793 let results = match &mut self.results {
794 Some(r) => r,
795 None => return Ok(None),
796 };
797
798 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
799
800 for (key, states) in results.by_ref() {
801 let key_values = key.to_values();
803 for (i, value) in key_values.into_iter().enumerate() {
804 if let Some(col) = builder.column_mut(i) {
805 col.push_value(value);
806 }
807 }
808
809 for (i, state) in states.iter().enumerate() {
811 let col_idx = self.group_columns.len() + i;
812 if let Some(col) = builder.column_mut(col_idx) {
813 col.push_value(state.finalize());
814 }
815 }
816
817 builder.advance_row();
818
819 if builder.is_full() {
820 return Ok(Some(builder.finish()));
821 }
822 }
823
824 if builder.row_count() > 0 {
825 Ok(Some(builder.finish()))
826 } else {
827 Ok(None)
828 }
829 }
830
831 fn reset(&mut self) {
832 self.child.reset();
833 self.groups.clear();
834 self.aggregation_complete = false;
835 self.results = None;
836 }
837
838 fn name(&self) -> &'static str {
839 "HashAggregate"
840 }
841}
842
843pub struct SimpleAggregateOperator {
847 child: Box<dyn Operator>,
849 aggregates: Vec<AggregateExpr>,
851 output_schema: Vec<LogicalType>,
853 states: Vec<AggregateState>,
855 done: bool,
857}
858
859impl SimpleAggregateOperator {
860 pub fn new(
862 child: Box<dyn Operator>,
863 aggregates: Vec<AggregateExpr>,
864 output_schema: Vec<LogicalType>,
865 ) -> Self {
866 let states = aggregates
867 .iter()
868 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
869 .collect();
870
871 Self {
872 child,
873 aggregates,
874 output_schema,
875 states,
876 done: false,
877 }
878 }
879}
880
881impl Operator for SimpleAggregateOperator {
882 fn next(&mut self) -> OperatorResult {
883 if self.done {
884 return Ok(None);
885 }
886
887 while let Some(chunk) = self.child.next()? {
889 for row in chunk.selected_indices() {
890 for (i, agg) in self.aggregates.iter().enumerate() {
891 let value = match (agg.function, agg.distinct) {
892 (AggregateFunction::Count, false) => None,
894 (AggregateFunction::Count, true) => agg
896 .column
897 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
898 _ => agg
899 .column
900 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
901 };
902
903 match (agg.function, agg.distinct) {
904 (AggregateFunction::Count, false) => self.states[i].update(None),
905 (AggregateFunction::Count, true) => {
906 if value.is_some() && !matches!(value, Some(Value::Null)) {
908 self.states[i].update(value);
909 }
910 }
911 (AggregateFunction::CountNonNull, _) => {
912 if value.is_some() && !matches!(value, Some(Value::Null)) {
913 self.states[i].update(value);
914 }
915 }
916 _ => {
917 if value.is_some() && !matches!(value, Some(Value::Null)) {
918 self.states[i].update(value);
919 }
920 }
921 }
922 }
923 }
924 }
925
926 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
928
929 for (i, state) in self.states.iter().enumerate() {
930 if let Some(col) = builder.column_mut(i) {
931 col.push_value(state.finalize());
932 }
933 }
934 builder.advance_row();
935
936 self.done = true;
937 Ok(Some(builder.finish()))
938 }
939
940 fn reset(&mut self) {
941 self.child.reset();
942 self.states = self
943 .aggregates
944 .iter()
945 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
946 .collect();
947 self.done = false;
948 }
949
950 fn name(&self) -> &'static str {
951 "SimpleAggregate"
952 }
953}
954
955#[cfg(test)]
956mod tests {
957 use super::*;
958 use crate::execution::chunk::DataChunkBuilder;
959
960 struct MockOperator {
961 chunks: Vec<DataChunk>,
962 position: usize,
963 }
964
965 impl MockOperator {
966 fn new(chunks: Vec<DataChunk>) -> Self {
967 Self {
968 chunks,
969 position: 0,
970 }
971 }
972 }
973
974 impl Operator for MockOperator {
975 fn next(&mut self) -> OperatorResult {
976 if self.position < self.chunks.len() {
977 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
978 self.position += 1;
979 Ok(Some(chunk))
980 } else {
981 Ok(None)
982 }
983 }
984
985 fn reset(&mut self) {
986 self.position = 0;
987 }
988
989 fn name(&self) -> &'static str {
990 "Mock"
991 }
992 }
993
994 fn create_test_chunk() -> DataChunk {
995 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
997
998 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
999 for (group, value) in data {
1000 builder.column_mut(0).unwrap().push_int64(group);
1001 builder.column_mut(1).unwrap().push_int64(value);
1002 builder.advance_row();
1003 }
1004
1005 builder.finish()
1006 }
1007
1008 #[test]
1009 fn test_simple_count() {
1010 let mock = MockOperator::new(vec![create_test_chunk()]);
1011
1012 let mut agg = SimpleAggregateOperator::new(
1013 Box::new(mock),
1014 vec![AggregateExpr::count_star()],
1015 vec![LogicalType::Int64],
1016 );
1017
1018 let result = agg.next().unwrap().unwrap();
1019 assert_eq!(result.row_count(), 1);
1020 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1021
1022 assert!(agg.next().unwrap().is_none());
1024 }
1025
1026 #[test]
1027 fn test_simple_sum() {
1028 let mock = MockOperator::new(vec![create_test_chunk()]);
1029
1030 let mut agg = SimpleAggregateOperator::new(
1031 Box::new(mock),
1032 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1034 );
1035
1036 let result = agg.next().unwrap().unwrap();
1037 assert_eq!(result.row_count(), 1);
1038 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1040 }
1041
1042 #[test]
1043 fn test_simple_avg() {
1044 let mock = MockOperator::new(vec![create_test_chunk()]);
1045
1046 let mut agg = SimpleAggregateOperator::new(
1047 Box::new(mock),
1048 vec![AggregateExpr::avg(1)],
1049 vec![LogicalType::Float64],
1050 );
1051
1052 let result = agg.next().unwrap().unwrap();
1053 assert_eq!(result.row_count(), 1);
1054 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1056 assert!((avg - 30.0).abs() < 0.001);
1057 }
1058
1059 #[test]
1060 fn test_simple_min_max() {
1061 let mock = MockOperator::new(vec![create_test_chunk()]);
1062
1063 let mut agg = SimpleAggregateOperator::new(
1064 Box::new(mock),
1065 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1066 vec![LogicalType::Int64, LogicalType::Int64],
1067 );
1068
1069 let result = agg.next().unwrap().unwrap();
1070 assert_eq!(result.row_count(), 1);
1071 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1074
1075 #[test]
1076 fn test_sum_with_string_values() {
1077 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1079 builder.column_mut(0).unwrap().push_string("30");
1080 builder.advance_row();
1081 builder.column_mut(0).unwrap().push_string("25");
1082 builder.advance_row();
1083 builder.column_mut(0).unwrap().push_string("35");
1084 builder.advance_row();
1085 let chunk = builder.finish();
1086
1087 let mock = MockOperator::new(vec![chunk]);
1088 let mut agg = SimpleAggregateOperator::new(
1089 Box::new(mock),
1090 vec![AggregateExpr::sum(0)],
1091 vec![LogicalType::Float64],
1092 );
1093
1094 let result = agg.next().unwrap().unwrap();
1095 assert_eq!(result.row_count(), 1);
1096 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1098 assert!(
1099 (sum_val - 90.0).abs() < 0.001,
1100 "Expected 90.0, got {}",
1101 sum_val
1102 );
1103 }
1104
1105 #[test]
1106 fn test_grouped_aggregation() {
1107 let mock = MockOperator::new(vec![create_test_chunk()]);
1108
1109 let mut agg = HashAggregateOperator::new(
1111 Box::new(mock),
1112 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1115 );
1116
1117 let mut results: Vec<(i64, i64)> = Vec::new();
1118 while let Some(chunk) = agg.next().unwrap() {
1119 for row in chunk.selected_indices() {
1120 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1121 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1122 results.push((group, sum));
1123 }
1124 }
1125
1126 results.sort_by_key(|(g, _)| *g);
1127 assert_eq!(results.len(), 2);
1128 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1131
1132 #[test]
1133 fn test_grouped_count() {
1134 let mock = MockOperator::new(vec![create_test_chunk()]);
1135
1136 let mut agg = HashAggregateOperator::new(
1138 Box::new(mock),
1139 vec![0],
1140 vec![AggregateExpr::count_star()],
1141 vec![LogicalType::Int64, LogicalType::Int64],
1142 );
1143
1144 let mut results: Vec<(i64, i64)> = Vec::new();
1145 while let Some(chunk) = agg.next().unwrap() {
1146 for row in chunk.selected_indices() {
1147 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1148 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1149 results.push((group, count));
1150 }
1151 }
1152
1153 results.sort_by_key(|(g, _)| *g);
1154 assert_eq!(results.len(), 2);
1155 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1158
1159 #[test]
1160 fn test_multiple_aggregates() {
1161 let mock = MockOperator::new(vec![create_test_chunk()]);
1162
1163 let mut agg = HashAggregateOperator::new(
1165 Box::new(mock),
1166 vec![0],
1167 vec![
1168 AggregateExpr::count_star(),
1169 AggregateExpr::sum(1),
1170 AggregateExpr::avg(1),
1171 ],
1172 vec![
1173 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1178 );
1179
1180 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1181 while let Some(chunk) = agg.next().unwrap() {
1182 for row in chunk.selected_indices() {
1183 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1184 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1185 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1186 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1187 results.push((group, count, sum, avg));
1188 }
1189 }
1190
1191 results.sort_by_key(|(g, _, _, _)| *g);
1192 assert_eq!(results.len(), 2);
1193
1194 assert_eq!(results[0].0, 1);
1196 assert_eq!(results[0].1, 2);
1197 assert_eq!(results[0].2, 30);
1198 assert!((results[0].3 - 15.0).abs() < 0.001);
1199
1200 assert_eq!(results[1].0, 2);
1202 assert_eq!(results[1].1, 3);
1203 assert_eq!(results[1].2, 120);
1204 assert!((results[1].3 - 40.0).abs() < 0.001);
1205 }
1206
1207 fn create_test_chunk_with_duplicates() -> DataChunk {
1208 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1213
1214 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1215 for (group, value) in data {
1216 builder.column_mut(0).unwrap().push_int64(group);
1217 builder.column_mut(1).unwrap().push_int64(value);
1218 builder.advance_row();
1219 }
1220
1221 builder.finish()
1222 }
1223
1224 #[test]
1225 fn test_count_distinct() {
1226 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1227
1228 let mut agg = SimpleAggregateOperator::new(
1230 Box::new(mock),
1231 vec![AggregateExpr::count(1).with_distinct()],
1232 vec![LogicalType::Int64],
1233 );
1234
1235 let result = agg.next().unwrap().unwrap();
1236 assert_eq!(result.row_count(), 1);
1237 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1239 }
1240
1241 #[test]
1242 fn test_grouped_count_distinct() {
1243 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1244
1245 let mut agg = HashAggregateOperator::new(
1247 Box::new(mock),
1248 vec![0],
1249 vec![AggregateExpr::count(1).with_distinct()],
1250 vec![LogicalType::Int64, LogicalType::Int64],
1251 );
1252
1253 let mut results: Vec<(i64, i64)> = Vec::new();
1254 while let Some(chunk) = agg.next().unwrap() {
1255 for row in chunk.selected_indices() {
1256 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1257 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1258 results.push((group, count));
1259 }
1260 }
1261
1262 results.sort_by_key(|(g, _)| *g);
1263 assert_eq!(results.len(), 2);
1264 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1267
1268 #[test]
1269 fn test_sum_distinct() {
1270 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1271
1272 let mut agg = SimpleAggregateOperator::new(
1274 Box::new(mock),
1275 vec![AggregateExpr::sum(1).with_distinct()],
1276 vec![LogicalType::Int64],
1277 );
1278
1279 let result = agg.next().unwrap().unwrap();
1280 assert_eq!(result.row_count(), 1);
1281 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1283 }
1284
1285 #[test]
1286 fn test_avg_distinct() {
1287 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1288
1289 let mut agg = SimpleAggregateOperator::new(
1291 Box::new(mock),
1292 vec![AggregateExpr::avg(1).with_distinct()],
1293 vec![LogicalType::Float64],
1294 );
1295
1296 let result = agg.next().unwrap().unwrap();
1297 assert_eq!(result.row_count(), 1);
1298 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1300 assert!((avg - 20.0).abs() < 0.001);
1301 }
1302
1303 fn create_statistical_test_chunk() -> DataChunk {
1304 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1307
1308 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1309 builder.column_mut(0).unwrap().push_int64(value);
1310 builder.advance_row();
1311 }
1312
1313 builder.finish()
1314 }
1315
1316 #[test]
1317 fn test_stdev_sample() {
1318 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1319
1320 let mut agg = SimpleAggregateOperator::new(
1321 Box::new(mock),
1322 vec![AggregateExpr::stdev(0)],
1323 vec![LogicalType::Float64],
1324 );
1325
1326 let result = agg.next().unwrap().unwrap();
1327 assert_eq!(result.row_count(), 1);
1328 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1331 assert!((stdev - 2.138).abs() < 0.01);
1332 }
1333
1334 #[test]
1335 fn test_stdev_population() {
1336 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1337
1338 let mut agg = SimpleAggregateOperator::new(
1339 Box::new(mock),
1340 vec![AggregateExpr::stdev_pop(0)],
1341 vec![LogicalType::Float64],
1342 );
1343
1344 let result = agg.next().unwrap().unwrap();
1345 assert_eq!(result.row_count(), 1);
1346 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1349 assert!((stdev - 2.0).abs() < 0.01);
1350 }
1351
1352 #[test]
1353 fn test_percentile_disc() {
1354 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1355
1356 let mut agg = SimpleAggregateOperator::new(
1358 Box::new(mock),
1359 vec![AggregateExpr::percentile_disc(0, 0.5)],
1360 vec![LogicalType::Float64],
1361 );
1362
1363 let result = agg.next().unwrap().unwrap();
1364 assert_eq!(result.row_count(), 1);
1365 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1367 assert!((percentile - 4.0).abs() < 0.01);
1368 }
1369
1370 #[test]
1371 fn test_percentile_cont() {
1372 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1373
1374 let mut agg = SimpleAggregateOperator::new(
1376 Box::new(mock),
1377 vec![AggregateExpr::percentile_cont(0, 0.5)],
1378 vec![LogicalType::Float64],
1379 );
1380
1381 let result = agg.next().unwrap().unwrap();
1382 assert_eq!(result.row_count(), 1);
1383 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1386 assert!((percentile - 4.5).abs() < 0.01);
1387 }
1388
1389 #[test]
1390 fn test_percentile_extremes() {
1391 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1393
1394 let mut agg = SimpleAggregateOperator::new(
1395 Box::new(mock),
1396 vec![
1397 AggregateExpr::percentile_disc(0, 0.0),
1398 AggregateExpr::percentile_disc(0, 1.0),
1399 ],
1400 vec![LogicalType::Float64, LogicalType::Float64],
1401 );
1402
1403 let result = agg.next().unwrap().unwrap();
1404 assert_eq!(result.row_count(), 1);
1405 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1407 assert!((p0 - 2.0).abs() < 0.01);
1408 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1410 assert!((p100 - 9.0).abs() < 0.01);
1411 }
1412
1413 #[test]
1414 fn test_stdev_single_value() {
1415 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1417 builder.column_mut(0).unwrap().push_int64(42);
1418 builder.advance_row();
1419 let chunk = builder.finish();
1420
1421 let mock = MockOperator::new(vec![chunk]);
1422
1423 let mut agg = SimpleAggregateOperator::new(
1424 Box::new(mock),
1425 vec![AggregateExpr::stdev(0)],
1426 vec![LogicalType::Float64],
1427 );
1428
1429 let result = agg.next().unwrap().unwrap();
1430 assert_eq!(result.row_count(), 1);
1431 assert!(matches!(
1433 result.column(0).unwrap().get_value(0),
1434 Some(Value::Null)
1435 ));
1436 }
1437
1438 #[test]
1439 fn test_stdev_pop_single_value() {
1440 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1442 builder.column_mut(0).unwrap().push_int64(42);
1443 builder.advance_row();
1444 let chunk = builder.finish();
1445
1446 let mock = MockOperator::new(vec![chunk]);
1447
1448 let mut agg = SimpleAggregateOperator::new(
1449 Box::new(mock),
1450 vec![AggregateExpr::stdev_pop(0)],
1451 vec![LogicalType::Float64],
1452 );
1453
1454 let result = agg.next().unwrap().unwrap();
1455 assert_eq!(result.row_count(), 1);
1456 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1458 assert!((stdev - 0.0).abs() < 0.01);
1459 }
1460}