1use std::collections::{HashMap, HashSet};
8
9use grafeo_common::types::{LogicalType, Value};
10
11#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13enum HashableValue {
14 Null,
15 Bool(bool),
16 Int64(i64),
17 Float64Bits(u64),
18 String(String),
19 Other(String),
20}
21
22impl From<&Value> for HashableValue {
23 fn from(v: &Value) -> Self {
24 match v {
25 Value::Null => HashableValue::Null,
26 Value::Bool(b) => HashableValue::Bool(*b),
27 Value::Int64(i) => HashableValue::Int64(*i),
28 Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
29 Value::String(s) => HashableValue::String(s.to_string()),
30 other => HashableValue::Other(format!("{other:?}")),
31 }
32 }
33}
34
35impl From<Value> for HashableValue {
36 fn from(v: Value) -> Self {
37 Self::from(&v)
38 }
39}
40
41use super::{Operator, OperatorError, OperatorResult};
42use crate::execution::DataChunk;
43use crate::execution::chunk::DataChunkBuilder;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AggregateFunction {
48 Count,
50 CountNonNull,
52 Sum,
54 Avg,
56 Min,
58 Max,
60 First,
62 Last,
64 Collect,
66}
67
68#[derive(Debug, Clone)]
70pub struct AggregateExpr {
71 pub function: AggregateFunction,
73 pub column: Option<usize>,
75 pub distinct: bool,
77 pub alias: Option<String>,
79}
80
81impl AggregateExpr {
82 pub fn count_star() -> Self {
84 Self {
85 function: AggregateFunction::Count,
86 column: None,
87 distinct: false,
88 alias: None,
89 }
90 }
91
92 pub fn count(column: usize) -> Self {
94 Self {
95 function: AggregateFunction::CountNonNull,
96 column: Some(column),
97 distinct: false,
98 alias: None,
99 }
100 }
101
102 pub fn sum(column: usize) -> Self {
104 Self {
105 function: AggregateFunction::Sum,
106 column: Some(column),
107 distinct: false,
108 alias: None,
109 }
110 }
111
112 pub fn avg(column: usize) -> Self {
114 Self {
115 function: AggregateFunction::Avg,
116 column: Some(column),
117 distinct: false,
118 alias: None,
119 }
120 }
121
122 pub fn min(column: usize) -> Self {
124 Self {
125 function: AggregateFunction::Min,
126 column: Some(column),
127 distinct: false,
128 alias: None,
129 }
130 }
131
132 pub fn max(column: usize) -> Self {
134 Self {
135 function: AggregateFunction::Max,
136 column: Some(column),
137 distinct: false,
138 alias: None,
139 }
140 }
141
142 pub fn first(column: usize) -> Self {
144 Self {
145 function: AggregateFunction::First,
146 column: Some(column),
147 distinct: false,
148 alias: None,
149 }
150 }
151
152 pub fn last(column: usize) -> Self {
154 Self {
155 function: AggregateFunction::Last,
156 column: Some(column),
157 distinct: false,
158 alias: None,
159 }
160 }
161
162 pub fn collect(column: usize) -> Self {
164 Self {
165 function: AggregateFunction::Collect,
166 column: Some(column),
167 distinct: false,
168 alias: None,
169 }
170 }
171
172 pub fn with_distinct(mut self) -> Self {
174 self.distinct = true;
175 self
176 }
177
178 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
180 self.alias = Some(alias.into());
181 self
182 }
183}
184
185#[derive(Debug, Clone)]
187enum AggregateState {
188 Count(i64),
190 CountDistinct(i64, HashSet<HashableValue>),
192 SumInt(i64),
194 SumIntDistinct(i64, HashSet<HashableValue>),
196 SumFloat(f64),
198 SumFloatDistinct(f64, HashSet<HashableValue>),
200 Avg(f64, i64),
202 AvgDistinct(f64, i64, HashSet<HashableValue>),
204 Min(Option<Value>),
206 Max(Option<Value>),
208 First(Option<Value>),
210 Last(Option<Value>),
212 Collect(Vec<Value>),
214 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
216}
217
218impl AggregateState {
219 fn new(function: AggregateFunction, distinct: bool) -> Self {
221 match (function, distinct) {
222 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
223 AggregateState::Count(0)
224 }
225 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
226 AggregateState::CountDistinct(0, HashSet::new())
227 }
228 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
229 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
230 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
231 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
232 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
234 (AggregateFunction::First, _) => AggregateState::First(None),
235 (AggregateFunction::Last, _) => AggregateState::Last(None),
236 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
237 (AggregateFunction::Collect, true) => {
238 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
239 }
240 }
241 }
242
243 fn update(&mut self, value: Option<Value>) {
245 match self {
246 AggregateState::Count(count) => {
247 *count += 1;
248 }
249 AggregateState::CountDistinct(count, seen) => {
250 if let Some(ref v) = value {
251 let hashable = HashableValue::from(v);
252 if seen.insert(hashable) {
253 *count += 1;
254 }
255 }
256 }
257 AggregateState::SumInt(sum) => {
258 if let Some(Value::Int64(v)) = value {
259 *sum += v;
260 } else if let Some(Value::Float64(v)) = value {
261 *self = AggregateState::SumFloat(*sum as f64 + v);
263 }
264 }
265 AggregateState::SumIntDistinct(sum, seen) => {
266 if let Some(ref v) = value {
267 let hashable = HashableValue::from(v);
268 if seen.insert(hashable) {
269 if let Value::Int64(i) = v {
270 *sum += i;
271 } else if let Value::Float64(f) = v {
272 let seen_clone = seen.clone();
274 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, seen_clone);
275 }
276 }
277 }
278 }
279 AggregateState::SumFloat(sum) => {
280 if let Some(Value::Int64(v)) = value {
281 *sum += v as f64;
282 } else if let Some(Value::Float64(v)) = value {
283 *sum += v;
284 }
285 }
286 AggregateState::SumFloatDistinct(sum, seen) => {
287 if let Some(ref v) = value {
288 let hashable = HashableValue::from(v);
289 if seen.insert(hashable) {
290 if let Some(num) = value_to_f64(v) {
291 *sum += num;
292 }
293 }
294 }
295 }
296 AggregateState::Avg(sum, count) => {
297 if let Some(ref v) = value {
298 if let Some(num) = value_to_f64(v) {
299 *sum += num;
300 *count += 1;
301 }
302 }
303 }
304 AggregateState::AvgDistinct(sum, count, seen) => {
305 if let Some(ref v) = value {
306 let hashable = HashableValue::from(v);
307 if seen.insert(hashable) {
308 if let Some(num) = value_to_f64(v) {
309 *sum += num;
310 *count += 1;
311 }
312 }
313 }
314 }
315 AggregateState::Min(min) => {
316 if let Some(v) = value {
317 match min {
318 None => *min = Some(v),
319 Some(current) => {
320 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
321 *min = Some(v);
322 }
323 }
324 }
325 }
326 }
327 AggregateState::Max(max) => {
328 if let Some(v) = value {
329 match max {
330 None => *max = Some(v),
331 Some(current) => {
332 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
333 *max = Some(v);
334 }
335 }
336 }
337 }
338 }
339 AggregateState::First(first) => {
340 if first.is_none() {
341 *first = value;
342 }
343 }
344 AggregateState::Last(last) => {
345 if value.is_some() {
346 *last = value;
347 }
348 }
349 AggregateState::Collect(list) => {
350 if let Some(v) = value {
351 list.push(v);
352 }
353 }
354 AggregateState::CollectDistinct(list, seen) => {
355 if let Some(v) = value {
356 let hashable = HashableValue::from(&v);
357 if seen.insert(hashable) {
358 list.push(v);
359 }
360 }
361 }
362 }
363 }
364
365 fn finalize(&self) -> Value {
367 match self {
368 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
369 Value::Int64(*count)
370 }
371 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
372 Value::Int64(*sum)
373 }
374 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
375 Value::Float64(*sum)
376 }
377 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
378 if *count == 0 {
379 Value::Null
380 } else {
381 Value::Float64(*sum / *count as f64)
382 }
383 }
384 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
385 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
386 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
387 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
388 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
389 Value::List(list.clone().into())
390 }
391 }
392 }
393}
394
395fn value_to_f64(value: &Value) -> Option<f64> {
397 match value {
398 Value::Int64(i) => Some(*i as f64),
399 Value::Float64(f) => Some(*f),
400 _ => None,
401 }
402}
403
404fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
406 match (a, b) {
407 (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
408 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
409 (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
410 (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
411 (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
412 (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
413 _ => None,
414 }
415}
416
417#[derive(Debug, Clone, PartialEq, Eq, Hash)]
419pub struct GroupKey(Vec<GroupKeyPart>);
420
421#[derive(Debug, Clone, PartialEq, Eq, Hash)]
422enum GroupKeyPart {
423 Null,
424 Bool(bool),
425 Int64(i64),
426 String(String),
427}
428
429impl GroupKey {
430 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
432 let parts: Vec<GroupKeyPart> = group_columns
433 .iter()
434 .map(|&col_idx| {
435 chunk
436 .column(col_idx)
437 .and_then(|col| col.get_value(row))
438 .map(|v| match v {
439 Value::Null => GroupKeyPart::Null,
440 Value::Bool(b) => GroupKeyPart::Bool(b),
441 Value::Int64(i) => GroupKeyPart::Int64(i),
442 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
443 Value::String(s) => GroupKeyPart::String(s.to_string()),
444 _ => GroupKeyPart::String(format!("{v:?}")),
445 })
446 .unwrap_or(GroupKeyPart::Null)
447 })
448 .collect();
449 GroupKey(parts)
450 }
451
452 fn to_values(&self) -> Vec<Value> {
454 self.0
455 .iter()
456 .map(|part| match part {
457 GroupKeyPart::Null => Value::Null,
458 GroupKeyPart::Bool(b) => Value::Bool(*b),
459 GroupKeyPart::Int64(i) => Value::Int64(*i),
460 GroupKeyPart::String(s) => Value::String(s.clone().into()),
461 })
462 .collect()
463 }
464}
465
466pub struct HashAggregateOperator {
470 child: Box<dyn Operator>,
472 group_columns: Vec<usize>,
474 aggregates: Vec<AggregateExpr>,
476 output_schema: Vec<LogicalType>,
478 groups: HashMap<GroupKey, Vec<AggregateState>>,
480 aggregation_complete: bool,
482 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
484}
485
486impl HashAggregateOperator {
487 pub fn new(
495 child: Box<dyn Operator>,
496 group_columns: Vec<usize>,
497 aggregates: Vec<AggregateExpr>,
498 output_schema: Vec<LogicalType>,
499 ) -> Self {
500 Self {
501 child,
502 group_columns,
503 aggregates,
504 output_schema,
505 groups: HashMap::new(),
506 aggregation_complete: false,
507 results: None,
508 }
509 }
510
511 fn aggregate(&mut self) -> Result<(), OperatorError> {
513 while let Some(chunk) = self.child.next()? {
514 for row in chunk.selected_indices() {
515 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
516
517 let states = self.groups.entry(key).or_insert_with(|| {
519 self.aggregates
520 .iter()
521 .map(|agg| AggregateState::new(agg.function, agg.distinct))
522 .collect()
523 });
524
525 for (i, agg) in self.aggregates.iter().enumerate() {
527 let value = match agg.function {
528 AggregateFunction::Count => None, _ => agg
530 .column
531 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
532 };
533
534 match agg.function {
536 AggregateFunction::Count => states[i].update(None),
537 AggregateFunction::CountNonNull => {
538 if value.is_some() && !matches!(value, Some(Value::Null)) {
539 states[i].update(value);
540 }
541 }
542 _ => {
543 if value.is_some() && !matches!(value, Some(Value::Null)) {
544 states[i].update(value);
545 }
546 }
547 }
548 }
549 }
550 }
551
552 self.aggregation_complete = true;
553
554 let results: Vec<_> = self.groups.drain().collect();
556 self.results = Some(results.into_iter());
557
558 Ok(())
559 }
560}
561
562impl Operator for HashAggregateOperator {
563 fn next(&mut self) -> OperatorResult {
564 if !self.aggregation_complete {
566 self.aggregate()?;
567 }
568
569 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
571 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
573
574 for agg in &self.aggregates {
575 let state = AggregateState::new(agg.function, agg.distinct);
576 let value = state.finalize();
577 if let Some(col) = builder.column_mut(self.group_columns.len()) {
578 col.push_value(value);
579 }
580 }
581 builder.advance_row();
582
583 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
585 }
586
587 let results = match &mut self.results {
588 Some(r) => r,
589 None => return Ok(None),
590 };
591
592 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
593
594 for (key, states) in results.by_ref() {
595 let key_values = key.to_values();
597 for (i, value) in key_values.into_iter().enumerate() {
598 if let Some(col) = builder.column_mut(i) {
599 col.push_value(value);
600 }
601 }
602
603 for (i, state) in states.iter().enumerate() {
605 let col_idx = self.group_columns.len() + i;
606 if let Some(col) = builder.column_mut(col_idx) {
607 col.push_value(state.finalize());
608 }
609 }
610
611 builder.advance_row();
612
613 if builder.is_full() {
614 return Ok(Some(builder.finish()));
615 }
616 }
617
618 if builder.row_count() > 0 {
619 Ok(Some(builder.finish()))
620 } else {
621 Ok(None)
622 }
623 }
624
625 fn reset(&mut self) {
626 self.child.reset();
627 self.groups.clear();
628 self.aggregation_complete = false;
629 self.results = None;
630 }
631
632 fn name(&self) -> &'static str {
633 "HashAggregate"
634 }
635}
636
637pub struct SimpleAggregateOperator {
641 child: Box<dyn Operator>,
643 aggregates: Vec<AggregateExpr>,
645 output_schema: Vec<LogicalType>,
647 states: Vec<AggregateState>,
649 done: bool,
651}
652
653impl SimpleAggregateOperator {
654 pub fn new(
656 child: Box<dyn Operator>,
657 aggregates: Vec<AggregateExpr>,
658 output_schema: Vec<LogicalType>,
659 ) -> Self {
660 let states = aggregates
661 .iter()
662 .map(|agg| AggregateState::new(agg.function, agg.distinct))
663 .collect();
664
665 Self {
666 child,
667 aggregates,
668 output_schema,
669 states,
670 done: false,
671 }
672 }
673}
674
675impl Operator for SimpleAggregateOperator {
676 fn next(&mut self) -> OperatorResult {
677 if self.done {
678 return Ok(None);
679 }
680
681 while let Some(chunk) = self.child.next()? {
683 for row in chunk.selected_indices() {
684 for (i, agg) in self.aggregates.iter().enumerate() {
685 let value = match agg.function {
686 AggregateFunction::Count => None,
687 _ => agg
688 .column
689 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
690 };
691
692 match agg.function {
693 AggregateFunction::Count => self.states[i].update(None),
694 AggregateFunction::CountNonNull => {
695 if value.is_some() && !matches!(value, Some(Value::Null)) {
696 self.states[i].update(value);
697 }
698 }
699 _ => {
700 if value.is_some() && !matches!(value, Some(Value::Null)) {
701 self.states[i].update(value);
702 }
703 }
704 }
705 }
706 }
707 }
708
709 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
711
712 for (i, state) in self.states.iter().enumerate() {
713 if let Some(col) = builder.column_mut(i) {
714 col.push_value(state.finalize());
715 }
716 }
717 builder.advance_row();
718
719 self.done = true;
720 Ok(Some(builder.finish()))
721 }
722
723 fn reset(&mut self) {
724 self.child.reset();
725 self.states = self
726 .aggregates
727 .iter()
728 .map(|agg| AggregateState::new(agg.function, agg.distinct))
729 .collect();
730 self.done = false;
731 }
732
733 fn name(&self) -> &'static str {
734 "SimpleAggregate"
735 }
736}
737
738#[cfg(test)]
739mod tests {
740 use super::*;
741 use crate::execution::chunk::DataChunkBuilder;
742
743 struct MockOperator {
744 chunks: Vec<DataChunk>,
745 position: usize,
746 }
747
748 impl MockOperator {
749 fn new(chunks: Vec<DataChunk>) -> Self {
750 Self {
751 chunks,
752 position: 0,
753 }
754 }
755 }
756
757 impl Operator for MockOperator {
758 fn next(&mut self) -> OperatorResult {
759 if self.position < self.chunks.len() {
760 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
761 self.position += 1;
762 Ok(Some(chunk))
763 } else {
764 Ok(None)
765 }
766 }
767
768 fn reset(&mut self) {
769 self.position = 0;
770 }
771
772 fn name(&self) -> &'static str {
773 "Mock"
774 }
775 }
776
777 fn create_test_chunk() -> DataChunk {
778 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
780
781 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
782 for (group, value) in data {
783 builder.column_mut(0).unwrap().push_int64(group);
784 builder.column_mut(1).unwrap().push_int64(value);
785 builder.advance_row();
786 }
787
788 builder.finish()
789 }
790
791 #[test]
792 fn test_simple_count() {
793 let mock = MockOperator::new(vec![create_test_chunk()]);
794
795 let mut agg = SimpleAggregateOperator::new(
796 Box::new(mock),
797 vec![AggregateExpr::count_star()],
798 vec![LogicalType::Int64],
799 );
800
801 let result = agg.next().unwrap().unwrap();
802 assert_eq!(result.row_count(), 1);
803 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
804
805 assert!(agg.next().unwrap().is_none());
807 }
808
809 #[test]
810 fn test_simple_sum() {
811 let mock = MockOperator::new(vec![create_test_chunk()]);
812
813 let mut agg = SimpleAggregateOperator::new(
814 Box::new(mock),
815 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
817 );
818
819 let result = agg.next().unwrap().unwrap();
820 assert_eq!(result.row_count(), 1);
821 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
823 }
824
825 #[test]
826 fn test_simple_avg() {
827 let mock = MockOperator::new(vec![create_test_chunk()]);
828
829 let mut agg = SimpleAggregateOperator::new(
830 Box::new(mock),
831 vec![AggregateExpr::avg(1)],
832 vec![LogicalType::Float64],
833 );
834
835 let result = agg.next().unwrap().unwrap();
836 assert_eq!(result.row_count(), 1);
837 let avg = result.column(0).unwrap().get_float64(0).unwrap();
839 assert!((avg - 30.0).abs() < 0.001);
840 }
841
842 #[test]
843 fn test_simple_min_max() {
844 let mock = MockOperator::new(vec![create_test_chunk()]);
845
846 let mut agg = SimpleAggregateOperator::new(
847 Box::new(mock),
848 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
849 vec![LogicalType::Int64, LogicalType::Int64],
850 );
851
852 let result = agg.next().unwrap().unwrap();
853 assert_eq!(result.row_count(), 1);
854 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
857
858 #[test]
859 fn test_grouped_aggregation() {
860 let mock = MockOperator::new(vec![create_test_chunk()]);
861
862 let mut agg = HashAggregateOperator::new(
864 Box::new(mock),
865 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
868 );
869
870 let mut results: Vec<(i64, i64)> = Vec::new();
871 while let Some(chunk) = agg.next().unwrap() {
872 for row in chunk.selected_indices() {
873 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
874 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
875 results.push((group, sum));
876 }
877 }
878
879 results.sort_by_key(|(g, _)| *g);
880 assert_eq!(results.len(), 2);
881 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
884
885 #[test]
886 fn test_grouped_count() {
887 let mock = MockOperator::new(vec![create_test_chunk()]);
888
889 let mut agg = HashAggregateOperator::new(
891 Box::new(mock),
892 vec![0],
893 vec![AggregateExpr::count_star()],
894 vec![LogicalType::Int64, LogicalType::Int64],
895 );
896
897 let mut results: Vec<(i64, i64)> = Vec::new();
898 while let Some(chunk) = agg.next().unwrap() {
899 for row in chunk.selected_indices() {
900 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
901 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
902 results.push((group, count));
903 }
904 }
905
906 results.sort_by_key(|(g, _)| *g);
907 assert_eq!(results.len(), 2);
908 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
911
912 #[test]
913 fn test_multiple_aggregates() {
914 let mock = MockOperator::new(vec![create_test_chunk()]);
915
916 let mut agg = HashAggregateOperator::new(
918 Box::new(mock),
919 vec![0],
920 vec![
921 AggregateExpr::count_star(),
922 AggregateExpr::sum(1),
923 AggregateExpr::avg(1),
924 ],
925 vec![
926 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
931 );
932
933 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
934 while let Some(chunk) = agg.next().unwrap() {
935 for row in chunk.selected_indices() {
936 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
937 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
938 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
939 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
940 results.push((group, count, sum, avg));
941 }
942 }
943
944 results.sort_by_key(|(g, _, _, _)| *g);
945 assert_eq!(results.len(), 2);
946
947 assert_eq!(results[0].0, 1);
949 assert_eq!(results[0].1, 2);
950 assert_eq!(results[0].2, 30);
951 assert!((results[0].3 - 15.0).abs() < 0.001);
952
953 assert_eq!(results[1].0, 2);
955 assert_eq!(results[1].1, 3);
956 assert_eq!(results[1].2, 120);
957 assert!((results[1].3 - 40.0).abs() < 0.001);
958 }
959
960 fn create_test_chunk_with_duplicates() -> DataChunk {
961 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
966
967 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
968 for (group, value) in data {
969 builder.column_mut(0).unwrap().push_int64(group);
970 builder.column_mut(1).unwrap().push_int64(value);
971 builder.advance_row();
972 }
973
974 builder.finish()
975 }
976
977 #[test]
978 fn test_count_distinct() {
979 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
980
981 let mut agg = SimpleAggregateOperator::new(
983 Box::new(mock),
984 vec![AggregateExpr::count(1).with_distinct()],
985 vec![LogicalType::Int64],
986 );
987
988 let result = agg.next().unwrap().unwrap();
989 assert_eq!(result.row_count(), 1);
990 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
992 }
993
994 #[test]
995 fn test_grouped_count_distinct() {
996 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
997
998 let mut agg = HashAggregateOperator::new(
1000 Box::new(mock),
1001 vec![0],
1002 vec![AggregateExpr::count(1).with_distinct()],
1003 vec![LogicalType::Int64, LogicalType::Int64],
1004 );
1005
1006 let mut results: Vec<(i64, i64)> = Vec::new();
1007 while let Some(chunk) = agg.next().unwrap() {
1008 for row in chunk.selected_indices() {
1009 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1010 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1011 results.push((group, count));
1012 }
1013 }
1014
1015 results.sort_by_key(|(g, _)| *g);
1016 assert_eq!(results.len(), 2);
1017 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1020
1021 #[test]
1022 fn test_sum_distinct() {
1023 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1024
1025 let mut agg = SimpleAggregateOperator::new(
1027 Box::new(mock),
1028 vec![AggregateExpr::sum(1).with_distinct()],
1029 vec![LogicalType::Int64],
1030 );
1031
1032 let result = agg.next().unwrap().unwrap();
1033 assert_eq!(result.row_count(), 1);
1034 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1036 }
1037
1038 #[test]
1039 fn test_avg_distinct() {
1040 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1041
1042 let mut agg = SimpleAggregateOperator::new(
1044 Box::new(mock),
1045 vec![AggregateExpr::avg(1).with_distinct()],
1046 vec![LogicalType::Float64],
1047 );
1048
1049 let result = agg.next().unwrap().unwrap();
1050 assert_eq!(result.row_count(), 1);
1051 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1053 assert!((avg - 20.0).abs() < 0.001);
1054 }
1055}