1use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25 Count(i64),
27 CountDistinct(i64, HashSet<HashableValue>),
29 SumInt(i64, i64),
31 SumIntDistinct(i64, i64, HashSet<HashableValue>),
33 SumFloat(f64, i64),
35 SumFloatDistinct(f64, i64, HashSet<HashableValue>),
37 Avg(f64, i64),
39 AvgDistinct(f64, i64, HashSet<HashableValue>),
41 Min(Option<Value>),
43 Max(Option<Value>),
45 First(Option<Value>),
47 Last(Option<Value>),
49 Collect(Vec<Value>),
51 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53 StdDev { count: i64, mean: f64, m2: f64 },
55 StdDevPop { count: i64, mean: f64, m2: f64 },
57 PercentileDisc { values: Vec<f64>, percentile: f64 },
59 PercentileCont { values: Vec<f64>, percentile: f64 },
61 GroupConcat(Vec<String>, String),
63 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65 Sample(Option<Value>),
67 Variance { count: i64, mean: f64, m2: f64 },
69 VariancePop { count: i64, mean: f64, m2: f64 },
71 Bivariate {
73 kind: AggregateFunction,
75 count: i64,
76 mean_x: f64,
77 mean_y: f64,
78 m2_x: f64,
79 m2_y: f64,
80 c_xy: f64,
81 },
82}
83
84impl AggregateState {
85 pub(crate) fn new(
87 function: AggregateFunction,
88 distinct: bool,
89 percentile: Option<f64>,
90 separator: Option<&str>,
91 ) -> Self {
92 match (function, distinct) {
93 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94 AggregateState::Count(0)
95 }
96 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97 AggregateState::CountDistinct(0, HashSet::new())
98 }
99 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
105 (AggregateFunction::First, _) => AggregateState::First(None),
106 (AggregateFunction::Last, _) => AggregateState::Last(None),
107 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108 (AggregateFunction::Collect, true) => {
109 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110 }
111 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113 count: 0,
114 mean: 0.0,
115 m2: 0.0,
116 },
117 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118 count: 0,
119 mean: 0.0,
120 m2: 0.0,
121 },
122 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123 values: Vec::new(),
124 percentile: percentile.unwrap_or(0.5),
125 },
126 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127 values: Vec::new(),
128 percentile: percentile.unwrap_or(0.5),
129 },
130 (AggregateFunction::GroupConcat, false) => {
131 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132 }
133 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134 Vec::new(),
135 separator.unwrap_or(" ").to_string(),
136 HashSet::new(),
137 ),
138 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139 (
141 AggregateFunction::CovarSamp
142 | AggregateFunction::CovarPop
143 | AggregateFunction::Corr
144 | AggregateFunction::RegrSlope
145 | AggregateFunction::RegrIntercept
146 | AggregateFunction::RegrR2
147 | AggregateFunction::RegrCount
148 | AggregateFunction::RegrSxx
149 | AggregateFunction::RegrSyy
150 | AggregateFunction::RegrSxy
151 | AggregateFunction::RegrAvgx
152 | AggregateFunction::RegrAvgy,
153 _,
154 ) => AggregateState::Bivariate {
155 kind: function,
156 count: 0,
157 mean_x: 0.0,
158 mean_y: 0.0,
159 m2_x: 0.0,
160 m2_y: 0.0,
161 c_xy: 0.0,
162 },
163 (AggregateFunction::Variance, _) => AggregateState::Variance {
164 count: 0,
165 mean: 0.0,
166 m2: 0.0,
167 },
168 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169 count: 0,
170 mean: 0.0,
171 m2: 0.0,
172 },
173 }
174 }
175
176 pub(crate) fn update(&mut self, value: Option<Value>) {
178 match self {
179 AggregateState::Count(count) => {
180 *count += 1;
181 }
182 AggregateState::CountDistinct(count, seen) => {
183 if let Some(ref v) = value {
184 let hashable = HashableValue::from(v);
185 if seen.insert(hashable) {
186 *count += 1;
187 }
188 }
189 }
190 AggregateState::SumInt(sum, count) => {
191 if let Some(Value::Int64(v)) = value {
192 *sum += v;
193 *count += 1;
194 } else if let Some(Value::Float64(v)) = value {
195 *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
197 } else if let Some(ref v) = value {
198 if let Some(num) = value_to_f64(v) {
200 *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
201 }
202 }
203 }
204 AggregateState::SumIntDistinct(sum, count, seen) => {
205 if let Some(ref v) = value {
206 let hashable = HashableValue::from(v);
207 if seen.insert(hashable) {
208 if let Value::Int64(i) = v {
209 *sum += i;
210 *count += 1;
211 } else if let Value::Float64(f) = v {
212 let moved_seen = std::mem::take(seen);
214 *self = AggregateState::SumFloatDistinct(
215 *sum as f64 + f,
216 *count + 1,
217 moved_seen,
218 );
219 } else if let Some(num) = value_to_f64(v) {
220 let moved_seen = std::mem::take(seen);
222 *self = AggregateState::SumFloatDistinct(
223 *sum as f64 + num,
224 *count + 1,
225 moved_seen,
226 );
227 }
228 }
229 }
230 }
231 AggregateState::SumFloat(sum, count) => {
232 if let Some(ref v) = value {
233 if let Some(num) = value_to_f64(v) {
235 *sum += num;
236 *count += 1;
237 }
238 }
239 }
240 AggregateState::SumFloatDistinct(sum, count, seen) => {
241 if let Some(ref v) = value {
242 let hashable = HashableValue::from(v);
243 if seen.insert(hashable)
244 && let Some(num) = value_to_f64(v)
245 {
246 *sum += num;
247 *count += 1;
248 }
249 }
250 }
251 AggregateState::Avg(sum, count) => {
252 if let Some(ref v) = value
253 && let Some(num) = value_to_f64(v)
254 {
255 *sum += num;
256 *count += 1;
257 }
258 }
259 AggregateState::AvgDistinct(sum, count, seen) => {
260 if let Some(ref v) = value {
261 let hashable = HashableValue::from(v);
262 if seen.insert(hashable)
263 && let Some(num) = value_to_f64(v)
264 {
265 *sum += num;
266 *count += 1;
267 }
268 }
269 }
270 AggregateState::Min(min) => {
271 if let Some(v) = value {
272 match min {
273 None => *min = Some(v),
274 Some(current) => {
275 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
276 *min = Some(v);
277 }
278 }
279 }
280 }
281 }
282 AggregateState::Max(max) => {
283 if let Some(v) = value {
284 match max {
285 None => *max = Some(v),
286 Some(current) => {
287 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
288 *max = Some(v);
289 }
290 }
291 }
292 }
293 }
294 AggregateState::First(first) => {
295 if first.is_none() {
296 *first = value;
297 }
298 }
299 AggregateState::Last(last) => {
300 if value.is_some() {
301 *last = value;
302 }
303 }
304 AggregateState::Collect(list) => {
305 if let Some(v) = value {
306 list.push(v);
307 }
308 }
309 AggregateState::CollectDistinct(list, seen) => {
310 if let Some(v) = value {
311 let hashable = HashableValue::from(&v);
312 if seen.insert(hashable) {
313 list.push(v);
314 }
315 }
316 }
317 AggregateState::StdDev { count, mean, m2 }
319 | AggregateState::StdDevPop { count, mean, m2 }
320 | AggregateState::Variance { count, mean, m2 }
321 | AggregateState::VariancePop { count, mean, m2 } => {
322 if let Some(ref v) = value
323 && let Some(x) = value_to_f64(v)
324 {
325 *count += 1;
326 let delta = x - *mean;
327 *mean += delta / *count as f64;
328 let delta2 = x - *mean;
329 *m2 += delta * delta2;
330 }
331 }
332 AggregateState::PercentileDisc { values, .. }
333 | AggregateState::PercentileCont { values, .. } => {
334 if let Some(ref v) = value
335 && let Some(x) = value_to_f64(v)
336 {
337 values.push(x);
338 }
339 }
340 AggregateState::GroupConcat(list, _sep) => {
341 if let Some(v) = value {
342 list.push(agg_value_to_string(&v));
343 }
344 }
345 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
346 if let Some(v) = value {
347 let hashable = HashableValue::from(&v);
348 if seen.insert(hashable) {
349 list.push(agg_value_to_string(&v));
350 }
351 }
352 }
353 AggregateState::Sample(sample) => {
354 if sample.is_none() {
355 *sample = value;
356 }
357 }
358 AggregateState::Bivariate { .. } => {
359 }
362 }
363 }
364
365 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
370 if let AggregateState::Bivariate {
371 count,
372 mean_x,
373 mean_y,
374 m2_x,
375 m2_y,
376 c_xy,
377 ..
378 } = self
379 {
380 if let (Some(y), Some(x)) = (&y_val, &x_val)
382 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
383 {
384 *count += 1;
385 let n = *count as f64;
386 let dx = x_f - *mean_x;
387 let dy = y_f - *mean_y;
388 *mean_x += dx / n;
389 *mean_y += dy / n;
390 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
393 *m2_y += dy * dy2;
394 *c_xy += dx * dy2;
395 }
396 }
397 }
398
399 pub(crate) fn finalize(&self) -> Value {
401 match self {
402 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
403 Value::Int64(*count)
404 }
405 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
406 if *count == 0 {
407 Value::Null
408 } else {
409 Value::Int64(*sum)
410 }
411 }
412 AggregateState::SumFloat(sum, count)
413 | AggregateState::SumFloatDistinct(sum, count, _) => {
414 if *count == 0 {
415 Value::Null
416 } else {
417 Value::Float64(*sum)
418 }
419 }
420 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
421 if *count == 0 {
422 Value::Null
423 } else {
424 Value::Float64(*sum / *count as f64)
425 }
426 }
427 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
428 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
429 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
430 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
431 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
432 Value::List(list.clone().into())
433 }
434 AggregateState::StdDev { count, m2, .. } => {
436 if *count < 2 {
437 Value::Null
438 } else {
439 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
440 }
441 }
442 AggregateState::StdDevPop { count, m2, .. } => {
444 if *count == 0 {
445 Value::Null
446 } else {
447 Value::Float64((*m2 / *count as f64).sqrt())
448 }
449 }
450 AggregateState::Variance { count, m2, .. } => {
452 if *count < 2 {
453 Value::Null
454 } else {
455 Value::Float64(*m2 / (*count - 1) as f64)
456 }
457 }
458 AggregateState::VariancePop { count, m2, .. } => {
460 if *count == 0 {
461 Value::Null
462 } else {
463 Value::Float64(*m2 / *count as f64)
464 }
465 }
466 AggregateState::PercentileDisc { values, percentile } => {
468 if values.is_empty() {
469 Value::Null
470 } else {
471 let mut sorted = values.clone();
472 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
473 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
475 Value::Float64(sorted[index])
476 }
477 }
478 AggregateState::PercentileCont { values, percentile } => {
480 if values.is_empty() {
481 Value::Null
482 } else {
483 let mut sorted = values.clone();
484 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
485 let rank = percentile * (sorted.len() - 1) as f64;
487 let lower_idx = rank.floor() as usize;
488 let upper_idx = rank.ceil() as usize;
489 if lower_idx == upper_idx {
490 Value::Float64(sorted[lower_idx])
491 } else {
492 let fraction = rank - lower_idx as f64;
493 let result =
494 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
495 Value::Float64(result)
496 }
497 }
498 }
499 AggregateState::GroupConcat(list, sep)
501 | AggregateState::GroupConcatDistinct(list, sep, _) => {
502 Value::String(list.join(sep).into())
503 }
504 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
506 AggregateState::Bivariate {
508 kind,
509 count,
510 mean_x,
511 mean_y,
512 m2_x,
513 m2_y,
514 c_xy,
515 } => {
516 let n = *count;
517 match kind {
518 AggregateFunction::CovarSamp => {
519 if n < 2 {
520 Value::Null
521 } else {
522 Value::Float64(*c_xy / (n - 1) as f64)
523 }
524 }
525 AggregateFunction::CovarPop => {
526 if n == 0 {
527 Value::Null
528 } else {
529 Value::Float64(*c_xy / n as f64)
530 }
531 }
532 AggregateFunction::Corr => {
533 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
534 Value::Null
535 } else {
536 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
537 }
538 }
539 AggregateFunction::RegrSlope => {
540 if n == 0 || *m2_x == 0.0 {
541 Value::Null
542 } else {
543 Value::Float64(*c_xy / *m2_x)
544 }
545 }
546 AggregateFunction::RegrIntercept => {
547 if n == 0 || *m2_x == 0.0 {
548 Value::Null
549 } else {
550 let slope = *c_xy / *m2_x;
551 Value::Float64(*mean_y - slope * *mean_x)
552 }
553 }
554 AggregateFunction::RegrR2 => {
555 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
556 Value::Null
557 } else {
558 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
559 }
560 }
561 AggregateFunction::RegrCount => Value::Int64(n),
562 AggregateFunction::RegrSxx => {
563 if n == 0 {
564 Value::Null
565 } else {
566 Value::Float64(*m2_x)
567 }
568 }
569 AggregateFunction::RegrSyy => {
570 if n == 0 {
571 Value::Null
572 } else {
573 Value::Float64(*m2_y)
574 }
575 }
576 AggregateFunction::RegrSxy => {
577 if n == 0 {
578 Value::Null
579 } else {
580 Value::Float64(*c_xy)
581 }
582 }
583 AggregateFunction::RegrAvgx => {
584 if n == 0 {
585 Value::Null
586 } else {
587 Value::Float64(*mean_x)
588 }
589 }
590 AggregateFunction::RegrAvgy => {
591 if n == 0 {
592 Value::Null
593 } else {
594 Value::Float64(*mean_y)
595 }
596 }
597 _ => Value::Null, }
599 }
600 }
601 }
602}
603
604use super::value_utils::{compare_values, value_to_f64};
605
606fn agg_value_to_string(val: &Value) -> String {
608 match val {
609 Value::String(s) => s.to_string(),
610 Value::Int64(i) => i.to_string(),
611 Value::Float64(f) => f.to_string(),
612 Value::Bool(b) => b.to_string(),
613 Value::Null => String::new(),
614 other => format!("{other:?}"),
615 }
616}
617
618#[derive(Debug, Clone, PartialEq, Eq, Hash)]
620pub struct GroupKey(Vec<GroupKeyPart>);
621
622#[derive(Debug, Clone, PartialEq, Eq, Hash)]
623enum GroupKeyPart {
624 Null,
625 Bool(bool),
626 Int64(i64),
627 String(ArcStr),
628 Bytes(Arc<[u8]>),
629 Date(grafeo_common::types::Date),
630 Time(grafeo_common::types::Time),
631 Timestamp(grafeo_common::types::Timestamp),
632 Duration(grafeo_common::types::Duration),
633 ZonedDatetime(grafeo_common::types::ZonedDatetime),
634 List(Vec<GroupKeyPart>),
635 Map(Vec<(ArcStr, GroupKeyPart)>),
636}
637
638impl GroupKeyPart {
639 fn from_value(v: Value) -> Self {
640 match v {
641 Value::Null => Self::Null,
642 Value::Bool(b) => Self::Bool(b),
643 Value::Int64(i) => Self::Int64(i),
644 Value::Float64(f) => Self::Int64(f.to_bits() as i64),
645 Value::String(s) => Self::String(s.clone()),
646 Value::Bytes(b) => Self::Bytes(b),
647 Value::Date(d) => Self::Date(d),
648 Value::Time(t) => Self::Time(t),
649 Value::Timestamp(ts) => Self::Timestamp(ts),
650 Value::Duration(d) => Self::Duration(d),
651 Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
652 Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
653 Value::Map(map) => {
654 let entries: Vec<(ArcStr, GroupKeyPart)> = map
656 .iter()
657 .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
658 .collect();
659 Self::Map(entries)
660 }
661 other => Self::String(ArcStr::from(format!("{other:?}"))),
663 }
664 }
665
666 fn to_value(&self) -> Value {
667 match self {
668 Self::Null => Value::Null,
669 Self::Bool(b) => Value::Bool(*b),
670 Self::Int64(i) => Value::Int64(*i),
671 Self::String(s) => Value::String(s.clone()),
672 Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
673 Self::Date(d) => Value::Date(*d),
674 Self::Time(t) => Value::Time(*t),
675 Self::Timestamp(ts) => Value::Timestamp(*ts),
676 Self::Duration(d) => Value::Duration(*d),
677 Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
678 Self::List(parts) => {
679 let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
680 Value::List(Arc::from(values.into_boxed_slice()))
681 }
682 Self::Map(entries) => {
683 let map: std::collections::BTreeMap<PropertyKey, Value> = entries
684 .iter()
685 .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
686 .collect();
687 Value::Map(Arc::new(map))
688 }
689 }
690 }
691}
692
693impl GroupKey {
694 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
696 let parts: Vec<GroupKeyPart> = group_columns
697 .iter()
698 .map(|&col_idx| {
699 chunk
700 .column(col_idx)
701 .and_then(|col| col.get_value(row))
702 .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
703 })
704 .collect();
705 GroupKey(parts)
706 }
707
708 fn to_values(&self) -> Vec<Value> {
710 self.0.iter().map(GroupKeyPart::to_value).collect()
711 }
712}
713
714pub struct HashAggregateOperator {
718 child: Box<dyn Operator>,
720 group_columns: Vec<usize>,
722 aggregates: Vec<AggregateExpr>,
724 output_schema: Vec<LogicalType>,
726 groups: IndexMap<GroupKey, Vec<AggregateState>>,
728 aggregation_complete: bool,
730 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
732}
733
734impl HashAggregateOperator {
735 pub fn new(
743 child: Box<dyn Operator>,
744 group_columns: Vec<usize>,
745 aggregates: Vec<AggregateExpr>,
746 output_schema: Vec<LogicalType>,
747 ) -> Self {
748 Self {
749 child,
750 group_columns,
751 aggregates,
752 output_schema,
753 groups: IndexMap::new(),
754 aggregation_complete: false,
755 results: None,
756 }
757 }
758
759 fn aggregate(&mut self) -> Result<(), OperatorError> {
761 while let Some(chunk) = self.child.next()? {
762 for row in chunk.selected_indices() {
763 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
764
765 let states = self.groups.entry(key).or_insert_with(|| {
767 self.aggregates
768 .iter()
769 .map(|agg| {
770 AggregateState::new(
771 agg.function,
772 agg.distinct,
773 agg.percentile,
774 agg.separator.as_deref(),
775 )
776 })
777 .collect()
778 });
779
780 for (i, agg) in self.aggregates.iter().enumerate() {
782 if agg.column2.is_some() {
784 let y_val = agg
785 .column
786 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
787 let x_val = agg
788 .column2
789 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
790 states[i].update_bivariate(y_val, x_val);
791 continue;
792 }
793
794 let value = match (agg.function, agg.distinct) {
795 (AggregateFunction::Count, false) => None,
797 (AggregateFunction::Count, true) => agg
799 .column
800 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
801 _ => agg
802 .column
803 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
804 };
805
806 match (agg.function, agg.distinct) {
808 (AggregateFunction::Count, false) => states[i].update(None),
809 (AggregateFunction::Count, true) => {
810 if value.is_some() && !matches!(value, Some(Value::Null)) {
812 states[i].update(value);
813 }
814 }
815 (AggregateFunction::CountNonNull, _) => {
816 if value.is_some() && !matches!(value, Some(Value::Null)) {
817 states[i].update(value);
818 }
819 }
820 _ => {
821 if value.is_some() && !matches!(value, Some(Value::Null)) {
822 states[i].update(value);
823 }
824 }
825 }
826 }
827 }
828 }
829
830 self.aggregation_complete = true;
831
832 let results: Vec<_> = self.groups.drain(..).collect();
834 self.results = Some(results.into_iter());
835
836 Ok(())
837 }
838}
839
840impl Operator for HashAggregateOperator {
841 fn next(&mut self) -> OperatorResult {
842 if !self.aggregation_complete {
844 self.aggregate()?;
845 }
846
847 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
849 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
851
852 for agg in &self.aggregates {
853 let state = AggregateState::new(
854 agg.function,
855 agg.distinct,
856 agg.percentile,
857 agg.separator.as_deref(),
858 );
859 let value = state.finalize();
860 if let Some(col) = builder.column_mut(self.group_columns.len()) {
861 col.push_value(value);
862 }
863 }
864 builder.advance_row();
865
866 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
868 }
869
870 let Some(results) = &mut self.results else {
871 return Ok(None);
872 };
873
874 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
875
876 for (key, states) in results.by_ref() {
877 let key_values = key.to_values();
879 for (i, value) in key_values.into_iter().enumerate() {
880 if let Some(col) = builder.column_mut(i) {
881 col.push_value(value);
882 }
883 }
884
885 for (i, state) in states.iter().enumerate() {
887 let col_idx = self.group_columns.len() + i;
888 if let Some(col) = builder.column_mut(col_idx) {
889 col.push_value(state.finalize());
890 }
891 }
892
893 builder.advance_row();
894
895 if builder.is_full() {
896 return Ok(Some(builder.finish()));
897 }
898 }
899
900 if builder.row_count() > 0 {
901 Ok(Some(builder.finish()))
902 } else {
903 Ok(None)
904 }
905 }
906
907 fn reset(&mut self) {
908 self.child.reset();
909 self.groups.clear();
910 self.aggregation_complete = false;
911 self.results = None;
912 }
913
914 fn name(&self) -> &'static str {
915 "HashAggregate"
916 }
917}
918
919pub struct SimpleAggregateOperator {
923 child: Box<dyn Operator>,
925 aggregates: Vec<AggregateExpr>,
927 output_schema: Vec<LogicalType>,
929 states: Vec<AggregateState>,
931 done: bool,
933}
934
935impl SimpleAggregateOperator {
936 pub fn new(
938 child: Box<dyn Operator>,
939 aggregates: Vec<AggregateExpr>,
940 output_schema: Vec<LogicalType>,
941 ) -> Self {
942 let states = aggregates
943 .iter()
944 .map(|agg| {
945 AggregateState::new(
946 agg.function,
947 agg.distinct,
948 agg.percentile,
949 agg.separator.as_deref(),
950 )
951 })
952 .collect();
953
954 Self {
955 child,
956 aggregates,
957 output_schema,
958 states,
959 done: false,
960 }
961 }
962}
963
964impl Operator for SimpleAggregateOperator {
965 fn next(&mut self) -> OperatorResult {
966 if self.done {
967 return Ok(None);
968 }
969
970 while let Some(chunk) = self.child.next()? {
972 for row in chunk.selected_indices() {
973 for (i, agg) in self.aggregates.iter().enumerate() {
974 if agg.column2.is_some() {
976 let y_val = agg
977 .column
978 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
979 let x_val = agg
980 .column2
981 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
982 self.states[i].update_bivariate(y_val, x_val);
983 continue;
984 }
985
986 let value = match (agg.function, agg.distinct) {
987 (AggregateFunction::Count, false) => None,
989 (AggregateFunction::Count, true) => agg
991 .column
992 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
993 _ => agg
994 .column
995 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
996 };
997
998 match (agg.function, agg.distinct) {
999 (AggregateFunction::Count, false) => self.states[i].update(None),
1000 (AggregateFunction::Count, true) => {
1001 if value.is_some() && !matches!(value, Some(Value::Null)) {
1003 self.states[i].update(value);
1004 }
1005 }
1006 (AggregateFunction::CountNonNull, _) => {
1007 if value.is_some() && !matches!(value, Some(Value::Null)) {
1008 self.states[i].update(value);
1009 }
1010 }
1011 _ => {
1012 if value.is_some() && !matches!(value, Some(Value::Null)) {
1013 self.states[i].update(value);
1014 }
1015 }
1016 }
1017 }
1018 }
1019 }
1020
1021 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1023
1024 for (i, state) in self.states.iter().enumerate() {
1025 if let Some(col) = builder.column_mut(i) {
1026 col.push_value(state.finalize());
1027 }
1028 }
1029 builder.advance_row();
1030
1031 self.done = true;
1032 Ok(Some(builder.finish()))
1033 }
1034
1035 fn reset(&mut self) {
1036 self.child.reset();
1037 self.states = self
1038 .aggregates
1039 .iter()
1040 .map(|agg| {
1041 AggregateState::new(
1042 agg.function,
1043 agg.distinct,
1044 agg.percentile,
1045 agg.separator.as_deref(),
1046 )
1047 })
1048 .collect();
1049 self.done = false;
1050 }
1051
1052 fn name(&self) -> &'static str {
1053 "SimpleAggregate"
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use crate::execution::chunk::DataChunkBuilder;
1061
1062 struct MockOperator {
1063 chunks: Vec<DataChunk>,
1064 position: usize,
1065 }
1066
1067 impl MockOperator {
1068 fn new(chunks: Vec<DataChunk>) -> Self {
1069 Self {
1070 chunks,
1071 position: 0,
1072 }
1073 }
1074 }
1075
1076 impl Operator for MockOperator {
1077 fn next(&mut self) -> OperatorResult {
1078 if self.position < self.chunks.len() {
1079 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1080 self.position += 1;
1081 Ok(Some(chunk))
1082 } else {
1083 Ok(None)
1084 }
1085 }
1086
1087 fn reset(&mut self) {
1088 self.position = 0;
1089 }
1090
1091 fn name(&self) -> &'static str {
1092 "Mock"
1093 }
1094 }
1095
1096 fn create_test_chunk() -> DataChunk {
1097 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1099
1100 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1101 for (group, value) in data {
1102 builder.column_mut(0).unwrap().push_int64(group);
1103 builder.column_mut(1).unwrap().push_int64(value);
1104 builder.advance_row();
1105 }
1106
1107 builder.finish()
1108 }
1109
1110 #[test]
1111 fn test_simple_count() {
1112 let mock = MockOperator::new(vec![create_test_chunk()]);
1113
1114 let mut agg = SimpleAggregateOperator::new(
1115 Box::new(mock),
1116 vec![AggregateExpr::count_star()],
1117 vec![LogicalType::Int64],
1118 );
1119
1120 let result = agg.next().unwrap().unwrap();
1121 assert_eq!(result.row_count(), 1);
1122 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1123
1124 assert!(agg.next().unwrap().is_none());
1126 }
1127
1128 #[test]
1129 fn test_simple_sum() {
1130 let mock = MockOperator::new(vec![create_test_chunk()]);
1131
1132 let mut agg = SimpleAggregateOperator::new(
1133 Box::new(mock),
1134 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1136 );
1137
1138 let result = agg.next().unwrap().unwrap();
1139 assert_eq!(result.row_count(), 1);
1140 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1142 }
1143
1144 #[test]
1145 fn test_simple_avg() {
1146 let mock = MockOperator::new(vec![create_test_chunk()]);
1147
1148 let mut agg = SimpleAggregateOperator::new(
1149 Box::new(mock),
1150 vec![AggregateExpr::avg(1)],
1151 vec![LogicalType::Float64],
1152 );
1153
1154 let result = agg.next().unwrap().unwrap();
1155 assert_eq!(result.row_count(), 1);
1156 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1158 assert!((avg - 30.0).abs() < 0.001);
1159 }
1160
1161 #[test]
1162 fn test_simple_min_max() {
1163 let mock = MockOperator::new(vec![create_test_chunk()]);
1164
1165 let mut agg = SimpleAggregateOperator::new(
1166 Box::new(mock),
1167 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1168 vec![LogicalType::Int64, LogicalType::Int64],
1169 );
1170
1171 let result = agg.next().unwrap().unwrap();
1172 assert_eq!(result.row_count(), 1);
1173 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1176
1177 #[test]
1178 fn test_sum_with_string_values() {
1179 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1181 builder.column_mut(0).unwrap().push_string("30");
1182 builder.advance_row();
1183 builder.column_mut(0).unwrap().push_string("25");
1184 builder.advance_row();
1185 builder.column_mut(0).unwrap().push_string("35");
1186 builder.advance_row();
1187 let chunk = builder.finish();
1188
1189 let mock = MockOperator::new(vec![chunk]);
1190 let mut agg = SimpleAggregateOperator::new(
1191 Box::new(mock),
1192 vec![AggregateExpr::sum(0)],
1193 vec![LogicalType::Float64],
1194 );
1195
1196 let result = agg.next().unwrap().unwrap();
1197 assert_eq!(result.row_count(), 1);
1198 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1200 assert!(
1201 (sum_val - 90.0).abs() < 0.001,
1202 "Expected 90.0, got {}",
1203 sum_val
1204 );
1205 }
1206
1207 #[test]
1208 fn test_grouped_aggregation() {
1209 let mock = MockOperator::new(vec![create_test_chunk()]);
1210
1211 let mut agg = HashAggregateOperator::new(
1213 Box::new(mock),
1214 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1217 );
1218
1219 let mut results: Vec<(i64, i64)> = Vec::new();
1220 while let Some(chunk) = agg.next().unwrap() {
1221 for row in chunk.selected_indices() {
1222 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1223 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1224 results.push((group, sum));
1225 }
1226 }
1227
1228 results.sort_by_key(|(g, _)| *g);
1229 assert_eq!(results.len(), 2);
1230 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1233
1234 #[test]
1235 fn test_grouped_count() {
1236 let mock = MockOperator::new(vec![create_test_chunk()]);
1237
1238 let mut agg = HashAggregateOperator::new(
1240 Box::new(mock),
1241 vec![0],
1242 vec![AggregateExpr::count_star()],
1243 vec![LogicalType::Int64, LogicalType::Int64],
1244 );
1245
1246 let mut results: Vec<(i64, i64)> = Vec::new();
1247 while let Some(chunk) = agg.next().unwrap() {
1248 for row in chunk.selected_indices() {
1249 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1250 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1251 results.push((group, count));
1252 }
1253 }
1254
1255 results.sort_by_key(|(g, _)| *g);
1256 assert_eq!(results.len(), 2);
1257 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1260
1261 #[test]
1262 fn test_multiple_aggregates() {
1263 let mock = MockOperator::new(vec![create_test_chunk()]);
1264
1265 let mut agg = HashAggregateOperator::new(
1267 Box::new(mock),
1268 vec![0],
1269 vec![
1270 AggregateExpr::count_star(),
1271 AggregateExpr::sum(1),
1272 AggregateExpr::avg(1),
1273 ],
1274 vec![
1275 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1280 );
1281
1282 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1283 while let Some(chunk) = agg.next().unwrap() {
1284 for row in chunk.selected_indices() {
1285 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1286 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1287 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1288 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1289 results.push((group, count, sum, avg));
1290 }
1291 }
1292
1293 results.sort_by_key(|(g, _, _, _)| *g);
1294 assert_eq!(results.len(), 2);
1295
1296 assert_eq!(results[0].0, 1);
1298 assert_eq!(results[0].1, 2);
1299 assert_eq!(results[0].2, 30);
1300 assert!((results[0].3 - 15.0).abs() < 0.001);
1301
1302 assert_eq!(results[1].0, 2);
1304 assert_eq!(results[1].1, 3);
1305 assert_eq!(results[1].2, 120);
1306 assert!((results[1].3 - 40.0).abs() < 0.001);
1307 }
1308
1309 fn create_test_chunk_with_duplicates() -> DataChunk {
1310 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1315
1316 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1317 for (group, value) in data {
1318 builder.column_mut(0).unwrap().push_int64(group);
1319 builder.column_mut(1).unwrap().push_int64(value);
1320 builder.advance_row();
1321 }
1322
1323 builder.finish()
1324 }
1325
1326 #[test]
1327 fn test_count_distinct() {
1328 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1329
1330 let mut agg = SimpleAggregateOperator::new(
1332 Box::new(mock),
1333 vec![AggregateExpr::count(1).with_distinct()],
1334 vec![LogicalType::Int64],
1335 );
1336
1337 let result = agg.next().unwrap().unwrap();
1338 assert_eq!(result.row_count(), 1);
1339 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1341 }
1342
1343 #[test]
1344 fn test_grouped_count_distinct() {
1345 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1346
1347 let mut agg = HashAggregateOperator::new(
1349 Box::new(mock),
1350 vec![0],
1351 vec![AggregateExpr::count(1).with_distinct()],
1352 vec![LogicalType::Int64, LogicalType::Int64],
1353 );
1354
1355 let mut results: Vec<(i64, i64)> = Vec::new();
1356 while let Some(chunk) = agg.next().unwrap() {
1357 for row in chunk.selected_indices() {
1358 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1359 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1360 results.push((group, count));
1361 }
1362 }
1363
1364 results.sort_by_key(|(g, _)| *g);
1365 assert_eq!(results.len(), 2);
1366 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1369
1370 #[test]
1371 fn test_sum_distinct() {
1372 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1373
1374 let mut agg = SimpleAggregateOperator::new(
1376 Box::new(mock),
1377 vec![AggregateExpr::sum(1).with_distinct()],
1378 vec![LogicalType::Int64],
1379 );
1380
1381 let result = agg.next().unwrap().unwrap();
1382 assert_eq!(result.row_count(), 1);
1383 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1385 }
1386
1387 #[test]
1388 fn test_avg_distinct() {
1389 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1390
1391 let mut agg = SimpleAggregateOperator::new(
1393 Box::new(mock),
1394 vec![AggregateExpr::avg(1).with_distinct()],
1395 vec![LogicalType::Float64],
1396 );
1397
1398 let result = agg.next().unwrap().unwrap();
1399 assert_eq!(result.row_count(), 1);
1400 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1402 assert!((avg - 20.0).abs() < 0.001);
1403 }
1404
1405 fn create_statistical_test_chunk() -> DataChunk {
1406 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1409
1410 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1411 builder.column_mut(0).unwrap().push_int64(value);
1412 builder.advance_row();
1413 }
1414
1415 builder.finish()
1416 }
1417
1418 #[test]
1419 fn test_stdev_sample() {
1420 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1421
1422 let mut agg = SimpleAggregateOperator::new(
1423 Box::new(mock),
1424 vec![AggregateExpr::stdev(0)],
1425 vec![LogicalType::Float64],
1426 );
1427
1428 let result = agg.next().unwrap().unwrap();
1429 assert_eq!(result.row_count(), 1);
1430 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1433 assert!((stdev - 2.138).abs() < 0.01);
1434 }
1435
1436 #[test]
1437 fn test_stdev_population() {
1438 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1439
1440 let mut agg = SimpleAggregateOperator::new(
1441 Box::new(mock),
1442 vec![AggregateExpr::stdev_pop(0)],
1443 vec![LogicalType::Float64],
1444 );
1445
1446 let result = agg.next().unwrap().unwrap();
1447 assert_eq!(result.row_count(), 1);
1448 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1451 assert!((stdev - 2.0).abs() < 0.01);
1452 }
1453
1454 #[test]
1455 fn test_percentile_disc() {
1456 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1457
1458 let mut agg = SimpleAggregateOperator::new(
1460 Box::new(mock),
1461 vec![AggregateExpr::percentile_disc(0, 0.5)],
1462 vec![LogicalType::Float64],
1463 );
1464
1465 let result = agg.next().unwrap().unwrap();
1466 assert_eq!(result.row_count(), 1);
1467 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1469 assert!((percentile - 4.0).abs() < 0.01);
1470 }
1471
1472 #[test]
1473 fn test_percentile_cont() {
1474 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1475
1476 let mut agg = SimpleAggregateOperator::new(
1478 Box::new(mock),
1479 vec![AggregateExpr::percentile_cont(0, 0.5)],
1480 vec![LogicalType::Float64],
1481 );
1482
1483 let result = agg.next().unwrap().unwrap();
1484 assert_eq!(result.row_count(), 1);
1485 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1488 assert!((percentile - 4.5).abs() < 0.01);
1489 }
1490
1491 #[test]
1492 fn test_percentile_extremes() {
1493 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1495
1496 let mut agg = SimpleAggregateOperator::new(
1497 Box::new(mock),
1498 vec![
1499 AggregateExpr::percentile_disc(0, 0.0),
1500 AggregateExpr::percentile_disc(0, 1.0),
1501 ],
1502 vec![LogicalType::Float64, LogicalType::Float64],
1503 );
1504
1505 let result = agg.next().unwrap().unwrap();
1506 assert_eq!(result.row_count(), 1);
1507 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1509 assert!((p0 - 2.0).abs() < 0.01);
1510 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1512 assert!((p100 - 9.0).abs() < 0.01);
1513 }
1514
1515 #[test]
1516 fn test_stdev_single_value() {
1517 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1519 builder.column_mut(0).unwrap().push_int64(42);
1520 builder.advance_row();
1521 let chunk = builder.finish();
1522
1523 let mock = MockOperator::new(vec![chunk]);
1524
1525 let mut agg = SimpleAggregateOperator::new(
1526 Box::new(mock),
1527 vec![AggregateExpr::stdev(0)],
1528 vec![LogicalType::Float64],
1529 );
1530
1531 let result = agg.next().unwrap().unwrap();
1532 assert_eq!(result.row_count(), 1);
1533 assert!(matches!(
1535 result.column(0).unwrap().get_value(0),
1536 Some(Value::Null)
1537 ));
1538 }
1539
1540 #[test]
1541 fn test_first_and_last() {
1542 let mock = MockOperator::new(vec![create_test_chunk()]);
1543
1544 let mut agg = SimpleAggregateOperator::new(
1545 Box::new(mock),
1546 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1547 vec![LogicalType::Int64, LogicalType::Int64],
1548 );
1549
1550 let result = agg.next().unwrap().unwrap();
1551 assert_eq!(result.row_count(), 1);
1552 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1554 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1555 }
1556
1557 #[test]
1558 fn test_collect() {
1559 let mock = MockOperator::new(vec![create_test_chunk()]);
1560
1561 let mut agg = SimpleAggregateOperator::new(
1562 Box::new(mock),
1563 vec![AggregateExpr::collect(1)],
1564 vec![LogicalType::Any],
1565 );
1566
1567 let result = agg.next().unwrap().unwrap();
1568 let val = result.column(0).unwrap().get_value(0).unwrap();
1569 if let Value::List(items) = val {
1570 assert_eq!(items.len(), 5);
1571 } else {
1572 panic!("Expected List value");
1573 }
1574 }
1575
1576 #[test]
1577 fn test_collect_distinct() {
1578 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1579
1580 let mut agg = SimpleAggregateOperator::new(
1581 Box::new(mock),
1582 vec![AggregateExpr::collect(1).with_distinct()],
1583 vec![LogicalType::Any],
1584 );
1585
1586 let result = agg.next().unwrap().unwrap();
1587 let val = result.column(0).unwrap().get_value(0).unwrap();
1588 if let Value::List(items) = val {
1589 assert_eq!(items.len(), 3);
1591 } else {
1592 panic!("Expected List value");
1593 }
1594 }
1595
1596 #[test]
1597 fn test_group_concat() {
1598 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1599 for s in ["hello", "world", "foo"] {
1600 builder.column_mut(0).unwrap().push_string(s);
1601 builder.advance_row();
1602 }
1603 let chunk = builder.finish();
1604 let mock = MockOperator::new(vec![chunk]);
1605
1606 let agg_expr = AggregateExpr {
1607 function: AggregateFunction::GroupConcat,
1608 column: Some(0),
1609 column2: None,
1610 distinct: false,
1611 alias: None,
1612 percentile: None,
1613 separator: None,
1614 };
1615
1616 let mut agg =
1617 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1618
1619 let result = agg.next().unwrap().unwrap();
1620 let val = result.column(0).unwrap().get_value(0).unwrap();
1621 assert_eq!(val, Value::String("hello world foo".into()));
1622 }
1623
1624 #[test]
1625 fn test_sample() {
1626 let mock = MockOperator::new(vec![create_test_chunk()]);
1627
1628 let agg_expr = AggregateExpr {
1629 function: AggregateFunction::Sample,
1630 column: Some(1),
1631 column2: None,
1632 distinct: false,
1633 alias: None,
1634 percentile: None,
1635 separator: None,
1636 };
1637
1638 let mut agg =
1639 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1640
1641 let result = agg.next().unwrap().unwrap();
1642 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1644 }
1645
1646 #[test]
1647 fn test_variance_sample() {
1648 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1649
1650 let agg_expr = AggregateExpr {
1651 function: AggregateFunction::Variance,
1652 column: Some(0),
1653 column2: None,
1654 distinct: false,
1655 alias: None,
1656 percentile: None,
1657 separator: None,
1658 };
1659
1660 let mut agg = SimpleAggregateOperator::new(
1661 Box::new(mock),
1662 vec![agg_expr],
1663 vec![LogicalType::Float64],
1664 );
1665
1666 let result = agg.next().unwrap().unwrap();
1667 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1669 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1670 }
1671
1672 #[test]
1673 fn test_variance_population() {
1674 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1675
1676 let agg_expr = AggregateExpr {
1677 function: AggregateFunction::VariancePop,
1678 column: Some(0),
1679 column2: None,
1680 distinct: false,
1681 alias: None,
1682 percentile: None,
1683 separator: None,
1684 };
1685
1686 let mut agg = SimpleAggregateOperator::new(
1687 Box::new(mock),
1688 vec![agg_expr],
1689 vec![LogicalType::Float64],
1690 );
1691
1692 let result = agg.next().unwrap().unwrap();
1693 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1695 assert!((variance - 4.0).abs() < 0.01);
1696 }
1697
1698 #[test]
1699 fn test_variance_single_value() {
1700 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1701 builder.column_mut(0).unwrap().push_int64(42);
1702 builder.advance_row();
1703 let chunk = builder.finish();
1704 let mock = MockOperator::new(vec![chunk]);
1705
1706 let agg_expr = AggregateExpr {
1707 function: AggregateFunction::Variance,
1708 column: Some(0),
1709 column2: None,
1710 distinct: false,
1711 alias: None,
1712 percentile: None,
1713 separator: None,
1714 };
1715
1716 let mut agg = SimpleAggregateOperator::new(
1717 Box::new(mock),
1718 vec![agg_expr],
1719 vec![LogicalType::Float64],
1720 );
1721
1722 let result = agg.next().unwrap().unwrap();
1723 assert!(matches!(
1725 result.column(0).unwrap().get_value(0),
1726 Some(Value::Null)
1727 ));
1728 }
1729
1730 #[test]
1731 fn test_empty_aggregation() {
1732 let mock = MockOperator::new(vec![]);
1735
1736 let mut agg = SimpleAggregateOperator::new(
1737 Box::new(mock),
1738 vec![
1739 AggregateExpr::count_star(),
1740 AggregateExpr::sum(0),
1741 AggregateExpr::avg(0),
1742 AggregateExpr::min(0),
1743 AggregateExpr::max(0),
1744 ],
1745 vec![
1746 LogicalType::Int64,
1747 LogicalType::Int64,
1748 LogicalType::Float64,
1749 LogicalType::Int64,
1750 LogicalType::Int64,
1751 ],
1752 );
1753
1754 let result = agg.next().unwrap().unwrap();
1755 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1757 result.column(1).unwrap().get_value(0),
1758 Some(Value::Null)
1759 )); assert!(matches!(
1761 result.column(2).unwrap().get_value(0),
1762 Some(Value::Null)
1763 )); assert!(matches!(
1765 result.column(3).unwrap().get_value(0),
1766 Some(Value::Null)
1767 )); assert!(matches!(
1769 result.column(4).unwrap().get_value(0),
1770 Some(Value::Null)
1771 )); }
1773
1774 #[test]
1775 fn test_stdev_pop_single_value() {
1776 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1778 builder.column_mut(0).unwrap().push_int64(42);
1779 builder.advance_row();
1780 let chunk = builder.finish();
1781
1782 let mock = MockOperator::new(vec![chunk]);
1783
1784 let mut agg = SimpleAggregateOperator::new(
1785 Box::new(mock),
1786 vec![AggregateExpr::stdev_pop(0)],
1787 vec![LogicalType::Float64],
1788 );
1789
1790 let result = agg.next().unwrap().unwrap();
1791 assert_eq!(result.row_count(), 1);
1792 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1794 assert!((stdev - 0.0).abs() < 0.01);
1795 }
1796}