1use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20#[derive(Debug, Clone)]
22pub(crate) enum AggregateState {
23 Count(i64),
25 CountDistinct(i64, HashSet<HashableValue>),
27 SumInt(i64, i64),
29 SumIntDistinct(i64, i64, HashSet<HashableValue>),
31 SumFloat(f64, i64),
33 SumFloatDistinct(f64, i64, HashSet<HashableValue>),
35 Avg(f64, i64),
37 AvgDistinct(f64, i64, HashSet<HashableValue>),
39 Min(Option<Value>),
41 Max(Option<Value>),
43 First(Option<Value>),
45 Last(Option<Value>),
47 Collect(Vec<Value>),
49 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51 StdDev { count: i64, mean: f64, m2: f64 },
53 StdDevPop { count: i64, mean: f64, m2: f64 },
55 PercentileDisc { values: Vec<f64>, percentile: f64 },
57 PercentileCont { values: Vec<f64>, percentile: f64 },
59 GroupConcat(Vec<String>, String),
61 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
63 Sample(Option<Value>),
65 Variance { count: i64, mean: f64, m2: f64 },
67 VariancePop { count: i64, mean: f64, m2: f64 },
69 Bivariate {
71 kind: AggregateFunction,
73 count: i64,
74 mean_x: f64,
75 mean_y: f64,
76 m2_x: f64,
77 m2_y: f64,
78 c_xy: f64,
79 },
80}
81
82impl AggregateState {
83 pub(crate) fn new(
85 function: AggregateFunction,
86 distinct: bool,
87 percentile: Option<f64>,
88 separator: Option<&str>,
89 ) -> Self {
90 match (function, distinct) {
91 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
92 AggregateState::Count(0)
93 }
94 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
95 AggregateState::CountDistinct(0, HashSet::new())
96 }
97 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
98 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
99 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
100 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
101 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
103 (AggregateFunction::First, _) => AggregateState::First(None),
104 (AggregateFunction::Last, _) => AggregateState::Last(None),
105 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
106 (AggregateFunction::Collect, true) => {
107 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
108 }
109 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
111 count: 0,
112 mean: 0.0,
113 m2: 0.0,
114 },
115 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
116 count: 0,
117 mean: 0.0,
118 m2: 0.0,
119 },
120 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
121 values: Vec::new(),
122 percentile: percentile.unwrap_or(0.5),
123 },
124 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
125 values: Vec::new(),
126 percentile: percentile.unwrap_or(0.5),
127 },
128 (AggregateFunction::GroupConcat, false) => {
129 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
130 }
131 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
132 Vec::new(),
133 separator.unwrap_or(" ").to_string(),
134 HashSet::new(),
135 ),
136 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
137 (
139 AggregateFunction::CovarSamp
140 | AggregateFunction::CovarPop
141 | AggregateFunction::Corr
142 | AggregateFunction::RegrSlope
143 | AggregateFunction::RegrIntercept
144 | AggregateFunction::RegrR2
145 | AggregateFunction::RegrCount
146 | AggregateFunction::RegrSxx
147 | AggregateFunction::RegrSyy
148 | AggregateFunction::RegrSxy
149 | AggregateFunction::RegrAvgx
150 | AggregateFunction::RegrAvgy,
151 _,
152 ) => AggregateState::Bivariate {
153 kind: function,
154 count: 0,
155 mean_x: 0.0,
156 mean_y: 0.0,
157 m2_x: 0.0,
158 m2_y: 0.0,
159 c_xy: 0.0,
160 },
161 (AggregateFunction::Variance, _) => AggregateState::Variance {
162 count: 0,
163 mean: 0.0,
164 m2: 0.0,
165 },
166 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
167 count: 0,
168 mean: 0.0,
169 m2: 0.0,
170 },
171 }
172 }
173
174 pub(crate) fn update(&mut self, value: Option<Value>) {
176 match self {
177 AggregateState::Count(count) => {
178 *count += 1;
179 }
180 AggregateState::CountDistinct(count, seen) => {
181 if let Some(ref v) = value {
182 let hashable = HashableValue::from(v);
183 if seen.insert(hashable) {
184 *count += 1;
185 }
186 }
187 }
188 AggregateState::SumInt(sum, count) => {
189 if let Some(Value::Int64(v)) = value {
190 *sum += v;
191 *count += 1;
192 } else if let Some(Value::Float64(v)) = value {
193 *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
195 } else if let Some(ref v) = value {
196 if let Some(num) = value_to_f64(v) {
198 *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
199 }
200 }
201 }
202 AggregateState::SumIntDistinct(sum, count, seen) => {
203 if let Some(ref v) = value {
204 let hashable = HashableValue::from(v);
205 if seen.insert(hashable) {
206 if let Value::Int64(i) = v {
207 *sum += i;
208 *count += 1;
209 } else if let Value::Float64(f) = v {
210 let moved_seen = std::mem::take(seen);
212 *self = AggregateState::SumFloatDistinct(
213 *sum as f64 + f,
214 *count + 1,
215 moved_seen,
216 );
217 } else if let Some(num) = value_to_f64(v) {
218 let moved_seen = std::mem::take(seen);
220 *self = AggregateState::SumFloatDistinct(
221 *sum as f64 + num,
222 *count + 1,
223 moved_seen,
224 );
225 }
226 }
227 }
228 }
229 AggregateState::SumFloat(sum, count) => {
230 if let Some(ref v) = value {
231 if let Some(num) = value_to_f64(v) {
233 *sum += num;
234 *count += 1;
235 }
236 }
237 }
238 AggregateState::SumFloatDistinct(sum, count, seen) => {
239 if let Some(ref v) = value {
240 let hashable = HashableValue::from(v);
241 if seen.insert(hashable)
242 && let Some(num) = value_to_f64(v)
243 {
244 *sum += num;
245 *count += 1;
246 }
247 }
248 }
249 AggregateState::Avg(sum, count) => {
250 if let Some(ref v) = value
251 && let Some(num) = value_to_f64(v)
252 {
253 *sum += num;
254 *count += 1;
255 }
256 }
257 AggregateState::AvgDistinct(sum, count, seen) => {
258 if let Some(ref v) = value {
259 let hashable = HashableValue::from(v);
260 if seen.insert(hashable)
261 && let Some(num) = value_to_f64(v)
262 {
263 *sum += num;
264 *count += 1;
265 }
266 }
267 }
268 AggregateState::Min(min) => {
269 if let Some(v) = value {
270 match min {
271 None => *min = Some(v),
272 Some(current) => {
273 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
274 *min = Some(v);
275 }
276 }
277 }
278 }
279 }
280 AggregateState::Max(max) => {
281 if let Some(v) = value {
282 match max {
283 None => *max = Some(v),
284 Some(current) => {
285 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
286 *max = Some(v);
287 }
288 }
289 }
290 }
291 }
292 AggregateState::First(first) => {
293 if first.is_none() {
294 *first = value;
295 }
296 }
297 AggregateState::Last(last) => {
298 if value.is_some() {
299 *last = value;
300 }
301 }
302 AggregateState::Collect(list) => {
303 if let Some(v) = value {
304 list.push(v);
305 }
306 }
307 AggregateState::CollectDistinct(list, seen) => {
308 if let Some(v) = value {
309 let hashable = HashableValue::from(&v);
310 if seen.insert(hashable) {
311 list.push(v);
312 }
313 }
314 }
315 AggregateState::StdDev { count, mean, m2 }
317 | AggregateState::StdDevPop { count, mean, m2 }
318 | AggregateState::Variance { count, mean, m2 }
319 | AggregateState::VariancePop { count, mean, m2 } => {
320 if let Some(ref v) = value
321 && let Some(x) = value_to_f64(v)
322 {
323 *count += 1;
324 let delta = x - *mean;
325 *mean += delta / *count as f64;
326 let delta2 = x - *mean;
327 *m2 += delta * delta2;
328 }
329 }
330 AggregateState::PercentileDisc { values, .. }
331 | AggregateState::PercentileCont { values, .. } => {
332 if let Some(ref v) = value
333 && let Some(x) = value_to_f64(v)
334 {
335 values.push(x);
336 }
337 }
338 AggregateState::GroupConcat(list, _sep) => {
339 if let Some(v) = value {
340 list.push(agg_value_to_string(&v));
341 }
342 }
343 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
344 if let Some(v) = value {
345 let hashable = HashableValue::from(&v);
346 if seen.insert(hashable) {
347 list.push(agg_value_to_string(&v));
348 }
349 }
350 }
351 AggregateState::Sample(sample) => {
352 if sample.is_none() {
353 *sample = value;
354 }
355 }
356 AggregateState::Bivariate { .. } => {
357 }
360 }
361 }
362
363 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
368 if let AggregateState::Bivariate {
369 count,
370 mean_x,
371 mean_y,
372 m2_x,
373 m2_y,
374 c_xy,
375 ..
376 } = self
377 {
378 if let (Some(y), Some(x)) = (&y_val, &x_val)
380 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
381 {
382 *count += 1;
383 let n = *count as f64;
384 let dx = x_f - *mean_x;
385 let dy = y_f - *mean_y;
386 *mean_x += dx / n;
387 *mean_y += dy / n;
388 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
391 *m2_y += dy * dy2;
392 *c_xy += dx * dy2;
393 }
394 }
395 }
396
397 pub(crate) fn finalize(&self) -> Value {
399 match self {
400 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
401 Value::Int64(*count)
402 }
403 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
404 if *count == 0 {
405 Value::Null
406 } else {
407 Value::Int64(*sum)
408 }
409 }
410 AggregateState::SumFloat(sum, count)
411 | AggregateState::SumFloatDistinct(sum, count, _) => {
412 if *count == 0 {
413 Value::Null
414 } else {
415 Value::Float64(*sum)
416 }
417 }
418 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
419 if *count == 0 {
420 Value::Null
421 } else {
422 Value::Float64(*sum / *count as f64)
423 }
424 }
425 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
426 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
427 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
428 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
429 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
430 Value::List(list.clone().into())
431 }
432 AggregateState::StdDev { count, m2, .. } => {
434 if *count < 2 {
435 Value::Null
436 } else {
437 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
438 }
439 }
440 AggregateState::StdDevPop { count, m2, .. } => {
442 if *count == 0 {
443 Value::Null
444 } else {
445 Value::Float64((*m2 / *count as f64).sqrt())
446 }
447 }
448 AggregateState::Variance { count, m2, .. } => {
450 if *count < 2 {
451 Value::Null
452 } else {
453 Value::Float64(*m2 / (*count - 1) as f64)
454 }
455 }
456 AggregateState::VariancePop { count, m2, .. } => {
458 if *count == 0 {
459 Value::Null
460 } else {
461 Value::Float64(*m2 / *count as f64)
462 }
463 }
464 AggregateState::PercentileDisc { values, percentile } => {
466 if values.is_empty() {
467 Value::Null
468 } else {
469 let mut sorted = values.clone();
470 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
471 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
473 Value::Float64(sorted[index])
474 }
475 }
476 AggregateState::PercentileCont { values, percentile } => {
478 if values.is_empty() {
479 Value::Null
480 } else {
481 let mut sorted = values.clone();
482 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
483 let rank = percentile * (sorted.len() - 1) as f64;
485 let lower_idx = rank.floor() as usize;
486 let upper_idx = rank.ceil() as usize;
487 if lower_idx == upper_idx {
488 Value::Float64(sorted[lower_idx])
489 } else {
490 let fraction = rank - lower_idx as f64;
491 let result =
492 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
493 Value::Float64(result)
494 }
495 }
496 }
497 AggregateState::GroupConcat(list, sep)
499 | AggregateState::GroupConcatDistinct(list, sep, _) => {
500 Value::String(list.join(sep).into())
501 }
502 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
504 AggregateState::Bivariate {
506 kind,
507 count,
508 mean_x,
509 mean_y,
510 m2_x,
511 m2_y,
512 c_xy,
513 } => {
514 let n = *count;
515 match kind {
516 AggregateFunction::CovarSamp => {
517 if n < 2 {
518 Value::Null
519 } else {
520 Value::Float64(*c_xy / (n - 1) as f64)
521 }
522 }
523 AggregateFunction::CovarPop => {
524 if n == 0 {
525 Value::Null
526 } else {
527 Value::Float64(*c_xy / n as f64)
528 }
529 }
530 AggregateFunction::Corr => {
531 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
532 Value::Null
533 } else {
534 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
535 }
536 }
537 AggregateFunction::RegrSlope => {
538 if n == 0 || *m2_x == 0.0 {
539 Value::Null
540 } else {
541 Value::Float64(*c_xy / *m2_x)
542 }
543 }
544 AggregateFunction::RegrIntercept => {
545 if n == 0 || *m2_x == 0.0 {
546 Value::Null
547 } else {
548 let slope = *c_xy / *m2_x;
549 Value::Float64(*mean_y - slope * *mean_x)
550 }
551 }
552 AggregateFunction::RegrR2 => {
553 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
554 Value::Null
555 } else {
556 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
557 }
558 }
559 AggregateFunction::RegrCount => Value::Int64(n),
560 AggregateFunction::RegrSxx => {
561 if n == 0 {
562 Value::Null
563 } else {
564 Value::Float64(*m2_x)
565 }
566 }
567 AggregateFunction::RegrSyy => {
568 if n == 0 {
569 Value::Null
570 } else {
571 Value::Float64(*m2_y)
572 }
573 }
574 AggregateFunction::RegrSxy => {
575 if n == 0 {
576 Value::Null
577 } else {
578 Value::Float64(*c_xy)
579 }
580 }
581 AggregateFunction::RegrAvgx => {
582 if n == 0 {
583 Value::Null
584 } else {
585 Value::Float64(*mean_x)
586 }
587 }
588 AggregateFunction::RegrAvgy => {
589 if n == 0 {
590 Value::Null
591 } else {
592 Value::Float64(*mean_y)
593 }
594 }
595 _ => Value::Null, }
597 }
598 }
599 }
600}
601
602use super::value_utils::{compare_values, value_to_f64};
603
604fn agg_value_to_string(val: &Value) -> String {
606 match val {
607 Value::String(s) => s.to_string(),
608 Value::Int64(i) => i.to_string(),
609 Value::Float64(f) => f.to_string(),
610 Value::Bool(b) => b.to_string(),
611 Value::Null => String::new(),
612 other => format!("{other:?}"),
613 }
614}
615
616#[derive(Debug, Clone, PartialEq, Eq, Hash)]
618pub struct GroupKey(Vec<GroupKeyPart>);
619
620#[derive(Debug, Clone, PartialEq, Eq, Hash)]
621enum GroupKeyPart {
622 Null,
623 Bool(bool),
624 Int64(i64),
625 String(String),
626}
627
628impl GroupKey {
629 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
631 let parts: Vec<GroupKeyPart> = group_columns
632 .iter()
633 .map(|&col_idx| {
634 chunk
635 .column(col_idx)
636 .and_then(|col| col.get_value(row))
637 .map_or(GroupKeyPart::Null, |v| match v {
638 Value::Null => GroupKeyPart::Null,
639 Value::Bool(b) => GroupKeyPart::Bool(b),
640 Value::Int64(i) => GroupKeyPart::Int64(i),
641 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
642 Value::String(s) => GroupKeyPart::String(s.to_string()),
643 _ => GroupKeyPart::String(format!("{v:?}")),
644 })
645 })
646 .collect();
647 GroupKey(parts)
648 }
649
650 fn to_values(&self) -> Vec<Value> {
652 self.0
653 .iter()
654 .map(|part| match part {
655 GroupKeyPart::Null => Value::Null,
656 GroupKeyPart::Bool(b) => Value::Bool(*b),
657 GroupKeyPart::Int64(i) => Value::Int64(*i),
658 GroupKeyPart::String(s) => Value::String(s.clone().into()),
659 })
660 .collect()
661 }
662}
663
664pub struct HashAggregateOperator {
668 child: Box<dyn Operator>,
670 group_columns: Vec<usize>,
672 aggregates: Vec<AggregateExpr>,
674 output_schema: Vec<LogicalType>,
676 groups: IndexMap<GroupKey, Vec<AggregateState>>,
678 aggregation_complete: bool,
680 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
682}
683
684impl HashAggregateOperator {
685 pub fn new(
693 child: Box<dyn Operator>,
694 group_columns: Vec<usize>,
695 aggregates: Vec<AggregateExpr>,
696 output_schema: Vec<LogicalType>,
697 ) -> Self {
698 Self {
699 child,
700 group_columns,
701 aggregates,
702 output_schema,
703 groups: IndexMap::new(),
704 aggregation_complete: false,
705 results: None,
706 }
707 }
708
709 fn aggregate(&mut self) -> Result<(), OperatorError> {
711 while let Some(chunk) = self.child.next()? {
712 for row in chunk.selected_indices() {
713 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
714
715 let states = self.groups.entry(key).or_insert_with(|| {
717 self.aggregates
718 .iter()
719 .map(|agg| {
720 AggregateState::new(
721 agg.function,
722 agg.distinct,
723 agg.percentile,
724 agg.separator.as_deref(),
725 )
726 })
727 .collect()
728 });
729
730 for (i, agg) in self.aggregates.iter().enumerate() {
732 if agg.column2.is_some() {
734 let y_val = agg
735 .column
736 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
737 let x_val = agg
738 .column2
739 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
740 states[i].update_bivariate(y_val, x_val);
741 continue;
742 }
743
744 let value = match (agg.function, agg.distinct) {
745 (AggregateFunction::Count, false) => None,
747 (AggregateFunction::Count, true) => agg
749 .column
750 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
751 _ => agg
752 .column
753 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
754 };
755
756 match (agg.function, agg.distinct) {
758 (AggregateFunction::Count, false) => states[i].update(None),
759 (AggregateFunction::Count, true) => {
760 if value.is_some() && !matches!(value, Some(Value::Null)) {
762 states[i].update(value);
763 }
764 }
765 (AggregateFunction::CountNonNull, _) => {
766 if value.is_some() && !matches!(value, Some(Value::Null)) {
767 states[i].update(value);
768 }
769 }
770 _ => {
771 if value.is_some() && !matches!(value, Some(Value::Null)) {
772 states[i].update(value);
773 }
774 }
775 }
776 }
777 }
778 }
779
780 self.aggregation_complete = true;
781
782 let results: Vec<_> = self.groups.drain(..).collect();
784 self.results = Some(results.into_iter());
785
786 Ok(())
787 }
788}
789
790impl Operator for HashAggregateOperator {
791 fn next(&mut self) -> OperatorResult {
792 if !self.aggregation_complete {
794 self.aggregate()?;
795 }
796
797 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
799 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
801
802 for agg in &self.aggregates {
803 let state = AggregateState::new(
804 agg.function,
805 agg.distinct,
806 agg.percentile,
807 agg.separator.as_deref(),
808 );
809 let value = state.finalize();
810 if let Some(col) = builder.column_mut(self.group_columns.len()) {
811 col.push_value(value);
812 }
813 }
814 builder.advance_row();
815
816 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
818 }
819
820 let Some(results) = &mut self.results else {
821 return Ok(None);
822 };
823
824 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
825
826 for (key, states) in results.by_ref() {
827 let key_values = key.to_values();
829 for (i, value) in key_values.into_iter().enumerate() {
830 if let Some(col) = builder.column_mut(i) {
831 col.push_value(value);
832 }
833 }
834
835 for (i, state) in states.iter().enumerate() {
837 let col_idx = self.group_columns.len() + i;
838 if let Some(col) = builder.column_mut(col_idx) {
839 col.push_value(state.finalize());
840 }
841 }
842
843 builder.advance_row();
844
845 if builder.is_full() {
846 return Ok(Some(builder.finish()));
847 }
848 }
849
850 if builder.row_count() > 0 {
851 Ok(Some(builder.finish()))
852 } else {
853 Ok(None)
854 }
855 }
856
857 fn reset(&mut self) {
858 self.child.reset();
859 self.groups.clear();
860 self.aggregation_complete = false;
861 self.results = None;
862 }
863
864 fn name(&self) -> &'static str {
865 "HashAggregate"
866 }
867}
868
869pub struct SimpleAggregateOperator {
873 child: Box<dyn Operator>,
875 aggregates: Vec<AggregateExpr>,
877 output_schema: Vec<LogicalType>,
879 states: Vec<AggregateState>,
881 done: bool,
883}
884
885impl SimpleAggregateOperator {
886 pub fn new(
888 child: Box<dyn Operator>,
889 aggregates: Vec<AggregateExpr>,
890 output_schema: Vec<LogicalType>,
891 ) -> Self {
892 let states = aggregates
893 .iter()
894 .map(|agg| {
895 AggregateState::new(
896 agg.function,
897 agg.distinct,
898 agg.percentile,
899 agg.separator.as_deref(),
900 )
901 })
902 .collect();
903
904 Self {
905 child,
906 aggregates,
907 output_schema,
908 states,
909 done: false,
910 }
911 }
912}
913
914impl Operator for SimpleAggregateOperator {
915 fn next(&mut self) -> OperatorResult {
916 if self.done {
917 return Ok(None);
918 }
919
920 while let Some(chunk) = self.child.next()? {
922 for row in chunk.selected_indices() {
923 for (i, agg) in self.aggregates.iter().enumerate() {
924 if agg.column2.is_some() {
926 let y_val = agg
927 .column
928 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
929 let x_val = agg
930 .column2
931 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
932 self.states[i].update_bivariate(y_val, x_val);
933 continue;
934 }
935
936 let value = match (agg.function, agg.distinct) {
937 (AggregateFunction::Count, false) => None,
939 (AggregateFunction::Count, true) => agg
941 .column
942 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
943 _ => agg
944 .column
945 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
946 };
947
948 match (agg.function, agg.distinct) {
949 (AggregateFunction::Count, false) => self.states[i].update(None),
950 (AggregateFunction::Count, true) => {
951 if value.is_some() && !matches!(value, Some(Value::Null)) {
953 self.states[i].update(value);
954 }
955 }
956 (AggregateFunction::CountNonNull, _) => {
957 if value.is_some() && !matches!(value, Some(Value::Null)) {
958 self.states[i].update(value);
959 }
960 }
961 _ => {
962 if value.is_some() && !matches!(value, Some(Value::Null)) {
963 self.states[i].update(value);
964 }
965 }
966 }
967 }
968 }
969 }
970
971 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
973
974 for (i, state) in self.states.iter().enumerate() {
975 if let Some(col) = builder.column_mut(i) {
976 col.push_value(state.finalize());
977 }
978 }
979 builder.advance_row();
980
981 self.done = true;
982 Ok(Some(builder.finish()))
983 }
984
985 fn reset(&mut self) {
986 self.child.reset();
987 self.states = self
988 .aggregates
989 .iter()
990 .map(|agg| {
991 AggregateState::new(
992 agg.function,
993 agg.distinct,
994 agg.percentile,
995 agg.separator.as_deref(),
996 )
997 })
998 .collect();
999 self.done = false;
1000 }
1001
1002 fn name(&self) -> &'static str {
1003 "SimpleAggregate"
1004 }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009 use super::*;
1010 use crate::execution::chunk::DataChunkBuilder;
1011
1012 struct MockOperator {
1013 chunks: Vec<DataChunk>,
1014 position: usize,
1015 }
1016
1017 impl MockOperator {
1018 fn new(chunks: Vec<DataChunk>) -> Self {
1019 Self {
1020 chunks,
1021 position: 0,
1022 }
1023 }
1024 }
1025
1026 impl Operator for MockOperator {
1027 fn next(&mut self) -> OperatorResult {
1028 if self.position < self.chunks.len() {
1029 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1030 self.position += 1;
1031 Ok(Some(chunk))
1032 } else {
1033 Ok(None)
1034 }
1035 }
1036
1037 fn reset(&mut self) {
1038 self.position = 0;
1039 }
1040
1041 fn name(&self) -> &'static str {
1042 "Mock"
1043 }
1044 }
1045
1046 fn create_test_chunk() -> DataChunk {
1047 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1049
1050 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1051 for (group, value) in data {
1052 builder.column_mut(0).unwrap().push_int64(group);
1053 builder.column_mut(1).unwrap().push_int64(value);
1054 builder.advance_row();
1055 }
1056
1057 builder.finish()
1058 }
1059
1060 #[test]
1061 fn test_simple_count() {
1062 let mock = MockOperator::new(vec![create_test_chunk()]);
1063
1064 let mut agg = SimpleAggregateOperator::new(
1065 Box::new(mock),
1066 vec![AggregateExpr::count_star()],
1067 vec![LogicalType::Int64],
1068 );
1069
1070 let result = agg.next().unwrap().unwrap();
1071 assert_eq!(result.row_count(), 1);
1072 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1073
1074 assert!(agg.next().unwrap().is_none());
1076 }
1077
1078 #[test]
1079 fn test_simple_sum() {
1080 let mock = MockOperator::new(vec![create_test_chunk()]);
1081
1082 let mut agg = SimpleAggregateOperator::new(
1083 Box::new(mock),
1084 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1086 );
1087
1088 let result = agg.next().unwrap().unwrap();
1089 assert_eq!(result.row_count(), 1);
1090 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1092 }
1093
1094 #[test]
1095 fn test_simple_avg() {
1096 let mock = MockOperator::new(vec![create_test_chunk()]);
1097
1098 let mut agg = SimpleAggregateOperator::new(
1099 Box::new(mock),
1100 vec![AggregateExpr::avg(1)],
1101 vec![LogicalType::Float64],
1102 );
1103
1104 let result = agg.next().unwrap().unwrap();
1105 assert_eq!(result.row_count(), 1);
1106 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1108 assert!((avg - 30.0).abs() < 0.001);
1109 }
1110
1111 #[test]
1112 fn test_simple_min_max() {
1113 let mock = MockOperator::new(vec![create_test_chunk()]);
1114
1115 let mut agg = SimpleAggregateOperator::new(
1116 Box::new(mock),
1117 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1118 vec![LogicalType::Int64, LogicalType::Int64],
1119 );
1120
1121 let result = agg.next().unwrap().unwrap();
1122 assert_eq!(result.row_count(), 1);
1123 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1126
1127 #[test]
1128 fn test_sum_with_string_values() {
1129 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1131 builder.column_mut(0).unwrap().push_string("30");
1132 builder.advance_row();
1133 builder.column_mut(0).unwrap().push_string("25");
1134 builder.advance_row();
1135 builder.column_mut(0).unwrap().push_string("35");
1136 builder.advance_row();
1137 let chunk = builder.finish();
1138
1139 let mock = MockOperator::new(vec![chunk]);
1140 let mut agg = SimpleAggregateOperator::new(
1141 Box::new(mock),
1142 vec![AggregateExpr::sum(0)],
1143 vec![LogicalType::Float64],
1144 );
1145
1146 let result = agg.next().unwrap().unwrap();
1147 assert_eq!(result.row_count(), 1);
1148 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1150 assert!(
1151 (sum_val - 90.0).abs() < 0.001,
1152 "Expected 90.0, got {}",
1153 sum_val
1154 );
1155 }
1156
1157 #[test]
1158 fn test_grouped_aggregation() {
1159 let mock = MockOperator::new(vec![create_test_chunk()]);
1160
1161 let mut agg = HashAggregateOperator::new(
1163 Box::new(mock),
1164 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1167 );
1168
1169 let mut results: Vec<(i64, i64)> = Vec::new();
1170 while let Some(chunk) = agg.next().unwrap() {
1171 for row in chunk.selected_indices() {
1172 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1173 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1174 results.push((group, sum));
1175 }
1176 }
1177
1178 results.sort_by_key(|(g, _)| *g);
1179 assert_eq!(results.len(), 2);
1180 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1183
1184 #[test]
1185 fn test_grouped_count() {
1186 let mock = MockOperator::new(vec![create_test_chunk()]);
1187
1188 let mut agg = HashAggregateOperator::new(
1190 Box::new(mock),
1191 vec![0],
1192 vec![AggregateExpr::count_star()],
1193 vec![LogicalType::Int64, LogicalType::Int64],
1194 );
1195
1196 let mut results: Vec<(i64, i64)> = Vec::new();
1197 while let Some(chunk) = agg.next().unwrap() {
1198 for row in chunk.selected_indices() {
1199 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1200 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1201 results.push((group, count));
1202 }
1203 }
1204
1205 results.sort_by_key(|(g, _)| *g);
1206 assert_eq!(results.len(), 2);
1207 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1210
1211 #[test]
1212 fn test_multiple_aggregates() {
1213 let mock = MockOperator::new(vec![create_test_chunk()]);
1214
1215 let mut agg = HashAggregateOperator::new(
1217 Box::new(mock),
1218 vec![0],
1219 vec![
1220 AggregateExpr::count_star(),
1221 AggregateExpr::sum(1),
1222 AggregateExpr::avg(1),
1223 ],
1224 vec![
1225 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1230 );
1231
1232 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1233 while let Some(chunk) = agg.next().unwrap() {
1234 for row in chunk.selected_indices() {
1235 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1236 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1237 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1238 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1239 results.push((group, count, sum, avg));
1240 }
1241 }
1242
1243 results.sort_by_key(|(g, _, _, _)| *g);
1244 assert_eq!(results.len(), 2);
1245
1246 assert_eq!(results[0].0, 1);
1248 assert_eq!(results[0].1, 2);
1249 assert_eq!(results[0].2, 30);
1250 assert!((results[0].3 - 15.0).abs() < 0.001);
1251
1252 assert_eq!(results[1].0, 2);
1254 assert_eq!(results[1].1, 3);
1255 assert_eq!(results[1].2, 120);
1256 assert!((results[1].3 - 40.0).abs() < 0.001);
1257 }
1258
1259 fn create_test_chunk_with_duplicates() -> DataChunk {
1260 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1265
1266 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1267 for (group, value) in data {
1268 builder.column_mut(0).unwrap().push_int64(group);
1269 builder.column_mut(1).unwrap().push_int64(value);
1270 builder.advance_row();
1271 }
1272
1273 builder.finish()
1274 }
1275
1276 #[test]
1277 fn test_count_distinct() {
1278 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1279
1280 let mut agg = SimpleAggregateOperator::new(
1282 Box::new(mock),
1283 vec![AggregateExpr::count(1).with_distinct()],
1284 vec![LogicalType::Int64],
1285 );
1286
1287 let result = agg.next().unwrap().unwrap();
1288 assert_eq!(result.row_count(), 1);
1289 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1291 }
1292
1293 #[test]
1294 fn test_grouped_count_distinct() {
1295 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1296
1297 let mut agg = HashAggregateOperator::new(
1299 Box::new(mock),
1300 vec![0],
1301 vec![AggregateExpr::count(1).with_distinct()],
1302 vec![LogicalType::Int64, LogicalType::Int64],
1303 );
1304
1305 let mut results: Vec<(i64, i64)> = Vec::new();
1306 while let Some(chunk) = agg.next().unwrap() {
1307 for row in chunk.selected_indices() {
1308 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1309 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1310 results.push((group, count));
1311 }
1312 }
1313
1314 results.sort_by_key(|(g, _)| *g);
1315 assert_eq!(results.len(), 2);
1316 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1319
1320 #[test]
1321 fn test_sum_distinct() {
1322 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1323
1324 let mut agg = SimpleAggregateOperator::new(
1326 Box::new(mock),
1327 vec![AggregateExpr::sum(1).with_distinct()],
1328 vec![LogicalType::Int64],
1329 );
1330
1331 let result = agg.next().unwrap().unwrap();
1332 assert_eq!(result.row_count(), 1);
1333 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1335 }
1336
1337 #[test]
1338 fn test_avg_distinct() {
1339 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1340
1341 let mut agg = SimpleAggregateOperator::new(
1343 Box::new(mock),
1344 vec![AggregateExpr::avg(1).with_distinct()],
1345 vec![LogicalType::Float64],
1346 );
1347
1348 let result = agg.next().unwrap().unwrap();
1349 assert_eq!(result.row_count(), 1);
1350 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1352 assert!((avg - 20.0).abs() < 0.001);
1353 }
1354
1355 fn create_statistical_test_chunk() -> DataChunk {
1356 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1359
1360 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1361 builder.column_mut(0).unwrap().push_int64(value);
1362 builder.advance_row();
1363 }
1364
1365 builder.finish()
1366 }
1367
1368 #[test]
1369 fn test_stdev_sample() {
1370 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1371
1372 let mut agg = SimpleAggregateOperator::new(
1373 Box::new(mock),
1374 vec![AggregateExpr::stdev(0)],
1375 vec![LogicalType::Float64],
1376 );
1377
1378 let result = agg.next().unwrap().unwrap();
1379 assert_eq!(result.row_count(), 1);
1380 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1383 assert!((stdev - 2.138).abs() < 0.01);
1384 }
1385
1386 #[test]
1387 fn test_stdev_population() {
1388 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1389
1390 let mut agg = SimpleAggregateOperator::new(
1391 Box::new(mock),
1392 vec![AggregateExpr::stdev_pop(0)],
1393 vec![LogicalType::Float64],
1394 );
1395
1396 let result = agg.next().unwrap().unwrap();
1397 assert_eq!(result.row_count(), 1);
1398 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1401 assert!((stdev - 2.0).abs() < 0.01);
1402 }
1403
1404 #[test]
1405 fn test_percentile_disc() {
1406 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1407
1408 let mut agg = SimpleAggregateOperator::new(
1410 Box::new(mock),
1411 vec![AggregateExpr::percentile_disc(0, 0.5)],
1412 vec![LogicalType::Float64],
1413 );
1414
1415 let result = agg.next().unwrap().unwrap();
1416 assert_eq!(result.row_count(), 1);
1417 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1419 assert!((percentile - 4.0).abs() < 0.01);
1420 }
1421
1422 #[test]
1423 fn test_percentile_cont() {
1424 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1425
1426 let mut agg = SimpleAggregateOperator::new(
1428 Box::new(mock),
1429 vec![AggregateExpr::percentile_cont(0, 0.5)],
1430 vec![LogicalType::Float64],
1431 );
1432
1433 let result = agg.next().unwrap().unwrap();
1434 assert_eq!(result.row_count(), 1);
1435 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1438 assert!((percentile - 4.5).abs() < 0.01);
1439 }
1440
1441 #[test]
1442 fn test_percentile_extremes() {
1443 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1445
1446 let mut agg = SimpleAggregateOperator::new(
1447 Box::new(mock),
1448 vec![
1449 AggregateExpr::percentile_disc(0, 0.0),
1450 AggregateExpr::percentile_disc(0, 1.0),
1451 ],
1452 vec![LogicalType::Float64, LogicalType::Float64],
1453 );
1454
1455 let result = agg.next().unwrap().unwrap();
1456 assert_eq!(result.row_count(), 1);
1457 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1459 assert!((p0 - 2.0).abs() < 0.01);
1460 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1462 assert!((p100 - 9.0).abs() < 0.01);
1463 }
1464
1465 #[test]
1466 fn test_stdev_single_value() {
1467 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1469 builder.column_mut(0).unwrap().push_int64(42);
1470 builder.advance_row();
1471 let chunk = builder.finish();
1472
1473 let mock = MockOperator::new(vec![chunk]);
1474
1475 let mut agg = SimpleAggregateOperator::new(
1476 Box::new(mock),
1477 vec![AggregateExpr::stdev(0)],
1478 vec![LogicalType::Float64],
1479 );
1480
1481 let result = agg.next().unwrap().unwrap();
1482 assert_eq!(result.row_count(), 1);
1483 assert!(matches!(
1485 result.column(0).unwrap().get_value(0),
1486 Some(Value::Null)
1487 ));
1488 }
1489
1490 #[test]
1491 fn test_first_and_last() {
1492 let mock = MockOperator::new(vec![create_test_chunk()]);
1493
1494 let mut agg = SimpleAggregateOperator::new(
1495 Box::new(mock),
1496 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1497 vec![LogicalType::Int64, LogicalType::Int64],
1498 );
1499
1500 let result = agg.next().unwrap().unwrap();
1501 assert_eq!(result.row_count(), 1);
1502 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1504 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1505 }
1506
1507 #[test]
1508 fn test_collect() {
1509 let mock = MockOperator::new(vec![create_test_chunk()]);
1510
1511 let mut agg = SimpleAggregateOperator::new(
1512 Box::new(mock),
1513 vec![AggregateExpr::collect(1)],
1514 vec![LogicalType::Any],
1515 );
1516
1517 let result = agg.next().unwrap().unwrap();
1518 let val = result.column(0).unwrap().get_value(0).unwrap();
1519 if let Value::List(items) = val {
1520 assert_eq!(items.len(), 5);
1521 } else {
1522 panic!("Expected List value");
1523 }
1524 }
1525
1526 #[test]
1527 fn test_collect_distinct() {
1528 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1529
1530 let mut agg = SimpleAggregateOperator::new(
1531 Box::new(mock),
1532 vec![AggregateExpr::collect(1).with_distinct()],
1533 vec![LogicalType::Any],
1534 );
1535
1536 let result = agg.next().unwrap().unwrap();
1537 let val = result.column(0).unwrap().get_value(0).unwrap();
1538 if let Value::List(items) = val {
1539 assert_eq!(items.len(), 3);
1541 } else {
1542 panic!("Expected List value");
1543 }
1544 }
1545
1546 #[test]
1547 fn test_group_concat() {
1548 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1549 for s in ["hello", "world", "foo"] {
1550 builder.column_mut(0).unwrap().push_string(s);
1551 builder.advance_row();
1552 }
1553 let chunk = builder.finish();
1554 let mock = MockOperator::new(vec![chunk]);
1555
1556 let agg_expr = AggregateExpr {
1557 function: AggregateFunction::GroupConcat,
1558 column: Some(0),
1559 column2: None,
1560 distinct: false,
1561 alias: None,
1562 percentile: None,
1563 separator: None,
1564 };
1565
1566 let mut agg =
1567 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1568
1569 let result = agg.next().unwrap().unwrap();
1570 let val = result.column(0).unwrap().get_value(0).unwrap();
1571 assert_eq!(val, Value::String("hello world foo".into()));
1572 }
1573
1574 #[test]
1575 fn test_sample() {
1576 let mock = MockOperator::new(vec![create_test_chunk()]);
1577
1578 let agg_expr = AggregateExpr {
1579 function: AggregateFunction::Sample,
1580 column: Some(1),
1581 column2: None,
1582 distinct: false,
1583 alias: None,
1584 percentile: None,
1585 separator: None,
1586 };
1587
1588 let mut agg =
1589 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1590
1591 let result = agg.next().unwrap().unwrap();
1592 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1594 }
1595
1596 #[test]
1597 fn test_variance_sample() {
1598 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1599
1600 let agg_expr = AggregateExpr {
1601 function: AggregateFunction::Variance,
1602 column: Some(0),
1603 column2: None,
1604 distinct: false,
1605 alias: None,
1606 percentile: None,
1607 separator: None,
1608 };
1609
1610 let mut agg = SimpleAggregateOperator::new(
1611 Box::new(mock),
1612 vec![agg_expr],
1613 vec![LogicalType::Float64],
1614 );
1615
1616 let result = agg.next().unwrap().unwrap();
1617 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1619 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1620 }
1621
1622 #[test]
1623 fn test_variance_population() {
1624 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1625
1626 let agg_expr = AggregateExpr {
1627 function: AggregateFunction::VariancePop,
1628 column: Some(0),
1629 column2: None,
1630 distinct: false,
1631 alias: None,
1632 percentile: None,
1633 separator: None,
1634 };
1635
1636 let mut agg = SimpleAggregateOperator::new(
1637 Box::new(mock),
1638 vec![agg_expr],
1639 vec![LogicalType::Float64],
1640 );
1641
1642 let result = agg.next().unwrap().unwrap();
1643 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1645 assert!((variance - 4.0).abs() < 0.01);
1646 }
1647
1648 #[test]
1649 fn test_variance_single_value() {
1650 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1651 builder.column_mut(0).unwrap().push_int64(42);
1652 builder.advance_row();
1653 let chunk = builder.finish();
1654 let mock = MockOperator::new(vec![chunk]);
1655
1656 let agg_expr = AggregateExpr {
1657 function: AggregateFunction::Variance,
1658 column: Some(0),
1659 column2: None,
1660 distinct: false,
1661 alias: None,
1662 percentile: None,
1663 separator: None,
1664 };
1665
1666 let mut agg = SimpleAggregateOperator::new(
1667 Box::new(mock),
1668 vec![agg_expr],
1669 vec![LogicalType::Float64],
1670 );
1671
1672 let result = agg.next().unwrap().unwrap();
1673 assert!(matches!(
1675 result.column(0).unwrap().get_value(0),
1676 Some(Value::Null)
1677 ));
1678 }
1679
1680 #[test]
1681 fn test_empty_aggregation() {
1682 let mock = MockOperator::new(vec![]);
1685
1686 let mut agg = SimpleAggregateOperator::new(
1687 Box::new(mock),
1688 vec![
1689 AggregateExpr::count_star(),
1690 AggregateExpr::sum(0),
1691 AggregateExpr::avg(0),
1692 AggregateExpr::min(0),
1693 AggregateExpr::max(0),
1694 ],
1695 vec![
1696 LogicalType::Int64,
1697 LogicalType::Int64,
1698 LogicalType::Float64,
1699 LogicalType::Int64,
1700 LogicalType::Int64,
1701 ],
1702 );
1703
1704 let result = agg.next().unwrap().unwrap();
1705 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1707 result.column(1).unwrap().get_value(0),
1708 Some(Value::Null)
1709 )); assert!(matches!(
1711 result.column(2).unwrap().get_value(0),
1712 Some(Value::Null)
1713 )); assert!(matches!(
1715 result.column(3).unwrap().get_value(0),
1716 Some(Value::Null)
1717 )); assert!(matches!(
1719 result.column(4).unwrap().get_value(0),
1720 Some(Value::Null)
1721 )); }
1723
1724 #[test]
1725 fn test_stdev_pop_single_value() {
1726 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1728 builder.column_mut(0).unwrap().push_int64(42);
1729 builder.advance_row();
1730 let chunk = builder.finish();
1731
1732 let mock = MockOperator::new(vec![chunk]);
1733
1734 let mut agg = SimpleAggregateOperator::new(
1735 Box::new(mock),
1736 vec![AggregateExpr::stdev_pop(0)],
1737 vec![LogicalType::Float64],
1738 );
1739
1740 let result = agg.next().unwrap().unwrap();
1741 assert_eq!(result.row_count(), 1);
1742 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1744 assert!((stdev - 0.0).abs() < 0.01);
1745 }
1746}