1use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25 Count(i64),
27 CountDistinct(i64, HashSet<HashableValue>),
29 SumInt(i64, i64),
31 SumIntDistinct(i64, i64, HashSet<HashableValue>),
33 SumFloat(f64, i64),
35 SumFloatDistinct(f64, i64, HashSet<HashableValue>),
37 Avg(f64, i64),
39 AvgDistinct(f64, i64, HashSet<HashableValue>),
41 Min(Option<Value>),
43 Max(Option<Value>),
45 First(Option<Value>),
47 Last(Option<Value>),
49 Collect(Vec<Value>),
51 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53 StdDev { count: i64, mean: f64, m2: f64 },
55 StdDevPop { count: i64, mean: f64, m2: f64 },
57 PercentileDisc { values: Vec<f64>, percentile: f64 },
59 PercentileCont { values: Vec<f64>, percentile: f64 },
61 GroupConcat(Vec<String>, String),
63 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65 Sample(Option<Value>),
67 Variance { count: i64, mean: f64, m2: f64 },
69 VariancePop { count: i64, mean: f64, m2: f64 },
71 Bivariate {
73 kind: AggregateFunction,
75 count: i64,
76 mean_x: f64,
77 mean_y: f64,
78 m2_x: f64,
79 m2_y: f64,
80 c_xy: f64,
81 },
82}
83
84impl AggregateState {
85 pub(crate) fn new(
87 function: AggregateFunction,
88 distinct: bool,
89 percentile: Option<f64>,
90 separator: Option<&str>,
91 ) -> Self {
92 match (function, distinct) {
93 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94 AggregateState::Count(0)
95 }
96 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97 AggregateState::CountDistinct(0, HashSet::new())
98 }
99 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
105 (AggregateFunction::First, _) => AggregateState::First(None),
106 (AggregateFunction::Last, _) => AggregateState::Last(None),
107 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108 (AggregateFunction::Collect, true) => {
109 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110 }
111 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113 count: 0,
114 mean: 0.0,
115 m2: 0.0,
116 },
117 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118 count: 0,
119 mean: 0.0,
120 m2: 0.0,
121 },
122 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123 values: Vec::new(),
124 percentile: percentile.unwrap_or(0.5),
125 },
126 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127 values: Vec::new(),
128 percentile: percentile.unwrap_or(0.5),
129 },
130 (AggregateFunction::GroupConcat, false) => {
131 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132 }
133 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134 Vec::new(),
135 separator.unwrap_or(" ").to_string(),
136 HashSet::new(),
137 ),
138 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139 (
141 AggregateFunction::CovarSamp
142 | AggregateFunction::CovarPop
143 | AggregateFunction::Corr
144 | AggregateFunction::RegrSlope
145 | AggregateFunction::RegrIntercept
146 | AggregateFunction::RegrR2
147 | AggregateFunction::RegrCount
148 | AggregateFunction::RegrSxx
149 | AggregateFunction::RegrSyy
150 | AggregateFunction::RegrSxy
151 | AggregateFunction::RegrAvgx
152 | AggregateFunction::RegrAvgy,
153 _,
154 ) => AggregateState::Bivariate {
155 kind: function,
156 count: 0,
157 mean_x: 0.0,
158 mean_y: 0.0,
159 m2_x: 0.0,
160 m2_y: 0.0,
161 c_xy: 0.0,
162 },
163 (AggregateFunction::Variance, _) => AggregateState::Variance {
164 count: 0,
165 mean: 0.0,
166 m2: 0.0,
167 },
168 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169 count: 0,
170 mean: 0.0,
171 m2: 0.0,
172 },
173 }
174 }
175
176 pub(crate) fn update(&mut self, value: Option<Value>) {
178 match self {
179 AggregateState::Count(count) => {
180 *count += 1;
181 }
182 AggregateState::CountDistinct(count, seen) => {
183 if let Some(ref v) = value {
184 let hashable = HashableValue::from(v);
185 if seen.insert(hashable) {
186 *count += 1;
187 }
188 }
189 }
190 AggregateState::SumInt(sum, count) => {
191 if let Some(Value::Int64(v)) = value {
192 *sum += v;
193 *count += 1;
194 } else if let Some(Value::Float64(v)) = value {
195 *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
197 } else if let Some(ref v) = value {
198 if let Some(num) = value_to_f64(v) {
200 *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
201 }
202 }
203 }
204 AggregateState::SumIntDistinct(sum, count, seen) => {
205 if let Some(ref v) = value {
206 let hashable = HashableValue::from(v);
207 if seen.insert(hashable) {
208 if let Value::Int64(i) = v {
209 *sum += i;
210 *count += 1;
211 } else if let Value::Float64(f) = v {
212 let moved_seen = std::mem::take(seen);
214 *self = AggregateState::SumFloatDistinct(
215 *sum as f64 + f,
216 *count + 1,
217 moved_seen,
218 );
219 } else if let Some(num) = value_to_f64(v) {
220 let moved_seen = std::mem::take(seen);
222 *self = AggregateState::SumFloatDistinct(
223 *sum as f64 + num,
224 *count + 1,
225 moved_seen,
226 );
227 }
228 }
229 }
230 }
231 AggregateState::SumFloat(sum, count) => {
232 if let Some(ref v) = value {
233 if let Some(num) = value_to_f64(v) {
235 *sum += num;
236 *count += 1;
237 }
238 }
239 }
240 AggregateState::SumFloatDistinct(sum, count, seen) => {
241 if let Some(ref v) = value {
242 let hashable = HashableValue::from(v);
243 if seen.insert(hashable)
244 && let Some(num) = value_to_f64(v)
245 {
246 *sum += num;
247 *count += 1;
248 }
249 }
250 }
251 AggregateState::Avg(sum, count) => {
252 if let Some(ref v) = value
253 && let Some(num) = value_to_f64(v)
254 {
255 *sum += num;
256 *count += 1;
257 }
258 }
259 AggregateState::AvgDistinct(sum, count, seen) => {
260 if let Some(ref v) = value {
261 let hashable = HashableValue::from(v);
262 if seen.insert(hashable)
263 && let Some(num) = value_to_f64(v)
264 {
265 *sum += num;
266 *count += 1;
267 }
268 }
269 }
270 AggregateState::Min(min) => {
271 if let Some(v) = value {
272 match min {
273 None => *min = Some(v),
274 Some(current) => {
275 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
276 *min = Some(v);
277 }
278 }
279 }
280 }
281 }
282 AggregateState::Max(max) => {
283 if let Some(v) = value {
284 match max {
285 None => *max = Some(v),
286 Some(current) => {
287 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
288 *max = Some(v);
289 }
290 }
291 }
292 }
293 }
294 AggregateState::First(first) => {
295 if first.is_none() {
296 *first = value;
297 }
298 }
299 AggregateState::Last(last) => {
300 if value.is_some() {
301 *last = value;
302 }
303 }
304 AggregateState::Collect(list) => {
305 if let Some(v) = value {
306 list.push(v);
307 }
308 }
309 AggregateState::CollectDistinct(list, seen) => {
310 if let Some(v) = value {
311 let hashable = HashableValue::from(&v);
312 if seen.insert(hashable) {
313 list.push(v);
314 }
315 }
316 }
317 AggregateState::StdDev { count, mean, m2 }
319 | AggregateState::StdDevPop { count, mean, m2 }
320 | AggregateState::Variance { count, mean, m2 }
321 | AggregateState::VariancePop { count, mean, m2 } => {
322 if let Some(ref v) = value
323 && let Some(x) = value_to_f64(v)
324 {
325 *count += 1;
326 let delta = x - *mean;
327 *mean += delta / *count as f64;
328 let delta2 = x - *mean;
329 *m2 += delta * delta2;
330 }
331 }
332 AggregateState::PercentileDisc { values, .. }
333 | AggregateState::PercentileCont { values, .. } => {
334 if let Some(ref v) = value
335 && let Some(x) = value_to_f64(v)
336 {
337 values.push(x);
338 }
339 }
340 AggregateState::GroupConcat(list, _sep) => {
341 if let Some(v) = value {
342 list.push(agg_value_to_string(&v));
343 }
344 }
345 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
346 if let Some(v) = value {
347 let hashable = HashableValue::from(&v);
348 if seen.insert(hashable) {
349 list.push(agg_value_to_string(&v));
350 }
351 }
352 }
353 AggregateState::Sample(sample) => {
354 if sample.is_none() {
355 *sample = value;
356 }
357 }
358 AggregateState::Bivariate { .. } => {
359 }
362 }
363 }
364
365 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
370 if let AggregateState::Bivariate {
371 count,
372 mean_x,
373 mean_y,
374 m2_x,
375 m2_y,
376 c_xy,
377 ..
378 } = self
379 {
380 if let (Some(y), Some(x)) = (&y_val, &x_val)
382 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
383 {
384 *count += 1;
385 let n = *count as f64;
386 let dx = x_f - *mean_x;
387 let dy = y_f - *mean_y;
388 *mean_x += dx / n;
389 *mean_y += dy / n;
390 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
393 *m2_y += dy * dy2;
394 *c_xy += dx * dy2;
395 }
396 }
397 }
398
399 pub(crate) fn finalize(&self) -> Value {
401 match self {
402 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
403 Value::Int64(*count)
404 }
405 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
406 if *count == 0 {
407 Value::Null
408 } else {
409 Value::Int64(*sum)
410 }
411 }
412 AggregateState::SumFloat(sum, count)
413 | AggregateState::SumFloatDistinct(sum, count, _) => {
414 if *count == 0 {
415 Value::Null
416 } else {
417 Value::Float64(*sum)
418 }
419 }
420 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
421 if *count == 0 {
422 Value::Null
423 } else {
424 Value::Float64(*sum / *count as f64)
425 }
426 }
427 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
428 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
429 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
430 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
431 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
432 Value::List(list.clone().into())
433 }
434 AggregateState::StdDev { count, m2, .. } => {
436 if *count < 2 {
437 Value::Null
438 } else {
439 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
440 }
441 }
442 AggregateState::StdDevPop { count, m2, .. } => {
444 if *count == 0 {
445 Value::Null
446 } else {
447 Value::Float64((*m2 / *count as f64).sqrt())
448 }
449 }
450 AggregateState::Variance { count, m2, .. } => {
452 if *count < 2 {
453 Value::Null
454 } else {
455 Value::Float64(*m2 / (*count - 1) as f64)
456 }
457 }
458 AggregateState::VariancePop { count, m2, .. } => {
460 if *count == 0 {
461 Value::Null
462 } else {
463 Value::Float64(*m2 / *count as f64)
464 }
465 }
466 AggregateState::PercentileDisc { values, percentile } => {
468 if values.is_empty() {
469 Value::Null
470 } else {
471 let mut sorted = values.clone();
472 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
473 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
475 Value::Float64(sorted[index])
476 }
477 }
478 AggregateState::PercentileCont { values, percentile } => {
480 if values.is_empty() {
481 Value::Null
482 } else {
483 let mut sorted = values.clone();
484 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
485 let rank = percentile * (sorted.len() - 1) as f64;
487 let lower_idx = rank.floor() as usize;
488 let upper_idx = rank.ceil() as usize;
489 if lower_idx == upper_idx {
490 Value::Float64(sorted[lower_idx])
491 } else {
492 let fraction = rank - lower_idx as f64;
493 let result =
494 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
495 Value::Float64(result)
496 }
497 }
498 }
499 AggregateState::GroupConcat(list, sep)
501 | AggregateState::GroupConcatDistinct(list, sep, _) => {
502 Value::String(list.join(sep).into())
503 }
504 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
506 AggregateState::Bivariate {
508 kind,
509 count,
510 mean_x,
511 mean_y,
512 m2_x,
513 m2_y,
514 c_xy,
515 } => {
516 let n = *count;
517 match kind {
518 AggregateFunction::CovarSamp => {
519 if n < 2 {
520 Value::Null
521 } else {
522 Value::Float64(*c_xy / (n - 1) as f64)
523 }
524 }
525 AggregateFunction::CovarPop => {
526 if n == 0 {
527 Value::Null
528 } else {
529 Value::Float64(*c_xy / n as f64)
530 }
531 }
532 AggregateFunction::Corr => {
533 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
534 Value::Null
535 } else {
536 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
537 }
538 }
539 AggregateFunction::RegrSlope => {
540 if n == 0 || *m2_x == 0.0 {
541 Value::Null
542 } else {
543 Value::Float64(*c_xy / *m2_x)
544 }
545 }
546 AggregateFunction::RegrIntercept => {
547 if n == 0 || *m2_x == 0.0 {
548 Value::Null
549 } else {
550 let slope = *c_xy / *m2_x;
551 Value::Float64(*mean_y - slope * *mean_x)
552 }
553 }
554 AggregateFunction::RegrR2 => {
555 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
556 Value::Null
557 } else {
558 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
559 }
560 }
561 AggregateFunction::RegrCount => Value::Int64(n),
562 AggregateFunction::RegrSxx => {
563 if n == 0 {
564 Value::Null
565 } else {
566 Value::Float64(*m2_x)
567 }
568 }
569 AggregateFunction::RegrSyy => {
570 if n == 0 {
571 Value::Null
572 } else {
573 Value::Float64(*m2_y)
574 }
575 }
576 AggregateFunction::RegrSxy => {
577 if n == 0 {
578 Value::Null
579 } else {
580 Value::Float64(*c_xy)
581 }
582 }
583 AggregateFunction::RegrAvgx => {
584 if n == 0 {
585 Value::Null
586 } else {
587 Value::Float64(*mean_x)
588 }
589 }
590 AggregateFunction::RegrAvgy => {
591 if n == 0 {
592 Value::Null
593 } else {
594 Value::Float64(*mean_y)
595 }
596 }
597 _ => Value::Null, }
599 }
600 }
601 }
602}
603
604use super::value_utils::{compare_values, value_to_f64};
605
606fn agg_value_to_string(val: &Value) -> String {
608 match val {
609 Value::String(s) => s.to_string(),
610 Value::Int64(i) => i.to_string(),
611 Value::Float64(f) => f.to_string(),
612 Value::Bool(b) => b.to_string(),
613 Value::Null => String::new(),
614 other => format!("{other:?}"),
615 }
616}
617
618#[derive(Debug, Clone, PartialEq, Eq, Hash)]
620pub struct GroupKey(Vec<GroupKeyPart>);
621
622#[derive(Debug, Clone, PartialEq, Eq, Hash)]
623enum GroupKeyPart {
624 Null,
625 Bool(bool),
626 Int64(i64),
627 String(ArcStr),
628 List(Vec<GroupKeyPart>),
629}
630
631impl GroupKeyPart {
632 fn from_value(v: Value) -> Self {
633 match v {
634 Value::Null => Self::Null,
635 Value::Bool(b) => Self::Bool(b),
636 Value::Int64(i) => Self::Int64(i),
637 Value::Float64(f) => Self::Int64(f.to_bits() as i64),
638 Value::String(s) => Self::String(s.clone()),
639 Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
640 other => Self::String(ArcStr::from(format!("{other:?}"))),
641 }
642 }
643
644 fn to_value(&self) -> Value {
645 match self {
646 Self::Null => Value::Null,
647 Self::Bool(b) => Value::Bool(*b),
648 Self::Int64(i) => Value::Int64(*i),
649 Self::String(s) => Value::String(s.clone()),
650 Self::List(parts) => {
651 let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
652 Value::List(Arc::from(values.into_boxed_slice()))
653 }
654 }
655 }
656}
657
658impl GroupKey {
659 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
661 let parts: Vec<GroupKeyPart> = group_columns
662 .iter()
663 .map(|&col_idx| {
664 chunk
665 .column(col_idx)
666 .and_then(|col| col.get_value(row))
667 .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
668 })
669 .collect();
670 GroupKey(parts)
671 }
672
673 fn to_values(&self) -> Vec<Value> {
675 self.0.iter().map(GroupKeyPart::to_value).collect()
676 }
677}
678
679pub struct HashAggregateOperator {
683 child: Box<dyn Operator>,
685 group_columns: Vec<usize>,
687 aggregates: Vec<AggregateExpr>,
689 output_schema: Vec<LogicalType>,
691 groups: IndexMap<GroupKey, Vec<AggregateState>>,
693 aggregation_complete: bool,
695 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
697}
698
699impl HashAggregateOperator {
700 pub fn new(
708 child: Box<dyn Operator>,
709 group_columns: Vec<usize>,
710 aggregates: Vec<AggregateExpr>,
711 output_schema: Vec<LogicalType>,
712 ) -> Self {
713 Self {
714 child,
715 group_columns,
716 aggregates,
717 output_schema,
718 groups: IndexMap::new(),
719 aggregation_complete: false,
720 results: None,
721 }
722 }
723
724 fn aggregate(&mut self) -> Result<(), OperatorError> {
726 while let Some(chunk) = self.child.next()? {
727 for row in chunk.selected_indices() {
728 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
729
730 let states = self.groups.entry(key).or_insert_with(|| {
732 self.aggregates
733 .iter()
734 .map(|agg| {
735 AggregateState::new(
736 agg.function,
737 agg.distinct,
738 agg.percentile,
739 agg.separator.as_deref(),
740 )
741 })
742 .collect()
743 });
744
745 for (i, agg) in self.aggregates.iter().enumerate() {
747 if agg.column2.is_some() {
749 let y_val = agg
750 .column
751 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
752 let x_val = agg
753 .column2
754 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
755 states[i].update_bivariate(y_val, x_val);
756 continue;
757 }
758
759 let value = match (agg.function, agg.distinct) {
760 (AggregateFunction::Count, false) => None,
762 (AggregateFunction::Count, true) => agg
764 .column
765 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
766 _ => agg
767 .column
768 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
769 };
770
771 match (agg.function, agg.distinct) {
773 (AggregateFunction::Count, false) => states[i].update(None),
774 (AggregateFunction::Count, true) => {
775 if value.is_some() && !matches!(value, Some(Value::Null)) {
777 states[i].update(value);
778 }
779 }
780 (AggregateFunction::CountNonNull, _) => {
781 if value.is_some() && !matches!(value, Some(Value::Null)) {
782 states[i].update(value);
783 }
784 }
785 _ => {
786 if value.is_some() && !matches!(value, Some(Value::Null)) {
787 states[i].update(value);
788 }
789 }
790 }
791 }
792 }
793 }
794
795 self.aggregation_complete = true;
796
797 let results: Vec<_> = self.groups.drain(..).collect();
799 self.results = Some(results.into_iter());
800
801 Ok(())
802 }
803}
804
805impl Operator for HashAggregateOperator {
806 fn next(&mut self) -> OperatorResult {
807 if !self.aggregation_complete {
809 self.aggregate()?;
810 }
811
812 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
814 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
816
817 for agg in &self.aggregates {
818 let state = AggregateState::new(
819 agg.function,
820 agg.distinct,
821 agg.percentile,
822 agg.separator.as_deref(),
823 );
824 let value = state.finalize();
825 if let Some(col) = builder.column_mut(self.group_columns.len()) {
826 col.push_value(value);
827 }
828 }
829 builder.advance_row();
830
831 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
833 }
834
835 let Some(results) = &mut self.results else {
836 return Ok(None);
837 };
838
839 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
840
841 for (key, states) in results.by_ref() {
842 let key_values = key.to_values();
844 for (i, value) in key_values.into_iter().enumerate() {
845 if let Some(col) = builder.column_mut(i) {
846 col.push_value(value);
847 }
848 }
849
850 for (i, state) in states.iter().enumerate() {
852 let col_idx = self.group_columns.len() + i;
853 if let Some(col) = builder.column_mut(col_idx) {
854 col.push_value(state.finalize());
855 }
856 }
857
858 builder.advance_row();
859
860 if builder.is_full() {
861 return Ok(Some(builder.finish()));
862 }
863 }
864
865 if builder.row_count() > 0 {
866 Ok(Some(builder.finish()))
867 } else {
868 Ok(None)
869 }
870 }
871
872 fn reset(&mut self) {
873 self.child.reset();
874 self.groups.clear();
875 self.aggregation_complete = false;
876 self.results = None;
877 }
878
879 fn name(&self) -> &'static str {
880 "HashAggregate"
881 }
882}
883
884pub struct SimpleAggregateOperator {
888 child: Box<dyn Operator>,
890 aggregates: Vec<AggregateExpr>,
892 output_schema: Vec<LogicalType>,
894 states: Vec<AggregateState>,
896 done: bool,
898}
899
900impl SimpleAggregateOperator {
901 pub fn new(
903 child: Box<dyn Operator>,
904 aggregates: Vec<AggregateExpr>,
905 output_schema: Vec<LogicalType>,
906 ) -> Self {
907 let states = aggregates
908 .iter()
909 .map(|agg| {
910 AggregateState::new(
911 agg.function,
912 agg.distinct,
913 agg.percentile,
914 agg.separator.as_deref(),
915 )
916 })
917 .collect();
918
919 Self {
920 child,
921 aggregates,
922 output_schema,
923 states,
924 done: false,
925 }
926 }
927}
928
929impl Operator for SimpleAggregateOperator {
930 fn next(&mut self) -> OperatorResult {
931 if self.done {
932 return Ok(None);
933 }
934
935 while let Some(chunk) = self.child.next()? {
937 for row in chunk.selected_indices() {
938 for (i, agg) in self.aggregates.iter().enumerate() {
939 if agg.column2.is_some() {
941 let y_val = agg
942 .column
943 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
944 let x_val = agg
945 .column2
946 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
947 self.states[i].update_bivariate(y_val, x_val);
948 continue;
949 }
950
951 let value = match (agg.function, agg.distinct) {
952 (AggregateFunction::Count, false) => None,
954 (AggregateFunction::Count, true) => agg
956 .column
957 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
958 _ => agg
959 .column
960 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
961 };
962
963 match (agg.function, agg.distinct) {
964 (AggregateFunction::Count, false) => self.states[i].update(None),
965 (AggregateFunction::Count, true) => {
966 if value.is_some() && !matches!(value, Some(Value::Null)) {
968 self.states[i].update(value);
969 }
970 }
971 (AggregateFunction::CountNonNull, _) => {
972 if value.is_some() && !matches!(value, Some(Value::Null)) {
973 self.states[i].update(value);
974 }
975 }
976 _ => {
977 if value.is_some() && !matches!(value, Some(Value::Null)) {
978 self.states[i].update(value);
979 }
980 }
981 }
982 }
983 }
984 }
985
986 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
988
989 for (i, state) in self.states.iter().enumerate() {
990 if let Some(col) = builder.column_mut(i) {
991 col.push_value(state.finalize());
992 }
993 }
994 builder.advance_row();
995
996 self.done = true;
997 Ok(Some(builder.finish()))
998 }
999
1000 fn reset(&mut self) {
1001 self.child.reset();
1002 self.states = self
1003 .aggregates
1004 .iter()
1005 .map(|agg| {
1006 AggregateState::new(
1007 agg.function,
1008 agg.distinct,
1009 agg.percentile,
1010 agg.separator.as_deref(),
1011 )
1012 })
1013 .collect();
1014 self.done = false;
1015 }
1016
1017 fn name(&self) -> &'static str {
1018 "SimpleAggregate"
1019 }
1020}
1021
1022#[cfg(test)]
1023mod tests {
1024 use super::*;
1025 use crate::execution::chunk::DataChunkBuilder;
1026
1027 struct MockOperator {
1028 chunks: Vec<DataChunk>,
1029 position: usize,
1030 }
1031
1032 impl MockOperator {
1033 fn new(chunks: Vec<DataChunk>) -> Self {
1034 Self {
1035 chunks,
1036 position: 0,
1037 }
1038 }
1039 }
1040
1041 impl Operator for MockOperator {
1042 fn next(&mut self) -> OperatorResult {
1043 if self.position < self.chunks.len() {
1044 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1045 self.position += 1;
1046 Ok(Some(chunk))
1047 } else {
1048 Ok(None)
1049 }
1050 }
1051
1052 fn reset(&mut self) {
1053 self.position = 0;
1054 }
1055
1056 fn name(&self) -> &'static str {
1057 "Mock"
1058 }
1059 }
1060
1061 fn create_test_chunk() -> DataChunk {
1062 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1064
1065 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1066 for (group, value) in data {
1067 builder.column_mut(0).unwrap().push_int64(group);
1068 builder.column_mut(1).unwrap().push_int64(value);
1069 builder.advance_row();
1070 }
1071
1072 builder.finish()
1073 }
1074
1075 #[test]
1076 fn test_simple_count() {
1077 let mock = MockOperator::new(vec![create_test_chunk()]);
1078
1079 let mut agg = SimpleAggregateOperator::new(
1080 Box::new(mock),
1081 vec![AggregateExpr::count_star()],
1082 vec![LogicalType::Int64],
1083 );
1084
1085 let result = agg.next().unwrap().unwrap();
1086 assert_eq!(result.row_count(), 1);
1087 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1088
1089 assert!(agg.next().unwrap().is_none());
1091 }
1092
1093 #[test]
1094 fn test_simple_sum() {
1095 let mock = MockOperator::new(vec![create_test_chunk()]);
1096
1097 let mut agg = SimpleAggregateOperator::new(
1098 Box::new(mock),
1099 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1101 );
1102
1103 let result = agg.next().unwrap().unwrap();
1104 assert_eq!(result.row_count(), 1);
1105 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1107 }
1108
1109 #[test]
1110 fn test_simple_avg() {
1111 let mock = MockOperator::new(vec![create_test_chunk()]);
1112
1113 let mut agg = SimpleAggregateOperator::new(
1114 Box::new(mock),
1115 vec![AggregateExpr::avg(1)],
1116 vec![LogicalType::Float64],
1117 );
1118
1119 let result = agg.next().unwrap().unwrap();
1120 assert_eq!(result.row_count(), 1);
1121 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1123 assert!((avg - 30.0).abs() < 0.001);
1124 }
1125
1126 #[test]
1127 fn test_simple_min_max() {
1128 let mock = MockOperator::new(vec![create_test_chunk()]);
1129
1130 let mut agg = SimpleAggregateOperator::new(
1131 Box::new(mock),
1132 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1133 vec![LogicalType::Int64, LogicalType::Int64],
1134 );
1135
1136 let result = agg.next().unwrap().unwrap();
1137 assert_eq!(result.row_count(), 1);
1138 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1141
1142 #[test]
1143 fn test_sum_with_string_values() {
1144 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1146 builder.column_mut(0).unwrap().push_string("30");
1147 builder.advance_row();
1148 builder.column_mut(0).unwrap().push_string("25");
1149 builder.advance_row();
1150 builder.column_mut(0).unwrap().push_string("35");
1151 builder.advance_row();
1152 let chunk = builder.finish();
1153
1154 let mock = MockOperator::new(vec![chunk]);
1155 let mut agg = SimpleAggregateOperator::new(
1156 Box::new(mock),
1157 vec![AggregateExpr::sum(0)],
1158 vec![LogicalType::Float64],
1159 );
1160
1161 let result = agg.next().unwrap().unwrap();
1162 assert_eq!(result.row_count(), 1);
1163 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1165 assert!(
1166 (sum_val - 90.0).abs() < 0.001,
1167 "Expected 90.0, got {}",
1168 sum_val
1169 );
1170 }
1171
1172 #[test]
1173 fn test_grouped_aggregation() {
1174 let mock = MockOperator::new(vec![create_test_chunk()]);
1175
1176 let mut agg = HashAggregateOperator::new(
1178 Box::new(mock),
1179 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1182 );
1183
1184 let mut results: Vec<(i64, i64)> = Vec::new();
1185 while let Some(chunk) = agg.next().unwrap() {
1186 for row in chunk.selected_indices() {
1187 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1188 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1189 results.push((group, sum));
1190 }
1191 }
1192
1193 results.sort_by_key(|(g, _)| *g);
1194 assert_eq!(results.len(), 2);
1195 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1198
1199 #[test]
1200 fn test_grouped_count() {
1201 let mock = MockOperator::new(vec![create_test_chunk()]);
1202
1203 let mut agg = HashAggregateOperator::new(
1205 Box::new(mock),
1206 vec![0],
1207 vec![AggregateExpr::count_star()],
1208 vec![LogicalType::Int64, LogicalType::Int64],
1209 );
1210
1211 let mut results: Vec<(i64, i64)> = Vec::new();
1212 while let Some(chunk) = agg.next().unwrap() {
1213 for row in chunk.selected_indices() {
1214 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1215 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1216 results.push((group, count));
1217 }
1218 }
1219
1220 results.sort_by_key(|(g, _)| *g);
1221 assert_eq!(results.len(), 2);
1222 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1225
1226 #[test]
1227 fn test_multiple_aggregates() {
1228 let mock = MockOperator::new(vec![create_test_chunk()]);
1229
1230 let mut agg = HashAggregateOperator::new(
1232 Box::new(mock),
1233 vec![0],
1234 vec![
1235 AggregateExpr::count_star(),
1236 AggregateExpr::sum(1),
1237 AggregateExpr::avg(1),
1238 ],
1239 vec![
1240 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1245 );
1246
1247 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1248 while let Some(chunk) = agg.next().unwrap() {
1249 for row in chunk.selected_indices() {
1250 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1251 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1252 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1253 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1254 results.push((group, count, sum, avg));
1255 }
1256 }
1257
1258 results.sort_by_key(|(g, _, _, _)| *g);
1259 assert_eq!(results.len(), 2);
1260
1261 assert_eq!(results[0].0, 1);
1263 assert_eq!(results[0].1, 2);
1264 assert_eq!(results[0].2, 30);
1265 assert!((results[0].3 - 15.0).abs() < 0.001);
1266
1267 assert_eq!(results[1].0, 2);
1269 assert_eq!(results[1].1, 3);
1270 assert_eq!(results[1].2, 120);
1271 assert!((results[1].3 - 40.0).abs() < 0.001);
1272 }
1273
1274 fn create_test_chunk_with_duplicates() -> DataChunk {
1275 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1280
1281 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1282 for (group, value) in data {
1283 builder.column_mut(0).unwrap().push_int64(group);
1284 builder.column_mut(1).unwrap().push_int64(value);
1285 builder.advance_row();
1286 }
1287
1288 builder.finish()
1289 }
1290
1291 #[test]
1292 fn test_count_distinct() {
1293 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1294
1295 let mut agg = SimpleAggregateOperator::new(
1297 Box::new(mock),
1298 vec![AggregateExpr::count(1).with_distinct()],
1299 vec![LogicalType::Int64],
1300 );
1301
1302 let result = agg.next().unwrap().unwrap();
1303 assert_eq!(result.row_count(), 1);
1304 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1306 }
1307
1308 #[test]
1309 fn test_grouped_count_distinct() {
1310 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1311
1312 let mut agg = HashAggregateOperator::new(
1314 Box::new(mock),
1315 vec![0],
1316 vec![AggregateExpr::count(1).with_distinct()],
1317 vec![LogicalType::Int64, LogicalType::Int64],
1318 );
1319
1320 let mut results: Vec<(i64, i64)> = Vec::new();
1321 while let Some(chunk) = agg.next().unwrap() {
1322 for row in chunk.selected_indices() {
1323 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1324 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1325 results.push((group, count));
1326 }
1327 }
1328
1329 results.sort_by_key(|(g, _)| *g);
1330 assert_eq!(results.len(), 2);
1331 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1334
1335 #[test]
1336 fn test_sum_distinct() {
1337 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1338
1339 let mut agg = SimpleAggregateOperator::new(
1341 Box::new(mock),
1342 vec![AggregateExpr::sum(1).with_distinct()],
1343 vec![LogicalType::Int64],
1344 );
1345
1346 let result = agg.next().unwrap().unwrap();
1347 assert_eq!(result.row_count(), 1);
1348 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1350 }
1351
1352 #[test]
1353 fn test_avg_distinct() {
1354 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1355
1356 let mut agg = SimpleAggregateOperator::new(
1358 Box::new(mock),
1359 vec![AggregateExpr::avg(1).with_distinct()],
1360 vec![LogicalType::Float64],
1361 );
1362
1363 let result = agg.next().unwrap().unwrap();
1364 assert_eq!(result.row_count(), 1);
1365 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1367 assert!((avg - 20.0).abs() < 0.001);
1368 }
1369
1370 fn create_statistical_test_chunk() -> DataChunk {
1371 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1374
1375 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1376 builder.column_mut(0).unwrap().push_int64(value);
1377 builder.advance_row();
1378 }
1379
1380 builder.finish()
1381 }
1382
1383 #[test]
1384 fn test_stdev_sample() {
1385 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1386
1387 let mut agg = SimpleAggregateOperator::new(
1388 Box::new(mock),
1389 vec![AggregateExpr::stdev(0)],
1390 vec![LogicalType::Float64],
1391 );
1392
1393 let result = agg.next().unwrap().unwrap();
1394 assert_eq!(result.row_count(), 1);
1395 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1398 assert!((stdev - 2.138).abs() < 0.01);
1399 }
1400
1401 #[test]
1402 fn test_stdev_population() {
1403 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1404
1405 let mut agg = SimpleAggregateOperator::new(
1406 Box::new(mock),
1407 vec![AggregateExpr::stdev_pop(0)],
1408 vec![LogicalType::Float64],
1409 );
1410
1411 let result = agg.next().unwrap().unwrap();
1412 assert_eq!(result.row_count(), 1);
1413 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1416 assert!((stdev - 2.0).abs() < 0.01);
1417 }
1418
1419 #[test]
1420 fn test_percentile_disc() {
1421 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1422
1423 let mut agg = SimpleAggregateOperator::new(
1425 Box::new(mock),
1426 vec![AggregateExpr::percentile_disc(0, 0.5)],
1427 vec![LogicalType::Float64],
1428 );
1429
1430 let result = agg.next().unwrap().unwrap();
1431 assert_eq!(result.row_count(), 1);
1432 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1434 assert!((percentile - 4.0).abs() < 0.01);
1435 }
1436
1437 #[test]
1438 fn test_percentile_cont() {
1439 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1440
1441 let mut agg = SimpleAggregateOperator::new(
1443 Box::new(mock),
1444 vec![AggregateExpr::percentile_cont(0, 0.5)],
1445 vec![LogicalType::Float64],
1446 );
1447
1448 let result = agg.next().unwrap().unwrap();
1449 assert_eq!(result.row_count(), 1);
1450 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1453 assert!((percentile - 4.5).abs() < 0.01);
1454 }
1455
1456 #[test]
1457 fn test_percentile_extremes() {
1458 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1460
1461 let mut agg = SimpleAggregateOperator::new(
1462 Box::new(mock),
1463 vec![
1464 AggregateExpr::percentile_disc(0, 0.0),
1465 AggregateExpr::percentile_disc(0, 1.0),
1466 ],
1467 vec![LogicalType::Float64, LogicalType::Float64],
1468 );
1469
1470 let result = agg.next().unwrap().unwrap();
1471 assert_eq!(result.row_count(), 1);
1472 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1474 assert!((p0 - 2.0).abs() < 0.01);
1475 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1477 assert!((p100 - 9.0).abs() < 0.01);
1478 }
1479
1480 #[test]
1481 fn test_stdev_single_value() {
1482 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1484 builder.column_mut(0).unwrap().push_int64(42);
1485 builder.advance_row();
1486 let chunk = builder.finish();
1487
1488 let mock = MockOperator::new(vec![chunk]);
1489
1490 let mut agg = SimpleAggregateOperator::new(
1491 Box::new(mock),
1492 vec![AggregateExpr::stdev(0)],
1493 vec![LogicalType::Float64],
1494 );
1495
1496 let result = agg.next().unwrap().unwrap();
1497 assert_eq!(result.row_count(), 1);
1498 assert!(matches!(
1500 result.column(0).unwrap().get_value(0),
1501 Some(Value::Null)
1502 ));
1503 }
1504
1505 #[test]
1506 fn test_first_and_last() {
1507 let mock = MockOperator::new(vec![create_test_chunk()]);
1508
1509 let mut agg = SimpleAggregateOperator::new(
1510 Box::new(mock),
1511 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1512 vec![LogicalType::Int64, LogicalType::Int64],
1513 );
1514
1515 let result = agg.next().unwrap().unwrap();
1516 assert_eq!(result.row_count(), 1);
1517 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1519 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1520 }
1521
1522 #[test]
1523 fn test_collect() {
1524 let mock = MockOperator::new(vec![create_test_chunk()]);
1525
1526 let mut agg = SimpleAggregateOperator::new(
1527 Box::new(mock),
1528 vec![AggregateExpr::collect(1)],
1529 vec![LogicalType::Any],
1530 );
1531
1532 let result = agg.next().unwrap().unwrap();
1533 let val = result.column(0).unwrap().get_value(0).unwrap();
1534 if let Value::List(items) = val {
1535 assert_eq!(items.len(), 5);
1536 } else {
1537 panic!("Expected List value");
1538 }
1539 }
1540
1541 #[test]
1542 fn test_collect_distinct() {
1543 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1544
1545 let mut agg = SimpleAggregateOperator::new(
1546 Box::new(mock),
1547 vec![AggregateExpr::collect(1).with_distinct()],
1548 vec![LogicalType::Any],
1549 );
1550
1551 let result = agg.next().unwrap().unwrap();
1552 let val = result.column(0).unwrap().get_value(0).unwrap();
1553 if let Value::List(items) = val {
1554 assert_eq!(items.len(), 3);
1556 } else {
1557 panic!("Expected List value");
1558 }
1559 }
1560
1561 #[test]
1562 fn test_group_concat() {
1563 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1564 for s in ["hello", "world", "foo"] {
1565 builder.column_mut(0).unwrap().push_string(s);
1566 builder.advance_row();
1567 }
1568 let chunk = builder.finish();
1569 let mock = MockOperator::new(vec![chunk]);
1570
1571 let agg_expr = AggregateExpr {
1572 function: AggregateFunction::GroupConcat,
1573 column: Some(0),
1574 column2: None,
1575 distinct: false,
1576 alias: None,
1577 percentile: None,
1578 separator: None,
1579 };
1580
1581 let mut agg =
1582 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1583
1584 let result = agg.next().unwrap().unwrap();
1585 let val = result.column(0).unwrap().get_value(0).unwrap();
1586 assert_eq!(val, Value::String("hello world foo".into()));
1587 }
1588
1589 #[test]
1590 fn test_sample() {
1591 let mock = MockOperator::new(vec![create_test_chunk()]);
1592
1593 let agg_expr = AggregateExpr {
1594 function: AggregateFunction::Sample,
1595 column: Some(1),
1596 column2: None,
1597 distinct: false,
1598 alias: None,
1599 percentile: None,
1600 separator: None,
1601 };
1602
1603 let mut agg =
1604 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1605
1606 let result = agg.next().unwrap().unwrap();
1607 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1609 }
1610
1611 #[test]
1612 fn test_variance_sample() {
1613 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1614
1615 let agg_expr = AggregateExpr {
1616 function: AggregateFunction::Variance,
1617 column: Some(0),
1618 column2: None,
1619 distinct: false,
1620 alias: None,
1621 percentile: None,
1622 separator: None,
1623 };
1624
1625 let mut agg = SimpleAggregateOperator::new(
1626 Box::new(mock),
1627 vec![agg_expr],
1628 vec![LogicalType::Float64],
1629 );
1630
1631 let result = agg.next().unwrap().unwrap();
1632 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1634 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1635 }
1636
1637 #[test]
1638 fn test_variance_population() {
1639 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1640
1641 let agg_expr = AggregateExpr {
1642 function: AggregateFunction::VariancePop,
1643 column: Some(0),
1644 column2: None,
1645 distinct: false,
1646 alias: None,
1647 percentile: None,
1648 separator: None,
1649 };
1650
1651 let mut agg = SimpleAggregateOperator::new(
1652 Box::new(mock),
1653 vec![agg_expr],
1654 vec![LogicalType::Float64],
1655 );
1656
1657 let result = agg.next().unwrap().unwrap();
1658 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1660 assert!((variance - 4.0).abs() < 0.01);
1661 }
1662
1663 #[test]
1664 fn test_variance_single_value() {
1665 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1666 builder.column_mut(0).unwrap().push_int64(42);
1667 builder.advance_row();
1668 let chunk = builder.finish();
1669 let mock = MockOperator::new(vec![chunk]);
1670
1671 let agg_expr = AggregateExpr {
1672 function: AggregateFunction::Variance,
1673 column: Some(0),
1674 column2: None,
1675 distinct: false,
1676 alias: None,
1677 percentile: None,
1678 separator: None,
1679 };
1680
1681 let mut agg = SimpleAggregateOperator::new(
1682 Box::new(mock),
1683 vec![agg_expr],
1684 vec![LogicalType::Float64],
1685 );
1686
1687 let result = agg.next().unwrap().unwrap();
1688 assert!(matches!(
1690 result.column(0).unwrap().get_value(0),
1691 Some(Value::Null)
1692 ));
1693 }
1694
1695 #[test]
1696 fn test_empty_aggregation() {
1697 let mock = MockOperator::new(vec![]);
1700
1701 let mut agg = SimpleAggregateOperator::new(
1702 Box::new(mock),
1703 vec![
1704 AggregateExpr::count_star(),
1705 AggregateExpr::sum(0),
1706 AggregateExpr::avg(0),
1707 AggregateExpr::min(0),
1708 AggregateExpr::max(0),
1709 ],
1710 vec![
1711 LogicalType::Int64,
1712 LogicalType::Int64,
1713 LogicalType::Float64,
1714 LogicalType::Int64,
1715 LogicalType::Int64,
1716 ],
1717 );
1718
1719 let result = agg.next().unwrap().unwrap();
1720 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1722 result.column(1).unwrap().get_value(0),
1723 Some(Value::Null)
1724 )); assert!(matches!(
1726 result.column(2).unwrap().get_value(0),
1727 Some(Value::Null)
1728 )); assert!(matches!(
1730 result.column(3).unwrap().get_value(0),
1731 Some(Value::Null)
1732 )); assert!(matches!(
1734 result.column(4).unwrap().get_value(0),
1735 Some(Value::Null)
1736 )); }
1738
1739 #[test]
1740 fn test_stdev_pop_single_value() {
1741 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1743 builder.column_mut(0).unwrap().push_int64(42);
1744 builder.advance_row();
1745 let chunk = builder.finish();
1746
1747 let mock = MockOperator::new(vec![chunk]);
1748
1749 let mut agg = SimpleAggregateOperator::new(
1750 Box::new(mock),
1751 vec![AggregateExpr::stdev_pop(0)],
1752 vec![LogicalType::Float64],
1753 );
1754
1755 let result = agg.next().unwrap().unwrap();
1756 assert_eq!(result.row_count(), 1);
1757 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1759 assert!((stdev - 0.0).abs() < 0.01);
1760 }
1761}