1use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20#[derive(Debug, Clone)]
22enum AggregateState {
23 Count(i64),
25 CountDistinct(i64, HashSet<HashableValue>),
27 SumInt(i64),
29 SumIntDistinct(i64, HashSet<HashableValue>),
31 SumFloat(f64),
33 SumFloatDistinct(f64, HashSet<HashableValue>),
35 Avg(f64, i64),
37 AvgDistinct(f64, i64, HashSet<HashableValue>),
39 Min(Option<Value>),
41 Max(Option<Value>),
43 First(Option<Value>),
45 Last(Option<Value>),
47 Collect(Vec<Value>),
49 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51 StdDev { count: i64, mean: f64, m2: f64 },
53 StdDevPop { count: i64, mean: f64, m2: f64 },
55 PercentileDisc { values: Vec<f64>, percentile: f64 },
57 PercentileCont { values: Vec<f64>, percentile: f64 },
59}
60
61impl AggregateState {
62 fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
64 match (function, distinct) {
65 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
66 AggregateState::Count(0)
67 }
68 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
69 AggregateState::CountDistinct(0, HashSet::new())
70 }
71 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
72 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
73 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
74 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
75 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
77 (AggregateFunction::First, _) => AggregateState::First(None),
78 (AggregateFunction::Last, _) => AggregateState::Last(None),
79 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
80 (AggregateFunction::Collect, true) => {
81 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
82 }
83 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
85 count: 0,
86 mean: 0.0,
87 m2: 0.0,
88 },
89 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
90 count: 0,
91 mean: 0.0,
92 m2: 0.0,
93 },
94 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
95 values: Vec::new(),
96 percentile: percentile.unwrap_or(0.5),
97 },
98 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
99 values: Vec::new(),
100 percentile: percentile.unwrap_or(0.5),
101 },
102 }
103 }
104
105 fn update(&mut self, value: Option<Value>) {
107 match self {
108 AggregateState::Count(count) => {
109 *count += 1;
110 }
111 AggregateState::CountDistinct(count, seen) => {
112 if let Some(ref v) = value {
113 let hashable = HashableValue::from(v);
114 if seen.insert(hashable) {
115 *count += 1;
116 }
117 }
118 }
119 AggregateState::SumInt(sum) => {
120 if let Some(Value::Int64(v)) = value {
121 *sum += v;
122 } else if let Some(Value::Float64(v)) = value {
123 *self = AggregateState::SumFloat(*sum as f64 + v);
125 } else if let Some(ref v) = value {
126 if let Some(num) = value_to_f64(v) {
128 *self = AggregateState::SumFloat(*sum as f64 + num);
129 }
130 }
131 }
132 AggregateState::SumIntDistinct(sum, seen) => {
133 if let Some(ref v) = value {
134 let hashable = HashableValue::from(v);
135 if seen.insert(hashable) {
136 if let Value::Int64(i) = v {
137 *sum += i;
138 } else if let Value::Float64(f) = v {
139 let moved_seen = std::mem::take(seen);
141 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
142 } else if let Some(num) = value_to_f64(v) {
143 let moved_seen = std::mem::take(seen);
145 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
146 }
147 }
148 }
149 }
150 AggregateState::SumFloat(sum) => {
151 if let Some(ref v) = value {
152 if let Some(num) = value_to_f64(v) {
154 *sum += num;
155 }
156 }
157 }
158 AggregateState::SumFloatDistinct(sum, seen) => {
159 if let Some(ref v) = value {
160 let hashable = HashableValue::from(v);
161 if seen.insert(hashable)
162 && let Some(num) = value_to_f64(v)
163 {
164 *sum += num;
165 }
166 }
167 }
168 AggregateState::Avg(sum, count) => {
169 if let Some(ref v) = value
170 && let Some(num) = value_to_f64(v)
171 {
172 *sum += num;
173 *count += 1;
174 }
175 }
176 AggregateState::AvgDistinct(sum, count, seen) => {
177 if let Some(ref v) = value {
178 let hashable = HashableValue::from(v);
179 if seen.insert(hashable)
180 && let Some(num) = value_to_f64(v)
181 {
182 *sum += num;
183 *count += 1;
184 }
185 }
186 }
187 AggregateState::Min(min) => {
188 if let Some(v) = value {
189 match min {
190 None => *min = Some(v),
191 Some(current) => {
192 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
193 *min = Some(v);
194 }
195 }
196 }
197 }
198 }
199 AggregateState::Max(max) => {
200 if let Some(v) = value {
201 match max {
202 None => *max = Some(v),
203 Some(current) => {
204 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
205 *max = Some(v);
206 }
207 }
208 }
209 }
210 }
211 AggregateState::First(first) => {
212 if first.is_none() {
213 *first = value;
214 }
215 }
216 AggregateState::Last(last) => {
217 if value.is_some() {
218 *last = value;
219 }
220 }
221 AggregateState::Collect(list) => {
222 if let Some(v) = value {
223 list.push(v);
224 }
225 }
226 AggregateState::CollectDistinct(list, seen) => {
227 if let Some(v) = value {
228 let hashable = HashableValue::from(&v);
229 if seen.insert(hashable) {
230 list.push(v);
231 }
232 }
233 }
234 AggregateState::StdDev { count, mean, m2 }
236 | AggregateState::StdDevPop { count, mean, m2 } => {
237 if let Some(ref v) = value
238 && let Some(x) = value_to_f64(v)
239 {
240 *count += 1;
241 let delta = x - *mean;
242 *mean += delta / *count as f64;
243 let delta2 = x - *mean;
244 *m2 += delta * delta2;
245 }
246 }
247 AggregateState::PercentileDisc { values, .. }
248 | AggregateState::PercentileCont { values, .. } => {
249 if let Some(ref v) = value
250 && let Some(x) = value_to_f64(v)
251 {
252 values.push(x);
253 }
254 }
255 }
256 }
257
258 fn finalize(&self) -> Value {
260 match self {
261 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
262 Value::Int64(*count)
263 }
264 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
265 Value::Int64(*sum)
266 }
267 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
268 Value::Float64(*sum)
269 }
270 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
271 if *count == 0 {
272 Value::Null
273 } else {
274 Value::Float64(*sum / *count as f64)
275 }
276 }
277 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
278 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
279 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
280 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
281 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
282 Value::List(list.clone().into())
283 }
284 AggregateState::StdDev { count, m2, .. } => {
286 if *count < 2 {
287 Value::Null
288 } else {
289 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
290 }
291 }
292 AggregateState::StdDevPop { count, m2, .. } => {
294 if *count == 0 {
295 Value::Null
296 } else {
297 Value::Float64((*m2 / *count as f64).sqrt())
298 }
299 }
300 AggregateState::PercentileDisc { values, percentile } => {
302 if values.is_empty() {
303 Value::Null
304 } else {
305 let mut sorted = values.clone();
306 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
307 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
309 Value::Float64(sorted[index])
310 }
311 }
312 AggregateState::PercentileCont { values, percentile } => {
314 if values.is_empty() {
315 Value::Null
316 } else {
317 let mut sorted = values.clone();
318 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
319 let rank = percentile * (sorted.len() - 1) as f64;
321 let lower_idx = rank.floor() as usize;
322 let upper_idx = rank.ceil() as usize;
323 if lower_idx == upper_idx {
324 Value::Float64(sorted[lower_idx])
325 } else {
326 let fraction = rank - lower_idx as f64;
327 let result =
328 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
329 Value::Float64(result)
330 }
331 }
332 }
333 }
334 }
335}
336
337use super::value_utils::{compare_values, value_to_f64};
338
339#[derive(Debug, Clone, PartialEq, Eq, Hash)]
341pub struct GroupKey(Vec<GroupKeyPart>);
342
343#[derive(Debug, Clone, PartialEq, Eq, Hash)]
344enum GroupKeyPart {
345 Null,
346 Bool(bool),
347 Int64(i64),
348 String(String),
349}
350
351impl GroupKey {
352 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
354 let parts: Vec<GroupKeyPart> = group_columns
355 .iter()
356 .map(|&col_idx| {
357 chunk
358 .column(col_idx)
359 .and_then(|col| col.get_value(row))
360 .map_or(GroupKeyPart::Null, |v| match v {
361 Value::Null => GroupKeyPart::Null,
362 Value::Bool(b) => GroupKeyPart::Bool(b),
363 Value::Int64(i) => GroupKeyPart::Int64(i),
364 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
365 Value::String(s) => GroupKeyPart::String(s.to_string()),
366 _ => GroupKeyPart::String(format!("{v:?}")),
367 })
368 })
369 .collect();
370 GroupKey(parts)
371 }
372
373 fn to_values(&self) -> Vec<Value> {
375 self.0
376 .iter()
377 .map(|part| match part {
378 GroupKeyPart::Null => Value::Null,
379 GroupKeyPart::Bool(b) => Value::Bool(*b),
380 GroupKeyPart::Int64(i) => Value::Int64(*i),
381 GroupKeyPart::String(s) => Value::String(s.clone().into()),
382 })
383 .collect()
384 }
385}
386
387pub struct HashAggregateOperator {
391 child: Box<dyn Operator>,
393 group_columns: Vec<usize>,
395 aggregates: Vec<AggregateExpr>,
397 output_schema: Vec<LogicalType>,
399 groups: IndexMap<GroupKey, Vec<AggregateState>>,
401 aggregation_complete: bool,
403 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
405}
406
407impl HashAggregateOperator {
408 pub fn new(
416 child: Box<dyn Operator>,
417 group_columns: Vec<usize>,
418 aggregates: Vec<AggregateExpr>,
419 output_schema: Vec<LogicalType>,
420 ) -> Self {
421 Self {
422 child,
423 group_columns,
424 aggregates,
425 output_schema,
426 groups: IndexMap::new(),
427 aggregation_complete: false,
428 results: None,
429 }
430 }
431
432 fn aggregate(&mut self) -> Result<(), OperatorError> {
434 while let Some(chunk) = self.child.next()? {
435 for row in chunk.selected_indices() {
436 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
437
438 let states = self.groups.entry(key).or_insert_with(|| {
440 self.aggregates
441 .iter()
442 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
443 .collect()
444 });
445
446 for (i, agg) in self.aggregates.iter().enumerate() {
448 let value = match (agg.function, agg.distinct) {
449 (AggregateFunction::Count, false) => None,
451 (AggregateFunction::Count, true) => agg
453 .column
454 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
455 _ => agg
456 .column
457 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
458 };
459
460 match (agg.function, agg.distinct) {
462 (AggregateFunction::Count, false) => states[i].update(None),
463 (AggregateFunction::Count, true) => {
464 if value.is_some() && !matches!(value, Some(Value::Null)) {
466 states[i].update(value);
467 }
468 }
469 (AggregateFunction::CountNonNull, _) => {
470 if value.is_some() && !matches!(value, Some(Value::Null)) {
471 states[i].update(value);
472 }
473 }
474 _ => {
475 if value.is_some() && !matches!(value, Some(Value::Null)) {
476 states[i].update(value);
477 }
478 }
479 }
480 }
481 }
482 }
483
484 self.aggregation_complete = true;
485
486 let results: Vec<_> = self.groups.drain(..).collect();
488 self.results = Some(results.into_iter());
489
490 Ok(())
491 }
492}
493
494impl Operator for HashAggregateOperator {
495 fn next(&mut self) -> OperatorResult {
496 if !self.aggregation_complete {
498 self.aggregate()?;
499 }
500
501 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
503 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
505
506 for agg in &self.aggregates {
507 let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
508 let value = state.finalize();
509 if let Some(col) = builder.column_mut(self.group_columns.len()) {
510 col.push_value(value);
511 }
512 }
513 builder.advance_row();
514
515 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
517 }
518
519 let Some(results) = &mut self.results else {
520 return Ok(None);
521 };
522
523 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
524
525 for (key, states) in results.by_ref() {
526 let key_values = key.to_values();
528 for (i, value) in key_values.into_iter().enumerate() {
529 if let Some(col) = builder.column_mut(i) {
530 col.push_value(value);
531 }
532 }
533
534 for (i, state) in states.iter().enumerate() {
536 let col_idx = self.group_columns.len() + i;
537 if let Some(col) = builder.column_mut(col_idx) {
538 col.push_value(state.finalize());
539 }
540 }
541
542 builder.advance_row();
543
544 if builder.is_full() {
545 return Ok(Some(builder.finish()));
546 }
547 }
548
549 if builder.row_count() > 0 {
550 Ok(Some(builder.finish()))
551 } else {
552 Ok(None)
553 }
554 }
555
556 fn reset(&mut self) {
557 self.child.reset();
558 self.groups.clear();
559 self.aggregation_complete = false;
560 self.results = None;
561 }
562
563 fn name(&self) -> &'static str {
564 "HashAggregate"
565 }
566}
567
568pub struct SimpleAggregateOperator {
572 child: Box<dyn Operator>,
574 aggregates: Vec<AggregateExpr>,
576 output_schema: Vec<LogicalType>,
578 states: Vec<AggregateState>,
580 done: bool,
582}
583
584impl SimpleAggregateOperator {
585 pub fn new(
587 child: Box<dyn Operator>,
588 aggregates: Vec<AggregateExpr>,
589 output_schema: Vec<LogicalType>,
590 ) -> Self {
591 let states = aggregates
592 .iter()
593 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
594 .collect();
595
596 Self {
597 child,
598 aggregates,
599 output_schema,
600 states,
601 done: false,
602 }
603 }
604}
605
606impl Operator for SimpleAggregateOperator {
607 fn next(&mut self) -> OperatorResult {
608 if self.done {
609 return Ok(None);
610 }
611
612 while let Some(chunk) = self.child.next()? {
614 for row in chunk.selected_indices() {
615 for (i, agg) in self.aggregates.iter().enumerate() {
616 let value = match (agg.function, agg.distinct) {
617 (AggregateFunction::Count, false) => None,
619 (AggregateFunction::Count, true) => agg
621 .column
622 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
623 _ => agg
624 .column
625 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
626 };
627
628 match (agg.function, agg.distinct) {
629 (AggregateFunction::Count, false) => self.states[i].update(None),
630 (AggregateFunction::Count, true) => {
631 if value.is_some() && !matches!(value, Some(Value::Null)) {
633 self.states[i].update(value);
634 }
635 }
636 (AggregateFunction::CountNonNull, _) => {
637 if value.is_some() && !matches!(value, Some(Value::Null)) {
638 self.states[i].update(value);
639 }
640 }
641 _ => {
642 if value.is_some() && !matches!(value, Some(Value::Null)) {
643 self.states[i].update(value);
644 }
645 }
646 }
647 }
648 }
649 }
650
651 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
653
654 for (i, state) in self.states.iter().enumerate() {
655 if let Some(col) = builder.column_mut(i) {
656 col.push_value(state.finalize());
657 }
658 }
659 builder.advance_row();
660
661 self.done = true;
662 Ok(Some(builder.finish()))
663 }
664
665 fn reset(&mut self) {
666 self.child.reset();
667 self.states = self
668 .aggregates
669 .iter()
670 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
671 .collect();
672 self.done = false;
673 }
674
675 fn name(&self) -> &'static str {
676 "SimpleAggregate"
677 }
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683 use crate::execution::chunk::DataChunkBuilder;
684
685 struct MockOperator {
686 chunks: Vec<DataChunk>,
687 position: usize,
688 }
689
690 impl MockOperator {
691 fn new(chunks: Vec<DataChunk>) -> Self {
692 Self {
693 chunks,
694 position: 0,
695 }
696 }
697 }
698
699 impl Operator for MockOperator {
700 fn next(&mut self) -> OperatorResult {
701 if self.position < self.chunks.len() {
702 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
703 self.position += 1;
704 Ok(Some(chunk))
705 } else {
706 Ok(None)
707 }
708 }
709
710 fn reset(&mut self) {
711 self.position = 0;
712 }
713
714 fn name(&self) -> &'static str {
715 "Mock"
716 }
717 }
718
719 fn create_test_chunk() -> DataChunk {
720 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
722
723 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
724 for (group, value) in data {
725 builder.column_mut(0).unwrap().push_int64(group);
726 builder.column_mut(1).unwrap().push_int64(value);
727 builder.advance_row();
728 }
729
730 builder.finish()
731 }
732
733 #[test]
734 fn test_simple_count() {
735 let mock = MockOperator::new(vec![create_test_chunk()]);
736
737 let mut agg = SimpleAggregateOperator::new(
738 Box::new(mock),
739 vec![AggregateExpr::count_star()],
740 vec![LogicalType::Int64],
741 );
742
743 let result = agg.next().unwrap().unwrap();
744 assert_eq!(result.row_count(), 1);
745 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
746
747 assert!(agg.next().unwrap().is_none());
749 }
750
751 #[test]
752 fn test_simple_sum() {
753 let mock = MockOperator::new(vec![create_test_chunk()]);
754
755 let mut agg = SimpleAggregateOperator::new(
756 Box::new(mock),
757 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
759 );
760
761 let result = agg.next().unwrap().unwrap();
762 assert_eq!(result.row_count(), 1);
763 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
765 }
766
767 #[test]
768 fn test_simple_avg() {
769 let mock = MockOperator::new(vec![create_test_chunk()]);
770
771 let mut agg = SimpleAggregateOperator::new(
772 Box::new(mock),
773 vec![AggregateExpr::avg(1)],
774 vec![LogicalType::Float64],
775 );
776
777 let result = agg.next().unwrap().unwrap();
778 assert_eq!(result.row_count(), 1);
779 let avg = result.column(0).unwrap().get_float64(0).unwrap();
781 assert!((avg - 30.0).abs() < 0.001);
782 }
783
784 #[test]
785 fn test_simple_min_max() {
786 let mock = MockOperator::new(vec![create_test_chunk()]);
787
788 let mut agg = SimpleAggregateOperator::new(
789 Box::new(mock),
790 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
791 vec![LogicalType::Int64, LogicalType::Int64],
792 );
793
794 let result = agg.next().unwrap().unwrap();
795 assert_eq!(result.row_count(), 1);
796 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
799
800 #[test]
801 fn test_sum_with_string_values() {
802 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
804 builder.column_mut(0).unwrap().push_string("30");
805 builder.advance_row();
806 builder.column_mut(0).unwrap().push_string("25");
807 builder.advance_row();
808 builder.column_mut(0).unwrap().push_string("35");
809 builder.advance_row();
810 let chunk = builder.finish();
811
812 let mock = MockOperator::new(vec![chunk]);
813 let mut agg = SimpleAggregateOperator::new(
814 Box::new(mock),
815 vec![AggregateExpr::sum(0)],
816 vec![LogicalType::Float64],
817 );
818
819 let result = agg.next().unwrap().unwrap();
820 assert_eq!(result.row_count(), 1);
821 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
823 assert!(
824 (sum_val - 90.0).abs() < 0.001,
825 "Expected 90.0, got {}",
826 sum_val
827 );
828 }
829
830 #[test]
831 fn test_grouped_aggregation() {
832 let mock = MockOperator::new(vec![create_test_chunk()]);
833
834 let mut agg = HashAggregateOperator::new(
836 Box::new(mock),
837 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
840 );
841
842 let mut results: Vec<(i64, i64)> = Vec::new();
843 while let Some(chunk) = agg.next().unwrap() {
844 for row in chunk.selected_indices() {
845 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
846 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
847 results.push((group, sum));
848 }
849 }
850
851 results.sort_by_key(|(g, _)| *g);
852 assert_eq!(results.len(), 2);
853 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
856
857 #[test]
858 fn test_grouped_count() {
859 let mock = MockOperator::new(vec![create_test_chunk()]);
860
861 let mut agg = HashAggregateOperator::new(
863 Box::new(mock),
864 vec![0],
865 vec![AggregateExpr::count_star()],
866 vec![LogicalType::Int64, LogicalType::Int64],
867 );
868
869 let mut results: Vec<(i64, i64)> = Vec::new();
870 while let Some(chunk) = agg.next().unwrap() {
871 for row in chunk.selected_indices() {
872 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
873 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
874 results.push((group, count));
875 }
876 }
877
878 results.sort_by_key(|(g, _)| *g);
879 assert_eq!(results.len(), 2);
880 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
883
884 #[test]
885 fn test_multiple_aggregates() {
886 let mock = MockOperator::new(vec![create_test_chunk()]);
887
888 let mut agg = HashAggregateOperator::new(
890 Box::new(mock),
891 vec![0],
892 vec![
893 AggregateExpr::count_star(),
894 AggregateExpr::sum(1),
895 AggregateExpr::avg(1),
896 ],
897 vec![
898 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
903 );
904
905 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
906 while let Some(chunk) = agg.next().unwrap() {
907 for row in chunk.selected_indices() {
908 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
909 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
910 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
911 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
912 results.push((group, count, sum, avg));
913 }
914 }
915
916 results.sort_by_key(|(g, _, _, _)| *g);
917 assert_eq!(results.len(), 2);
918
919 assert_eq!(results[0].0, 1);
921 assert_eq!(results[0].1, 2);
922 assert_eq!(results[0].2, 30);
923 assert!((results[0].3 - 15.0).abs() < 0.001);
924
925 assert_eq!(results[1].0, 2);
927 assert_eq!(results[1].1, 3);
928 assert_eq!(results[1].2, 120);
929 assert!((results[1].3 - 40.0).abs() < 0.001);
930 }
931
932 fn create_test_chunk_with_duplicates() -> DataChunk {
933 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
938
939 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
940 for (group, value) in data {
941 builder.column_mut(0).unwrap().push_int64(group);
942 builder.column_mut(1).unwrap().push_int64(value);
943 builder.advance_row();
944 }
945
946 builder.finish()
947 }
948
949 #[test]
950 fn test_count_distinct() {
951 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
952
953 let mut agg = SimpleAggregateOperator::new(
955 Box::new(mock),
956 vec![AggregateExpr::count(1).with_distinct()],
957 vec![LogicalType::Int64],
958 );
959
960 let result = agg.next().unwrap().unwrap();
961 assert_eq!(result.row_count(), 1);
962 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
964 }
965
966 #[test]
967 fn test_grouped_count_distinct() {
968 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
969
970 let mut agg = HashAggregateOperator::new(
972 Box::new(mock),
973 vec![0],
974 vec![AggregateExpr::count(1).with_distinct()],
975 vec![LogicalType::Int64, LogicalType::Int64],
976 );
977
978 let mut results: Vec<(i64, i64)> = Vec::new();
979 while let Some(chunk) = agg.next().unwrap() {
980 for row in chunk.selected_indices() {
981 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
982 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
983 results.push((group, count));
984 }
985 }
986
987 results.sort_by_key(|(g, _)| *g);
988 assert_eq!(results.len(), 2);
989 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
992
993 #[test]
994 fn test_sum_distinct() {
995 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
996
997 let mut agg = SimpleAggregateOperator::new(
999 Box::new(mock),
1000 vec![AggregateExpr::sum(1).with_distinct()],
1001 vec![LogicalType::Int64],
1002 );
1003
1004 let result = agg.next().unwrap().unwrap();
1005 assert_eq!(result.row_count(), 1);
1006 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1008 }
1009
1010 #[test]
1011 fn test_avg_distinct() {
1012 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1013
1014 let mut agg = SimpleAggregateOperator::new(
1016 Box::new(mock),
1017 vec![AggregateExpr::avg(1).with_distinct()],
1018 vec![LogicalType::Float64],
1019 );
1020
1021 let result = agg.next().unwrap().unwrap();
1022 assert_eq!(result.row_count(), 1);
1023 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1025 assert!((avg - 20.0).abs() < 0.001);
1026 }
1027
1028 fn create_statistical_test_chunk() -> DataChunk {
1029 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1032
1033 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1034 builder.column_mut(0).unwrap().push_int64(value);
1035 builder.advance_row();
1036 }
1037
1038 builder.finish()
1039 }
1040
1041 #[test]
1042 fn test_stdev_sample() {
1043 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1044
1045 let mut agg = SimpleAggregateOperator::new(
1046 Box::new(mock),
1047 vec![AggregateExpr::stdev(0)],
1048 vec![LogicalType::Float64],
1049 );
1050
1051 let result = agg.next().unwrap().unwrap();
1052 assert_eq!(result.row_count(), 1);
1053 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1056 assert!((stdev - 2.138).abs() < 0.01);
1057 }
1058
1059 #[test]
1060 fn test_stdev_population() {
1061 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1062
1063 let mut agg = SimpleAggregateOperator::new(
1064 Box::new(mock),
1065 vec![AggregateExpr::stdev_pop(0)],
1066 vec![LogicalType::Float64],
1067 );
1068
1069 let result = agg.next().unwrap().unwrap();
1070 assert_eq!(result.row_count(), 1);
1071 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1074 assert!((stdev - 2.0).abs() < 0.01);
1075 }
1076
1077 #[test]
1078 fn test_percentile_disc() {
1079 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1080
1081 let mut agg = SimpleAggregateOperator::new(
1083 Box::new(mock),
1084 vec![AggregateExpr::percentile_disc(0, 0.5)],
1085 vec![LogicalType::Float64],
1086 );
1087
1088 let result = agg.next().unwrap().unwrap();
1089 assert_eq!(result.row_count(), 1);
1090 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1092 assert!((percentile - 4.0).abs() < 0.01);
1093 }
1094
1095 #[test]
1096 fn test_percentile_cont() {
1097 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1098
1099 let mut agg = SimpleAggregateOperator::new(
1101 Box::new(mock),
1102 vec![AggregateExpr::percentile_cont(0, 0.5)],
1103 vec![LogicalType::Float64],
1104 );
1105
1106 let result = agg.next().unwrap().unwrap();
1107 assert_eq!(result.row_count(), 1);
1108 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1111 assert!((percentile - 4.5).abs() < 0.01);
1112 }
1113
1114 #[test]
1115 fn test_percentile_extremes() {
1116 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1118
1119 let mut agg = SimpleAggregateOperator::new(
1120 Box::new(mock),
1121 vec![
1122 AggregateExpr::percentile_disc(0, 0.0),
1123 AggregateExpr::percentile_disc(0, 1.0),
1124 ],
1125 vec![LogicalType::Float64, LogicalType::Float64],
1126 );
1127
1128 let result = agg.next().unwrap().unwrap();
1129 assert_eq!(result.row_count(), 1);
1130 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1132 assert!((p0 - 2.0).abs() < 0.01);
1133 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1135 assert!((p100 - 9.0).abs() < 0.01);
1136 }
1137
1138 #[test]
1139 fn test_stdev_single_value() {
1140 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1142 builder.column_mut(0).unwrap().push_int64(42);
1143 builder.advance_row();
1144 let chunk = builder.finish();
1145
1146 let mock = MockOperator::new(vec![chunk]);
1147
1148 let mut agg = SimpleAggregateOperator::new(
1149 Box::new(mock),
1150 vec![AggregateExpr::stdev(0)],
1151 vec![LogicalType::Float64],
1152 );
1153
1154 let result = agg.next().unwrap().unwrap();
1155 assert_eq!(result.row_count(), 1);
1156 assert!(matches!(
1158 result.column(0).unwrap().get_value(0),
1159 Some(Value::Null)
1160 ));
1161 }
1162
1163 #[test]
1164 fn test_stdev_pop_single_value() {
1165 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1167 builder.column_mut(0).unwrap().push_int64(42);
1168 builder.advance_row();
1169 let chunk = builder.finish();
1170
1171 let mock = MockOperator::new(vec![chunk]);
1172
1173 let mut agg = SimpleAggregateOperator::new(
1174 Box::new(mock),
1175 vec![AggregateExpr::stdev_pop(0)],
1176 vec![LogicalType::Float64],
1177 );
1178
1179 let result = agg.next().unwrap().unwrap();
1180 assert_eq!(result.row_count(), 1);
1181 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1183 assert!((stdev - 0.0).abs() < 0.01);
1184 }
1185}