1use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use arcstr::ArcStr;
14use grafeo_common::types::{LogicalType, Value};
15
16use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
17use super::{Operator, OperatorError, OperatorResult};
18use crate::execution::DataChunk;
19use crate::execution::chunk::DataChunkBuilder;
20
21#[derive(Debug, Clone)]
23pub(crate) enum AggregateState {
24 Count(i64),
26 CountDistinct(i64, HashSet<HashableValue>),
28 SumInt(i64, i64),
30 SumIntDistinct(i64, i64, HashSet<HashableValue>),
32 SumFloat(f64, i64),
34 SumFloatDistinct(f64, i64, HashSet<HashableValue>),
36 Avg(f64, i64),
38 AvgDistinct(f64, i64, HashSet<HashableValue>),
40 Min(Option<Value>),
42 Max(Option<Value>),
44 First(Option<Value>),
46 Last(Option<Value>),
48 Collect(Vec<Value>),
50 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
52 StdDev { count: i64, mean: f64, m2: f64 },
54 StdDevPop { count: i64, mean: f64, m2: f64 },
56 PercentileDisc { values: Vec<f64>, percentile: f64 },
58 PercentileCont { values: Vec<f64>, percentile: f64 },
60 GroupConcat(Vec<String>, String),
62 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
64 Sample(Option<Value>),
66 Variance { count: i64, mean: f64, m2: f64 },
68 VariancePop { count: i64, mean: f64, m2: f64 },
70 Bivariate {
72 kind: AggregateFunction,
74 count: i64,
75 mean_x: f64,
76 mean_y: f64,
77 m2_x: f64,
78 m2_y: f64,
79 c_xy: f64,
80 },
81}
82
83impl AggregateState {
84 pub(crate) fn new(
86 function: AggregateFunction,
87 distinct: bool,
88 percentile: Option<f64>,
89 separator: Option<&str>,
90 ) -> Self {
91 match (function, distinct) {
92 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
93 AggregateState::Count(0)
94 }
95 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
96 AggregateState::CountDistinct(0, HashSet::new())
97 }
98 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
99 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
100 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
101 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
102 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
104 (AggregateFunction::First, _) => AggregateState::First(None),
105 (AggregateFunction::Last, _) => AggregateState::Last(None),
106 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
107 (AggregateFunction::Collect, true) => {
108 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
109 }
110 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
112 count: 0,
113 mean: 0.0,
114 m2: 0.0,
115 },
116 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
117 count: 0,
118 mean: 0.0,
119 m2: 0.0,
120 },
121 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
122 values: Vec::new(),
123 percentile: percentile.unwrap_or(0.5),
124 },
125 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
126 values: Vec::new(),
127 percentile: percentile.unwrap_or(0.5),
128 },
129 (AggregateFunction::GroupConcat, false) => {
130 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
131 }
132 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
133 Vec::new(),
134 separator.unwrap_or(" ").to_string(),
135 HashSet::new(),
136 ),
137 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
138 (
140 AggregateFunction::CovarSamp
141 | AggregateFunction::CovarPop
142 | AggregateFunction::Corr
143 | AggregateFunction::RegrSlope
144 | AggregateFunction::RegrIntercept
145 | AggregateFunction::RegrR2
146 | AggregateFunction::RegrCount
147 | AggregateFunction::RegrSxx
148 | AggregateFunction::RegrSyy
149 | AggregateFunction::RegrSxy
150 | AggregateFunction::RegrAvgx
151 | AggregateFunction::RegrAvgy,
152 _,
153 ) => AggregateState::Bivariate {
154 kind: function,
155 count: 0,
156 mean_x: 0.0,
157 mean_y: 0.0,
158 m2_x: 0.0,
159 m2_y: 0.0,
160 c_xy: 0.0,
161 },
162 (AggregateFunction::Variance, _) => AggregateState::Variance {
163 count: 0,
164 mean: 0.0,
165 m2: 0.0,
166 },
167 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
168 count: 0,
169 mean: 0.0,
170 m2: 0.0,
171 },
172 }
173 }
174
175 pub(crate) fn update(&mut self, value: Option<Value>) {
177 match self {
178 AggregateState::Count(count) => {
179 *count += 1;
180 }
181 AggregateState::CountDistinct(count, seen) => {
182 if let Some(ref v) = value {
183 let hashable = HashableValue::from(v);
184 if seen.insert(hashable) {
185 *count += 1;
186 }
187 }
188 }
189 AggregateState::SumInt(sum, count) => {
190 if let Some(Value::Int64(v)) = value {
191 *sum += v;
192 *count += 1;
193 } else if let Some(Value::Float64(v)) = value {
194 *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
196 } else if let Some(ref v) = value {
197 if let Some(num) = value_to_f64(v) {
199 *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
200 }
201 }
202 }
203 AggregateState::SumIntDistinct(sum, count, seen) => {
204 if let Some(ref v) = value {
205 let hashable = HashableValue::from(v);
206 if seen.insert(hashable) {
207 if let Value::Int64(i) = v {
208 *sum += i;
209 *count += 1;
210 } else if let Value::Float64(f) = v {
211 let moved_seen = std::mem::take(seen);
213 *self = AggregateState::SumFloatDistinct(
214 *sum as f64 + f,
215 *count + 1,
216 moved_seen,
217 );
218 } else if let Some(num) = value_to_f64(v) {
219 let moved_seen = std::mem::take(seen);
221 *self = AggregateState::SumFloatDistinct(
222 *sum as f64 + num,
223 *count + 1,
224 moved_seen,
225 );
226 }
227 }
228 }
229 }
230 AggregateState::SumFloat(sum, count) => {
231 if let Some(ref v) = value {
232 if let Some(num) = value_to_f64(v) {
234 *sum += num;
235 *count += 1;
236 }
237 }
238 }
239 AggregateState::SumFloatDistinct(sum, count, seen) => {
240 if let Some(ref v) = value {
241 let hashable = HashableValue::from(v);
242 if seen.insert(hashable)
243 && let Some(num) = value_to_f64(v)
244 {
245 *sum += num;
246 *count += 1;
247 }
248 }
249 }
250 AggregateState::Avg(sum, count) => {
251 if let Some(ref v) = value
252 && let Some(num) = value_to_f64(v)
253 {
254 *sum += num;
255 *count += 1;
256 }
257 }
258 AggregateState::AvgDistinct(sum, count, seen) => {
259 if let Some(ref v) = value {
260 let hashable = HashableValue::from(v);
261 if seen.insert(hashable)
262 && let Some(num) = value_to_f64(v)
263 {
264 *sum += num;
265 *count += 1;
266 }
267 }
268 }
269 AggregateState::Min(min) => {
270 if let Some(v) = value {
271 match min {
272 None => *min = Some(v),
273 Some(current) => {
274 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
275 *min = Some(v);
276 }
277 }
278 }
279 }
280 }
281 AggregateState::Max(max) => {
282 if let Some(v) = value {
283 match max {
284 None => *max = Some(v),
285 Some(current) => {
286 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
287 *max = Some(v);
288 }
289 }
290 }
291 }
292 }
293 AggregateState::First(first) => {
294 if first.is_none() {
295 *first = value;
296 }
297 }
298 AggregateState::Last(last) => {
299 if value.is_some() {
300 *last = value;
301 }
302 }
303 AggregateState::Collect(list) => {
304 if let Some(v) = value {
305 list.push(v);
306 }
307 }
308 AggregateState::CollectDistinct(list, seen) => {
309 if let Some(v) = value {
310 let hashable = HashableValue::from(&v);
311 if seen.insert(hashable) {
312 list.push(v);
313 }
314 }
315 }
316 AggregateState::StdDev { count, mean, m2 }
318 | AggregateState::StdDevPop { count, mean, m2 }
319 | AggregateState::Variance { count, mean, m2 }
320 | AggregateState::VariancePop { count, mean, m2 } => {
321 if let Some(ref v) = value
322 && let Some(x) = value_to_f64(v)
323 {
324 *count += 1;
325 let delta = x - *mean;
326 *mean += delta / *count as f64;
327 let delta2 = x - *mean;
328 *m2 += delta * delta2;
329 }
330 }
331 AggregateState::PercentileDisc { values, .. }
332 | AggregateState::PercentileCont { values, .. } => {
333 if let Some(ref v) = value
334 && let Some(x) = value_to_f64(v)
335 {
336 values.push(x);
337 }
338 }
339 AggregateState::GroupConcat(list, _sep) => {
340 if let Some(v) = value {
341 list.push(agg_value_to_string(&v));
342 }
343 }
344 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
345 if let Some(v) = value {
346 let hashable = HashableValue::from(&v);
347 if seen.insert(hashable) {
348 list.push(agg_value_to_string(&v));
349 }
350 }
351 }
352 AggregateState::Sample(sample) => {
353 if sample.is_none() {
354 *sample = value;
355 }
356 }
357 AggregateState::Bivariate { .. } => {
358 }
361 }
362 }
363
364 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
369 if let AggregateState::Bivariate {
370 count,
371 mean_x,
372 mean_y,
373 m2_x,
374 m2_y,
375 c_xy,
376 ..
377 } = self
378 {
379 if let (Some(y), Some(x)) = (&y_val, &x_val)
381 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
382 {
383 *count += 1;
384 let n = *count as f64;
385 let dx = x_f - *mean_x;
386 let dy = y_f - *mean_y;
387 *mean_x += dx / n;
388 *mean_y += dy / n;
389 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
392 *m2_y += dy * dy2;
393 *c_xy += dx * dy2;
394 }
395 }
396 }
397
398 pub(crate) fn finalize(&self) -> Value {
400 match self {
401 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
402 Value::Int64(*count)
403 }
404 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
405 if *count == 0 {
406 Value::Null
407 } else {
408 Value::Int64(*sum)
409 }
410 }
411 AggregateState::SumFloat(sum, count)
412 | AggregateState::SumFloatDistinct(sum, count, _) => {
413 if *count == 0 {
414 Value::Null
415 } else {
416 Value::Float64(*sum)
417 }
418 }
419 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
420 if *count == 0 {
421 Value::Null
422 } else {
423 Value::Float64(*sum / *count as f64)
424 }
425 }
426 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
427 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
428 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
429 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
430 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
431 Value::List(list.clone().into())
432 }
433 AggregateState::StdDev { count, m2, .. } => {
435 if *count < 2 {
436 Value::Null
437 } else {
438 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
439 }
440 }
441 AggregateState::StdDevPop { count, m2, .. } => {
443 if *count == 0 {
444 Value::Null
445 } else {
446 Value::Float64((*m2 / *count as f64).sqrt())
447 }
448 }
449 AggregateState::Variance { count, m2, .. } => {
451 if *count < 2 {
452 Value::Null
453 } else {
454 Value::Float64(*m2 / (*count - 1) as f64)
455 }
456 }
457 AggregateState::VariancePop { count, m2, .. } => {
459 if *count == 0 {
460 Value::Null
461 } else {
462 Value::Float64(*m2 / *count as f64)
463 }
464 }
465 AggregateState::PercentileDisc { values, percentile } => {
467 if values.is_empty() {
468 Value::Null
469 } else {
470 let mut sorted = values.clone();
471 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
474 Value::Float64(sorted[index])
475 }
476 }
477 AggregateState::PercentileCont { values, percentile } => {
479 if values.is_empty() {
480 Value::Null
481 } else {
482 let mut sorted = values.clone();
483 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
484 let rank = percentile * (sorted.len() - 1) as f64;
486 let lower_idx = rank.floor() as usize;
487 let upper_idx = rank.ceil() as usize;
488 if lower_idx == upper_idx {
489 Value::Float64(sorted[lower_idx])
490 } else {
491 let fraction = rank - lower_idx as f64;
492 let result =
493 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
494 Value::Float64(result)
495 }
496 }
497 }
498 AggregateState::GroupConcat(list, sep)
500 | AggregateState::GroupConcatDistinct(list, sep, _) => {
501 Value::String(list.join(sep).into())
502 }
503 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
505 AggregateState::Bivariate {
507 kind,
508 count,
509 mean_x,
510 mean_y,
511 m2_x,
512 m2_y,
513 c_xy,
514 } => {
515 let n = *count;
516 match kind {
517 AggregateFunction::CovarSamp => {
518 if n < 2 {
519 Value::Null
520 } else {
521 Value::Float64(*c_xy / (n - 1) as f64)
522 }
523 }
524 AggregateFunction::CovarPop => {
525 if n == 0 {
526 Value::Null
527 } else {
528 Value::Float64(*c_xy / n as f64)
529 }
530 }
531 AggregateFunction::Corr => {
532 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
533 Value::Null
534 } else {
535 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
536 }
537 }
538 AggregateFunction::RegrSlope => {
539 if n == 0 || *m2_x == 0.0 {
540 Value::Null
541 } else {
542 Value::Float64(*c_xy / *m2_x)
543 }
544 }
545 AggregateFunction::RegrIntercept => {
546 if n == 0 || *m2_x == 0.0 {
547 Value::Null
548 } else {
549 let slope = *c_xy / *m2_x;
550 Value::Float64(*mean_y - slope * *mean_x)
551 }
552 }
553 AggregateFunction::RegrR2 => {
554 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
555 Value::Null
556 } else {
557 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
558 }
559 }
560 AggregateFunction::RegrCount => Value::Int64(n),
561 AggregateFunction::RegrSxx => {
562 if n == 0 {
563 Value::Null
564 } else {
565 Value::Float64(*m2_x)
566 }
567 }
568 AggregateFunction::RegrSyy => {
569 if n == 0 {
570 Value::Null
571 } else {
572 Value::Float64(*m2_y)
573 }
574 }
575 AggregateFunction::RegrSxy => {
576 if n == 0 {
577 Value::Null
578 } else {
579 Value::Float64(*c_xy)
580 }
581 }
582 AggregateFunction::RegrAvgx => {
583 if n == 0 {
584 Value::Null
585 } else {
586 Value::Float64(*mean_x)
587 }
588 }
589 AggregateFunction::RegrAvgy => {
590 if n == 0 {
591 Value::Null
592 } else {
593 Value::Float64(*mean_y)
594 }
595 }
596 _ => Value::Null, }
598 }
599 }
600 }
601}
602
603use super::value_utils::{compare_values, value_to_f64};
604
605fn agg_value_to_string(val: &Value) -> String {
607 match val {
608 Value::String(s) => s.to_string(),
609 Value::Int64(i) => i.to_string(),
610 Value::Float64(f) => f.to_string(),
611 Value::Bool(b) => b.to_string(),
612 Value::Null => String::new(),
613 other => format!("{other:?}"),
614 }
615}
616
617#[derive(Debug, Clone, PartialEq, Eq, Hash)]
619pub struct GroupKey(Vec<GroupKeyPart>);
620
621#[derive(Debug, Clone, PartialEq, Eq, Hash)]
622enum GroupKeyPart {
623 Null,
624 Bool(bool),
625 Int64(i64),
626 String(ArcStr),
627}
628
629impl GroupKey {
630 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
632 let parts: Vec<GroupKeyPart> = group_columns
633 .iter()
634 .map(|&col_idx| {
635 chunk
636 .column(col_idx)
637 .and_then(|col| col.get_value(row))
638 .map_or(GroupKeyPart::Null, |v| match v {
639 Value::Null => GroupKeyPart::Null,
640 Value::Bool(b) => GroupKeyPart::Bool(b),
641 Value::Int64(i) => GroupKeyPart::Int64(i),
642 Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
643 Value::String(s) => GroupKeyPart::String(s.clone()),
644 _ => GroupKeyPart::String(ArcStr::from(format!("{v:?}"))),
645 })
646 })
647 .collect();
648 GroupKey(parts)
649 }
650
651 fn to_values(&self) -> Vec<Value> {
653 self.0
654 .iter()
655 .map(|part| match part {
656 GroupKeyPart::Null => Value::Null,
657 GroupKeyPart::Bool(b) => Value::Bool(*b),
658 GroupKeyPart::Int64(i) => Value::Int64(*i),
659 GroupKeyPart::String(s) => Value::String(s.clone()),
660 })
661 .collect()
662 }
663}
664
665pub struct HashAggregateOperator {
669 child: Box<dyn Operator>,
671 group_columns: Vec<usize>,
673 aggregates: Vec<AggregateExpr>,
675 output_schema: Vec<LogicalType>,
677 groups: IndexMap<GroupKey, Vec<AggregateState>>,
679 aggregation_complete: bool,
681 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
683}
684
685impl HashAggregateOperator {
686 pub fn new(
694 child: Box<dyn Operator>,
695 group_columns: Vec<usize>,
696 aggregates: Vec<AggregateExpr>,
697 output_schema: Vec<LogicalType>,
698 ) -> Self {
699 Self {
700 child,
701 group_columns,
702 aggregates,
703 output_schema,
704 groups: IndexMap::new(),
705 aggregation_complete: false,
706 results: None,
707 }
708 }
709
710 fn aggregate(&mut self) -> Result<(), OperatorError> {
712 while let Some(chunk) = self.child.next()? {
713 for row in chunk.selected_indices() {
714 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
715
716 let states = self.groups.entry(key).or_insert_with(|| {
718 self.aggregates
719 .iter()
720 .map(|agg| {
721 AggregateState::new(
722 agg.function,
723 agg.distinct,
724 agg.percentile,
725 agg.separator.as_deref(),
726 )
727 })
728 .collect()
729 });
730
731 for (i, agg) in self.aggregates.iter().enumerate() {
733 if agg.column2.is_some() {
735 let y_val = agg
736 .column
737 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
738 let x_val = agg
739 .column2
740 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
741 states[i].update_bivariate(y_val, x_val);
742 continue;
743 }
744
745 let value = match (agg.function, agg.distinct) {
746 (AggregateFunction::Count, false) => None,
748 (AggregateFunction::Count, true) => agg
750 .column
751 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
752 _ => agg
753 .column
754 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
755 };
756
757 match (agg.function, agg.distinct) {
759 (AggregateFunction::Count, false) => states[i].update(None),
760 (AggregateFunction::Count, true) => {
761 if value.is_some() && !matches!(value, Some(Value::Null)) {
763 states[i].update(value);
764 }
765 }
766 (AggregateFunction::CountNonNull, _) => {
767 if value.is_some() && !matches!(value, Some(Value::Null)) {
768 states[i].update(value);
769 }
770 }
771 _ => {
772 if value.is_some() && !matches!(value, Some(Value::Null)) {
773 states[i].update(value);
774 }
775 }
776 }
777 }
778 }
779 }
780
781 self.aggregation_complete = true;
782
783 let results: Vec<_> = self.groups.drain(..).collect();
785 self.results = Some(results.into_iter());
786
787 Ok(())
788 }
789}
790
791impl Operator for HashAggregateOperator {
792 fn next(&mut self) -> OperatorResult {
793 if !self.aggregation_complete {
795 self.aggregate()?;
796 }
797
798 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
800 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
802
803 for agg in &self.aggregates {
804 let state = AggregateState::new(
805 agg.function,
806 agg.distinct,
807 agg.percentile,
808 agg.separator.as_deref(),
809 );
810 let value = state.finalize();
811 if let Some(col) = builder.column_mut(self.group_columns.len()) {
812 col.push_value(value);
813 }
814 }
815 builder.advance_row();
816
817 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
819 }
820
821 let Some(results) = &mut self.results else {
822 return Ok(None);
823 };
824
825 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
826
827 for (key, states) in results.by_ref() {
828 let key_values = key.to_values();
830 for (i, value) in key_values.into_iter().enumerate() {
831 if let Some(col) = builder.column_mut(i) {
832 col.push_value(value);
833 }
834 }
835
836 for (i, state) in states.iter().enumerate() {
838 let col_idx = self.group_columns.len() + i;
839 if let Some(col) = builder.column_mut(col_idx) {
840 col.push_value(state.finalize());
841 }
842 }
843
844 builder.advance_row();
845
846 if builder.is_full() {
847 return Ok(Some(builder.finish()));
848 }
849 }
850
851 if builder.row_count() > 0 {
852 Ok(Some(builder.finish()))
853 } else {
854 Ok(None)
855 }
856 }
857
858 fn reset(&mut self) {
859 self.child.reset();
860 self.groups.clear();
861 self.aggregation_complete = false;
862 self.results = None;
863 }
864
865 fn name(&self) -> &'static str {
866 "HashAggregate"
867 }
868}
869
870pub struct SimpleAggregateOperator {
874 child: Box<dyn Operator>,
876 aggregates: Vec<AggregateExpr>,
878 output_schema: Vec<LogicalType>,
880 states: Vec<AggregateState>,
882 done: bool,
884}
885
886impl SimpleAggregateOperator {
887 pub fn new(
889 child: Box<dyn Operator>,
890 aggregates: Vec<AggregateExpr>,
891 output_schema: Vec<LogicalType>,
892 ) -> Self {
893 let states = aggregates
894 .iter()
895 .map(|agg| {
896 AggregateState::new(
897 agg.function,
898 agg.distinct,
899 agg.percentile,
900 agg.separator.as_deref(),
901 )
902 })
903 .collect();
904
905 Self {
906 child,
907 aggregates,
908 output_schema,
909 states,
910 done: false,
911 }
912 }
913}
914
915impl Operator for SimpleAggregateOperator {
916 fn next(&mut self) -> OperatorResult {
917 if self.done {
918 return Ok(None);
919 }
920
921 while let Some(chunk) = self.child.next()? {
923 for row in chunk.selected_indices() {
924 for (i, agg) in self.aggregates.iter().enumerate() {
925 if agg.column2.is_some() {
927 let y_val = agg
928 .column
929 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
930 let x_val = agg
931 .column2
932 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
933 self.states[i].update_bivariate(y_val, x_val);
934 continue;
935 }
936
937 let value = match (agg.function, agg.distinct) {
938 (AggregateFunction::Count, false) => None,
940 (AggregateFunction::Count, true) => agg
942 .column
943 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
944 _ => agg
945 .column
946 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
947 };
948
949 match (agg.function, agg.distinct) {
950 (AggregateFunction::Count, false) => self.states[i].update(None),
951 (AggregateFunction::Count, true) => {
952 if value.is_some() && !matches!(value, Some(Value::Null)) {
954 self.states[i].update(value);
955 }
956 }
957 (AggregateFunction::CountNonNull, _) => {
958 if value.is_some() && !matches!(value, Some(Value::Null)) {
959 self.states[i].update(value);
960 }
961 }
962 _ => {
963 if value.is_some() && !matches!(value, Some(Value::Null)) {
964 self.states[i].update(value);
965 }
966 }
967 }
968 }
969 }
970 }
971
972 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
974
975 for (i, state) in self.states.iter().enumerate() {
976 if let Some(col) = builder.column_mut(i) {
977 col.push_value(state.finalize());
978 }
979 }
980 builder.advance_row();
981
982 self.done = true;
983 Ok(Some(builder.finish()))
984 }
985
986 fn reset(&mut self) {
987 self.child.reset();
988 self.states = self
989 .aggregates
990 .iter()
991 .map(|agg| {
992 AggregateState::new(
993 agg.function,
994 agg.distinct,
995 agg.percentile,
996 agg.separator.as_deref(),
997 )
998 })
999 .collect();
1000 self.done = false;
1001 }
1002
1003 fn name(&self) -> &'static str {
1004 "SimpleAggregate"
1005 }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010 use super::*;
1011 use crate::execution::chunk::DataChunkBuilder;
1012
1013 struct MockOperator {
1014 chunks: Vec<DataChunk>,
1015 position: usize,
1016 }
1017
1018 impl MockOperator {
1019 fn new(chunks: Vec<DataChunk>) -> Self {
1020 Self {
1021 chunks,
1022 position: 0,
1023 }
1024 }
1025 }
1026
1027 impl Operator for MockOperator {
1028 fn next(&mut self) -> OperatorResult {
1029 if self.position < self.chunks.len() {
1030 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1031 self.position += 1;
1032 Ok(Some(chunk))
1033 } else {
1034 Ok(None)
1035 }
1036 }
1037
1038 fn reset(&mut self) {
1039 self.position = 0;
1040 }
1041
1042 fn name(&self) -> &'static str {
1043 "Mock"
1044 }
1045 }
1046
1047 fn create_test_chunk() -> DataChunk {
1048 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1050
1051 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1052 for (group, value) in data {
1053 builder.column_mut(0).unwrap().push_int64(group);
1054 builder.column_mut(1).unwrap().push_int64(value);
1055 builder.advance_row();
1056 }
1057
1058 builder.finish()
1059 }
1060
1061 #[test]
1062 fn test_simple_count() {
1063 let mock = MockOperator::new(vec![create_test_chunk()]);
1064
1065 let mut agg = SimpleAggregateOperator::new(
1066 Box::new(mock),
1067 vec![AggregateExpr::count_star()],
1068 vec![LogicalType::Int64],
1069 );
1070
1071 let result = agg.next().unwrap().unwrap();
1072 assert_eq!(result.row_count(), 1);
1073 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1074
1075 assert!(agg.next().unwrap().is_none());
1077 }
1078
1079 #[test]
1080 fn test_simple_sum() {
1081 let mock = MockOperator::new(vec![create_test_chunk()]);
1082
1083 let mut agg = SimpleAggregateOperator::new(
1084 Box::new(mock),
1085 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1087 );
1088
1089 let result = agg.next().unwrap().unwrap();
1090 assert_eq!(result.row_count(), 1);
1091 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1093 }
1094
1095 #[test]
1096 fn test_simple_avg() {
1097 let mock = MockOperator::new(vec![create_test_chunk()]);
1098
1099 let mut agg = SimpleAggregateOperator::new(
1100 Box::new(mock),
1101 vec![AggregateExpr::avg(1)],
1102 vec![LogicalType::Float64],
1103 );
1104
1105 let result = agg.next().unwrap().unwrap();
1106 assert_eq!(result.row_count(), 1);
1107 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1109 assert!((avg - 30.0).abs() < 0.001);
1110 }
1111
1112 #[test]
1113 fn test_simple_min_max() {
1114 let mock = MockOperator::new(vec![create_test_chunk()]);
1115
1116 let mut agg = SimpleAggregateOperator::new(
1117 Box::new(mock),
1118 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1119 vec![LogicalType::Int64, LogicalType::Int64],
1120 );
1121
1122 let result = agg.next().unwrap().unwrap();
1123 assert_eq!(result.row_count(), 1);
1124 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1127
1128 #[test]
1129 fn test_sum_with_string_values() {
1130 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1132 builder.column_mut(0).unwrap().push_string("30");
1133 builder.advance_row();
1134 builder.column_mut(0).unwrap().push_string("25");
1135 builder.advance_row();
1136 builder.column_mut(0).unwrap().push_string("35");
1137 builder.advance_row();
1138 let chunk = builder.finish();
1139
1140 let mock = MockOperator::new(vec![chunk]);
1141 let mut agg = SimpleAggregateOperator::new(
1142 Box::new(mock),
1143 vec![AggregateExpr::sum(0)],
1144 vec![LogicalType::Float64],
1145 );
1146
1147 let result = agg.next().unwrap().unwrap();
1148 assert_eq!(result.row_count(), 1);
1149 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1151 assert!(
1152 (sum_val - 90.0).abs() < 0.001,
1153 "Expected 90.0, got {}",
1154 sum_val
1155 );
1156 }
1157
1158 #[test]
1159 fn test_grouped_aggregation() {
1160 let mock = MockOperator::new(vec![create_test_chunk()]);
1161
1162 let mut agg = HashAggregateOperator::new(
1164 Box::new(mock),
1165 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1168 );
1169
1170 let mut results: Vec<(i64, i64)> = Vec::new();
1171 while let Some(chunk) = agg.next().unwrap() {
1172 for row in chunk.selected_indices() {
1173 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1174 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1175 results.push((group, sum));
1176 }
1177 }
1178
1179 results.sort_by_key(|(g, _)| *g);
1180 assert_eq!(results.len(), 2);
1181 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1184
1185 #[test]
1186 fn test_grouped_count() {
1187 let mock = MockOperator::new(vec![create_test_chunk()]);
1188
1189 let mut agg = HashAggregateOperator::new(
1191 Box::new(mock),
1192 vec![0],
1193 vec![AggregateExpr::count_star()],
1194 vec![LogicalType::Int64, LogicalType::Int64],
1195 );
1196
1197 let mut results: Vec<(i64, i64)> = Vec::new();
1198 while let Some(chunk) = agg.next().unwrap() {
1199 for row in chunk.selected_indices() {
1200 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1201 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1202 results.push((group, count));
1203 }
1204 }
1205
1206 results.sort_by_key(|(g, _)| *g);
1207 assert_eq!(results.len(), 2);
1208 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1211
1212 #[test]
1213 fn test_multiple_aggregates() {
1214 let mock = MockOperator::new(vec![create_test_chunk()]);
1215
1216 let mut agg = HashAggregateOperator::new(
1218 Box::new(mock),
1219 vec![0],
1220 vec![
1221 AggregateExpr::count_star(),
1222 AggregateExpr::sum(1),
1223 AggregateExpr::avg(1),
1224 ],
1225 vec![
1226 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1231 );
1232
1233 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1234 while let Some(chunk) = agg.next().unwrap() {
1235 for row in chunk.selected_indices() {
1236 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1237 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1238 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1239 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1240 results.push((group, count, sum, avg));
1241 }
1242 }
1243
1244 results.sort_by_key(|(g, _, _, _)| *g);
1245 assert_eq!(results.len(), 2);
1246
1247 assert_eq!(results[0].0, 1);
1249 assert_eq!(results[0].1, 2);
1250 assert_eq!(results[0].2, 30);
1251 assert!((results[0].3 - 15.0).abs() < 0.001);
1252
1253 assert_eq!(results[1].0, 2);
1255 assert_eq!(results[1].1, 3);
1256 assert_eq!(results[1].2, 120);
1257 assert!((results[1].3 - 40.0).abs() < 0.001);
1258 }
1259
1260 fn create_test_chunk_with_duplicates() -> DataChunk {
1261 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1266
1267 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1268 for (group, value) in data {
1269 builder.column_mut(0).unwrap().push_int64(group);
1270 builder.column_mut(1).unwrap().push_int64(value);
1271 builder.advance_row();
1272 }
1273
1274 builder.finish()
1275 }
1276
1277 #[test]
1278 fn test_count_distinct() {
1279 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1280
1281 let mut agg = SimpleAggregateOperator::new(
1283 Box::new(mock),
1284 vec![AggregateExpr::count(1).with_distinct()],
1285 vec![LogicalType::Int64],
1286 );
1287
1288 let result = agg.next().unwrap().unwrap();
1289 assert_eq!(result.row_count(), 1);
1290 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1292 }
1293
1294 #[test]
1295 fn test_grouped_count_distinct() {
1296 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1297
1298 let mut agg = HashAggregateOperator::new(
1300 Box::new(mock),
1301 vec![0],
1302 vec![AggregateExpr::count(1).with_distinct()],
1303 vec![LogicalType::Int64, LogicalType::Int64],
1304 );
1305
1306 let mut results: Vec<(i64, i64)> = Vec::new();
1307 while let Some(chunk) = agg.next().unwrap() {
1308 for row in chunk.selected_indices() {
1309 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1310 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1311 results.push((group, count));
1312 }
1313 }
1314
1315 results.sort_by_key(|(g, _)| *g);
1316 assert_eq!(results.len(), 2);
1317 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1320
1321 #[test]
1322 fn test_sum_distinct() {
1323 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1324
1325 let mut agg = SimpleAggregateOperator::new(
1327 Box::new(mock),
1328 vec![AggregateExpr::sum(1).with_distinct()],
1329 vec![LogicalType::Int64],
1330 );
1331
1332 let result = agg.next().unwrap().unwrap();
1333 assert_eq!(result.row_count(), 1);
1334 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1336 }
1337
1338 #[test]
1339 fn test_avg_distinct() {
1340 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1341
1342 let mut agg = SimpleAggregateOperator::new(
1344 Box::new(mock),
1345 vec![AggregateExpr::avg(1).with_distinct()],
1346 vec![LogicalType::Float64],
1347 );
1348
1349 let result = agg.next().unwrap().unwrap();
1350 assert_eq!(result.row_count(), 1);
1351 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1353 assert!((avg - 20.0).abs() < 0.001);
1354 }
1355
1356 fn create_statistical_test_chunk() -> DataChunk {
1357 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1360
1361 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1362 builder.column_mut(0).unwrap().push_int64(value);
1363 builder.advance_row();
1364 }
1365
1366 builder.finish()
1367 }
1368
1369 #[test]
1370 fn test_stdev_sample() {
1371 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1372
1373 let mut agg = SimpleAggregateOperator::new(
1374 Box::new(mock),
1375 vec![AggregateExpr::stdev(0)],
1376 vec![LogicalType::Float64],
1377 );
1378
1379 let result = agg.next().unwrap().unwrap();
1380 assert_eq!(result.row_count(), 1);
1381 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1384 assert!((stdev - 2.138).abs() < 0.01);
1385 }
1386
1387 #[test]
1388 fn test_stdev_population() {
1389 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1390
1391 let mut agg = SimpleAggregateOperator::new(
1392 Box::new(mock),
1393 vec![AggregateExpr::stdev_pop(0)],
1394 vec![LogicalType::Float64],
1395 );
1396
1397 let result = agg.next().unwrap().unwrap();
1398 assert_eq!(result.row_count(), 1);
1399 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1402 assert!((stdev - 2.0).abs() < 0.01);
1403 }
1404
1405 #[test]
1406 fn test_percentile_disc() {
1407 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1408
1409 let mut agg = SimpleAggregateOperator::new(
1411 Box::new(mock),
1412 vec![AggregateExpr::percentile_disc(0, 0.5)],
1413 vec![LogicalType::Float64],
1414 );
1415
1416 let result = agg.next().unwrap().unwrap();
1417 assert_eq!(result.row_count(), 1);
1418 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1420 assert!((percentile - 4.0).abs() < 0.01);
1421 }
1422
1423 #[test]
1424 fn test_percentile_cont() {
1425 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1426
1427 let mut agg = SimpleAggregateOperator::new(
1429 Box::new(mock),
1430 vec![AggregateExpr::percentile_cont(0, 0.5)],
1431 vec![LogicalType::Float64],
1432 );
1433
1434 let result = agg.next().unwrap().unwrap();
1435 assert_eq!(result.row_count(), 1);
1436 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1439 assert!((percentile - 4.5).abs() < 0.01);
1440 }
1441
1442 #[test]
1443 fn test_percentile_extremes() {
1444 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1446
1447 let mut agg = SimpleAggregateOperator::new(
1448 Box::new(mock),
1449 vec![
1450 AggregateExpr::percentile_disc(0, 0.0),
1451 AggregateExpr::percentile_disc(0, 1.0),
1452 ],
1453 vec![LogicalType::Float64, LogicalType::Float64],
1454 );
1455
1456 let result = agg.next().unwrap().unwrap();
1457 assert_eq!(result.row_count(), 1);
1458 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1460 assert!((p0 - 2.0).abs() < 0.01);
1461 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1463 assert!((p100 - 9.0).abs() < 0.01);
1464 }
1465
1466 #[test]
1467 fn test_stdev_single_value() {
1468 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1470 builder.column_mut(0).unwrap().push_int64(42);
1471 builder.advance_row();
1472 let chunk = builder.finish();
1473
1474 let mock = MockOperator::new(vec![chunk]);
1475
1476 let mut agg = SimpleAggregateOperator::new(
1477 Box::new(mock),
1478 vec![AggregateExpr::stdev(0)],
1479 vec![LogicalType::Float64],
1480 );
1481
1482 let result = agg.next().unwrap().unwrap();
1483 assert_eq!(result.row_count(), 1);
1484 assert!(matches!(
1486 result.column(0).unwrap().get_value(0),
1487 Some(Value::Null)
1488 ));
1489 }
1490
1491 #[test]
1492 fn test_first_and_last() {
1493 let mock = MockOperator::new(vec![create_test_chunk()]);
1494
1495 let mut agg = SimpleAggregateOperator::new(
1496 Box::new(mock),
1497 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1498 vec![LogicalType::Int64, LogicalType::Int64],
1499 );
1500
1501 let result = agg.next().unwrap().unwrap();
1502 assert_eq!(result.row_count(), 1);
1503 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1505 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1506 }
1507
1508 #[test]
1509 fn test_collect() {
1510 let mock = MockOperator::new(vec![create_test_chunk()]);
1511
1512 let mut agg = SimpleAggregateOperator::new(
1513 Box::new(mock),
1514 vec![AggregateExpr::collect(1)],
1515 vec![LogicalType::Any],
1516 );
1517
1518 let result = agg.next().unwrap().unwrap();
1519 let val = result.column(0).unwrap().get_value(0).unwrap();
1520 if let Value::List(items) = val {
1521 assert_eq!(items.len(), 5);
1522 } else {
1523 panic!("Expected List value");
1524 }
1525 }
1526
1527 #[test]
1528 fn test_collect_distinct() {
1529 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1530
1531 let mut agg = SimpleAggregateOperator::new(
1532 Box::new(mock),
1533 vec![AggregateExpr::collect(1).with_distinct()],
1534 vec![LogicalType::Any],
1535 );
1536
1537 let result = agg.next().unwrap().unwrap();
1538 let val = result.column(0).unwrap().get_value(0).unwrap();
1539 if let Value::List(items) = val {
1540 assert_eq!(items.len(), 3);
1542 } else {
1543 panic!("Expected List value");
1544 }
1545 }
1546
1547 #[test]
1548 fn test_group_concat() {
1549 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1550 for s in ["hello", "world", "foo"] {
1551 builder.column_mut(0).unwrap().push_string(s);
1552 builder.advance_row();
1553 }
1554 let chunk = builder.finish();
1555 let mock = MockOperator::new(vec![chunk]);
1556
1557 let agg_expr = AggregateExpr {
1558 function: AggregateFunction::GroupConcat,
1559 column: Some(0),
1560 column2: None,
1561 distinct: false,
1562 alias: None,
1563 percentile: None,
1564 separator: None,
1565 };
1566
1567 let mut agg =
1568 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1569
1570 let result = agg.next().unwrap().unwrap();
1571 let val = result.column(0).unwrap().get_value(0).unwrap();
1572 assert_eq!(val, Value::String("hello world foo".into()));
1573 }
1574
1575 #[test]
1576 fn test_sample() {
1577 let mock = MockOperator::new(vec![create_test_chunk()]);
1578
1579 let agg_expr = AggregateExpr {
1580 function: AggregateFunction::Sample,
1581 column: Some(1),
1582 column2: None,
1583 distinct: false,
1584 alias: None,
1585 percentile: None,
1586 separator: None,
1587 };
1588
1589 let mut agg =
1590 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1591
1592 let result = agg.next().unwrap().unwrap();
1593 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1595 }
1596
1597 #[test]
1598 fn test_variance_sample() {
1599 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1600
1601 let agg_expr = AggregateExpr {
1602 function: AggregateFunction::Variance,
1603 column: Some(0),
1604 column2: None,
1605 distinct: false,
1606 alias: None,
1607 percentile: None,
1608 separator: None,
1609 };
1610
1611 let mut agg = SimpleAggregateOperator::new(
1612 Box::new(mock),
1613 vec![agg_expr],
1614 vec![LogicalType::Float64],
1615 );
1616
1617 let result = agg.next().unwrap().unwrap();
1618 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1620 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1621 }
1622
1623 #[test]
1624 fn test_variance_population() {
1625 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1626
1627 let agg_expr = AggregateExpr {
1628 function: AggregateFunction::VariancePop,
1629 column: Some(0),
1630 column2: None,
1631 distinct: false,
1632 alias: None,
1633 percentile: None,
1634 separator: None,
1635 };
1636
1637 let mut agg = SimpleAggregateOperator::new(
1638 Box::new(mock),
1639 vec![agg_expr],
1640 vec![LogicalType::Float64],
1641 );
1642
1643 let result = agg.next().unwrap().unwrap();
1644 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1646 assert!((variance - 4.0).abs() < 0.01);
1647 }
1648
1649 #[test]
1650 fn test_variance_single_value() {
1651 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1652 builder.column_mut(0).unwrap().push_int64(42);
1653 builder.advance_row();
1654 let chunk = builder.finish();
1655 let mock = MockOperator::new(vec![chunk]);
1656
1657 let agg_expr = AggregateExpr {
1658 function: AggregateFunction::Variance,
1659 column: Some(0),
1660 column2: None,
1661 distinct: false,
1662 alias: None,
1663 percentile: None,
1664 separator: None,
1665 };
1666
1667 let mut agg = SimpleAggregateOperator::new(
1668 Box::new(mock),
1669 vec![agg_expr],
1670 vec![LogicalType::Float64],
1671 );
1672
1673 let result = agg.next().unwrap().unwrap();
1674 assert!(matches!(
1676 result.column(0).unwrap().get_value(0),
1677 Some(Value::Null)
1678 ));
1679 }
1680
1681 #[test]
1682 fn test_empty_aggregation() {
1683 let mock = MockOperator::new(vec![]);
1686
1687 let mut agg = SimpleAggregateOperator::new(
1688 Box::new(mock),
1689 vec![
1690 AggregateExpr::count_star(),
1691 AggregateExpr::sum(0),
1692 AggregateExpr::avg(0),
1693 AggregateExpr::min(0),
1694 AggregateExpr::max(0),
1695 ],
1696 vec![
1697 LogicalType::Int64,
1698 LogicalType::Int64,
1699 LogicalType::Float64,
1700 LogicalType::Int64,
1701 LogicalType::Int64,
1702 ],
1703 );
1704
1705 let result = agg.next().unwrap().unwrap();
1706 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1708 result.column(1).unwrap().get_value(0),
1709 Some(Value::Null)
1710 )); assert!(matches!(
1712 result.column(2).unwrap().get_value(0),
1713 Some(Value::Null)
1714 )); assert!(matches!(
1716 result.column(3).unwrap().get_value(0),
1717 Some(Value::Null)
1718 )); assert!(matches!(
1720 result.column(4).unwrap().get_value(0),
1721 Some(Value::Null)
1722 )); }
1724
1725 #[test]
1726 fn test_stdev_pop_single_value() {
1727 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1729 builder.column_mut(0).unwrap().push_int64(42);
1730 builder.advance_row();
1731 let chunk = builder.finish();
1732
1733 let mock = MockOperator::new(vec![chunk]);
1734
1735 let mut agg = SimpleAggregateOperator::new(
1736 Box::new(mock),
1737 vec![AggregateExpr::stdev_pop(0)],
1738 vec![LogicalType::Float64],
1739 );
1740
1741 let result = agg.next().unwrap().unwrap();
1742 assert_eq!(result.row_count(), 1);
1743 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1745 assert!((stdev - 0.0).abs() < 0.01);
1746 }
1747}