1use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20#[derive(Debug, Clone)]
22pub(crate) enum AggregateState {
23 Count(i64),
25 CountDistinct(i64, HashSet<HashableValue>),
27 SumInt(i64),
29 SumIntDistinct(i64, HashSet<HashableValue>),
31 SumFloat(f64),
33 SumFloatDistinct(f64, HashSet<HashableValue>),
35 Avg(f64, i64),
37 AvgDistinct(f64, i64, HashSet<HashableValue>),
39 Min(Option<Value>),
41 Max(Option<Value>),
43 First(Option<Value>),
45 Last(Option<Value>),
47 Collect(Vec<Value>),
49 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51 StdDev { count: i64, mean: f64, m2: f64 },
53 StdDevPop { count: i64, mean: f64, m2: f64 },
55 PercentileDisc { values: Vec<f64>, percentile: f64 },
57 PercentileCont { values: Vec<f64>, percentile: f64 },
59 GroupConcat(Vec<String>, String),
61 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
63 Sample(Option<Value>),
65 Variance { count: i64, mean: f64, m2: f64 },
67 VariancePop { count: i64, mean: f64, m2: f64 },
69 Bivariate {
71 kind: AggregateFunction,
73 count: i64,
74 mean_x: f64,
75 mean_y: f64,
76 m2_x: f64,
77 m2_y: f64,
78 c_xy: f64,
79 },
80}
81
82impl AggregateState {
83 pub(crate) fn new(
85 function: AggregateFunction,
86 distinct: bool,
87 percentile: Option<f64>,
88 separator: Option<&str>,
89 ) -> Self {
90 match (function, distinct) {
91 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
92 AggregateState::Count(0)
93 }
94 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
95 AggregateState::CountDistinct(0, HashSet::new())
96 }
97 (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
98 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
99 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
100 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
101 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
103 (AggregateFunction::First, _) => AggregateState::First(None),
104 (AggregateFunction::Last, _) => AggregateState::Last(None),
105 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
106 (AggregateFunction::Collect, true) => {
107 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
108 }
109 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
111 count: 0,
112 mean: 0.0,
113 m2: 0.0,
114 },
115 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
116 count: 0,
117 mean: 0.0,
118 m2: 0.0,
119 },
120 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
121 values: Vec::new(),
122 percentile: percentile.unwrap_or(0.5),
123 },
124 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
125 values: Vec::new(),
126 percentile: percentile.unwrap_or(0.5),
127 },
128 (AggregateFunction::GroupConcat, false) => {
129 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
130 }
131 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
132 Vec::new(),
133 separator.unwrap_or(" ").to_string(),
134 HashSet::new(),
135 ),
136 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
137 (
139 AggregateFunction::CovarSamp
140 | AggregateFunction::CovarPop
141 | AggregateFunction::Corr
142 | AggregateFunction::RegrSlope
143 | AggregateFunction::RegrIntercept
144 | AggregateFunction::RegrR2
145 | AggregateFunction::RegrCount
146 | AggregateFunction::RegrSxx
147 | AggregateFunction::RegrSyy
148 | AggregateFunction::RegrSxy
149 | AggregateFunction::RegrAvgx
150 | AggregateFunction::RegrAvgy,
151 _,
152 ) => AggregateState::Bivariate {
153 kind: function,
154 count: 0,
155 mean_x: 0.0,
156 mean_y: 0.0,
157 m2_x: 0.0,
158 m2_y: 0.0,
159 c_xy: 0.0,
160 },
161 (AggregateFunction::Variance, _) => AggregateState::Variance {
162 count: 0,
163 mean: 0.0,
164 m2: 0.0,
165 },
166 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
167 count: 0,
168 mean: 0.0,
169 m2: 0.0,
170 },
171 }
172 }
173
174 pub(crate) fn update(&mut self, value: Option<Value>) {
176 match self {
177 AggregateState::Count(count) => {
178 *count += 1;
179 }
180 AggregateState::CountDistinct(count, seen) => {
181 if let Some(ref v) = value {
182 let hashable = HashableValue::from(v);
183 if seen.insert(hashable) {
184 *count += 1;
185 }
186 }
187 }
188 AggregateState::SumInt(sum) => {
189 if let Some(Value::Int64(v)) = value {
190 *sum += v;
191 } else if let Some(Value::Float64(v)) = value {
192 *self = AggregateState::SumFloat(*sum as f64 + v);
194 } else if let Some(ref v) = value {
195 if let Some(num) = value_to_f64(v) {
197 *self = AggregateState::SumFloat(*sum as f64 + num);
198 }
199 }
200 }
201 AggregateState::SumIntDistinct(sum, seen) => {
202 if let Some(ref v) = value {
203 let hashable = HashableValue::from(v);
204 if seen.insert(hashable) {
205 if let Value::Int64(i) = v {
206 *sum += i;
207 } else if let Value::Float64(f) = v {
208 let moved_seen = std::mem::take(seen);
210 *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
211 } else if let Some(num) = value_to_f64(v) {
212 let moved_seen = std::mem::take(seen);
214 *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
215 }
216 }
217 }
218 }
219 AggregateState::SumFloat(sum) => {
220 if let Some(ref v) = value {
221 if let Some(num) = value_to_f64(v) {
223 *sum += num;
224 }
225 }
226 }
227 AggregateState::SumFloatDistinct(sum, seen) => {
228 if let Some(ref v) = value {
229 let hashable = HashableValue::from(v);
230 if seen.insert(hashable)
231 && let Some(num) = value_to_f64(v)
232 {
233 *sum += num;
234 }
235 }
236 }
237 AggregateState::Avg(sum, count) => {
238 if let Some(ref v) = value
239 && let Some(num) = value_to_f64(v)
240 {
241 *sum += num;
242 *count += 1;
243 }
244 }
245 AggregateState::AvgDistinct(sum, count, seen) => {
246 if let Some(ref v) = value {
247 let hashable = HashableValue::from(v);
248 if seen.insert(hashable)
249 && let Some(num) = value_to_f64(v)
250 {
251 *sum += num;
252 *count += 1;
253 }
254 }
255 }
256 AggregateState::Min(min) => {
257 if let Some(v) = value {
258 match min {
259 None => *min = Some(v),
260 Some(current) => {
261 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
262 *min = Some(v);
263 }
264 }
265 }
266 }
267 }
268 AggregateState::Max(max) => {
269 if let Some(v) = value {
270 match max {
271 None => *max = Some(v),
272 Some(current) => {
273 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
274 *max = Some(v);
275 }
276 }
277 }
278 }
279 }
280 AggregateState::First(first) => {
281 if first.is_none() {
282 *first = value;
283 }
284 }
285 AggregateState::Last(last) => {
286 if value.is_some() {
287 *last = value;
288 }
289 }
290 AggregateState::Collect(list) => {
291 if let Some(v) = value {
292 list.push(v);
293 }
294 }
295 AggregateState::CollectDistinct(list, seen) => {
296 if let Some(v) = value {
297 let hashable = HashableValue::from(&v);
298 if seen.insert(hashable) {
299 list.push(v);
300 }
301 }
302 }
303 AggregateState::StdDev { count, mean, m2 }
305 | AggregateState::StdDevPop { count, mean, m2 }
306 | AggregateState::Variance { count, mean, m2 }
307 | AggregateState::VariancePop { count, mean, m2 } => {
308 if let Some(ref v) = value
309 && let Some(x) = value_to_f64(v)
310 {
311 *count += 1;
312 let delta = x - *mean;
313 *mean += delta / *count as f64;
314 let delta2 = x - *mean;
315 *m2 += delta * delta2;
316 }
317 }
318 AggregateState::PercentileDisc { values, .. }
319 | AggregateState::PercentileCont { values, .. } => {
320 if let Some(ref v) = value
321 && let Some(x) = value_to_f64(v)
322 {
323 values.push(x);
324 }
325 }
326 AggregateState::GroupConcat(list, _sep) => {
327 if let Some(v) = value {
328 list.push(agg_value_to_string(&v));
329 }
330 }
331 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
332 if let Some(v) = value {
333 let hashable = HashableValue::from(&v);
334 if seen.insert(hashable) {
335 list.push(agg_value_to_string(&v));
336 }
337 }
338 }
339 AggregateState::Sample(sample) => {
340 if sample.is_none() {
341 *sample = value;
342 }
343 }
344 AggregateState::Bivariate { .. } => {
345 }
348 }
349 }
350
351 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
356 if let AggregateState::Bivariate {
357 count,
358 mean_x,
359 mean_y,
360 m2_x,
361 m2_y,
362 c_xy,
363 ..
364 } = self
365 {
366 if let (Some(y), Some(x)) = (&y_val, &x_val)
368 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
369 {
370 *count += 1;
371 let n = *count as f64;
372 let dx = x_f - *mean_x;
373 let dy = y_f - *mean_y;
374 *mean_x += dx / n;
375 *mean_y += dy / n;
376 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
379 *m2_y += dy * dy2;
380 *c_xy += dx * dy2;
381 }
382 }
383 }
384
385 pub(crate) fn finalize(&self) -> Value {
387 match self {
388 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
389 Value::Int64(*count)
390 }
391 AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
392 Value::Int64(*sum)
393 }
394 AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
395 Value::Float64(*sum)
396 }
397 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
398 if *count == 0 {
399 Value::Null
400 } else {
401 Value::Float64(*sum / *count as f64)
402 }
403 }
404 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
405 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
406 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
407 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
408 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
409 Value::List(list.clone().into())
410 }
411 AggregateState::StdDev { count, m2, .. } => {
413 if *count < 2 {
414 Value::Null
415 } else {
416 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
417 }
418 }
419 AggregateState::StdDevPop { count, m2, .. } => {
421 if *count == 0 {
422 Value::Null
423 } else {
424 Value::Float64((*m2 / *count as f64).sqrt())
425 }
426 }
427 AggregateState::Variance { count, m2, .. } => {
429 if *count < 2 {
430 Value::Null
431 } else {
432 Value::Float64(*m2 / (*count - 1) as f64)
433 }
434 }
435 AggregateState::VariancePop { count, m2, .. } => {
437 if *count == 0 {
438 Value::Null
439 } else {
440 Value::Float64(*m2 / *count as f64)
441 }
442 }
443 AggregateState::PercentileDisc { values, percentile } => {
445 if values.is_empty() {
446 Value::Null
447 } else {
448 let mut sorted = values.clone();
449 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
450 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
452 Value::Float64(sorted[index])
453 }
454 }
455 AggregateState::PercentileCont { values, percentile } => {
457 if values.is_empty() {
458 Value::Null
459 } else {
460 let mut sorted = values.clone();
461 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462 let rank = percentile * (sorted.len() - 1) as f64;
464 let lower_idx = rank.floor() as usize;
465 let upper_idx = rank.ceil() as usize;
466 if lower_idx == upper_idx {
467 Value::Float64(sorted[lower_idx])
468 } else {
469 let fraction = rank - lower_idx as f64;
470 let result =
471 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
472 Value::Float64(result)
473 }
474 }
475 }
476 AggregateState::GroupConcat(list, sep)
478 | AggregateState::GroupConcatDistinct(list, sep, _) => {
479 Value::String(list.join(sep).into())
480 }
481 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
483 AggregateState::Bivariate {
485 kind,
486 count,
487 mean_x,
488 mean_y,
489 m2_x,
490 m2_y,
491 c_xy,
492 } => {
493 let n = *count;
494 match kind {
495 AggregateFunction::CovarSamp => {
496 if n < 2 {
497 Value::Null
498 } else {
499 Value::Float64(*c_xy / (n - 1) as f64)
500 }
501 }
502 AggregateFunction::CovarPop => {
503 if n == 0 {
504 Value::Null
505 } else {
506 Value::Float64(*c_xy / n as f64)
507 }
508 }
509 AggregateFunction::Corr => {
510 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
511 Value::Null
512 } else {
513 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
514 }
515 }
516 AggregateFunction::RegrSlope => {
517 if n == 0 || *m2_x == 0.0 {
518 Value::Null
519 } else {
520 Value::Float64(*c_xy / *m2_x)
521 }
522 }
523 AggregateFunction::RegrIntercept => {
524 if n == 0 || *m2_x == 0.0 {
525 Value::Null
526 } else {
527 let slope = *c_xy / *m2_x;
528 Value::Float64(*mean_y - slope * *mean_x)
529 }
530 }
531 AggregateFunction::RegrR2 => {
532 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
533 Value::Null
534 } else {
535 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
536 }
537 }
538 AggregateFunction::RegrCount => Value::Int64(n),
539 AggregateFunction::RegrSxx => {
540 if n == 0 {
541 Value::Null
542 } else {
543 Value::Float64(*m2_x)
544 }
545 }
546 AggregateFunction::RegrSyy => {
547 if n == 0 {
548 Value::Null
549 } else {
550 Value::Float64(*m2_y)
551 }
552 }
553 AggregateFunction::RegrSxy => {
554 if n == 0 {
555 Value::Null
556 } else {
557 Value::Float64(*c_xy)
558 }
559 }
560 AggregateFunction::RegrAvgx => {
561 if n == 0 {
562 Value::Null
563 } else {
564 Value::Float64(*mean_x)
565 }
566 }
567 AggregateFunction::RegrAvgy => {
568 if n == 0 {
569 Value::Null
570 } else {
571 Value::Float64(*mean_y)
572 }
573 }
574 _ => Value::Null, }
576 }
577 }
578 }
579}
580
581use super::value_utils::{compare_values, value_to_f64};
582
583fn agg_value_to_string(val: &Value) -> String {
585 match val {
586 Value::String(s) => s.to_string(),
587 Value::Int64(i) => i.to_string(),
588 Value::Float64(f) => f.to_string(),
589 Value::Bool(b) => b.to_string(),
590 Value::Null => String::new(),
591 other => format!("{other:?}"),
592 }
593}
594
595#[derive(Debug, Clone, PartialEq, Eq, Hash)]
597pub struct GroupKey(Vec<GroupKeyPart>);
598
599#[derive(Debug, Clone, PartialEq, Eq, Hash)]
600enum GroupKeyPart {
601 Null,
602 Bool(bool),
603 Int64(i64),
604 String(String),
605}
606
607impl GroupKey {
608 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
610 let parts: Vec<GroupKeyPart> = group_columns
611 .iter()
612 .map(|&col_idx| {
613 chunk
614 .column(col_idx)
615 .and_then(|col| col.get_value(row))
616 .map_or(GroupKeyPart::Null, |v| match v {
617 Value::Null => GroupKeyPart::Null,
618 Value::Bool(b) => GroupKeyPart::Bool(b),
619 Value::Int64(i) => GroupKeyPart::Int64(i),
620 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
621 Value::String(s) => GroupKeyPart::String(s.to_string()),
622 _ => GroupKeyPart::String(format!("{v:?}")),
623 })
624 })
625 .collect();
626 GroupKey(parts)
627 }
628
629 fn to_values(&self) -> Vec<Value> {
631 self.0
632 .iter()
633 .map(|part| match part {
634 GroupKeyPart::Null => Value::Null,
635 GroupKeyPart::Bool(b) => Value::Bool(*b),
636 GroupKeyPart::Int64(i) => Value::Int64(*i),
637 GroupKeyPart::String(s) => Value::String(s.clone().into()),
638 })
639 .collect()
640 }
641}
642
643pub struct HashAggregateOperator {
647 child: Box<dyn Operator>,
649 group_columns: Vec<usize>,
651 aggregates: Vec<AggregateExpr>,
653 output_schema: Vec<LogicalType>,
655 groups: IndexMap<GroupKey, Vec<AggregateState>>,
657 aggregation_complete: bool,
659 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
661}
662
663impl HashAggregateOperator {
664 pub fn new(
672 child: Box<dyn Operator>,
673 group_columns: Vec<usize>,
674 aggregates: Vec<AggregateExpr>,
675 output_schema: Vec<LogicalType>,
676 ) -> Self {
677 Self {
678 child,
679 group_columns,
680 aggregates,
681 output_schema,
682 groups: IndexMap::new(),
683 aggregation_complete: false,
684 results: None,
685 }
686 }
687
688 fn aggregate(&mut self) -> Result<(), OperatorError> {
690 while let Some(chunk) = self.child.next()? {
691 for row in chunk.selected_indices() {
692 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
693
694 let states = self.groups.entry(key).or_insert_with(|| {
696 self.aggregates
697 .iter()
698 .map(|agg| {
699 AggregateState::new(
700 agg.function,
701 agg.distinct,
702 agg.percentile,
703 agg.separator.as_deref(),
704 )
705 })
706 .collect()
707 });
708
709 for (i, agg) in self.aggregates.iter().enumerate() {
711 if agg.column2.is_some() {
713 let y_val = agg
714 .column
715 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
716 let x_val = agg
717 .column2
718 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
719 states[i].update_bivariate(y_val, x_val);
720 continue;
721 }
722
723 let value = match (agg.function, agg.distinct) {
724 (AggregateFunction::Count, false) => None,
726 (AggregateFunction::Count, true) => agg
728 .column
729 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
730 _ => agg
731 .column
732 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
733 };
734
735 match (agg.function, agg.distinct) {
737 (AggregateFunction::Count, false) => states[i].update(None),
738 (AggregateFunction::Count, true) => {
739 if value.is_some() && !matches!(value, Some(Value::Null)) {
741 states[i].update(value);
742 }
743 }
744 (AggregateFunction::CountNonNull, _) => {
745 if value.is_some() && !matches!(value, Some(Value::Null)) {
746 states[i].update(value);
747 }
748 }
749 _ => {
750 if value.is_some() && !matches!(value, Some(Value::Null)) {
751 states[i].update(value);
752 }
753 }
754 }
755 }
756 }
757 }
758
759 self.aggregation_complete = true;
760
761 let results: Vec<_> = self.groups.drain(..).collect();
763 self.results = Some(results.into_iter());
764
765 Ok(())
766 }
767}
768
769impl Operator for HashAggregateOperator {
770 fn next(&mut self) -> OperatorResult {
771 if !self.aggregation_complete {
773 self.aggregate()?;
774 }
775
776 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
778 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
780
781 for agg in &self.aggregates {
782 let state = AggregateState::new(
783 agg.function,
784 agg.distinct,
785 agg.percentile,
786 agg.separator.as_deref(),
787 );
788 let value = state.finalize();
789 if let Some(col) = builder.column_mut(self.group_columns.len()) {
790 col.push_value(value);
791 }
792 }
793 builder.advance_row();
794
795 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
797 }
798
799 let Some(results) = &mut self.results else {
800 return Ok(None);
801 };
802
803 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
804
805 for (key, states) in results.by_ref() {
806 let key_values = key.to_values();
808 for (i, value) in key_values.into_iter().enumerate() {
809 if let Some(col) = builder.column_mut(i) {
810 col.push_value(value);
811 }
812 }
813
814 for (i, state) in states.iter().enumerate() {
816 let col_idx = self.group_columns.len() + i;
817 if let Some(col) = builder.column_mut(col_idx) {
818 col.push_value(state.finalize());
819 }
820 }
821
822 builder.advance_row();
823
824 if builder.is_full() {
825 return Ok(Some(builder.finish()));
826 }
827 }
828
829 if builder.row_count() > 0 {
830 Ok(Some(builder.finish()))
831 } else {
832 Ok(None)
833 }
834 }
835
836 fn reset(&mut self) {
837 self.child.reset();
838 self.groups.clear();
839 self.aggregation_complete = false;
840 self.results = None;
841 }
842
843 fn name(&self) -> &'static str {
844 "HashAggregate"
845 }
846}
847
848pub struct SimpleAggregateOperator {
852 child: Box<dyn Operator>,
854 aggregates: Vec<AggregateExpr>,
856 output_schema: Vec<LogicalType>,
858 states: Vec<AggregateState>,
860 done: bool,
862}
863
864impl SimpleAggregateOperator {
865 pub fn new(
867 child: Box<dyn Operator>,
868 aggregates: Vec<AggregateExpr>,
869 output_schema: Vec<LogicalType>,
870 ) -> Self {
871 let states = aggregates
872 .iter()
873 .map(|agg| {
874 AggregateState::new(
875 agg.function,
876 agg.distinct,
877 agg.percentile,
878 agg.separator.as_deref(),
879 )
880 })
881 .collect();
882
883 Self {
884 child,
885 aggregates,
886 output_schema,
887 states,
888 done: false,
889 }
890 }
891}
892
893impl Operator for SimpleAggregateOperator {
894 fn next(&mut self) -> OperatorResult {
895 if self.done {
896 return Ok(None);
897 }
898
899 while let Some(chunk) = self.child.next()? {
901 for row in chunk.selected_indices() {
902 for (i, agg) in self.aggregates.iter().enumerate() {
903 if agg.column2.is_some() {
905 let y_val = agg
906 .column
907 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
908 let x_val = agg
909 .column2
910 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
911 self.states[i].update_bivariate(y_val, x_val);
912 continue;
913 }
914
915 let value = match (agg.function, agg.distinct) {
916 (AggregateFunction::Count, false) => None,
918 (AggregateFunction::Count, true) => agg
920 .column
921 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
922 _ => agg
923 .column
924 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
925 };
926
927 match (agg.function, agg.distinct) {
928 (AggregateFunction::Count, false) => self.states[i].update(None),
929 (AggregateFunction::Count, true) => {
930 if value.is_some() && !matches!(value, Some(Value::Null)) {
932 self.states[i].update(value);
933 }
934 }
935 (AggregateFunction::CountNonNull, _) => {
936 if value.is_some() && !matches!(value, Some(Value::Null)) {
937 self.states[i].update(value);
938 }
939 }
940 _ => {
941 if value.is_some() && !matches!(value, Some(Value::Null)) {
942 self.states[i].update(value);
943 }
944 }
945 }
946 }
947 }
948 }
949
950 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
952
953 for (i, state) in self.states.iter().enumerate() {
954 if let Some(col) = builder.column_mut(i) {
955 col.push_value(state.finalize());
956 }
957 }
958 builder.advance_row();
959
960 self.done = true;
961 Ok(Some(builder.finish()))
962 }
963
964 fn reset(&mut self) {
965 self.child.reset();
966 self.states = self
967 .aggregates
968 .iter()
969 .map(|agg| {
970 AggregateState::new(
971 agg.function,
972 agg.distinct,
973 agg.percentile,
974 agg.separator.as_deref(),
975 )
976 })
977 .collect();
978 self.done = false;
979 }
980
981 fn name(&self) -> &'static str {
982 "SimpleAggregate"
983 }
984}
985
986#[cfg(test)]
987mod tests {
988 use super::*;
989 use crate::execution::chunk::DataChunkBuilder;
990
991 struct MockOperator {
992 chunks: Vec<DataChunk>,
993 position: usize,
994 }
995
996 impl MockOperator {
997 fn new(chunks: Vec<DataChunk>) -> Self {
998 Self {
999 chunks,
1000 position: 0,
1001 }
1002 }
1003 }
1004
1005 impl Operator for MockOperator {
1006 fn next(&mut self) -> OperatorResult {
1007 if self.position < self.chunks.len() {
1008 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1009 self.position += 1;
1010 Ok(Some(chunk))
1011 } else {
1012 Ok(None)
1013 }
1014 }
1015
1016 fn reset(&mut self) {
1017 self.position = 0;
1018 }
1019
1020 fn name(&self) -> &'static str {
1021 "Mock"
1022 }
1023 }
1024
1025 fn create_test_chunk() -> DataChunk {
1026 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1028
1029 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1030 for (group, value) in data {
1031 builder.column_mut(0).unwrap().push_int64(group);
1032 builder.column_mut(1).unwrap().push_int64(value);
1033 builder.advance_row();
1034 }
1035
1036 builder.finish()
1037 }
1038
1039 #[test]
1040 fn test_simple_count() {
1041 let mock = MockOperator::new(vec![create_test_chunk()]);
1042
1043 let mut agg = SimpleAggregateOperator::new(
1044 Box::new(mock),
1045 vec![AggregateExpr::count_star()],
1046 vec![LogicalType::Int64],
1047 );
1048
1049 let result = agg.next().unwrap().unwrap();
1050 assert_eq!(result.row_count(), 1);
1051 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1052
1053 assert!(agg.next().unwrap().is_none());
1055 }
1056
1057 #[test]
1058 fn test_simple_sum() {
1059 let mock = MockOperator::new(vec![create_test_chunk()]);
1060
1061 let mut agg = SimpleAggregateOperator::new(
1062 Box::new(mock),
1063 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1065 );
1066
1067 let result = agg.next().unwrap().unwrap();
1068 assert_eq!(result.row_count(), 1);
1069 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1071 }
1072
1073 #[test]
1074 fn test_simple_avg() {
1075 let mock = MockOperator::new(vec![create_test_chunk()]);
1076
1077 let mut agg = SimpleAggregateOperator::new(
1078 Box::new(mock),
1079 vec![AggregateExpr::avg(1)],
1080 vec![LogicalType::Float64],
1081 );
1082
1083 let result = agg.next().unwrap().unwrap();
1084 assert_eq!(result.row_count(), 1);
1085 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1087 assert!((avg - 30.0).abs() < 0.001);
1088 }
1089
1090 #[test]
1091 fn test_simple_min_max() {
1092 let mock = MockOperator::new(vec![create_test_chunk()]);
1093
1094 let mut agg = SimpleAggregateOperator::new(
1095 Box::new(mock),
1096 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1097 vec![LogicalType::Int64, LogicalType::Int64],
1098 );
1099
1100 let result = agg.next().unwrap().unwrap();
1101 assert_eq!(result.row_count(), 1);
1102 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1105
1106 #[test]
1107 fn test_sum_with_string_values() {
1108 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1110 builder.column_mut(0).unwrap().push_string("30");
1111 builder.advance_row();
1112 builder.column_mut(0).unwrap().push_string("25");
1113 builder.advance_row();
1114 builder.column_mut(0).unwrap().push_string("35");
1115 builder.advance_row();
1116 let chunk = builder.finish();
1117
1118 let mock = MockOperator::new(vec![chunk]);
1119 let mut agg = SimpleAggregateOperator::new(
1120 Box::new(mock),
1121 vec![AggregateExpr::sum(0)],
1122 vec![LogicalType::Float64],
1123 );
1124
1125 let result = agg.next().unwrap().unwrap();
1126 assert_eq!(result.row_count(), 1);
1127 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1129 assert!(
1130 (sum_val - 90.0).abs() < 0.001,
1131 "Expected 90.0, got {}",
1132 sum_val
1133 );
1134 }
1135
1136 #[test]
1137 fn test_grouped_aggregation() {
1138 let mock = MockOperator::new(vec![create_test_chunk()]);
1139
1140 let mut agg = HashAggregateOperator::new(
1142 Box::new(mock),
1143 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1146 );
1147
1148 let mut results: Vec<(i64, i64)> = Vec::new();
1149 while let Some(chunk) = agg.next().unwrap() {
1150 for row in chunk.selected_indices() {
1151 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1152 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1153 results.push((group, sum));
1154 }
1155 }
1156
1157 results.sort_by_key(|(g, _)| *g);
1158 assert_eq!(results.len(), 2);
1159 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1162
1163 #[test]
1164 fn test_grouped_count() {
1165 let mock = MockOperator::new(vec![create_test_chunk()]);
1166
1167 let mut agg = HashAggregateOperator::new(
1169 Box::new(mock),
1170 vec![0],
1171 vec![AggregateExpr::count_star()],
1172 vec![LogicalType::Int64, LogicalType::Int64],
1173 );
1174
1175 let mut results: Vec<(i64, i64)> = Vec::new();
1176 while let Some(chunk) = agg.next().unwrap() {
1177 for row in chunk.selected_indices() {
1178 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1179 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1180 results.push((group, count));
1181 }
1182 }
1183
1184 results.sort_by_key(|(g, _)| *g);
1185 assert_eq!(results.len(), 2);
1186 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1189
1190 #[test]
1191 fn test_multiple_aggregates() {
1192 let mock = MockOperator::new(vec![create_test_chunk()]);
1193
1194 let mut agg = HashAggregateOperator::new(
1196 Box::new(mock),
1197 vec![0],
1198 vec![
1199 AggregateExpr::count_star(),
1200 AggregateExpr::sum(1),
1201 AggregateExpr::avg(1),
1202 ],
1203 vec![
1204 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1209 );
1210
1211 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1212 while let Some(chunk) = agg.next().unwrap() {
1213 for row in chunk.selected_indices() {
1214 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1215 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1216 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1217 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1218 results.push((group, count, sum, avg));
1219 }
1220 }
1221
1222 results.sort_by_key(|(g, _, _, _)| *g);
1223 assert_eq!(results.len(), 2);
1224
1225 assert_eq!(results[0].0, 1);
1227 assert_eq!(results[0].1, 2);
1228 assert_eq!(results[0].2, 30);
1229 assert!((results[0].3 - 15.0).abs() < 0.001);
1230
1231 assert_eq!(results[1].0, 2);
1233 assert_eq!(results[1].1, 3);
1234 assert_eq!(results[1].2, 120);
1235 assert!((results[1].3 - 40.0).abs() < 0.001);
1236 }
1237
1238 fn create_test_chunk_with_duplicates() -> DataChunk {
1239 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1244
1245 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1246 for (group, value) in data {
1247 builder.column_mut(0).unwrap().push_int64(group);
1248 builder.column_mut(1).unwrap().push_int64(value);
1249 builder.advance_row();
1250 }
1251
1252 builder.finish()
1253 }
1254
1255 #[test]
1256 fn test_count_distinct() {
1257 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1258
1259 let mut agg = SimpleAggregateOperator::new(
1261 Box::new(mock),
1262 vec![AggregateExpr::count(1).with_distinct()],
1263 vec![LogicalType::Int64],
1264 );
1265
1266 let result = agg.next().unwrap().unwrap();
1267 assert_eq!(result.row_count(), 1);
1268 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1270 }
1271
1272 #[test]
1273 fn test_grouped_count_distinct() {
1274 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1275
1276 let mut agg = HashAggregateOperator::new(
1278 Box::new(mock),
1279 vec![0],
1280 vec![AggregateExpr::count(1).with_distinct()],
1281 vec![LogicalType::Int64, LogicalType::Int64],
1282 );
1283
1284 let mut results: Vec<(i64, i64)> = Vec::new();
1285 while let Some(chunk) = agg.next().unwrap() {
1286 for row in chunk.selected_indices() {
1287 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1288 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1289 results.push((group, count));
1290 }
1291 }
1292
1293 results.sort_by_key(|(g, _)| *g);
1294 assert_eq!(results.len(), 2);
1295 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1298
1299 #[test]
1300 fn test_sum_distinct() {
1301 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1302
1303 let mut agg = SimpleAggregateOperator::new(
1305 Box::new(mock),
1306 vec![AggregateExpr::sum(1).with_distinct()],
1307 vec![LogicalType::Int64],
1308 );
1309
1310 let result = agg.next().unwrap().unwrap();
1311 assert_eq!(result.row_count(), 1);
1312 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1314 }
1315
1316 #[test]
1317 fn test_avg_distinct() {
1318 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1319
1320 let mut agg = SimpleAggregateOperator::new(
1322 Box::new(mock),
1323 vec![AggregateExpr::avg(1).with_distinct()],
1324 vec![LogicalType::Float64],
1325 );
1326
1327 let result = agg.next().unwrap().unwrap();
1328 assert_eq!(result.row_count(), 1);
1329 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1331 assert!((avg - 20.0).abs() < 0.001);
1332 }
1333
1334 fn create_statistical_test_chunk() -> DataChunk {
1335 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1338
1339 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1340 builder.column_mut(0).unwrap().push_int64(value);
1341 builder.advance_row();
1342 }
1343
1344 builder.finish()
1345 }
1346
1347 #[test]
1348 fn test_stdev_sample() {
1349 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1350
1351 let mut agg = SimpleAggregateOperator::new(
1352 Box::new(mock),
1353 vec![AggregateExpr::stdev(0)],
1354 vec![LogicalType::Float64],
1355 );
1356
1357 let result = agg.next().unwrap().unwrap();
1358 assert_eq!(result.row_count(), 1);
1359 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1362 assert!((stdev - 2.138).abs() < 0.01);
1363 }
1364
1365 #[test]
1366 fn test_stdev_population() {
1367 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1368
1369 let mut agg = SimpleAggregateOperator::new(
1370 Box::new(mock),
1371 vec![AggregateExpr::stdev_pop(0)],
1372 vec![LogicalType::Float64],
1373 );
1374
1375 let result = agg.next().unwrap().unwrap();
1376 assert_eq!(result.row_count(), 1);
1377 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1380 assert!((stdev - 2.0).abs() < 0.01);
1381 }
1382
1383 #[test]
1384 fn test_percentile_disc() {
1385 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1386
1387 let mut agg = SimpleAggregateOperator::new(
1389 Box::new(mock),
1390 vec![AggregateExpr::percentile_disc(0, 0.5)],
1391 vec![LogicalType::Float64],
1392 );
1393
1394 let result = agg.next().unwrap().unwrap();
1395 assert_eq!(result.row_count(), 1);
1396 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1398 assert!((percentile - 4.0).abs() < 0.01);
1399 }
1400
1401 #[test]
1402 fn test_percentile_cont() {
1403 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1404
1405 let mut agg = SimpleAggregateOperator::new(
1407 Box::new(mock),
1408 vec![AggregateExpr::percentile_cont(0, 0.5)],
1409 vec![LogicalType::Float64],
1410 );
1411
1412 let result = agg.next().unwrap().unwrap();
1413 assert_eq!(result.row_count(), 1);
1414 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1417 assert!((percentile - 4.5).abs() < 0.01);
1418 }
1419
1420 #[test]
1421 fn test_percentile_extremes() {
1422 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1424
1425 let mut agg = SimpleAggregateOperator::new(
1426 Box::new(mock),
1427 vec![
1428 AggregateExpr::percentile_disc(0, 0.0),
1429 AggregateExpr::percentile_disc(0, 1.0),
1430 ],
1431 vec![LogicalType::Float64, LogicalType::Float64],
1432 );
1433
1434 let result = agg.next().unwrap().unwrap();
1435 assert_eq!(result.row_count(), 1);
1436 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1438 assert!((p0 - 2.0).abs() < 0.01);
1439 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1441 assert!((p100 - 9.0).abs() < 0.01);
1442 }
1443
1444 #[test]
1445 fn test_stdev_single_value() {
1446 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1448 builder.column_mut(0).unwrap().push_int64(42);
1449 builder.advance_row();
1450 let chunk = builder.finish();
1451
1452 let mock = MockOperator::new(vec![chunk]);
1453
1454 let mut agg = SimpleAggregateOperator::new(
1455 Box::new(mock),
1456 vec![AggregateExpr::stdev(0)],
1457 vec![LogicalType::Float64],
1458 );
1459
1460 let result = agg.next().unwrap().unwrap();
1461 assert_eq!(result.row_count(), 1);
1462 assert!(matches!(
1464 result.column(0).unwrap().get_value(0),
1465 Some(Value::Null)
1466 ));
1467 }
1468
1469 #[test]
1470 fn test_first_and_last() {
1471 let mock = MockOperator::new(vec![create_test_chunk()]);
1472
1473 let mut agg = SimpleAggregateOperator::new(
1474 Box::new(mock),
1475 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1476 vec![LogicalType::Int64, LogicalType::Int64],
1477 );
1478
1479 let result = agg.next().unwrap().unwrap();
1480 assert_eq!(result.row_count(), 1);
1481 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1483 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1484 }
1485
1486 #[test]
1487 fn test_collect() {
1488 let mock = MockOperator::new(vec![create_test_chunk()]);
1489
1490 let mut agg = SimpleAggregateOperator::new(
1491 Box::new(mock),
1492 vec![AggregateExpr::collect(1)],
1493 vec![LogicalType::Any],
1494 );
1495
1496 let result = agg.next().unwrap().unwrap();
1497 let val = result.column(0).unwrap().get_value(0).unwrap();
1498 if let Value::List(items) = val {
1499 assert_eq!(items.len(), 5);
1500 } else {
1501 panic!("Expected List value");
1502 }
1503 }
1504
1505 #[test]
1506 fn test_collect_distinct() {
1507 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1508
1509 let mut agg = SimpleAggregateOperator::new(
1510 Box::new(mock),
1511 vec![AggregateExpr::collect(1).with_distinct()],
1512 vec![LogicalType::Any],
1513 );
1514
1515 let result = agg.next().unwrap().unwrap();
1516 let val = result.column(0).unwrap().get_value(0).unwrap();
1517 if let Value::List(items) = val {
1518 assert_eq!(items.len(), 3);
1520 } else {
1521 panic!("Expected List value");
1522 }
1523 }
1524
1525 #[test]
1526 fn test_group_concat() {
1527 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1528 for s in ["hello", "world", "foo"] {
1529 builder.column_mut(0).unwrap().push_string(s);
1530 builder.advance_row();
1531 }
1532 let chunk = builder.finish();
1533 let mock = MockOperator::new(vec![chunk]);
1534
1535 let agg_expr = AggregateExpr {
1536 function: AggregateFunction::GroupConcat,
1537 column: Some(0),
1538 column2: None,
1539 distinct: false,
1540 alias: None,
1541 percentile: None,
1542 separator: None,
1543 };
1544
1545 let mut agg =
1546 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1547
1548 let result = agg.next().unwrap().unwrap();
1549 let val = result.column(0).unwrap().get_value(0).unwrap();
1550 assert_eq!(val, Value::String("hello world foo".into()));
1551 }
1552
1553 #[test]
1554 fn test_sample() {
1555 let mock = MockOperator::new(vec![create_test_chunk()]);
1556
1557 let agg_expr = AggregateExpr {
1558 function: AggregateFunction::Sample,
1559 column: Some(1),
1560 column2: None,
1561 distinct: false,
1562 alias: None,
1563 percentile: None,
1564 separator: None,
1565 };
1566
1567 let mut agg =
1568 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1569
1570 let result = agg.next().unwrap().unwrap();
1571 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1573 }
1574
1575 #[test]
1576 fn test_variance_sample() {
1577 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1578
1579 let agg_expr = AggregateExpr {
1580 function: AggregateFunction::Variance,
1581 column: Some(0),
1582 column2: None,
1583 distinct: false,
1584 alias: None,
1585 percentile: None,
1586 separator: None,
1587 };
1588
1589 let mut agg = SimpleAggregateOperator::new(
1590 Box::new(mock),
1591 vec![agg_expr],
1592 vec![LogicalType::Float64],
1593 );
1594
1595 let result = agg.next().unwrap().unwrap();
1596 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1598 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1599 }
1600
1601 #[test]
1602 fn test_variance_population() {
1603 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1604
1605 let agg_expr = AggregateExpr {
1606 function: AggregateFunction::VariancePop,
1607 column: Some(0),
1608 column2: None,
1609 distinct: false,
1610 alias: None,
1611 percentile: None,
1612 separator: None,
1613 };
1614
1615 let mut agg = SimpleAggregateOperator::new(
1616 Box::new(mock),
1617 vec![agg_expr],
1618 vec![LogicalType::Float64],
1619 );
1620
1621 let result = agg.next().unwrap().unwrap();
1622 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1624 assert!((variance - 4.0).abs() < 0.01);
1625 }
1626
1627 #[test]
1628 fn test_variance_single_value() {
1629 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1630 builder.column_mut(0).unwrap().push_int64(42);
1631 builder.advance_row();
1632 let chunk = builder.finish();
1633 let mock = MockOperator::new(vec![chunk]);
1634
1635 let agg_expr = AggregateExpr {
1636 function: AggregateFunction::Variance,
1637 column: Some(0),
1638 column2: None,
1639 distinct: false,
1640 alias: None,
1641 percentile: None,
1642 separator: None,
1643 };
1644
1645 let mut agg = SimpleAggregateOperator::new(
1646 Box::new(mock),
1647 vec![agg_expr],
1648 vec![LogicalType::Float64],
1649 );
1650
1651 let result = agg.next().unwrap().unwrap();
1652 assert!(matches!(
1654 result.column(0).unwrap().get_value(0),
1655 Some(Value::Null)
1656 ));
1657 }
1658
1659 #[test]
1660 fn test_empty_aggregation() {
1661 let mock = MockOperator::new(vec![]);
1663
1664 let mut agg = SimpleAggregateOperator::new(
1665 Box::new(mock),
1666 vec![
1667 AggregateExpr::count_star(),
1668 AggregateExpr::sum(0),
1669 AggregateExpr::avg(0),
1670 AggregateExpr::min(0),
1671 AggregateExpr::max(0),
1672 ],
1673 vec![
1674 LogicalType::Int64,
1675 LogicalType::Int64,
1676 LogicalType::Float64,
1677 LogicalType::Int64,
1678 LogicalType::Int64,
1679 ],
1680 );
1681
1682 let result = agg.next().unwrap().unwrap();
1683 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(0)); assert!(matches!(
1686 result.column(2).unwrap().get_value(0),
1687 Some(Value::Null)
1688 )); assert!(matches!(
1690 result.column(3).unwrap().get_value(0),
1691 Some(Value::Null)
1692 )); assert!(matches!(
1694 result.column(4).unwrap().get_value(0),
1695 Some(Value::Null)
1696 )); }
1698
1699 #[test]
1700 fn test_stdev_pop_single_value() {
1701 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1703 builder.column_mut(0).unwrap().push_int64(42);
1704 builder.advance_row();
1705 let chunk = builder.finish();
1706
1707 let mock = MockOperator::new(vec![chunk]);
1708
1709 let mut agg = SimpleAggregateOperator::new(
1710 Box::new(mock),
1711 vec![AggregateExpr::stdev_pop(0)],
1712 vec![LogicalType::Float64],
1713 );
1714
1715 let result = agg.next().unwrap().unwrap();
1716 assert_eq!(result.row_count(), 1);
1717 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1719 assert!((stdev - 0.0).abs() < 0.01);
1720 }
1721}