1use indexmap::IndexMap;
8use std::collections::HashSet;
9
10use grafeo_common::types::{LogicalType, Value};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum HashableValue {
15 Null,
16 Bool(bool),
17 Int64(i64),
18 Float64Bits(u64),
19 String(String),
20 Other(String),
21}
22
23impl From<&Value> for HashableValue {
24 fn from(v: &Value) -> Self {
25 match v {
26 Value::Null => HashableValue::Null,
27 Value::Bool(b) => HashableValue::Bool(*b),
28 Value::Int64(i) => HashableValue::Int64(*i),
29 Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
30 Value::String(s) => HashableValue::String(s.to_string()),
31 other => HashableValue::Other(format!("{other:?}")),
32 }
33 }
34}
35
36impl From<Value> for HashableValue {
37 fn from(v: Value) -> Self {
38 Self::from(&v)
39 }
40}
41
42use super::{Operator, OperatorError, OperatorResult};
43use crate::execution::DataChunk;
44use crate::execution::chunk::DataChunkBuilder;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AggregateFunction {
49 Count,
51 CountNonNull,
53 Sum,
55 Avg,
57 Min,
59 Max,
61 First,
63 Last,
65 Collect,
67 StdDev,
69 StdDevPop,
71 PercentileDisc,
73 PercentileCont,
75}
76
77#[derive(Debug, Clone)]
79pub struct AggregateExpr {
80 pub function: AggregateFunction,
82 pub column: Option<usize>,
84 pub distinct: bool,
86 pub alias: Option<String>,
88 pub percentile: Option<f64>,
90}
91
92impl AggregateExpr {
93 pub fn count_star() -> Self {
95 Self {
96 function: AggregateFunction::Count,
97 column: None,
98 distinct: false,
99 alias: None,
100 percentile: None,
101 }
102 }
103
104 pub fn count(column: usize) -> Self {
106 Self {
107 function: AggregateFunction::CountNonNull,
108 column: Some(column),
109 distinct: false,
110 alias: None,
111 percentile: None,
112 }
113 }
114
115 pub fn sum(column: usize) -> Self {
117 Self {
118 function: AggregateFunction::Sum,
119 column: Some(column),
120 distinct: false,
121 alias: None,
122 percentile: None,
123 }
124 }
125
126 pub fn avg(column: usize) -> Self {
128 Self {
129 function: AggregateFunction::Avg,
130 column: Some(column),
131 distinct: false,
132 alias: None,
133 percentile: None,
134 }
135 }
136
137 pub fn min(column: usize) -> Self {
139 Self {
140 function: AggregateFunction::Min,
141 column: Some(column),
142 distinct: false,
143 alias: None,
144 percentile: None,
145 }
146 }
147
148 pub fn max(column: usize) -> Self {
150 Self {
151 function: AggregateFunction::Max,
152 column: Some(column),
153 distinct: false,
154 alias: None,
155 percentile: None,
156 }
157 }
158
159 pub fn first(column: usize) -> Self {
161 Self {
162 function: AggregateFunction::First,
163 column: Some(column),
164 distinct: false,
165 alias: None,
166 percentile: None,
167 }
168 }
169
170 pub fn last(column: usize) -> Self {
172 Self {
173 function: AggregateFunction::Last,
174 column: Some(column),
175 distinct: false,
176 alias: None,
177 percentile: None,
178 }
179 }
180
181 pub fn collect(column: usize) -> Self {
183 Self {
184 function: AggregateFunction::Collect,
185 column: Some(column),
186 distinct: false,
187 alias: None,
188 percentile: None,
189 }
190 }
191
192 pub fn stdev(column: usize) -> Self {
194 Self {
195 function: AggregateFunction::StdDev,
196 column: Some(column),
197 distinct: false,
198 alias: None,
199 percentile: None,
200 }
201 }
202
203 pub fn stdev_pop(column: usize) -> Self {
205 Self {
206 function: AggregateFunction::StdDevPop,
207 column: Some(column),
208 distinct: false,
209 alias: None,
210 percentile: None,
211 }
212 }
213
214 pub fn percentile_disc(column: usize, percentile: f64) -> Self {
220 Self {
221 function: AggregateFunction::PercentileDisc,
222 column: Some(column),
223 distinct: false,
224 alias: None,
225 percentile: Some(percentile.clamp(0.0, 1.0)),
226 }
227 }
228
229 pub fn percentile_cont(column: usize, percentile: f64) -> Self {
235 Self {
236 function: AggregateFunction::PercentileCont,
237 column: Some(column),
238 distinct: false,
239 alias: None,
240 percentile: Some(percentile.clamp(0.0, 1.0)),
241 }
242 }
243
244 pub fn with_distinct(mut self) -> Self {
246 self.distinct = true;
247 self
248 }
249
250 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
252 self.alias = Some(alias.into());
253 self
254 }
255}
256
257#[derive(Debug, Clone)]
259enum AggregateState {
260 Count(i64),
262 CountDistinct(i64, HashSet<HashableValue>),
264 SumInt(i64),
266 SumIntDistinct(i64, HashSet<HashableValue>),
268 SumFloat(f64),
270 SumFloatDistinct(f64, HashSet<HashableValue>),
272 Avg(f64, i64),
274 AvgDistinct(f64, i64, HashSet<HashableValue>),
276 Min(Option<Value>),
278 Max(Option<Value>),
280 First(Option<Value>),
282 Last(Option<Value>),
284 Collect(Vec<Value>),
286 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
288 StdDev { count: i64, mean: f64, m2: f64 },
290 StdDevPop { count: i64, mean: f64, m2: f64 },
292 PercentileDisc { values: Vec<f64>, percentile: f64 },
294 PercentileCont { values: Vec<f64>, percentile: f64 },
296}
297
298impl AggregateState {
299 fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
301 match (function, distinct) {
302 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
303 AggregateState::Count(0)
304 }
305 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
306 AggregateState::CountDistinct(0, HashSet::new())
307 }
308 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
309 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
310 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
311 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
312 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
314 (AggregateFunction::First, _) => AggregateState::First(None),
315 (AggregateFunction::Last, _) => AggregateState::Last(None),
316 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
317 (AggregateFunction::Collect, true) => {
318 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
319 }
320 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
322 count: 0,
323 mean: 0.0,
324 m2: 0.0,
325 },
326 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
327 count: 0,
328 mean: 0.0,
329 m2: 0.0,
330 },
331 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
332 values: Vec::new(),
333 percentile: percentile.unwrap_or(0.5),
334 },
335 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
336 values: Vec::new(),
337 percentile: percentile.unwrap_or(0.5),
338 },
339 }
340 }
341
342 fn update(&mut self, value: Option<Value>) {
344 match self {
345 AggregateState::Count(count) => {
346 *count += 1;
347 }
348 AggregateState::CountDistinct(count, seen) => {
349 if let Some(ref v) = value {
350 let hashable = HashableValue::from(v);
351 if seen.insert(hashable) {
352 *count += 1;
353 }
354 }
355 }
356 AggregateState::SumInt(sum) => {
357 if let Some(Value::Int64(v)) = value {
358 *sum += v;
359 } else if let Some(Value::Float64(v)) = value {
360 *self = AggregateState::SumFloat(*sum as f64 + v);
362 } else if let Some(ref v) = value {
363 if let Some(num) = value_to_f64(v) {
365 *self = AggregateState::SumFloat(*sum as f64 + num);
366 }
367 }
368 }
369 AggregateState::SumIntDistinct(sum, seen) => {
370 if let Some(ref v) = value {
371 let hashable = HashableValue::from(v);
372 if seen.insert(hashable) {
373 if let Value::Int64(i) = v {
374 *sum += i;
375 } else if let Value::Float64(f) = v {
376 let moved_seen = std::mem::take(seen);
378 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
379 } else if let Some(num) = value_to_f64(v) {
380 let moved_seen = std::mem::take(seen);
382 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
383 }
384 }
385 }
386 }
387 AggregateState::SumFloat(sum) => {
388 if let Some(ref v) = value {
389 if let Some(num) = value_to_f64(v) {
391 *sum += num;
392 }
393 }
394 }
395 AggregateState::SumFloatDistinct(sum, seen) => {
396 if let Some(ref v) = value {
397 let hashable = HashableValue::from(v);
398 if seen.insert(hashable)
399 && let Some(num) = value_to_f64(v)
400 {
401 *sum += num;
402 }
403 }
404 }
405 AggregateState::Avg(sum, count) => {
406 if let Some(ref v) = value
407 && let Some(num) = value_to_f64(v)
408 {
409 *sum += num;
410 *count += 1;
411 }
412 }
413 AggregateState::AvgDistinct(sum, count, seen) => {
414 if let Some(ref v) = value {
415 let hashable = HashableValue::from(v);
416 if seen.insert(hashable)
417 && let Some(num) = value_to_f64(v)
418 {
419 *sum += num;
420 *count += 1;
421 }
422 }
423 }
424 AggregateState::Min(min) => {
425 if let Some(v) = value {
426 match min {
427 None => *min = Some(v),
428 Some(current) => {
429 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
430 *min = Some(v);
431 }
432 }
433 }
434 }
435 }
436 AggregateState::Max(max) => {
437 if let Some(v) = value {
438 match max {
439 None => *max = Some(v),
440 Some(current) => {
441 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
442 *max = Some(v);
443 }
444 }
445 }
446 }
447 }
448 AggregateState::First(first) => {
449 if first.is_none() {
450 *first = value;
451 }
452 }
453 AggregateState::Last(last) => {
454 if value.is_some() {
455 *last = value;
456 }
457 }
458 AggregateState::Collect(list) => {
459 if let Some(v) = value {
460 list.push(v);
461 }
462 }
463 AggregateState::CollectDistinct(list, seen) => {
464 if let Some(v) = value {
465 let hashable = HashableValue::from(&v);
466 if seen.insert(hashable) {
467 list.push(v);
468 }
469 }
470 }
471 AggregateState::StdDev { count, mean, m2 }
473 | AggregateState::StdDevPop { count, mean, m2 } => {
474 if let Some(ref v) = value
475 && let Some(x) = value_to_f64(v)
476 {
477 *count += 1;
478 let delta = x - *mean;
479 *mean += delta / *count as f64;
480 let delta2 = x - *mean;
481 *m2 += delta * delta2;
482 }
483 }
484 AggregateState::PercentileDisc { values, .. }
485 | AggregateState::PercentileCont { values, .. } => {
486 if let Some(ref v) = value
487 && let Some(x) = value_to_f64(v)
488 {
489 values.push(x);
490 }
491 }
492 }
493 }
494
495 fn finalize(&self) -> Value {
497 match self {
498 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
499 Value::Int64(*count)
500 }
501 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
502 Value::Int64(*sum)
503 }
504 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
505 Value::Float64(*sum)
506 }
507 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
508 if *count == 0 {
509 Value::Null
510 } else {
511 Value::Float64(*sum / *count as f64)
512 }
513 }
514 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
515 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
516 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
517 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
518 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
519 Value::List(list.clone().into())
520 }
521 AggregateState::StdDev { count, m2, .. } => {
523 if *count < 2 {
524 Value::Null
525 } else {
526 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
527 }
528 }
529 AggregateState::StdDevPop { count, m2, .. } => {
531 if *count == 0 {
532 Value::Null
533 } else {
534 Value::Float64((*m2 / *count as f64).sqrt())
535 }
536 }
537 AggregateState::PercentileDisc { values, percentile } => {
539 if values.is_empty() {
540 Value::Null
541 } else {
542 let mut sorted = values.clone();
543 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
546 Value::Float64(sorted[index])
547 }
548 }
549 AggregateState::PercentileCont { values, percentile } => {
551 if values.is_empty() {
552 Value::Null
553 } else {
554 let mut sorted = values.clone();
555 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556 let rank = percentile * (sorted.len() - 1) as f64;
558 let lower_idx = rank.floor() as usize;
559 let upper_idx = rank.ceil() as usize;
560 if lower_idx == upper_idx {
561 Value::Float64(sorted[lower_idx])
562 } else {
563 let fraction = rank - lower_idx as f64;
564 let result =
565 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
566 Value::Float64(result)
567 }
568 }
569 }
570 }
571 }
572}
573
574use super::value_utils::{compare_values, value_to_f64};
575
576#[derive(Debug, Clone, PartialEq, Eq, Hash)]
578pub struct GroupKey(Vec<GroupKeyPart>);
579
580#[derive(Debug, Clone, PartialEq, Eq, Hash)]
581enum GroupKeyPart {
582 Null,
583 Bool(bool),
584 Int64(i64),
585 String(String),
586}
587
588impl GroupKey {
589 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
591 let parts: Vec<GroupKeyPart> = group_columns
592 .iter()
593 .map(|&col_idx| {
594 chunk
595 .column(col_idx)
596 .and_then(|col| col.get_value(row))
597 .map_or(GroupKeyPart::Null, |v| match v {
598 Value::Null => GroupKeyPart::Null,
599 Value::Bool(b) => GroupKeyPart::Bool(b),
600 Value::Int64(i) => GroupKeyPart::Int64(i),
601 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
602 Value::String(s) => GroupKeyPart::String(s.to_string()),
603 _ => GroupKeyPart::String(format!("{v:?}")),
604 })
605 })
606 .collect();
607 GroupKey(parts)
608 }
609
610 fn to_values(&self) -> Vec<Value> {
612 self.0
613 .iter()
614 .map(|part| match part {
615 GroupKeyPart::Null => Value::Null,
616 GroupKeyPart::Bool(b) => Value::Bool(*b),
617 GroupKeyPart::Int64(i) => Value::Int64(*i),
618 GroupKeyPart::String(s) => Value::String(s.clone().into()),
619 })
620 .collect()
621 }
622}
623
624pub struct HashAggregateOperator {
628 child: Box<dyn Operator>,
630 group_columns: Vec<usize>,
632 aggregates: Vec<AggregateExpr>,
634 output_schema: Vec<LogicalType>,
636 groups: IndexMap<GroupKey, Vec<AggregateState>>,
638 aggregation_complete: bool,
640 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
642}
643
644impl HashAggregateOperator {
645 pub fn new(
653 child: Box<dyn Operator>,
654 group_columns: Vec<usize>,
655 aggregates: Vec<AggregateExpr>,
656 output_schema: Vec<LogicalType>,
657 ) -> Self {
658 Self {
659 child,
660 group_columns,
661 aggregates,
662 output_schema,
663 groups: IndexMap::new(),
664 aggregation_complete: false,
665 results: None,
666 }
667 }
668
669 fn aggregate(&mut self) -> Result<(), OperatorError> {
671 while let Some(chunk) = self.child.next()? {
672 for row in chunk.selected_indices() {
673 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
674
675 let states = self.groups.entry(key).or_insert_with(|| {
677 self.aggregates
678 .iter()
679 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
680 .collect()
681 });
682
683 for (i, agg) in self.aggregates.iter().enumerate() {
685 let value = match (agg.function, agg.distinct) {
686 (AggregateFunction::Count, false) => None,
688 (AggregateFunction::Count, true) => agg
690 .column
691 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
692 _ => agg
693 .column
694 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
695 };
696
697 match (agg.function, agg.distinct) {
699 (AggregateFunction::Count, false) => states[i].update(None),
700 (AggregateFunction::Count, true) => {
701 if value.is_some() && !matches!(value, Some(Value::Null)) {
703 states[i].update(value);
704 }
705 }
706 (AggregateFunction::CountNonNull, _) => {
707 if value.is_some() && !matches!(value, Some(Value::Null)) {
708 states[i].update(value);
709 }
710 }
711 _ => {
712 if value.is_some() && !matches!(value, Some(Value::Null)) {
713 states[i].update(value);
714 }
715 }
716 }
717 }
718 }
719 }
720
721 self.aggregation_complete = true;
722
723 let results: Vec<_> = self.groups.drain(..).collect();
725 self.results = Some(results.into_iter());
726
727 Ok(())
728 }
729}
730
731impl Operator for HashAggregateOperator {
732 fn next(&mut self) -> OperatorResult {
733 if !self.aggregation_complete {
735 self.aggregate()?;
736 }
737
738 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
740 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
742
743 for agg in &self.aggregates {
744 let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
745 let value = state.finalize();
746 if let Some(col) = builder.column_mut(self.group_columns.len()) {
747 col.push_value(value);
748 }
749 }
750 builder.advance_row();
751
752 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
754 }
755
756 let Some(results) = &mut self.results else {
757 return Ok(None);
758 };
759
760 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
761
762 for (key, states) in results.by_ref() {
763 let key_values = key.to_values();
765 for (i, value) in key_values.into_iter().enumerate() {
766 if let Some(col) = builder.column_mut(i) {
767 col.push_value(value);
768 }
769 }
770
771 for (i, state) in states.iter().enumerate() {
773 let col_idx = self.group_columns.len() + i;
774 if let Some(col) = builder.column_mut(col_idx) {
775 col.push_value(state.finalize());
776 }
777 }
778
779 builder.advance_row();
780
781 if builder.is_full() {
782 return Ok(Some(builder.finish()));
783 }
784 }
785
786 if builder.row_count() > 0 {
787 Ok(Some(builder.finish()))
788 } else {
789 Ok(None)
790 }
791 }
792
793 fn reset(&mut self) {
794 self.child.reset();
795 self.groups.clear();
796 self.aggregation_complete = false;
797 self.results = None;
798 }
799
800 fn name(&self) -> &'static str {
801 "HashAggregate"
802 }
803}
804
805pub struct SimpleAggregateOperator {
809 child: Box<dyn Operator>,
811 aggregates: Vec<AggregateExpr>,
813 output_schema: Vec<LogicalType>,
815 states: Vec<AggregateState>,
817 done: bool,
819}
820
821impl SimpleAggregateOperator {
822 pub fn new(
824 child: Box<dyn Operator>,
825 aggregates: Vec<AggregateExpr>,
826 output_schema: Vec<LogicalType>,
827 ) -> Self {
828 let states = aggregates
829 .iter()
830 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
831 .collect();
832
833 Self {
834 child,
835 aggregates,
836 output_schema,
837 states,
838 done: false,
839 }
840 }
841}
842
843impl Operator for SimpleAggregateOperator {
844 fn next(&mut self) -> OperatorResult {
845 if self.done {
846 return Ok(None);
847 }
848
849 while let Some(chunk) = self.child.next()? {
851 for row in chunk.selected_indices() {
852 for (i, agg) in self.aggregates.iter().enumerate() {
853 let value = match (agg.function, agg.distinct) {
854 (AggregateFunction::Count, false) => None,
856 (AggregateFunction::Count, true) => agg
858 .column
859 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
860 _ => agg
861 .column
862 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
863 };
864
865 match (agg.function, agg.distinct) {
866 (AggregateFunction::Count, false) => self.states[i].update(None),
867 (AggregateFunction::Count, true) => {
868 if value.is_some() && !matches!(value, Some(Value::Null)) {
870 self.states[i].update(value);
871 }
872 }
873 (AggregateFunction::CountNonNull, _) => {
874 if value.is_some() && !matches!(value, Some(Value::Null)) {
875 self.states[i].update(value);
876 }
877 }
878 _ => {
879 if value.is_some() && !matches!(value, Some(Value::Null)) {
880 self.states[i].update(value);
881 }
882 }
883 }
884 }
885 }
886 }
887
888 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
890
891 for (i, state) in self.states.iter().enumerate() {
892 if let Some(col) = builder.column_mut(i) {
893 col.push_value(state.finalize());
894 }
895 }
896 builder.advance_row();
897
898 self.done = true;
899 Ok(Some(builder.finish()))
900 }
901
902 fn reset(&mut self) {
903 self.child.reset();
904 self.states = self
905 .aggregates
906 .iter()
907 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
908 .collect();
909 self.done = false;
910 }
911
912 fn name(&self) -> &'static str {
913 "SimpleAggregate"
914 }
915}
916
917#[cfg(test)]
918mod tests {
919 use super::*;
920 use crate::execution::chunk::DataChunkBuilder;
921
922 struct MockOperator {
923 chunks: Vec<DataChunk>,
924 position: usize,
925 }
926
927 impl MockOperator {
928 fn new(chunks: Vec<DataChunk>) -> Self {
929 Self {
930 chunks,
931 position: 0,
932 }
933 }
934 }
935
936 impl Operator for MockOperator {
937 fn next(&mut self) -> OperatorResult {
938 if self.position < self.chunks.len() {
939 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
940 self.position += 1;
941 Ok(Some(chunk))
942 } else {
943 Ok(None)
944 }
945 }
946
947 fn reset(&mut self) {
948 self.position = 0;
949 }
950
951 fn name(&self) -> &'static str {
952 "Mock"
953 }
954 }
955
956 fn create_test_chunk() -> DataChunk {
957 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
959
960 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
961 for (group, value) in data {
962 builder.column_mut(0).unwrap().push_int64(group);
963 builder.column_mut(1).unwrap().push_int64(value);
964 builder.advance_row();
965 }
966
967 builder.finish()
968 }
969
970 #[test]
971 fn test_simple_count() {
972 let mock = MockOperator::new(vec![create_test_chunk()]);
973
974 let mut agg = SimpleAggregateOperator::new(
975 Box::new(mock),
976 vec![AggregateExpr::count_star()],
977 vec![LogicalType::Int64],
978 );
979
980 let result = agg.next().unwrap().unwrap();
981 assert_eq!(result.row_count(), 1);
982 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
983
984 assert!(agg.next().unwrap().is_none());
986 }
987
988 #[test]
989 fn test_simple_sum() {
990 let mock = MockOperator::new(vec![create_test_chunk()]);
991
992 let mut agg = SimpleAggregateOperator::new(
993 Box::new(mock),
994 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
996 );
997
998 let result = agg.next().unwrap().unwrap();
999 assert_eq!(result.row_count(), 1);
1000 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1002 }
1003
1004 #[test]
1005 fn test_simple_avg() {
1006 let mock = MockOperator::new(vec![create_test_chunk()]);
1007
1008 let mut agg = SimpleAggregateOperator::new(
1009 Box::new(mock),
1010 vec![AggregateExpr::avg(1)],
1011 vec![LogicalType::Float64],
1012 );
1013
1014 let result = agg.next().unwrap().unwrap();
1015 assert_eq!(result.row_count(), 1);
1016 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1018 assert!((avg - 30.0).abs() < 0.001);
1019 }
1020
1021 #[test]
1022 fn test_simple_min_max() {
1023 let mock = MockOperator::new(vec![create_test_chunk()]);
1024
1025 let mut agg = SimpleAggregateOperator::new(
1026 Box::new(mock),
1027 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1028 vec![LogicalType::Int64, LogicalType::Int64],
1029 );
1030
1031 let result = agg.next().unwrap().unwrap();
1032 assert_eq!(result.row_count(), 1);
1033 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1036
1037 #[test]
1038 fn test_sum_with_string_values() {
1039 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1041 builder.column_mut(0).unwrap().push_string("30");
1042 builder.advance_row();
1043 builder.column_mut(0).unwrap().push_string("25");
1044 builder.advance_row();
1045 builder.column_mut(0).unwrap().push_string("35");
1046 builder.advance_row();
1047 let chunk = builder.finish();
1048
1049 let mock = MockOperator::new(vec![chunk]);
1050 let mut agg = SimpleAggregateOperator::new(
1051 Box::new(mock),
1052 vec![AggregateExpr::sum(0)],
1053 vec![LogicalType::Float64],
1054 );
1055
1056 let result = agg.next().unwrap().unwrap();
1057 assert_eq!(result.row_count(), 1);
1058 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1060 assert!(
1061 (sum_val - 90.0).abs() < 0.001,
1062 "Expected 90.0, got {}",
1063 sum_val
1064 );
1065 }
1066
1067 #[test]
1068 fn test_grouped_aggregation() {
1069 let mock = MockOperator::new(vec![create_test_chunk()]);
1070
1071 let mut agg = HashAggregateOperator::new(
1073 Box::new(mock),
1074 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1077 );
1078
1079 let mut results: Vec<(i64, i64)> = Vec::new();
1080 while let Some(chunk) = agg.next().unwrap() {
1081 for row in chunk.selected_indices() {
1082 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1083 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1084 results.push((group, sum));
1085 }
1086 }
1087
1088 results.sort_by_key(|(g, _)| *g);
1089 assert_eq!(results.len(), 2);
1090 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1093
1094 #[test]
1095 fn test_grouped_count() {
1096 let mock = MockOperator::new(vec![create_test_chunk()]);
1097
1098 let mut agg = HashAggregateOperator::new(
1100 Box::new(mock),
1101 vec![0],
1102 vec![AggregateExpr::count_star()],
1103 vec![LogicalType::Int64, LogicalType::Int64],
1104 );
1105
1106 let mut results: Vec<(i64, i64)> = Vec::new();
1107 while let Some(chunk) = agg.next().unwrap() {
1108 for row in chunk.selected_indices() {
1109 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1110 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1111 results.push((group, count));
1112 }
1113 }
1114
1115 results.sort_by_key(|(g, _)| *g);
1116 assert_eq!(results.len(), 2);
1117 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1120
1121 #[test]
1122 fn test_multiple_aggregates() {
1123 let mock = MockOperator::new(vec![create_test_chunk()]);
1124
1125 let mut agg = HashAggregateOperator::new(
1127 Box::new(mock),
1128 vec![0],
1129 vec![
1130 AggregateExpr::count_star(),
1131 AggregateExpr::sum(1),
1132 AggregateExpr::avg(1),
1133 ],
1134 vec![
1135 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1140 );
1141
1142 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1143 while let Some(chunk) = agg.next().unwrap() {
1144 for row in chunk.selected_indices() {
1145 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1146 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1147 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1148 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1149 results.push((group, count, sum, avg));
1150 }
1151 }
1152
1153 results.sort_by_key(|(g, _, _, _)| *g);
1154 assert_eq!(results.len(), 2);
1155
1156 assert_eq!(results[0].0, 1);
1158 assert_eq!(results[0].1, 2);
1159 assert_eq!(results[0].2, 30);
1160 assert!((results[0].3 - 15.0).abs() < 0.001);
1161
1162 assert_eq!(results[1].0, 2);
1164 assert_eq!(results[1].1, 3);
1165 assert_eq!(results[1].2, 120);
1166 assert!((results[1].3 - 40.0).abs() < 0.001);
1167 }
1168
1169 fn create_test_chunk_with_duplicates() -> DataChunk {
1170 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1175
1176 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1177 for (group, value) in data {
1178 builder.column_mut(0).unwrap().push_int64(group);
1179 builder.column_mut(1).unwrap().push_int64(value);
1180 builder.advance_row();
1181 }
1182
1183 builder.finish()
1184 }
1185
1186 #[test]
1187 fn test_count_distinct() {
1188 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1189
1190 let mut agg = SimpleAggregateOperator::new(
1192 Box::new(mock),
1193 vec![AggregateExpr::count(1).with_distinct()],
1194 vec![LogicalType::Int64],
1195 );
1196
1197 let result = agg.next().unwrap().unwrap();
1198 assert_eq!(result.row_count(), 1);
1199 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1201 }
1202
1203 #[test]
1204 fn test_grouped_count_distinct() {
1205 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1206
1207 let mut agg = HashAggregateOperator::new(
1209 Box::new(mock),
1210 vec![0],
1211 vec![AggregateExpr::count(1).with_distinct()],
1212 vec![LogicalType::Int64, LogicalType::Int64],
1213 );
1214
1215 let mut results: Vec<(i64, i64)> = Vec::new();
1216 while let Some(chunk) = agg.next().unwrap() {
1217 for row in chunk.selected_indices() {
1218 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1219 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1220 results.push((group, count));
1221 }
1222 }
1223
1224 results.sort_by_key(|(g, _)| *g);
1225 assert_eq!(results.len(), 2);
1226 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1229
1230 #[test]
1231 fn test_sum_distinct() {
1232 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1233
1234 let mut agg = SimpleAggregateOperator::new(
1236 Box::new(mock),
1237 vec![AggregateExpr::sum(1).with_distinct()],
1238 vec![LogicalType::Int64],
1239 );
1240
1241 let result = agg.next().unwrap().unwrap();
1242 assert_eq!(result.row_count(), 1);
1243 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1245 }
1246
1247 #[test]
1248 fn test_avg_distinct() {
1249 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1250
1251 let mut agg = SimpleAggregateOperator::new(
1253 Box::new(mock),
1254 vec![AggregateExpr::avg(1).with_distinct()],
1255 vec![LogicalType::Float64],
1256 );
1257
1258 let result = agg.next().unwrap().unwrap();
1259 assert_eq!(result.row_count(), 1);
1260 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1262 assert!((avg - 20.0).abs() < 0.001);
1263 }
1264
1265 fn create_statistical_test_chunk() -> DataChunk {
1266 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1269
1270 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1271 builder.column_mut(0).unwrap().push_int64(value);
1272 builder.advance_row();
1273 }
1274
1275 builder.finish()
1276 }
1277
1278 #[test]
1279 fn test_stdev_sample() {
1280 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1281
1282 let mut agg = SimpleAggregateOperator::new(
1283 Box::new(mock),
1284 vec![AggregateExpr::stdev(0)],
1285 vec![LogicalType::Float64],
1286 );
1287
1288 let result = agg.next().unwrap().unwrap();
1289 assert_eq!(result.row_count(), 1);
1290 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1293 assert!((stdev - 2.138).abs() < 0.01);
1294 }
1295
1296 #[test]
1297 fn test_stdev_population() {
1298 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1299
1300 let mut agg = SimpleAggregateOperator::new(
1301 Box::new(mock),
1302 vec![AggregateExpr::stdev_pop(0)],
1303 vec![LogicalType::Float64],
1304 );
1305
1306 let result = agg.next().unwrap().unwrap();
1307 assert_eq!(result.row_count(), 1);
1308 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1311 assert!((stdev - 2.0).abs() < 0.01);
1312 }
1313
1314 #[test]
1315 fn test_percentile_disc() {
1316 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1317
1318 let mut agg = SimpleAggregateOperator::new(
1320 Box::new(mock),
1321 vec![AggregateExpr::percentile_disc(0, 0.5)],
1322 vec![LogicalType::Float64],
1323 );
1324
1325 let result = agg.next().unwrap().unwrap();
1326 assert_eq!(result.row_count(), 1);
1327 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1329 assert!((percentile - 4.0).abs() < 0.01);
1330 }
1331
1332 #[test]
1333 fn test_percentile_cont() {
1334 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1335
1336 let mut agg = SimpleAggregateOperator::new(
1338 Box::new(mock),
1339 vec![AggregateExpr::percentile_cont(0, 0.5)],
1340 vec![LogicalType::Float64],
1341 );
1342
1343 let result = agg.next().unwrap().unwrap();
1344 assert_eq!(result.row_count(), 1);
1345 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1348 assert!((percentile - 4.5).abs() < 0.01);
1349 }
1350
1351 #[test]
1352 fn test_percentile_extremes() {
1353 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1355
1356 let mut agg = SimpleAggregateOperator::new(
1357 Box::new(mock),
1358 vec![
1359 AggregateExpr::percentile_disc(0, 0.0),
1360 AggregateExpr::percentile_disc(0, 1.0),
1361 ],
1362 vec![LogicalType::Float64, LogicalType::Float64],
1363 );
1364
1365 let result = agg.next().unwrap().unwrap();
1366 assert_eq!(result.row_count(), 1);
1367 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1369 assert!((p0 - 2.0).abs() < 0.01);
1370 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1372 assert!((p100 - 9.0).abs() < 0.01);
1373 }
1374
1375 #[test]
1376 fn test_stdev_single_value() {
1377 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1379 builder.column_mut(0).unwrap().push_int64(42);
1380 builder.advance_row();
1381 let chunk = builder.finish();
1382
1383 let mock = MockOperator::new(vec![chunk]);
1384
1385 let mut agg = SimpleAggregateOperator::new(
1386 Box::new(mock),
1387 vec![AggregateExpr::stdev(0)],
1388 vec![LogicalType::Float64],
1389 );
1390
1391 let result = agg.next().unwrap().unwrap();
1392 assert_eq!(result.row_count(), 1);
1393 assert!(matches!(
1395 result.column(0).unwrap().get_value(0),
1396 Some(Value::Null)
1397 ));
1398 }
1399
1400 #[test]
1401 fn test_stdev_pop_single_value() {
1402 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1404 builder.column_mut(0).unwrap().push_int64(42);
1405 builder.advance_row();
1406 let chunk = builder.finish();
1407
1408 let mock = MockOperator::new(vec![chunk]);
1409
1410 let mut agg = SimpleAggregateOperator::new(
1411 Box::new(mock),
1412 vec![AggregateExpr::stdev_pop(0)],
1413 vec![LogicalType::Float64],
1414 );
1415
1416 let result = agg.next().unwrap().unwrap();
1417 assert_eq!(result.row_count(), 1);
1418 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1420 assert!((stdev - 0.0).abs() < 0.01);
1421 }
1422}