1use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22#[derive(Debug, Clone)]
30#[allow(missing_docs)]
31pub enum AggregateState {
32 Count(i64),
34 CountDistinct(i64, HashSet<HashableValue>),
36 SumInt(i64, i64),
38 SumIntDistinct(i64, i64, HashSet<HashableValue>),
40 SumFloat(f64, f64, i64),
42 SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
44 Avg(f64, i64),
46 AvgDistinct(f64, i64, HashSet<HashableValue>),
48 Min(Option<Value>),
50 Max(Option<Value>),
52 First(Option<Value>),
54 Last(Option<Value>),
56 Collect(Vec<Value>),
58 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
60 StdDev { count: i64, mean: f64, m2: f64 },
62 StdDevPop { count: i64, mean: f64, m2: f64 },
64 PercentileDisc { values: Vec<f64>, percentile: f64 },
66 PercentileCont { values: Vec<f64>, percentile: f64 },
68 GroupConcat(Vec<String>, String),
70 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
72 Sample(Option<Value>),
74 Variance { count: i64, mean: f64, m2: f64 },
76 VariancePop { count: i64, mean: f64, m2: f64 },
78 Bivariate {
80 kind: AggregateFunction,
82 count: i64,
83 mean_x: f64,
84 mean_y: f64,
85 m2_x: f64,
86 m2_y: f64,
87 c_xy: f64,
88 },
89 Frozen(Value),
93}
94
95impl AggregateState {
96 pub fn new(
98 function: AggregateFunction,
99 distinct: bool,
100 percentile: Option<f64>,
101 separator: Option<&str>,
102 ) -> Self {
103 match (function, distinct) {
104 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
105 AggregateState::Count(0)
106 }
107 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
108 AggregateState::CountDistinct(0, HashSet::new())
109 }
110 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
111 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
112 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
113 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
114 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
116 (AggregateFunction::First, _) => AggregateState::First(None),
117 (AggregateFunction::Last, _) => AggregateState::Last(None),
118 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
119 (AggregateFunction::Collect, true) => {
120 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
121 }
122 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
124 count: 0,
125 mean: 0.0,
126 m2: 0.0,
127 },
128 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
129 count: 0,
130 mean: 0.0,
131 m2: 0.0,
132 },
133 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
134 values: Vec::new(),
135 percentile: percentile.unwrap_or(0.5),
136 },
137 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
138 values: Vec::new(),
139 percentile: percentile.unwrap_or(0.5),
140 },
141 (AggregateFunction::GroupConcat, false) => {
142 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
143 }
144 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
145 Vec::new(),
146 separator.unwrap_or(" ").to_string(),
147 HashSet::new(),
148 ),
149 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
150 (
152 AggregateFunction::CovarSamp
153 | AggregateFunction::CovarPop
154 | AggregateFunction::Corr
155 | AggregateFunction::RegrSlope
156 | AggregateFunction::RegrIntercept
157 | AggregateFunction::RegrR2
158 | AggregateFunction::RegrCount
159 | AggregateFunction::RegrSxx
160 | AggregateFunction::RegrSyy
161 | AggregateFunction::RegrSxy
162 | AggregateFunction::RegrAvgx
163 | AggregateFunction::RegrAvgy,
164 _,
165 ) => AggregateState::Bivariate {
166 kind: function,
167 count: 0,
168 mean_x: 0.0,
169 mean_y: 0.0,
170 m2_x: 0.0,
171 m2_y: 0.0,
172 c_xy: 0.0,
173 },
174 (AggregateFunction::Variance, _) => AggregateState::Variance {
175 count: 0,
176 mean: 0.0,
177 m2: 0.0,
178 },
179 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
180 count: 0,
181 mean: 0.0,
182 m2: 0.0,
183 },
184 }
185 }
186
187 pub fn update(&mut self, value: Option<Value>) {
192 match self {
193 AggregateState::Count(count) => {
194 *count += 1;
195 }
196 AggregateState::CountDistinct(count, seen) => {
197 if let Some(ref v) = value {
198 let hashable = HashableValue::from(v);
199 if seen.insert(hashable) {
200 *count += 1;
201 }
202 }
203 }
204 AggregateState::SumInt(sum, count) => {
205 if let Some(Value::Int64(v)) = value {
206 *sum += v;
207 *count += 1;
208 } else if let Some(Value::Float64(v)) = value {
209 *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
211 } else if let Some(ref v) = value {
212 if let Some(num) = value_to_f64(v) {
214 *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
215 }
216 }
217 }
218 AggregateState::SumIntDistinct(sum, count, seen) => {
219 if let Some(ref v) = value {
220 let hashable = HashableValue::from(v);
221 if seen.insert(hashable) {
222 if let Value::Int64(i) = v {
223 *sum += i;
224 *count += 1;
225 } else if let Value::Float64(f) = v {
226 let moved_seen = std::mem::take(seen);
228 *self = AggregateState::SumFloatDistinct(
229 *sum as f64 + f,
230 0.0,
231 *count + 1,
232 moved_seen,
233 );
234 } else if let Some(num) = value_to_f64(v) {
235 let moved_seen = std::mem::take(seen);
237 *self = AggregateState::SumFloatDistinct(
238 *sum as f64 + num,
239 0.0,
240 *count + 1,
241 moved_seen,
242 );
243 }
244 }
245 }
246 }
247 AggregateState::SumFloat(sum, comp, count) => {
248 if let Some(ref v) = value {
249 if let Some(num) = value_to_f64(v) {
251 let y = num - *comp;
253 let t = *sum + y;
254 *comp = (t - *sum) - y;
255 *sum = t;
256 *count += 1;
257 }
258 }
259 }
260 AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
261 if let Some(ref v) = value {
262 let hashable = HashableValue::from(v);
263 if seen.insert(hashable)
264 && let Some(num) = value_to_f64(v)
265 {
266 let y = num - *comp;
267 let t = *sum + y;
268 *comp = (t - *sum) - y;
269 *sum = t;
270 *count += 1;
271 }
272 }
273 }
274 AggregateState::Avg(sum, count) => {
275 if let Some(ref v) = value
276 && let Some(num) = value_to_f64(v)
277 {
278 *sum += num;
279 *count += 1;
280 }
281 }
282 AggregateState::AvgDistinct(sum, count, seen) => {
283 if let Some(ref v) = value {
284 let hashable = HashableValue::from(v);
285 if seen.insert(hashable)
286 && let Some(num) = value_to_f64(v)
287 {
288 *sum += num;
289 *count += 1;
290 }
291 }
292 }
293 AggregateState::Min(min) => {
294 if let Some(v) = value {
295 match min {
296 None => *min = Some(v),
297 Some(current) => {
298 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
299 *min = Some(v);
300 }
301 }
302 }
303 }
304 }
305 AggregateState::Max(max) => {
306 if let Some(v) = value {
307 match max {
308 None => *max = Some(v),
309 Some(current) => {
310 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
311 *max = Some(v);
312 }
313 }
314 }
315 }
316 }
317 AggregateState::First(first) => {
318 if first.is_none() {
319 *first = value;
320 }
321 }
322 AggregateState::Last(last) => {
323 if value.is_some() {
324 *last = value;
325 }
326 }
327 AggregateState::Collect(list) => {
328 if let Some(v) = value {
329 list.push(v);
330 }
331 }
332 AggregateState::CollectDistinct(list, seen) => {
333 if let Some(v) = value {
334 let hashable = HashableValue::from(&v);
335 if seen.insert(hashable) {
336 list.push(v);
337 }
338 }
339 }
340 AggregateState::StdDev { count, mean, m2 }
342 | AggregateState::StdDevPop { count, mean, m2 }
343 | AggregateState::Variance { count, mean, m2 }
344 | AggregateState::VariancePop { count, mean, m2 } => {
345 if let Some(ref v) = value
346 && let Some(x) = value_to_f64(v)
347 {
348 *count += 1;
349 let delta = x - *mean;
350 *mean += delta / *count as f64;
351 let delta2 = x - *mean;
352 *m2 += delta * delta2;
353 }
354 }
355 AggregateState::PercentileDisc { values, .. }
356 | AggregateState::PercentileCont { values, .. } => {
357 if let Some(ref v) = value
358 && let Some(x) = value_to_f64(v)
359 {
360 values.push(x);
361 }
362 }
363 AggregateState::GroupConcat(list, _sep) => {
364 if let Some(v) = value {
365 list.push(agg_value_to_string(&v));
366 }
367 }
368 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
369 if let Some(v) = value {
370 let hashable = HashableValue::from(&v);
371 if seen.insert(hashable) {
372 list.push(agg_value_to_string(&v));
373 }
374 }
375 }
376 AggregateState::Sample(sample) => {
377 if sample.is_none() {
378 *sample = value;
379 }
380 }
381 AggregateState::Bivariate { .. } => {
382 }
385 AggregateState::Frozen(_) => {}
386 }
387 }
388
389 pub fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
394 if let AggregateState::Bivariate {
395 count,
396 mean_x,
397 mean_y,
398 m2_x,
399 m2_y,
400 c_xy,
401 ..
402 } = self
403 {
404 if let (Some(y), Some(x)) = (&y_val, &x_val)
406 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
407 {
408 *count += 1;
409 let n = *count as f64;
410 let dx = x_f - *mean_x;
411 let dy = y_f - *mean_y;
412 *mean_x += dx / n;
413 *mean_y += dy / n;
414 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
417 *m2_y += dy * dy2;
418 *c_xy += dx * dy2;
419 }
420 }
421 }
422
423 pub fn finalize(&self) -> Value {
425 match self {
426 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
427 Value::Int64(*count)
428 }
429 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
430 if *count == 0 {
431 Value::Null
432 } else {
433 Value::Int64(*sum)
434 }
435 }
436 AggregateState::SumFloat(sum, _, count)
437 | AggregateState::SumFloatDistinct(sum, _, count, _) => {
438 if *count == 0 {
439 Value::Null
440 } else {
441 Value::Float64(*sum)
442 }
443 }
444 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
445 if *count == 0 {
446 Value::Null
447 } else {
448 Value::Float64(*sum / *count as f64)
449 }
450 }
451 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
452 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
453 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
454 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
455 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
456 Value::List(list.clone().into())
457 }
458 AggregateState::StdDev { count, m2, .. } => {
460 if *count < 2 {
461 Value::Null
462 } else {
463 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
464 }
465 }
466 AggregateState::StdDevPop { count, m2, .. } => {
468 if *count == 0 {
469 Value::Null
470 } else {
471 Value::Float64((*m2 / *count as f64).sqrt())
472 }
473 }
474 AggregateState::Variance { count, m2, .. } => {
476 if *count < 2 {
477 Value::Null
478 } else {
479 Value::Float64(*m2 / (*count - 1) as f64)
480 }
481 }
482 AggregateState::VariancePop { count, m2, .. } => {
484 if *count == 0 {
485 Value::Null
486 } else {
487 Value::Float64(*m2 / *count as f64)
488 }
489 }
490 AggregateState::PercentileDisc { values, percentile } => {
492 if values.is_empty() {
493 Value::Null
494 } else {
495 let mut sorted = values.clone();
496 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
500 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
501 Value::Float64(sorted[index])
502 }
503 }
504 AggregateState::PercentileCont { values, percentile } => {
506 if values.is_empty() {
507 Value::Null
508 } else {
509 let mut sorted = values.clone();
510 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
511 let rank = percentile * (sorted.len() - 1) as f64;
513 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
515 let lower_idx = rank.floor() as usize;
516 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
518 let upper_idx = rank.ceil() as usize;
519 if lower_idx == upper_idx {
520 Value::Float64(sorted[lower_idx])
521 } else {
522 let fraction = rank - lower_idx as f64;
523 let result =
524 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
525 Value::Float64(result)
526 }
527 }
528 }
529 AggregateState::GroupConcat(list, sep)
531 | AggregateState::GroupConcatDistinct(list, sep, _) => {
532 Value::String(list.join(sep).into())
533 }
534 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
536 AggregateState::Frozen(val) => val.clone(),
537 AggregateState::Bivariate {
539 kind,
540 count,
541 mean_x,
542 mean_y,
543 m2_x,
544 m2_y,
545 c_xy,
546 } => {
547 let n = *count;
548 match kind {
549 AggregateFunction::CovarSamp => {
550 if n < 2 {
551 Value::Null
552 } else {
553 Value::Float64(*c_xy / (n - 1) as f64)
554 }
555 }
556 AggregateFunction::CovarPop => {
557 if n == 0 {
558 Value::Null
559 } else {
560 Value::Float64(*c_xy / n as f64)
561 }
562 }
563 AggregateFunction::Corr => {
564 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
565 Value::Null
566 } else {
567 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
568 }
569 }
570 AggregateFunction::RegrSlope => {
571 if n == 0 || *m2_x == 0.0 {
572 Value::Null
573 } else {
574 Value::Float64(*c_xy / *m2_x)
575 }
576 }
577 AggregateFunction::RegrIntercept => {
578 if n == 0 || *m2_x == 0.0 {
579 Value::Null
580 } else {
581 let slope = *c_xy / *m2_x;
582 Value::Float64(*mean_y - slope * *mean_x)
583 }
584 }
585 AggregateFunction::RegrR2 => {
586 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
587 Value::Null
588 } else {
589 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
590 }
591 }
592 AggregateFunction::RegrCount => Value::Int64(n),
593 AggregateFunction::RegrSxx => {
594 if n == 0 {
595 Value::Null
596 } else {
597 Value::Float64(*m2_x)
598 }
599 }
600 AggregateFunction::RegrSyy => {
601 if n == 0 {
602 Value::Null
603 } else {
604 Value::Float64(*m2_y)
605 }
606 }
607 AggregateFunction::RegrSxy => {
608 if n == 0 {
609 Value::Null
610 } else {
611 Value::Float64(*c_xy)
612 }
613 }
614 AggregateFunction::RegrAvgx => {
615 if n == 0 {
616 Value::Null
617 } else {
618 Value::Float64(*mean_x)
619 }
620 }
621 AggregateFunction::RegrAvgy => {
622 if n == 0 {
623 Value::Null
624 } else {
625 Value::Float64(*mean_y)
626 }
627 }
628 _ => Value::Null, }
630 }
631 }
632 }
633}
634
635use super::value_utils::{compare_values, value_to_f64};
636
637fn agg_value_to_string(val: &Value) -> String {
639 match val {
640 Value::String(s) => s.to_string(),
641 Value::Int64(i) => i.to_string(),
642 Value::Float64(f) => f.to_string(),
643 Value::Bool(b) => b.to_string(),
644 Value::Null => String::new(),
645 other => format!("{other:?}"),
646 }
647}
648
649#[derive(Debug, Clone, PartialEq, Eq, Hash)]
651pub struct GroupKey(Vec<GroupKeyPart>);
652
653#[derive(Debug, Clone, PartialEq, Eq, Hash)]
654enum GroupKeyPart {
655 Null,
656 Bool(bool),
657 Int64(i64),
658 String(ArcStr),
659 Bytes(Arc<[u8]>),
660 Date(grafeo_common::types::Date),
661 Time(grafeo_common::types::Time),
662 Timestamp(grafeo_common::types::Timestamp),
663 Duration(grafeo_common::types::Duration),
664 ZonedDatetime(grafeo_common::types::ZonedDatetime),
665 List(Vec<GroupKeyPart>),
666 Map(Vec<(ArcStr, GroupKeyPart)>),
667}
668
669impl GroupKeyPart {
670 fn from_value(v: Value) -> Self {
671 match v {
672 Value::Null => Self::Null,
673 Value::Bool(b) => Self::Bool(b),
674 Value::Int64(i) => Self::Int64(i),
675 #[allow(clippy::cast_possible_wrap)]
677 Value::Float64(f) => Self::Int64(f.to_bits() as i64),
678 Value::String(s) => Self::String(s.clone()),
679 Value::Bytes(b) => Self::Bytes(b),
680 Value::Date(d) => Self::Date(d),
681 Value::Time(t) => Self::Time(t),
682 Value::Timestamp(ts) => Self::Timestamp(ts),
683 Value::Duration(d) => Self::Duration(d),
684 Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
685 Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
686 Value::Map(map) => {
687 let entries: Vec<(ArcStr, GroupKeyPart)> = map
689 .iter()
690 .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
691 .collect();
692 Self::Map(entries)
693 }
694 other => Self::String(ArcStr::from(format!("{other:?}"))),
696 }
697 }
698
699 fn to_value(&self) -> Value {
700 match self {
701 Self::Null => Value::Null,
702 Self::Bool(b) => Value::Bool(*b),
703 Self::Int64(i) => Value::Int64(*i),
704 Self::String(s) => Value::String(s.clone()),
705 Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
706 Self::Date(d) => Value::Date(*d),
707 Self::Time(t) => Value::Time(*t),
708 Self::Timestamp(ts) => Value::Timestamp(*ts),
709 Self::Duration(d) => Value::Duration(*d),
710 Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
711 Self::List(parts) => {
712 let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
713 Value::List(Arc::from(values.into_boxed_slice()))
714 }
715 Self::Map(entries) => {
716 let map: std::collections::BTreeMap<PropertyKey, Value> = entries
717 .iter()
718 .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
719 .collect();
720 Value::Map(Arc::new(map))
721 }
722 }
723 }
724}
725
726impl GroupKey {
727 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
729 let parts: Vec<GroupKeyPart> = group_columns
730 .iter()
731 .map(|&col_idx| {
732 chunk
733 .column(col_idx)
734 .and_then(|col| col.get_value(row))
735 .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
736 })
737 .collect();
738 GroupKey(parts)
739 }
740
741 fn to_values(&self) -> Vec<Value> {
743 self.0.iter().map(GroupKeyPart::to_value).collect()
744 }
745}
746
747pub struct HashAggregateOperator {
751 child: Box<dyn Operator>,
753 group_columns: Vec<usize>,
755 aggregates: Vec<AggregateExpr>,
757 output_schema: Vec<LogicalType>,
759 groups: IndexMap<GroupKey, Vec<AggregateState>>,
761 aggregation_complete: bool,
763 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
765}
766
767impl HashAggregateOperator {
768 pub fn new(
776 child: Box<dyn Operator>,
777 group_columns: Vec<usize>,
778 aggregates: Vec<AggregateExpr>,
779 output_schema: Vec<LogicalType>,
780 ) -> Self {
781 Self {
782 child,
783 group_columns,
784 aggregates,
785 output_schema,
786 groups: IndexMap::new(),
787 aggregation_complete: false,
788 results: None,
789 }
790 }
791
792 pub fn into_parts(self) -> (Box<dyn Operator>, Vec<usize>, Vec<AggregateExpr>) {
794 (self.child, self.group_columns, self.aggregates)
795 }
796
797 fn aggregate(&mut self) -> Result<(), OperatorError> {
799 while let Some(chunk) = self.child.next()? {
800 for row in chunk.selected_indices() {
801 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
802
803 let states = self.groups.entry(key).or_insert_with(|| {
805 self.aggregates
806 .iter()
807 .map(|agg| {
808 AggregateState::new(
809 agg.function,
810 agg.distinct,
811 agg.percentile,
812 agg.separator.as_deref(),
813 )
814 })
815 .collect()
816 });
817
818 for (i, agg) in self.aggregates.iter().enumerate() {
820 if agg.column2.is_some() {
822 let y_val = agg
823 .column
824 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
825 let x_val = agg
826 .column2
827 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
828 states[i].update_bivariate(y_val, x_val);
829 continue;
830 }
831
832 let value = match (agg.function, agg.distinct) {
833 (AggregateFunction::Count, false) => None,
835 (AggregateFunction::Count, true) => agg
837 .column
838 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
839 _ => agg
840 .column
841 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
842 };
843
844 match (agg.function, agg.distinct) {
846 (AggregateFunction::Count, false) => states[i].update(None),
847 (AggregateFunction::Count, true) => {
848 if value.is_some() && !matches!(value, Some(Value::Null)) {
850 states[i].update(value);
851 }
852 }
853 (AggregateFunction::CountNonNull, _) => {
854 if value.is_some() && !matches!(value, Some(Value::Null)) {
855 states[i].update(value);
856 }
857 }
858 _ => {
859 if value.is_some() && !matches!(value, Some(Value::Null)) {
860 states[i].update(value);
861 }
862 }
863 }
864 }
865 }
866 }
867
868 self.aggregation_complete = true;
869
870 let results: Vec<_> = self.groups.drain(..).collect();
872 self.results = Some(results.into_iter());
873
874 Ok(())
875 }
876}
877
878impl Operator for HashAggregateOperator {
879 fn next(&mut self) -> OperatorResult {
880 if !self.aggregation_complete {
882 self.aggregate()?;
883 }
884
885 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
887 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
889
890 for agg in &self.aggregates {
891 let state = AggregateState::new(
892 agg.function,
893 agg.distinct,
894 agg.percentile,
895 agg.separator.as_deref(),
896 );
897 let value = state.finalize();
898 if let Some(col) = builder.column_mut(self.group_columns.len()) {
899 col.push_value(value);
900 }
901 }
902 builder.advance_row();
903
904 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
906 }
907
908 let Some(results) = &mut self.results else {
909 return Ok(None);
910 };
911
912 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
913
914 for (key, states) in results.by_ref() {
915 let key_values = key.to_values();
917 for (i, value) in key_values.into_iter().enumerate() {
918 if let Some(col) = builder.column_mut(i) {
919 col.push_value(value);
920 }
921 }
922
923 for (i, state) in states.iter().enumerate() {
925 let col_idx = self.group_columns.len() + i;
926 if let Some(col) = builder.column_mut(col_idx) {
927 col.push_value(state.finalize());
928 }
929 }
930
931 builder.advance_row();
932
933 if builder.is_full() {
934 return Ok(Some(builder.finish()));
935 }
936 }
937
938 if builder.row_count() > 0 {
939 Ok(Some(builder.finish()))
940 } else {
941 Ok(None)
942 }
943 }
944
945 fn reset(&mut self) {
946 self.child.reset();
947 self.groups.clear();
948 self.aggregation_complete = false;
949 self.results = None;
950 }
951
952 fn name(&self) -> &'static str {
953 "HashAggregate"
954 }
955
956 fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
957 self
958 }
959}
960
961pub struct SimpleAggregateOperator {
965 child: Box<dyn Operator>,
967 aggregates: Vec<AggregateExpr>,
969 output_schema: Vec<LogicalType>,
971 states: Vec<AggregateState>,
973 done: bool,
975}
976
977impl SimpleAggregateOperator {
978 pub fn new(
980 child: Box<dyn Operator>,
981 aggregates: Vec<AggregateExpr>,
982 output_schema: Vec<LogicalType>,
983 ) -> Self {
984 let states = aggregates
985 .iter()
986 .map(|agg| {
987 AggregateState::new(
988 agg.function,
989 agg.distinct,
990 agg.percentile,
991 agg.separator.as_deref(),
992 )
993 })
994 .collect();
995
996 Self {
997 child,
998 aggregates,
999 output_schema,
1000 states,
1001 done: false,
1002 }
1003 }
1004}
1005
1006impl Operator for SimpleAggregateOperator {
1007 fn next(&mut self) -> OperatorResult {
1008 if self.done {
1009 return Ok(None);
1010 }
1011
1012 while let Some(chunk) = self.child.next()? {
1014 for row in chunk.selected_indices() {
1015 for (i, agg) in self.aggregates.iter().enumerate() {
1016 if agg.column2.is_some() {
1018 let y_val = agg
1019 .column
1020 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1021 let x_val = agg
1022 .column2
1023 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1024 self.states[i].update_bivariate(y_val, x_val);
1025 continue;
1026 }
1027
1028 let value = match (agg.function, agg.distinct) {
1029 (AggregateFunction::Count, false) => None,
1031 (AggregateFunction::Count, true) => agg
1033 .column
1034 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1035 _ => agg
1036 .column
1037 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1038 };
1039
1040 match (agg.function, agg.distinct) {
1041 (AggregateFunction::Count, false) => self.states[i].update(None),
1042 (AggregateFunction::Count, true) => {
1043 if value.is_some() && !matches!(value, Some(Value::Null)) {
1045 self.states[i].update(value);
1046 }
1047 }
1048 (AggregateFunction::CountNonNull, _) => {
1049 if value.is_some() && !matches!(value, Some(Value::Null)) {
1050 self.states[i].update(value);
1051 }
1052 }
1053 _ => {
1054 if value.is_some() && !matches!(value, Some(Value::Null)) {
1055 self.states[i].update(value);
1056 }
1057 }
1058 }
1059 }
1060 }
1061 }
1062
1063 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1065
1066 for (i, state) in self.states.iter().enumerate() {
1067 if let Some(col) = builder.column_mut(i) {
1068 col.push_value(state.finalize());
1069 }
1070 }
1071 builder.advance_row();
1072
1073 self.done = true;
1074 Ok(Some(builder.finish()))
1075 }
1076
1077 fn reset(&mut self) {
1078 self.child.reset();
1079 self.states = self
1080 .aggregates
1081 .iter()
1082 .map(|agg| {
1083 AggregateState::new(
1084 agg.function,
1085 agg.distinct,
1086 agg.percentile,
1087 agg.separator.as_deref(),
1088 )
1089 })
1090 .collect();
1091 self.done = false;
1092 }
1093
1094 fn name(&self) -> &'static str {
1095 "SimpleAggregate"
1096 }
1097
1098 fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
1099 self
1100 }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105 use super::*;
1106 use crate::execution::chunk::DataChunkBuilder;
1107
1108 struct MockOperator {
1109 chunks: Vec<DataChunk>,
1110 position: usize,
1111 }
1112
1113 impl MockOperator {
1114 fn new(chunks: Vec<DataChunk>) -> Self {
1115 Self {
1116 chunks,
1117 position: 0,
1118 }
1119 }
1120 }
1121
1122 impl Operator for MockOperator {
1123 fn next(&mut self) -> OperatorResult {
1124 if self.position < self.chunks.len() {
1125 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1126 self.position += 1;
1127 Ok(Some(chunk))
1128 } else {
1129 Ok(None)
1130 }
1131 }
1132
1133 fn reset(&mut self) {
1134 self.position = 0;
1135 }
1136
1137 fn name(&self) -> &'static str {
1138 "Mock"
1139 }
1140
1141 fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
1142 self
1143 }
1144 }
1145
1146 fn create_test_chunk() -> DataChunk {
1147 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1149
1150 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1151 for (group, value) in data {
1152 builder.column_mut(0).unwrap().push_int64(group);
1153 builder.column_mut(1).unwrap().push_int64(value);
1154 builder.advance_row();
1155 }
1156
1157 builder.finish()
1158 }
1159
1160 #[test]
1161 fn test_simple_count() {
1162 let mock = MockOperator::new(vec![create_test_chunk()]);
1163
1164 let mut agg = SimpleAggregateOperator::new(
1165 Box::new(mock),
1166 vec![AggregateExpr::count_star()],
1167 vec![LogicalType::Int64],
1168 );
1169
1170 let result = agg.next().unwrap().unwrap();
1171 assert_eq!(result.row_count(), 1);
1172 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1173
1174 assert!(agg.next().unwrap().is_none());
1176 }
1177
1178 #[test]
1179 fn test_simple_sum() {
1180 let mock = MockOperator::new(vec![create_test_chunk()]);
1181
1182 let mut agg = SimpleAggregateOperator::new(
1183 Box::new(mock),
1184 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1186 );
1187
1188 let result = agg.next().unwrap().unwrap();
1189 assert_eq!(result.row_count(), 1);
1190 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1192 }
1193
1194 #[test]
1195 fn test_simple_avg() {
1196 let mock = MockOperator::new(vec![create_test_chunk()]);
1197
1198 let mut agg = SimpleAggregateOperator::new(
1199 Box::new(mock),
1200 vec![AggregateExpr::avg(1)],
1201 vec![LogicalType::Float64],
1202 );
1203
1204 let result = agg.next().unwrap().unwrap();
1205 assert_eq!(result.row_count(), 1);
1206 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1208 assert!((avg - 30.0).abs() < 0.001);
1209 }
1210
1211 #[test]
1212 fn test_simple_min_max() {
1213 let mock = MockOperator::new(vec![create_test_chunk()]);
1214
1215 let mut agg = SimpleAggregateOperator::new(
1216 Box::new(mock),
1217 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1218 vec![LogicalType::Int64, LogicalType::Int64],
1219 );
1220
1221 let result = agg.next().unwrap().unwrap();
1222 assert_eq!(result.row_count(), 1);
1223 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1226
1227 #[test]
1228 fn test_sum_with_string_values() {
1229 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1231 builder.column_mut(0).unwrap().push_string("30");
1232 builder.advance_row();
1233 builder.column_mut(0).unwrap().push_string("25");
1234 builder.advance_row();
1235 builder.column_mut(0).unwrap().push_string("35");
1236 builder.advance_row();
1237 let chunk = builder.finish();
1238
1239 let mock = MockOperator::new(vec![chunk]);
1240 let mut agg = SimpleAggregateOperator::new(
1241 Box::new(mock),
1242 vec![AggregateExpr::sum(0)],
1243 vec![LogicalType::Float64],
1244 );
1245
1246 let result = agg.next().unwrap().unwrap();
1247 assert_eq!(result.row_count(), 1);
1248 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1250 assert!(
1251 (sum_val - 90.0).abs() < 0.001,
1252 "Expected 90.0, got {}",
1253 sum_val
1254 );
1255 }
1256
1257 #[test]
1258 fn test_grouped_aggregation() {
1259 let mock = MockOperator::new(vec![create_test_chunk()]);
1260
1261 let mut agg = HashAggregateOperator::new(
1263 Box::new(mock),
1264 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1267 );
1268
1269 let mut results: Vec<(i64, i64)> = Vec::new();
1270 while let Some(chunk) = agg.next().unwrap() {
1271 for row in chunk.selected_indices() {
1272 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1273 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1274 results.push((group, sum));
1275 }
1276 }
1277
1278 results.sort_by_key(|(g, _)| *g);
1279 assert_eq!(results.len(), 2);
1280 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1283
1284 #[test]
1285 fn test_grouped_count() {
1286 let mock = MockOperator::new(vec![create_test_chunk()]);
1287
1288 let mut agg = HashAggregateOperator::new(
1290 Box::new(mock),
1291 vec![0],
1292 vec![AggregateExpr::count_star()],
1293 vec![LogicalType::Int64, LogicalType::Int64],
1294 );
1295
1296 let mut results: Vec<(i64, i64)> = Vec::new();
1297 while let Some(chunk) = agg.next().unwrap() {
1298 for row in chunk.selected_indices() {
1299 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1300 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1301 results.push((group, count));
1302 }
1303 }
1304
1305 results.sort_by_key(|(g, _)| *g);
1306 assert_eq!(results.len(), 2);
1307 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1310
1311 #[test]
1312 fn test_multiple_aggregates() {
1313 let mock = MockOperator::new(vec![create_test_chunk()]);
1314
1315 let mut agg = HashAggregateOperator::new(
1317 Box::new(mock),
1318 vec![0],
1319 vec![
1320 AggregateExpr::count_star(),
1321 AggregateExpr::sum(1),
1322 AggregateExpr::avg(1),
1323 ],
1324 vec![
1325 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1330 );
1331
1332 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1333 while let Some(chunk) = agg.next().unwrap() {
1334 for row in chunk.selected_indices() {
1335 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1336 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1337 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1338 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1339 results.push((group, count, sum, avg));
1340 }
1341 }
1342
1343 results.sort_by_key(|(g, _, _, _)| *g);
1344 assert_eq!(results.len(), 2);
1345
1346 assert_eq!(results[0].0, 1);
1348 assert_eq!(results[0].1, 2);
1349 assert_eq!(results[0].2, 30);
1350 assert!((results[0].3 - 15.0).abs() < 0.001);
1351
1352 assert_eq!(results[1].0, 2);
1354 assert_eq!(results[1].1, 3);
1355 assert_eq!(results[1].2, 120);
1356 assert!((results[1].3 - 40.0).abs() < 0.001);
1357 }
1358
1359 fn create_test_chunk_with_duplicates() -> DataChunk {
1360 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1365
1366 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1367 for (group, value) in data {
1368 builder.column_mut(0).unwrap().push_int64(group);
1369 builder.column_mut(1).unwrap().push_int64(value);
1370 builder.advance_row();
1371 }
1372
1373 builder.finish()
1374 }
1375
1376 #[test]
1377 fn test_count_distinct() {
1378 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1379
1380 let mut agg = SimpleAggregateOperator::new(
1382 Box::new(mock),
1383 vec![AggregateExpr::count(1).with_distinct()],
1384 vec![LogicalType::Int64],
1385 );
1386
1387 let result = agg.next().unwrap().unwrap();
1388 assert_eq!(result.row_count(), 1);
1389 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1391 }
1392
1393 #[test]
1394 fn test_grouped_count_distinct() {
1395 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1396
1397 let mut agg = HashAggregateOperator::new(
1399 Box::new(mock),
1400 vec![0],
1401 vec![AggregateExpr::count(1).with_distinct()],
1402 vec![LogicalType::Int64, LogicalType::Int64],
1403 );
1404
1405 let mut results: Vec<(i64, i64)> = Vec::new();
1406 while let Some(chunk) = agg.next().unwrap() {
1407 for row in chunk.selected_indices() {
1408 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1409 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1410 results.push((group, count));
1411 }
1412 }
1413
1414 results.sort_by_key(|(g, _)| *g);
1415 assert_eq!(results.len(), 2);
1416 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1419
1420 #[test]
1421 fn test_sum_distinct() {
1422 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1423
1424 let mut agg = SimpleAggregateOperator::new(
1426 Box::new(mock),
1427 vec![AggregateExpr::sum(1).with_distinct()],
1428 vec![LogicalType::Int64],
1429 );
1430
1431 let result = agg.next().unwrap().unwrap();
1432 assert_eq!(result.row_count(), 1);
1433 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1435 }
1436
1437 #[test]
1438 fn test_avg_distinct() {
1439 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1440
1441 let mut agg = SimpleAggregateOperator::new(
1443 Box::new(mock),
1444 vec![AggregateExpr::avg(1).with_distinct()],
1445 vec![LogicalType::Float64],
1446 );
1447
1448 let result = agg.next().unwrap().unwrap();
1449 assert_eq!(result.row_count(), 1);
1450 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1452 assert!((avg - 20.0).abs() < 0.001);
1453 }
1454
1455 fn create_statistical_test_chunk() -> DataChunk {
1456 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1459
1460 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1461 builder.column_mut(0).unwrap().push_int64(value);
1462 builder.advance_row();
1463 }
1464
1465 builder.finish()
1466 }
1467
1468 #[test]
1469 fn test_stdev_sample() {
1470 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1471
1472 let mut agg = SimpleAggregateOperator::new(
1473 Box::new(mock),
1474 vec![AggregateExpr::stdev(0)],
1475 vec![LogicalType::Float64],
1476 );
1477
1478 let result = agg.next().unwrap().unwrap();
1479 assert_eq!(result.row_count(), 1);
1480 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1483 assert!((stdev - 2.138).abs() < 0.01);
1484 }
1485
1486 #[test]
1487 fn test_stdev_population() {
1488 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1489
1490 let mut agg = SimpleAggregateOperator::new(
1491 Box::new(mock),
1492 vec![AggregateExpr::stdev_pop(0)],
1493 vec![LogicalType::Float64],
1494 );
1495
1496 let result = agg.next().unwrap().unwrap();
1497 assert_eq!(result.row_count(), 1);
1498 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1501 assert!((stdev - 2.0).abs() < 0.01);
1502 }
1503
1504 #[test]
1505 fn test_percentile_disc() {
1506 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1507
1508 let mut agg = SimpleAggregateOperator::new(
1510 Box::new(mock),
1511 vec![AggregateExpr::percentile_disc(0, 0.5)],
1512 vec![LogicalType::Float64],
1513 );
1514
1515 let result = agg.next().unwrap().unwrap();
1516 assert_eq!(result.row_count(), 1);
1517 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1519 assert!((percentile - 4.0).abs() < 0.01);
1520 }
1521
1522 #[test]
1523 fn test_percentile_cont() {
1524 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1525
1526 let mut agg = SimpleAggregateOperator::new(
1528 Box::new(mock),
1529 vec![AggregateExpr::percentile_cont(0, 0.5)],
1530 vec![LogicalType::Float64],
1531 );
1532
1533 let result = agg.next().unwrap().unwrap();
1534 assert_eq!(result.row_count(), 1);
1535 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1538 assert!((percentile - 4.5).abs() < 0.01);
1539 }
1540
1541 #[test]
1542 fn test_percentile_extremes() {
1543 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1545
1546 let mut agg = SimpleAggregateOperator::new(
1547 Box::new(mock),
1548 vec![
1549 AggregateExpr::percentile_disc(0, 0.0),
1550 AggregateExpr::percentile_disc(0, 1.0),
1551 ],
1552 vec![LogicalType::Float64, LogicalType::Float64],
1553 );
1554
1555 let result = agg.next().unwrap().unwrap();
1556 assert_eq!(result.row_count(), 1);
1557 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1559 assert!((p0 - 2.0).abs() < 0.01);
1560 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1562 assert!((p100 - 9.0).abs() < 0.01);
1563 }
1564
1565 #[test]
1566 fn test_stdev_single_value() {
1567 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1569 builder.column_mut(0).unwrap().push_int64(42);
1570 builder.advance_row();
1571 let chunk = builder.finish();
1572
1573 let mock = MockOperator::new(vec![chunk]);
1574
1575 let mut agg = SimpleAggregateOperator::new(
1576 Box::new(mock),
1577 vec![AggregateExpr::stdev(0)],
1578 vec![LogicalType::Float64],
1579 );
1580
1581 let result = agg.next().unwrap().unwrap();
1582 assert_eq!(result.row_count(), 1);
1583 assert!(matches!(
1585 result.column(0).unwrap().get_value(0),
1586 Some(Value::Null)
1587 ));
1588 }
1589
1590 #[test]
1591 fn test_first_and_last() {
1592 let mock = MockOperator::new(vec![create_test_chunk()]);
1593
1594 let mut agg = SimpleAggregateOperator::new(
1595 Box::new(mock),
1596 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1597 vec![LogicalType::Int64, LogicalType::Int64],
1598 );
1599
1600 let result = agg.next().unwrap().unwrap();
1601 assert_eq!(result.row_count(), 1);
1602 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1604 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1605 }
1606
1607 #[test]
1608 fn test_collect() {
1609 let mock = MockOperator::new(vec![create_test_chunk()]);
1610
1611 let mut agg = SimpleAggregateOperator::new(
1612 Box::new(mock),
1613 vec![AggregateExpr::collect(1)],
1614 vec![LogicalType::Any],
1615 );
1616
1617 let result = agg.next().unwrap().unwrap();
1618 let val = result.column(0).unwrap().get_value(0).unwrap();
1619 if let Value::List(items) = val {
1620 assert_eq!(items.len(), 5);
1621 } else {
1622 panic!("Expected List value");
1623 }
1624 }
1625
1626 #[test]
1627 fn test_collect_distinct() {
1628 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1629
1630 let mut agg = SimpleAggregateOperator::new(
1631 Box::new(mock),
1632 vec![AggregateExpr::collect(1).with_distinct()],
1633 vec![LogicalType::Any],
1634 );
1635
1636 let result = agg.next().unwrap().unwrap();
1637 let val = result.column(0).unwrap().get_value(0).unwrap();
1638 if let Value::List(items) = val {
1639 assert_eq!(items.len(), 3);
1641 } else {
1642 panic!("Expected List value");
1643 }
1644 }
1645
1646 #[test]
1647 fn test_group_concat() {
1648 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1649 for s in ["hello", "world", "foo"] {
1650 builder.column_mut(0).unwrap().push_string(s);
1651 builder.advance_row();
1652 }
1653 let chunk = builder.finish();
1654 let mock = MockOperator::new(vec![chunk]);
1655
1656 let agg_expr = AggregateExpr {
1657 function: AggregateFunction::GroupConcat,
1658 column: Some(0),
1659 column2: None,
1660 distinct: false,
1661 alias: None,
1662 percentile: None,
1663 separator: None,
1664 };
1665
1666 let mut agg =
1667 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1668
1669 let result = agg.next().unwrap().unwrap();
1670 let val = result.column(0).unwrap().get_value(0).unwrap();
1671 assert_eq!(val, Value::String("hello world foo".into()));
1672 }
1673
1674 #[test]
1675 fn test_sample() {
1676 let mock = MockOperator::new(vec![create_test_chunk()]);
1677
1678 let agg_expr = AggregateExpr {
1679 function: AggregateFunction::Sample,
1680 column: Some(1),
1681 column2: None,
1682 distinct: false,
1683 alias: None,
1684 percentile: None,
1685 separator: None,
1686 };
1687
1688 let mut agg =
1689 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1690
1691 let result = agg.next().unwrap().unwrap();
1692 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1694 }
1695
1696 #[test]
1697 fn test_variance_sample() {
1698 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1699
1700 let agg_expr = AggregateExpr {
1701 function: AggregateFunction::Variance,
1702 column: Some(0),
1703 column2: None,
1704 distinct: false,
1705 alias: None,
1706 percentile: None,
1707 separator: None,
1708 };
1709
1710 let mut agg = SimpleAggregateOperator::new(
1711 Box::new(mock),
1712 vec![agg_expr],
1713 vec![LogicalType::Float64],
1714 );
1715
1716 let result = agg.next().unwrap().unwrap();
1717 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1719 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1720 }
1721
1722 #[test]
1723 fn test_variance_population() {
1724 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1725
1726 let agg_expr = AggregateExpr {
1727 function: AggregateFunction::VariancePop,
1728 column: Some(0),
1729 column2: None,
1730 distinct: false,
1731 alias: None,
1732 percentile: None,
1733 separator: None,
1734 };
1735
1736 let mut agg = SimpleAggregateOperator::new(
1737 Box::new(mock),
1738 vec![agg_expr],
1739 vec![LogicalType::Float64],
1740 );
1741
1742 let result = agg.next().unwrap().unwrap();
1743 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1745 assert!((variance - 4.0).abs() < 0.01);
1746 }
1747
1748 #[test]
1749 fn test_variance_single_value() {
1750 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1751 builder.column_mut(0).unwrap().push_int64(42);
1752 builder.advance_row();
1753 let chunk = builder.finish();
1754 let mock = MockOperator::new(vec![chunk]);
1755
1756 let agg_expr = AggregateExpr {
1757 function: AggregateFunction::Variance,
1758 column: Some(0),
1759 column2: None,
1760 distinct: false,
1761 alias: None,
1762 percentile: None,
1763 separator: None,
1764 };
1765
1766 let mut agg = SimpleAggregateOperator::new(
1767 Box::new(mock),
1768 vec![agg_expr],
1769 vec![LogicalType::Float64],
1770 );
1771
1772 let result = agg.next().unwrap().unwrap();
1773 assert!(matches!(
1775 result.column(0).unwrap().get_value(0),
1776 Some(Value::Null)
1777 ));
1778 }
1779
1780 #[test]
1781 fn test_empty_aggregation() {
1782 let mock = MockOperator::new(vec![]);
1785
1786 let mut agg = SimpleAggregateOperator::new(
1787 Box::new(mock),
1788 vec![
1789 AggregateExpr::count_star(),
1790 AggregateExpr::sum(0),
1791 AggregateExpr::avg(0),
1792 AggregateExpr::min(0),
1793 AggregateExpr::max(0),
1794 ],
1795 vec![
1796 LogicalType::Int64,
1797 LogicalType::Int64,
1798 LogicalType::Float64,
1799 LogicalType::Int64,
1800 LogicalType::Int64,
1801 ],
1802 );
1803
1804 let result = agg.next().unwrap().unwrap();
1805 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1807 result.column(1).unwrap().get_value(0),
1808 Some(Value::Null)
1809 )); assert!(matches!(
1811 result.column(2).unwrap().get_value(0),
1812 Some(Value::Null)
1813 )); assert!(matches!(
1815 result.column(3).unwrap().get_value(0),
1816 Some(Value::Null)
1817 )); assert!(matches!(
1819 result.column(4).unwrap().get_value(0),
1820 Some(Value::Null)
1821 )); }
1823
1824 #[test]
1825 fn test_stdev_pop_single_value() {
1826 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1828 builder.column_mut(0).unwrap().push_int64(42);
1829 builder.advance_row();
1830 let chunk = builder.finish();
1831
1832 let mock = MockOperator::new(vec![chunk]);
1833
1834 let mut agg = SimpleAggregateOperator::new(
1835 Box::new(mock),
1836 vec![AggregateExpr::stdev_pop(0)],
1837 vec![LogicalType::Float64],
1838 );
1839
1840 let result = agg.next().unwrap().unwrap();
1841 assert_eq!(result.row_count(), 1);
1842 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1844 assert!((stdev - 0.0).abs() < 0.01);
1845 }
1846
1847 #[test]
1848 fn test_hash_aggregate_into_any() {
1849 let mock = MockOperator::new(vec![]);
1850 let op = HashAggregateOperator::new(
1851 Box::new(mock),
1852 vec![0],
1853 vec![AggregateExpr::count_star()],
1854 vec![LogicalType::Int64, LogicalType::Int64],
1855 );
1856 let any = Box::new(op).into_any();
1857 assert!(any.downcast::<HashAggregateOperator>().is_ok());
1858 }
1859
1860 #[test]
1861 fn test_simple_aggregate_into_any() {
1862 let mock = MockOperator::new(vec![]);
1863 let op = SimpleAggregateOperator::new(
1864 Box::new(mock),
1865 vec![AggregateExpr::count_star()],
1866 vec![LogicalType::Int64],
1867 );
1868 let any = Box::new(op).into_any();
1869 assert!(any.downcast::<SimpleAggregateOperator>().is_ok());
1870 }
1871
1872 #[test]
1873 fn test_hash_aggregate_into_parts() {
1874 let mock = MockOperator::new(vec![]);
1875 let op = HashAggregateOperator::new(
1876 Box::new(mock),
1877 vec![0, 2],
1878 vec![AggregateExpr::sum(1), AggregateExpr::count_star()],
1879 vec![LogicalType::Int64, LogicalType::Int64, LogicalType::Int64],
1880 );
1881 let (mut child, group_columns, aggregates) = op.into_parts();
1882 assert_eq!(group_columns, vec![0, 2]);
1883 assert_eq!(aggregates.len(), 2);
1884 assert!(child.next().unwrap().is_none());
1885 }
1886}