1use indexmap::IndexMap;
8use std::collections::HashSet;
9
10use grafeo_common::types::{LogicalType, Value};
11
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum HashableValue {
15 Null,
16 Bool(bool),
17 Int64(i64),
18 Float64Bits(u64),
19 String(String),
20 Other(String),
21}
22
23impl From<&Value> for HashableValue {
24 fn from(v: &Value) -> Self {
25 match v {
26 Value::Null => HashableValue::Null,
27 Value::Bool(b) => HashableValue::Bool(*b),
28 Value::Int64(i) => HashableValue::Int64(*i),
29 Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
30 Value::String(s) => HashableValue::String(s.to_string()),
31 other => HashableValue::Other(format!("{other:?}")),
32 }
33 }
34}
35
36impl From<Value> for HashableValue {
37 fn from(v: Value) -> Self {
38 Self::from(&v)
39 }
40}
41
42use super::{Operator, OperatorError, OperatorResult};
43use crate::execution::DataChunk;
44use crate::execution::chunk::DataChunkBuilder;
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AggregateFunction {
49 Count,
51 CountNonNull,
53 Sum,
55 Avg,
57 Min,
59 Max,
61 First,
63 Last,
65 Collect,
67 StdDev,
69 StdDevPop,
71 PercentileDisc,
73 PercentileCont,
75}
76
77#[derive(Debug, Clone)]
79pub struct AggregateExpr {
80 pub function: AggregateFunction,
82 pub column: Option<usize>,
84 pub distinct: bool,
86 pub alias: Option<String>,
88 pub percentile: Option<f64>,
90}
91
92impl AggregateExpr {
93 pub fn count_star() -> Self {
95 Self {
96 function: AggregateFunction::Count,
97 column: None,
98 distinct: false,
99 alias: None,
100 percentile: None,
101 }
102 }
103
104 pub fn count(column: usize) -> Self {
106 Self {
107 function: AggregateFunction::CountNonNull,
108 column: Some(column),
109 distinct: false,
110 alias: None,
111 percentile: None,
112 }
113 }
114
115 pub fn sum(column: usize) -> Self {
117 Self {
118 function: AggregateFunction::Sum,
119 column: Some(column),
120 distinct: false,
121 alias: None,
122 percentile: None,
123 }
124 }
125
126 pub fn avg(column: usize) -> Self {
128 Self {
129 function: AggregateFunction::Avg,
130 column: Some(column),
131 distinct: false,
132 alias: None,
133 percentile: None,
134 }
135 }
136
137 pub fn min(column: usize) -> Self {
139 Self {
140 function: AggregateFunction::Min,
141 column: Some(column),
142 distinct: false,
143 alias: None,
144 percentile: None,
145 }
146 }
147
148 pub fn max(column: usize) -> Self {
150 Self {
151 function: AggregateFunction::Max,
152 column: Some(column),
153 distinct: false,
154 alias: None,
155 percentile: None,
156 }
157 }
158
159 pub fn first(column: usize) -> Self {
161 Self {
162 function: AggregateFunction::First,
163 column: Some(column),
164 distinct: false,
165 alias: None,
166 percentile: None,
167 }
168 }
169
170 pub fn last(column: usize) -> Self {
172 Self {
173 function: AggregateFunction::Last,
174 column: Some(column),
175 distinct: false,
176 alias: None,
177 percentile: None,
178 }
179 }
180
181 pub fn collect(column: usize) -> Self {
183 Self {
184 function: AggregateFunction::Collect,
185 column: Some(column),
186 distinct: false,
187 alias: None,
188 percentile: None,
189 }
190 }
191
192 pub fn stdev(column: usize) -> Self {
194 Self {
195 function: AggregateFunction::StdDev,
196 column: Some(column),
197 distinct: false,
198 alias: None,
199 percentile: None,
200 }
201 }
202
203 pub fn stdev_pop(column: usize) -> Self {
205 Self {
206 function: AggregateFunction::StdDevPop,
207 column: Some(column),
208 distinct: false,
209 alias: None,
210 percentile: None,
211 }
212 }
213
214 pub fn percentile_disc(column: usize, percentile: f64) -> Self {
220 Self {
221 function: AggregateFunction::PercentileDisc,
222 column: Some(column),
223 distinct: false,
224 alias: None,
225 percentile: Some(percentile.clamp(0.0, 1.0)),
226 }
227 }
228
229 pub fn percentile_cont(column: usize, percentile: f64) -> Self {
235 Self {
236 function: AggregateFunction::PercentileCont,
237 column: Some(column),
238 distinct: false,
239 alias: None,
240 percentile: Some(percentile.clamp(0.0, 1.0)),
241 }
242 }
243
244 pub fn with_distinct(mut self) -> Self {
246 self.distinct = true;
247 self
248 }
249
250 pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
252 self.alias = Some(alias.into());
253 self
254 }
255}
256
257#[derive(Debug, Clone)]
259enum AggregateState {
260 Count(i64),
262 CountDistinct(i64, HashSet<HashableValue>),
264 SumInt(i64),
266 SumIntDistinct(i64, HashSet<HashableValue>),
268 SumFloat(f64),
270 SumFloatDistinct(f64, HashSet<HashableValue>),
272 Avg(f64, i64),
274 AvgDistinct(f64, i64, HashSet<HashableValue>),
276 Min(Option<Value>),
278 Max(Option<Value>),
280 First(Option<Value>),
282 Last(Option<Value>),
284 Collect(Vec<Value>),
286 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
288 StdDev { count: i64, mean: f64, m2: f64 },
290 StdDevPop { count: i64, mean: f64, m2: f64 },
292 PercentileDisc { values: Vec<f64>, percentile: f64 },
294 PercentileCont { values: Vec<f64>, percentile: f64 },
296}
297
298impl AggregateState {
299 fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
301 match (function, distinct) {
302 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
303 AggregateState::Count(0)
304 }
305 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
306 AggregateState::CountDistinct(0, HashSet::new())
307 }
308 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
309 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
310 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
311 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
312 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
314 (AggregateFunction::First, _) => AggregateState::First(None),
315 (AggregateFunction::Last, _) => AggregateState::Last(None),
316 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
317 (AggregateFunction::Collect, true) => {
318 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
319 }
320 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
322 count: 0,
323 mean: 0.0,
324 m2: 0.0,
325 },
326 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
327 count: 0,
328 mean: 0.0,
329 m2: 0.0,
330 },
331 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
332 values: Vec::new(),
333 percentile: percentile.unwrap_or(0.5),
334 },
335 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
336 values: Vec::new(),
337 percentile: percentile.unwrap_or(0.5),
338 },
339 }
340 }
341
342 fn update(&mut self, value: Option<Value>) {
344 match self {
345 AggregateState::Count(count) => {
346 *count += 1;
347 }
348 AggregateState::CountDistinct(count, seen) => {
349 if let Some(ref v) = value {
350 let hashable = HashableValue::from(v);
351 if seen.insert(hashable) {
352 *count += 1;
353 }
354 }
355 }
356 AggregateState::SumInt(sum) => {
357 if let Some(Value::Int64(v)) = value {
358 *sum += v;
359 } else if let Some(Value::Float64(v)) = value {
360 *self = AggregateState::SumFloat(*sum as f64 + v);
362 } else if let Some(ref v) = value {
363 if let Some(num) = value_to_f64(v) {
365 *self = AggregateState::SumFloat(*sum as f64 + num);
366 }
367 }
368 }
369 AggregateState::SumIntDistinct(sum, seen) => {
370 if let Some(ref v) = value {
371 let hashable = HashableValue::from(v);
372 if seen.insert(hashable) {
373 if let Value::Int64(i) = v {
374 *sum += i;
375 } else if let Value::Float64(f) = v {
376 let seen_clone = seen.clone();
378 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, seen_clone);
379 } else if let Some(num) = value_to_f64(v) {
380 let seen_clone = seen.clone();
382 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, seen_clone);
383 }
384 }
385 }
386 }
387 AggregateState::SumFloat(sum) => {
388 if let Some(ref v) = value {
389 if let Some(num) = value_to_f64(v) {
391 *sum += num;
392 }
393 }
394 }
395 AggregateState::SumFloatDistinct(sum, seen) => {
396 if let Some(ref v) = value {
397 let hashable = HashableValue::from(v);
398 if seen.insert(hashable)
399 && let Some(num) = value_to_f64(v)
400 {
401 *sum += num;
402 }
403 }
404 }
405 AggregateState::Avg(sum, count) => {
406 if let Some(ref v) = value
407 && let Some(num) = value_to_f64(v)
408 {
409 *sum += num;
410 *count += 1;
411 }
412 }
413 AggregateState::AvgDistinct(sum, count, seen) => {
414 if let Some(ref v) = value {
415 let hashable = HashableValue::from(v);
416 if seen.insert(hashable)
417 && let Some(num) = value_to_f64(v)
418 {
419 *sum += num;
420 *count += 1;
421 }
422 }
423 }
424 AggregateState::Min(min) => {
425 if let Some(v) = value {
426 match min {
427 None => *min = Some(v),
428 Some(current) => {
429 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
430 *min = Some(v);
431 }
432 }
433 }
434 }
435 }
436 AggregateState::Max(max) => {
437 if let Some(v) = value {
438 match max {
439 None => *max = Some(v),
440 Some(current) => {
441 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
442 *max = Some(v);
443 }
444 }
445 }
446 }
447 }
448 AggregateState::First(first) => {
449 if first.is_none() {
450 *first = value;
451 }
452 }
453 AggregateState::Last(last) => {
454 if value.is_some() {
455 *last = value;
456 }
457 }
458 AggregateState::Collect(list) => {
459 if let Some(v) = value {
460 list.push(v);
461 }
462 }
463 AggregateState::CollectDistinct(list, seen) => {
464 if let Some(v) = value {
465 let hashable = HashableValue::from(&v);
466 if seen.insert(hashable) {
467 list.push(v);
468 }
469 }
470 }
471 AggregateState::StdDev { count, mean, m2 }
473 | AggregateState::StdDevPop { count, mean, m2 } => {
474 if let Some(ref v) = value
475 && let Some(x) = value_to_f64(v)
476 {
477 *count += 1;
478 let delta = x - *mean;
479 *mean += delta / *count as f64;
480 let delta2 = x - *mean;
481 *m2 += delta * delta2;
482 }
483 }
484 AggregateState::PercentileDisc { values, .. }
485 | AggregateState::PercentileCont { values, .. } => {
486 if let Some(ref v) = value
487 && let Some(x) = value_to_f64(v)
488 {
489 values.push(x);
490 }
491 }
492 }
493 }
494
495 fn finalize(&self) -> Value {
497 match self {
498 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
499 Value::Int64(*count)
500 }
501 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
502 Value::Int64(*sum)
503 }
504 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
505 Value::Float64(*sum)
506 }
507 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
508 if *count == 0 {
509 Value::Null
510 } else {
511 Value::Float64(*sum / *count as f64)
512 }
513 }
514 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
515 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
516 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
517 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
518 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
519 Value::List(list.clone().into())
520 }
521 AggregateState::StdDev { count, m2, .. } => {
523 if *count < 2 {
524 Value::Null
525 } else {
526 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
527 }
528 }
529 AggregateState::StdDevPop { count, m2, .. } => {
531 if *count == 0 {
532 Value::Null
533 } else {
534 Value::Float64((*m2 / *count as f64).sqrt())
535 }
536 }
537 AggregateState::PercentileDisc { values, percentile } => {
539 if values.is_empty() {
540 Value::Null
541 } else {
542 let mut sorted = values.clone();
543 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
546 Value::Float64(sorted[index])
547 }
548 }
549 AggregateState::PercentileCont { values, percentile } => {
551 if values.is_empty() {
552 Value::Null
553 } else {
554 let mut sorted = values.clone();
555 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556 let rank = percentile * (sorted.len() - 1) as f64;
558 let lower_idx = rank.floor() as usize;
559 let upper_idx = rank.ceil() as usize;
560 if lower_idx == upper_idx {
561 Value::Float64(sorted[lower_idx])
562 } else {
563 let fraction = rank - lower_idx as f64;
564 let result =
565 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
566 Value::Float64(result)
567 }
568 }
569 }
570 }
571 }
572}
573
574fn value_to_f64(value: &Value) -> Option<f64> {
577 match value {
578 Value::Int64(i) => Some(*i as f64),
579 Value::Float64(f) => Some(*f),
580 Value::String(s) => s.parse::<f64>().ok(),
582 _ => None,
583 }
584}
585
586fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
589 match (a, b) {
590 (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
591 (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
592 (Value::String(a), Value::String(b)) => {
593 if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
595 a_num.partial_cmp(&b_num)
596 } else {
597 Some(a.cmp(b))
598 }
599 }
600 (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
601 (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
602 (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
603 (Value::String(s), Value::Int64(i)) => s.parse::<f64>().ok()?.partial_cmp(&(*i as f64)),
605 (Value::String(s), Value::Float64(f)) => s.parse::<f64>().ok()?.partial_cmp(f),
606 (Value::Int64(i), Value::String(s)) => (*i as f64).partial_cmp(&s.parse::<f64>().ok()?),
607 (Value::Float64(f), Value::String(s)) => f.partial_cmp(&s.parse::<f64>().ok()?),
608 _ => None,
609 }
610}
611
612#[derive(Debug, Clone, PartialEq, Eq, Hash)]
614pub struct GroupKey(Vec<GroupKeyPart>);
615
616#[derive(Debug, Clone, PartialEq, Eq, Hash)]
617enum GroupKeyPart {
618 Null,
619 Bool(bool),
620 Int64(i64),
621 String(String),
622}
623
624impl GroupKey {
625 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
627 let parts: Vec<GroupKeyPart> = group_columns
628 .iter()
629 .map(|&col_idx| {
630 chunk
631 .column(col_idx)
632 .and_then(|col| col.get_value(row))
633 .map(|v| match v {
634 Value::Null => GroupKeyPart::Null,
635 Value::Bool(b) => GroupKeyPart::Bool(b),
636 Value::Int64(i) => GroupKeyPart::Int64(i),
637 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
638 Value::String(s) => GroupKeyPart::String(s.to_string()),
639 _ => GroupKeyPart::String(format!("{v:?}")),
640 })
641 .unwrap_or(GroupKeyPart::Null)
642 })
643 .collect();
644 GroupKey(parts)
645 }
646
647 fn to_values(&self) -> Vec<Value> {
649 self.0
650 .iter()
651 .map(|part| match part {
652 GroupKeyPart::Null => Value::Null,
653 GroupKeyPart::Bool(b) => Value::Bool(*b),
654 GroupKeyPart::Int64(i) => Value::Int64(*i),
655 GroupKeyPart::String(s) => Value::String(s.clone().into()),
656 })
657 .collect()
658 }
659}
660
661pub struct HashAggregateOperator {
665 child: Box<dyn Operator>,
667 group_columns: Vec<usize>,
669 aggregates: Vec<AggregateExpr>,
671 output_schema: Vec<LogicalType>,
673 groups: IndexMap<GroupKey, Vec<AggregateState>>,
675 aggregation_complete: bool,
677 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
679}
680
681impl HashAggregateOperator {
682 pub fn new(
690 child: Box<dyn Operator>,
691 group_columns: Vec<usize>,
692 aggregates: Vec<AggregateExpr>,
693 output_schema: Vec<LogicalType>,
694 ) -> Self {
695 Self {
696 child,
697 group_columns,
698 aggregates,
699 output_schema,
700 groups: IndexMap::new(),
701 aggregation_complete: false,
702 results: None,
703 }
704 }
705
706 fn aggregate(&mut self) -> Result<(), OperatorError> {
708 while let Some(chunk) = self.child.next()? {
709 for row in chunk.selected_indices() {
710 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
711
712 let states = self.groups.entry(key).or_insert_with(|| {
714 self.aggregates
715 .iter()
716 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
717 .collect()
718 });
719
720 for (i, agg) in self.aggregates.iter().enumerate() {
722 let value = match (agg.function, agg.distinct) {
723 (AggregateFunction::Count, false) => None,
725 (AggregateFunction::Count, true) => agg
727 .column
728 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
729 _ => agg
730 .column
731 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
732 };
733
734 match (agg.function, agg.distinct) {
736 (AggregateFunction::Count, false) => states[i].update(None),
737 (AggregateFunction::Count, true) => {
738 if value.is_some() && !matches!(value, Some(Value::Null)) {
740 states[i].update(value);
741 }
742 }
743 (AggregateFunction::CountNonNull, _) => {
744 if value.is_some() && !matches!(value, Some(Value::Null)) {
745 states[i].update(value);
746 }
747 }
748 _ => {
749 if value.is_some() && !matches!(value, Some(Value::Null)) {
750 states[i].update(value);
751 }
752 }
753 }
754 }
755 }
756 }
757
758 self.aggregation_complete = true;
759
760 let results: Vec<_> = self.groups.drain(..).collect();
762 self.results = Some(results.into_iter());
763
764 Ok(())
765 }
766}
767
768impl Operator for HashAggregateOperator {
769 fn next(&mut self) -> OperatorResult {
770 if !self.aggregation_complete {
772 self.aggregate()?;
773 }
774
775 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
777 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
779
780 for agg in &self.aggregates {
781 let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
782 let value = state.finalize();
783 if let Some(col) = builder.column_mut(self.group_columns.len()) {
784 col.push_value(value);
785 }
786 }
787 builder.advance_row();
788
789 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
791 }
792
793 let Some(results) = &mut self.results else {
794 return Ok(None);
795 };
796
797 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
798
799 for (key, states) in results.by_ref() {
800 let key_values = key.to_values();
802 for (i, value) in key_values.into_iter().enumerate() {
803 if let Some(col) = builder.column_mut(i) {
804 col.push_value(value);
805 }
806 }
807
808 for (i, state) in states.iter().enumerate() {
810 let col_idx = self.group_columns.len() + i;
811 if let Some(col) = builder.column_mut(col_idx) {
812 col.push_value(state.finalize());
813 }
814 }
815
816 builder.advance_row();
817
818 if builder.is_full() {
819 return Ok(Some(builder.finish()));
820 }
821 }
822
823 if builder.row_count() > 0 {
824 Ok(Some(builder.finish()))
825 } else {
826 Ok(None)
827 }
828 }
829
830 fn reset(&mut self) {
831 self.child.reset();
832 self.groups.clear();
833 self.aggregation_complete = false;
834 self.results = None;
835 }
836
837 fn name(&self) -> &'static str {
838 "HashAggregate"
839 }
840}
841
842pub struct SimpleAggregateOperator {
846 child: Box<dyn Operator>,
848 aggregates: Vec<AggregateExpr>,
850 output_schema: Vec<LogicalType>,
852 states: Vec<AggregateState>,
854 done: bool,
856}
857
858impl SimpleAggregateOperator {
859 pub fn new(
861 child: Box<dyn Operator>,
862 aggregates: Vec<AggregateExpr>,
863 output_schema: Vec<LogicalType>,
864 ) -> Self {
865 let states = aggregates
866 .iter()
867 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
868 .collect();
869
870 Self {
871 child,
872 aggregates,
873 output_schema,
874 states,
875 done: false,
876 }
877 }
878}
879
880impl Operator for SimpleAggregateOperator {
881 fn next(&mut self) -> OperatorResult {
882 if self.done {
883 return Ok(None);
884 }
885
886 while let Some(chunk) = self.child.next()? {
888 for row in chunk.selected_indices() {
889 for (i, agg) in self.aggregates.iter().enumerate() {
890 let value = match (agg.function, agg.distinct) {
891 (AggregateFunction::Count, false) => None,
893 (AggregateFunction::Count, true) => agg
895 .column
896 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
897 _ => agg
898 .column
899 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
900 };
901
902 match (agg.function, agg.distinct) {
903 (AggregateFunction::Count, false) => self.states[i].update(None),
904 (AggregateFunction::Count, true) => {
905 if value.is_some() && !matches!(value, Some(Value::Null)) {
907 self.states[i].update(value);
908 }
909 }
910 (AggregateFunction::CountNonNull, _) => {
911 if value.is_some() && !matches!(value, Some(Value::Null)) {
912 self.states[i].update(value);
913 }
914 }
915 _ => {
916 if value.is_some() && !matches!(value, Some(Value::Null)) {
917 self.states[i].update(value);
918 }
919 }
920 }
921 }
922 }
923 }
924
925 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
927
928 for (i, state) in self.states.iter().enumerate() {
929 if let Some(col) = builder.column_mut(i) {
930 col.push_value(state.finalize());
931 }
932 }
933 builder.advance_row();
934
935 self.done = true;
936 Ok(Some(builder.finish()))
937 }
938
939 fn reset(&mut self) {
940 self.child.reset();
941 self.states = self
942 .aggregates
943 .iter()
944 .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
945 .collect();
946 self.done = false;
947 }
948
949 fn name(&self) -> &'static str {
950 "SimpleAggregate"
951 }
952}
953
954#[cfg(test)]
955mod tests {
956 use super::*;
957 use crate::execution::chunk::DataChunkBuilder;
958
959 struct MockOperator {
960 chunks: Vec<DataChunk>,
961 position: usize,
962 }
963
964 impl MockOperator {
965 fn new(chunks: Vec<DataChunk>) -> Self {
966 Self {
967 chunks,
968 position: 0,
969 }
970 }
971 }
972
973 impl Operator for MockOperator {
974 fn next(&mut self) -> OperatorResult {
975 if self.position < self.chunks.len() {
976 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
977 self.position += 1;
978 Ok(Some(chunk))
979 } else {
980 Ok(None)
981 }
982 }
983
984 fn reset(&mut self) {
985 self.position = 0;
986 }
987
988 fn name(&self) -> &'static str {
989 "Mock"
990 }
991 }
992
993 fn create_test_chunk() -> DataChunk {
994 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
996
997 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
998 for (group, value) in data {
999 builder.column_mut(0).unwrap().push_int64(group);
1000 builder.column_mut(1).unwrap().push_int64(value);
1001 builder.advance_row();
1002 }
1003
1004 builder.finish()
1005 }
1006
1007 #[test]
1008 fn test_simple_count() {
1009 let mock = MockOperator::new(vec![create_test_chunk()]);
1010
1011 let mut agg = SimpleAggregateOperator::new(
1012 Box::new(mock),
1013 vec![AggregateExpr::count_star()],
1014 vec![LogicalType::Int64],
1015 );
1016
1017 let result = agg.next().unwrap().unwrap();
1018 assert_eq!(result.row_count(), 1);
1019 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1020
1021 assert!(agg.next().unwrap().is_none());
1023 }
1024
1025 #[test]
1026 fn test_simple_sum() {
1027 let mock = MockOperator::new(vec![create_test_chunk()]);
1028
1029 let mut agg = SimpleAggregateOperator::new(
1030 Box::new(mock),
1031 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1033 );
1034
1035 let result = agg.next().unwrap().unwrap();
1036 assert_eq!(result.row_count(), 1);
1037 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1039 }
1040
1041 #[test]
1042 fn test_simple_avg() {
1043 let mock = MockOperator::new(vec![create_test_chunk()]);
1044
1045 let mut agg = SimpleAggregateOperator::new(
1046 Box::new(mock),
1047 vec![AggregateExpr::avg(1)],
1048 vec![LogicalType::Float64],
1049 );
1050
1051 let result = agg.next().unwrap().unwrap();
1052 assert_eq!(result.row_count(), 1);
1053 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1055 assert!((avg - 30.0).abs() < 0.001);
1056 }
1057
1058 #[test]
1059 fn test_simple_min_max() {
1060 let mock = MockOperator::new(vec![create_test_chunk()]);
1061
1062 let mut agg = SimpleAggregateOperator::new(
1063 Box::new(mock),
1064 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1065 vec![LogicalType::Int64, LogicalType::Int64],
1066 );
1067
1068 let result = agg.next().unwrap().unwrap();
1069 assert_eq!(result.row_count(), 1);
1070 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1073
1074 #[test]
1075 fn test_sum_with_string_values() {
1076 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1078 builder.column_mut(0).unwrap().push_string("30");
1079 builder.advance_row();
1080 builder.column_mut(0).unwrap().push_string("25");
1081 builder.advance_row();
1082 builder.column_mut(0).unwrap().push_string("35");
1083 builder.advance_row();
1084 let chunk = builder.finish();
1085
1086 let mock = MockOperator::new(vec![chunk]);
1087 let mut agg = SimpleAggregateOperator::new(
1088 Box::new(mock),
1089 vec![AggregateExpr::sum(0)],
1090 vec![LogicalType::Float64],
1091 );
1092
1093 let result = agg.next().unwrap().unwrap();
1094 assert_eq!(result.row_count(), 1);
1095 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1097 assert!(
1098 (sum_val - 90.0).abs() < 0.001,
1099 "Expected 90.0, got {}",
1100 sum_val
1101 );
1102 }
1103
1104 #[test]
1105 fn test_grouped_aggregation() {
1106 let mock = MockOperator::new(vec![create_test_chunk()]);
1107
1108 let mut agg = HashAggregateOperator::new(
1110 Box::new(mock),
1111 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1114 );
1115
1116 let mut results: Vec<(i64, i64)> = Vec::new();
1117 while let Some(chunk) = agg.next().unwrap() {
1118 for row in chunk.selected_indices() {
1119 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1120 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1121 results.push((group, sum));
1122 }
1123 }
1124
1125 results.sort_by_key(|(g, _)| *g);
1126 assert_eq!(results.len(), 2);
1127 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1130
1131 #[test]
1132 fn test_grouped_count() {
1133 let mock = MockOperator::new(vec![create_test_chunk()]);
1134
1135 let mut agg = HashAggregateOperator::new(
1137 Box::new(mock),
1138 vec![0],
1139 vec![AggregateExpr::count_star()],
1140 vec![LogicalType::Int64, LogicalType::Int64],
1141 );
1142
1143 let mut results: Vec<(i64, i64)> = Vec::new();
1144 while let Some(chunk) = agg.next().unwrap() {
1145 for row in chunk.selected_indices() {
1146 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1147 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1148 results.push((group, count));
1149 }
1150 }
1151
1152 results.sort_by_key(|(g, _)| *g);
1153 assert_eq!(results.len(), 2);
1154 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1157
1158 #[test]
1159 fn test_multiple_aggregates() {
1160 let mock = MockOperator::new(vec![create_test_chunk()]);
1161
1162 let mut agg = HashAggregateOperator::new(
1164 Box::new(mock),
1165 vec![0],
1166 vec![
1167 AggregateExpr::count_star(),
1168 AggregateExpr::sum(1),
1169 AggregateExpr::avg(1),
1170 ],
1171 vec![
1172 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1177 );
1178
1179 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1180 while let Some(chunk) = agg.next().unwrap() {
1181 for row in chunk.selected_indices() {
1182 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1183 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1184 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1185 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1186 results.push((group, count, sum, avg));
1187 }
1188 }
1189
1190 results.sort_by_key(|(g, _, _, _)| *g);
1191 assert_eq!(results.len(), 2);
1192
1193 assert_eq!(results[0].0, 1);
1195 assert_eq!(results[0].1, 2);
1196 assert_eq!(results[0].2, 30);
1197 assert!((results[0].3 - 15.0).abs() < 0.001);
1198
1199 assert_eq!(results[1].0, 2);
1201 assert_eq!(results[1].1, 3);
1202 assert_eq!(results[1].2, 120);
1203 assert!((results[1].3 - 40.0).abs() < 0.001);
1204 }
1205
1206 fn create_test_chunk_with_duplicates() -> DataChunk {
1207 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1212
1213 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1214 for (group, value) in data {
1215 builder.column_mut(0).unwrap().push_int64(group);
1216 builder.column_mut(1).unwrap().push_int64(value);
1217 builder.advance_row();
1218 }
1219
1220 builder.finish()
1221 }
1222
1223 #[test]
1224 fn test_count_distinct() {
1225 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1226
1227 let mut agg = SimpleAggregateOperator::new(
1229 Box::new(mock),
1230 vec![AggregateExpr::count(1).with_distinct()],
1231 vec![LogicalType::Int64],
1232 );
1233
1234 let result = agg.next().unwrap().unwrap();
1235 assert_eq!(result.row_count(), 1);
1236 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1238 }
1239
1240 #[test]
1241 fn test_grouped_count_distinct() {
1242 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1243
1244 let mut agg = HashAggregateOperator::new(
1246 Box::new(mock),
1247 vec![0],
1248 vec![AggregateExpr::count(1).with_distinct()],
1249 vec![LogicalType::Int64, LogicalType::Int64],
1250 );
1251
1252 let mut results: Vec<(i64, i64)> = Vec::new();
1253 while let Some(chunk) = agg.next().unwrap() {
1254 for row in chunk.selected_indices() {
1255 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1256 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1257 results.push((group, count));
1258 }
1259 }
1260
1261 results.sort_by_key(|(g, _)| *g);
1262 assert_eq!(results.len(), 2);
1263 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1266
1267 #[test]
1268 fn test_sum_distinct() {
1269 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1270
1271 let mut agg = SimpleAggregateOperator::new(
1273 Box::new(mock),
1274 vec![AggregateExpr::sum(1).with_distinct()],
1275 vec![LogicalType::Int64],
1276 );
1277
1278 let result = agg.next().unwrap().unwrap();
1279 assert_eq!(result.row_count(), 1);
1280 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1282 }
1283
1284 #[test]
1285 fn test_avg_distinct() {
1286 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1287
1288 let mut agg = SimpleAggregateOperator::new(
1290 Box::new(mock),
1291 vec![AggregateExpr::avg(1).with_distinct()],
1292 vec![LogicalType::Float64],
1293 );
1294
1295 let result = agg.next().unwrap().unwrap();
1296 assert_eq!(result.row_count(), 1);
1297 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1299 assert!((avg - 20.0).abs() < 0.001);
1300 }
1301
1302 fn create_statistical_test_chunk() -> DataChunk {
1303 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1306
1307 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1308 builder.column_mut(0).unwrap().push_int64(value);
1309 builder.advance_row();
1310 }
1311
1312 builder.finish()
1313 }
1314
1315 #[test]
1316 fn test_stdev_sample() {
1317 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1318
1319 let mut agg = SimpleAggregateOperator::new(
1320 Box::new(mock),
1321 vec![AggregateExpr::stdev(0)],
1322 vec![LogicalType::Float64],
1323 );
1324
1325 let result = agg.next().unwrap().unwrap();
1326 assert_eq!(result.row_count(), 1);
1327 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1330 assert!((stdev - 2.138).abs() < 0.01);
1331 }
1332
1333 #[test]
1334 fn test_stdev_population() {
1335 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1336
1337 let mut agg = SimpleAggregateOperator::new(
1338 Box::new(mock),
1339 vec![AggregateExpr::stdev_pop(0)],
1340 vec![LogicalType::Float64],
1341 );
1342
1343 let result = agg.next().unwrap().unwrap();
1344 assert_eq!(result.row_count(), 1);
1345 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1348 assert!((stdev - 2.0).abs() < 0.01);
1349 }
1350
1351 #[test]
1352 fn test_percentile_disc() {
1353 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1354
1355 let mut agg = SimpleAggregateOperator::new(
1357 Box::new(mock),
1358 vec![AggregateExpr::percentile_disc(0, 0.5)],
1359 vec![LogicalType::Float64],
1360 );
1361
1362 let result = agg.next().unwrap().unwrap();
1363 assert_eq!(result.row_count(), 1);
1364 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1366 assert!((percentile - 4.0).abs() < 0.01);
1367 }
1368
1369 #[test]
1370 fn test_percentile_cont() {
1371 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1372
1373 let mut agg = SimpleAggregateOperator::new(
1375 Box::new(mock),
1376 vec![AggregateExpr::percentile_cont(0, 0.5)],
1377 vec![LogicalType::Float64],
1378 );
1379
1380 let result = agg.next().unwrap().unwrap();
1381 assert_eq!(result.row_count(), 1);
1382 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1385 assert!((percentile - 4.5).abs() < 0.01);
1386 }
1387
1388 #[test]
1389 fn test_percentile_extremes() {
1390 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1392
1393 let mut agg = SimpleAggregateOperator::new(
1394 Box::new(mock),
1395 vec![
1396 AggregateExpr::percentile_disc(0, 0.0),
1397 AggregateExpr::percentile_disc(0, 1.0),
1398 ],
1399 vec![LogicalType::Float64, LogicalType::Float64],
1400 );
1401
1402 let result = agg.next().unwrap().unwrap();
1403 assert_eq!(result.row_count(), 1);
1404 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1406 assert!((p0 - 2.0).abs() < 0.01);
1407 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1409 assert!((p100 - 9.0).abs() < 0.01);
1410 }
1411
1412 #[test]
1413 fn test_stdev_single_value() {
1414 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1416 builder.column_mut(0).unwrap().push_int64(42);
1417 builder.advance_row();
1418 let chunk = builder.finish();
1419
1420 let mock = MockOperator::new(vec![chunk]);
1421
1422 let mut agg = SimpleAggregateOperator::new(
1423 Box::new(mock),
1424 vec![AggregateExpr::stdev(0)],
1425 vec![LogicalType::Float64],
1426 );
1427
1428 let result = agg.next().unwrap().unwrap();
1429 assert_eq!(result.row_count(), 1);
1430 assert!(matches!(
1432 result.column(0).unwrap().get_value(0),
1433 Some(Value::Null)
1434 ));
1435 }
1436
1437 #[test]
1438 fn test_stdev_pop_single_value() {
1439 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1441 builder.column_mut(0).unwrap().push_int64(42);
1442 builder.advance_row();
1443 let chunk = builder.finish();
1444
1445 let mock = MockOperator::new(vec![chunk]);
1446
1447 let mut agg = SimpleAggregateOperator::new(
1448 Box::new(mock),
1449 vec![AggregateExpr::stdev_pop(0)],
1450 vec![LogicalType::Float64],
1451 );
1452
1453 let result = agg.next().unwrap().unwrap();
1454 assert_eq!(result.row_count(), 1);
1455 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1457 assert!((stdev - 0.0).abs() < 0.01);
1458 }
1459}