1use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22#[derive(Debug, Clone)]
30#[allow(missing_docs)]
31pub enum AggregateState {
32 Count(i64),
34 CountDistinct(i64, HashSet<HashableValue>),
36 SumInt(i64, i64),
38 SumIntDistinct(i64, i64, HashSet<HashableValue>),
40 SumFloat(f64, f64, i64),
42 SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
44 Avg(f64, i64),
46 AvgDistinct(f64, i64, HashSet<HashableValue>),
48 Min(Option<Value>),
50 Max(Option<Value>),
52 First(Option<Value>),
54 Last(Option<Value>),
56 Collect(Vec<Value>),
58 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
60 StdDev { count: i64, mean: f64, m2: f64 },
62 StdDevPop { count: i64, mean: f64, m2: f64 },
64 PercentileDisc { values: Vec<f64>, percentile: f64 },
66 PercentileCont { values: Vec<f64>, percentile: f64 },
68 GroupConcat(Vec<String>, String),
70 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
72 Sample(Option<Value>),
74 Variance { count: i64, mean: f64, m2: f64 },
76 VariancePop { count: i64, mean: f64, m2: f64 },
78 Bivariate {
80 kind: AggregateFunction,
82 count: i64,
83 mean_x: f64,
84 mean_y: f64,
85 m2_x: f64,
86 m2_y: f64,
87 c_xy: f64,
88 },
89 Frozen(Value),
93}
94
95impl AggregateState {
96 pub fn new(
98 function: AggregateFunction,
99 distinct: bool,
100 percentile: Option<f64>,
101 separator: Option<&str>,
102 ) -> Self {
103 match (function, distinct) {
104 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
105 AggregateState::Count(0)
106 }
107 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
108 AggregateState::CountDistinct(0, HashSet::new())
109 }
110 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
111 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
112 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
113 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
114 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
116 (AggregateFunction::First, _) => AggregateState::First(None),
117 (AggregateFunction::Last, _) => AggregateState::Last(None),
118 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
119 (AggregateFunction::Collect, true) => {
120 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
121 }
122 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
124 count: 0,
125 mean: 0.0,
126 m2: 0.0,
127 },
128 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
129 count: 0,
130 mean: 0.0,
131 m2: 0.0,
132 },
133 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
134 values: Vec::new(),
135 percentile: percentile.unwrap_or(0.5),
136 },
137 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
138 values: Vec::new(),
139 percentile: percentile.unwrap_or(0.5),
140 },
141 (AggregateFunction::GroupConcat, false) => {
142 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
143 }
144 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
145 Vec::new(),
146 separator.unwrap_or(" ").to_string(),
147 HashSet::new(),
148 ),
149 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
150 (
152 AggregateFunction::CovarSamp
153 | AggregateFunction::CovarPop
154 | AggregateFunction::Corr
155 | AggregateFunction::RegrSlope
156 | AggregateFunction::RegrIntercept
157 | AggregateFunction::RegrR2
158 | AggregateFunction::RegrCount
159 | AggregateFunction::RegrSxx
160 | AggregateFunction::RegrSyy
161 | AggregateFunction::RegrSxy
162 | AggregateFunction::RegrAvgx
163 | AggregateFunction::RegrAvgy,
164 _,
165 ) => AggregateState::Bivariate {
166 kind: function,
167 count: 0,
168 mean_x: 0.0,
169 mean_y: 0.0,
170 m2_x: 0.0,
171 m2_y: 0.0,
172 c_xy: 0.0,
173 },
174 (AggregateFunction::Variance, _) => AggregateState::Variance {
175 count: 0,
176 mean: 0.0,
177 m2: 0.0,
178 },
179 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
180 count: 0,
181 mean: 0.0,
182 m2: 0.0,
183 },
184 }
185 }
186
187 pub fn update(&mut self, value: Option<Value>) {
192 match self {
193 AggregateState::Count(count) => {
194 *count += 1;
195 }
196 AggregateState::CountDistinct(count, seen) => {
197 if let Some(ref v) = value {
198 let hashable = HashableValue::from(v);
199 if seen.insert(hashable) {
200 *count += 1;
201 }
202 }
203 }
204 AggregateState::SumInt(sum, count) => {
205 if let Some(Value::Int64(v)) = value {
206 *sum += v;
207 *count += 1;
208 } else if let Some(Value::Float64(v)) = value {
209 *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
211 } else if let Some(ref v) = value {
212 if let Some(num) = value_to_f64(v) {
214 *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
215 }
216 }
217 }
218 AggregateState::SumIntDistinct(sum, count, seen) => {
219 if let Some(ref v) = value {
220 let hashable = HashableValue::from(v);
221 if seen.insert(hashable) {
222 if let Value::Int64(i) = v {
223 *sum += i;
224 *count += 1;
225 } else if let Value::Float64(f) = v {
226 let moved_seen = std::mem::take(seen);
228 *self = AggregateState::SumFloatDistinct(
229 *sum as f64 + f,
230 0.0,
231 *count + 1,
232 moved_seen,
233 );
234 } else if let Some(num) = value_to_f64(v) {
235 let moved_seen = std::mem::take(seen);
237 *self = AggregateState::SumFloatDistinct(
238 *sum as f64 + num,
239 0.0,
240 *count + 1,
241 moved_seen,
242 );
243 }
244 }
245 }
246 }
247 AggregateState::SumFloat(sum, comp, count) => {
248 if let Some(ref v) = value {
249 if let Some(num) = value_to_f64(v) {
251 let y = num - *comp;
253 let t = *sum + y;
254 *comp = (t - *sum) - y;
255 *sum = t;
256 *count += 1;
257 }
258 }
259 }
260 AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
261 if let Some(ref v) = value {
262 let hashable = HashableValue::from(v);
263 if seen.insert(hashable)
264 && let Some(num) = value_to_f64(v)
265 {
266 let y = num - *comp;
267 let t = *sum + y;
268 *comp = (t - *sum) - y;
269 *sum = t;
270 *count += 1;
271 }
272 }
273 }
274 AggregateState::Avg(sum, count) => {
275 if let Some(ref v) = value
276 && let Some(num) = value_to_f64(v)
277 {
278 *sum += num;
279 *count += 1;
280 }
281 }
282 AggregateState::AvgDistinct(sum, count, seen) => {
283 if let Some(ref v) = value {
284 let hashable = HashableValue::from(v);
285 if seen.insert(hashable)
286 && let Some(num) = value_to_f64(v)
287 {
288 *sum += num;
289 *count += 1;
290 }
291 }
292 }
293 AggregateState::Min(min) => {
294 if let Some(v) = value {
295 match min {
296 None => *min = Some(v),
297 Some(current) => {
298 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
299 *min = Some(v);
300 }
301 }
302 }
303 }
304 }
305 AggregateState::Max(max) => {
306 if let Some(v) = value {
307 match max {
308 None => *max = Some(v),
309 Some(current) => {
310 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
311 *max = Some(v);
312 }
313 }
314 }
315 }
316 }
317 AggregateState::First(first) => {
318 if first.is_none() {
319 *first = value;
320 }
321 }
322 AggregateState::Last(last) => {
323 if value.is_some() {
324 *last = value;
325 }
326 }
327 AggregateState::Collect(list) => {
328 if let Some(v) = value {
329 list.push(v);
330 }
331 }
332 AggregateState::CollectDistinct(list, seen) => {
333 if let Some(v) = value {
334 let hashable = HashableValue::from(&v);
335 if seen.insert(hashable) {
336 list.push(v);
337 }
338 }
339 }
340 AggregateState::StdDev { count, mean, m2 }
342 | AggregateState::StdDevPop { count, mean, m2 }
343 | AggregateState::Variance { count, mean, m2 }
344 | AggregateState::VariancePop { count, mean, m2 } => {
345 if let Some(ref v) = value
346 && let Some(x) = value_to_f64(v)
347 {
348 *count += 1;
349 let delta = x - *mean;
350 *mean += delta / *count as f64;
351 let delta2 = x - *mean;
352 *m2 += delta * delta2;
353 }
354 }
355 AggregateState::PercentileDisc { values, .. }
356 | AggregateState::PercentileCont { values, .. } => {
357 if let Some(ref v) = value
358 && let Some(x) = value_to_f64(v)
359 {
360 values.push(x);
361 }
362 }
363 AggregateState::GroupConcat(list, _sep) => {
364 if let Some(v) = value {
365 list.push(agg_value_to_string(&v));
366 }
367 }
368 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
369 if let Some(v) = value {
370 let hashable = HashableValue::from(&v);
371 if seen.insert(hashable) {
372 list.push(agg_value_to_string(&v));
373 }
374 }
375 }
376 AggregateState::Sample(sample) => {
377 if sample.is_none() {
378 *sample = value;
379 }
380 }
381 AggregateState::Bivariate { .. } => {
382 }
385 AggregateState::Frozen(_) => {}
386 }
387 }
388
389 pub fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
394 if let AggregateState::Bivariate {
395 count,
396 mean_x,
397 mean_y,
398 m2_x,
399 m2_y,
400 c_xy,
401 ..
402 } = self
403 {
404 if let (Some(y), Some(x)) = (&y_val, &x_val)
406 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
407 {
408 *count += 1;
409 let n = *count as f64;
410 let dx = x_f - *mean_x;
411 let dy = y_f - *mean_y;
412 *mean_x += dx / n;
413 *mean_y += dy / n;
414 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
417 *m2_y += dy * dy2;
418 *c_xy += dx * dy2;
419 }
420 }
421 }
422
423 pub fn finalize(&self) -> Value {
425 match self {
426 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
427 Value::Int64(*count)
428 }
429 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
430 if *count == 0 {
431 Value::Null
432 } else {
433 Value::Int64(*sum)
434 }
435 }
436 AggregateState::SumFloat(sum, _, count)
437 | AggregateState::SumFloatDistinct(sum, _, count, _) => {
438 if *count == 0 {
439 Value::Null
440 } else {
441 Value::Float64(*sum)
442 }
443 }
444 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
445 if *count == 0 {
446 Value::Null
447 } else {
448 Value::Float64(*sum / *count as f64)
449 }
450 }
451 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
452 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
453 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
454 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
455 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
456 Value::List(list.clone().into())
457 }
458 AggregateState::StdDev { count, m2, .. } => {
460 if *count < 2 {
461 Value::Null
462 } else {
463 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
464 }
465 }
466 AggregateState::StdDevPop { count, m2, .. } => {
468 if *count == 0 {
469 Value::Null
470 } else {
471 Value::Float64((*m2 / *count as f64).sqrt())
472 }
473 }
474 AggregateState::Variance { count, m2, .. } => {
476 if *count < 2 {
477 Value::Null
478 } else {
479 Value::Float64(*m2 / (*count - 1) as f64)
480 }
481 }
482 AggregateState::VariancePop { count, m2, .. } => {
484 if *count == 0 {
485 Value::Null
486 } else {
487 Value::Float64(*m2 / *count as f64)
488 }
489 }
490 AggregateState::PercentileDisc { values, percentile } => {
492 if values.is_empty() {
493 Value::Null
494 } else {
495 let mut sorted = values.clone();
496 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
499 Value::Float64(sorted[index])
500 }
501 }
502 AggregateState::PercentileCont { values, percentile } => {
504 if values.is_empty() {
505 Value::Null
506 } else {
507 let mut sorted = values.clone();
508 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
509 let rank = percentile * (sorted.len() - 1) as f64;
511 let lower_idx = rank.floor() as usize;
512 let upper_idx = rank.ceil() as usize;
513 if lower_idx == upper_idx {
514 Value::Float64(sorted[lower_idx])
515 } else {
516 let fraction = rank - lower_idx as f64;
517 let result =
518 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
519 Value::Float64(result)
520 }
521 }
522 }
523 AggregateState::GroupConcat(list, sep)
525 | AggregateState::GroupConcatDistinct(list, sep, _) => {
526 Value::String(list.join(sep).into())
527 }
528 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
530 AggregateState::Frozen(val) => val.clone(),
531 AggregateState::Bivariate {
533 kind,
534 count,
535 mean_x,
536 mean_y,
537 m2_x,
538 m2_y,
539 c_xy,
540 } => {
541 let n = *count;
542 match kind {
543 AggregateFunction::CovarSamp => {
544 if n < 2 {
545 Value::Null
546 } else {
547 Value::Float64(*c_xy / (n - 1) as f64)
548 }
549 }
550 AggregateFunction::CovarPop => {
551 if n == 0 {
552 Value::Null
553 } else {
554 Value::Float64(*c_xy / n as f64)
555 }
556 }
557 AggregateFunction::Corr => {
558 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
559 Value::Null
560 } else {
561 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
562 }
563 }
564 AggregateFunction::RegrSlope => {
565 if n == 0 || *m2_x == 0.0 {
566 Value::Null
567 } else {
568 Value::Float64(*c_xy / *m2_x)
569 }
570 }
571 AggregateFunction::RegrIntercept => {
572 if n == 0 || *m2_x == 0.0 {
573 Value::Null
574 } else {
575 let slope = *c_xy / *m2_x;
576 Value::Float64(*mean_y - slope * *mean_x)
577 }
578 }
579 AggregateFunction::RegrR2 => {
580 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
581 Value::Null
582 } else {
583 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
584 }
585 }
586 AggregateFunction::RegrCount => Value::Int64(n),
587 AggregateFunction::RegrSxx => {
588 if n == 0 {
589 Value::Null
590 } else {
591 Value::Float64(*m2_x)
592 }
593 }
594 AggregateFunction::RegrSyy => {
595 if n == 0 {
596 Value::Null
597 } else {
598 Value::Float64(*m2_y)
599 }
600 }
601 AggregateFunction::RegrSxy => {
602 if n == 0 {
603 Value::Null
604 } else {
605 Value::Float64(*c_xy)
606 }
607 }
608 AggregateFunction::RegrAvgx => {
609 if n == 0 {
610 Value::Null
611 } else {
612 Value::Float64(*mean_x)
613 }
614 }
615 AggregateFunction::RegrAvgy => {
616 if n == 0 {
617 Value::Null
618 } else {
619 Value::Float64(*mean_y)
620 }
621 }
622 _ => Value::Null, }
624 }
625 }
626 }
627}
628
629use super::value_utils::{compare_values, value_to_f64};
630
631fn agg_value_to_string(val: &Value) -> String {
633 match val {
634 Value::String(s) => s.to_string(),
635 Value::Int64(i) => i.to_string(),
636 Value::Float64(f) => f.to_string(),
637 Value::Bool(b) => b.to_string(),
638 Value::Null => String::new(),
639 other => format!("{other:?}"),
640 }
641}
642
643#[derive(Debug, Clone, PartialEq, Eq, Hash)]
645pub struct GroupKey(Vec<GroupKeyPart>);
646
647#[derive(Debug, Clone, PartialEq, Eq, Hash)]
648enum GroupKeyPart {
649 Null,
650 Bool(bool),
651 Int64(i64),
652 String(ArcStr),
653 Bytes(Arc<[u8]>),
654 Date(grafeo_common::types::Date),
655 Time(grafeo_common::types::Time),
656 Timestamp(grafeo_common::types::Timestamp),
657 Duration(grafeo_common::types::Duration),
658 ZonedDatetime(grafeo_common::types::ZonedDatetime),
659 List(Vec<GroupKeyPart>),
660 Map(Vec<(ArcStr, GroupKeyPart)>),
661}
662
663impl GroupKeyPart {
664 fn from_value(v: Value) -> Self {
665 match v {
666 Value::Null => Self::Null,
667 Value::Bool(b) => Self::Bool(b),
668 Value::Int64(i) => Self::Int64(i),
669 Value::Float64(f) => Self::Int64(f.to_bits() as i64),
670 Value::String(s) => Self::String(s.clone()),
671 Value::Bytes(b) => Self::Bytes(b),
672 Value::Date(d) => Self::Date(d),
673 Value::Time(t) => Self::Time(t),
674 Value::Timestamp(ts) => Self::Timestamp(ts),
675 Value::Duration(d) => Self::Duration(d),
676 Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
677 Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
678 Value::Map(map) => {
679 let entries: Vec<(ArcStr, GroupKeyPart)> = map
681 .iter()
682 .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
683 .collect();
684 Self::Map(entries)
685 }
686 other => Self::String(ArcStr::from(format!("{other:?}"))),
688 }
689 }
690
691 fn to_value(&self) -> Value {
692 match self {
693 Self::Null => Value::Null,
694 Self::Bool(b) => Value::Bool(*b),
695 Self::Int64(i) => Value::Int64(*i),
696 Self::String(s) => Value::String(s.clone()),
697 Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
698 Self::Date(d) => Value::Date(*d),
699 Self::Time(t) => Value::Time(*t),
700 Self::Timestamp(ts) => Value::Timestamp(*ts),
701 Self::Duration(d) => Value::Duration(*d),
702 Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
703 Self::List(parts) => {
704 let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
705 Value::List(Arc::from(values.into_boxed_slice()))
706 }
707 Self::Map(entries) => {
708 let map: std::collections::BTreeMap<PropertyKey, Value> = entries
709 .iter()
710 .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
711 .collect();
712 Value::Map(Arc::new(map))
713 }
714 }
715 }
716}
717
718impl GroupKey {
719 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
721 let parts: Vec<GroupKeyPart> = group_columns
722 .iter()
723 .map(|&col_idx| {
724 chunk
725 .column(col_idx)
726 .and_then(|col| col.get_value(row))
727 .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
728 })
729 .collect();
730 GroupKey(parts)
731 }
732
733 fn to_values(&self) -> Vec<Value> {
735 self.0.iter().map(GroupKeyPart::to_value).collect()
736 }
737}
738
739pub struct HashAggregateOperator {
743 child: Box<dyn Operator>,
745 group_columns: Vec<usize>,
747 aggregates: Vec<AggregateExpr>,
749 output_schema: Vec<LogicalType>,
751 groups: IndexMap<GroupKey, Vec<AggregateState>>,
753 aggregation_complete: bool,
755 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
757}
758
759impl HashAggregateOperator {
760 pub fn new(
768 child: Box<dyn Operator>,
769 group_columns: Vec<usize>,
770 aggregates: Vec<AggregateExpr>,
771 output_schema: Vec<LogicalType>,
772 ) -> Self {
773 Self {
774 child,
775 group_columns,
776 aggregates,
777 output_schema,
778 groups: IndexMap::new(),
779 aggregation_complete: false,
780 results: None,
781 }
782 }
783
784 fn aggregate(&mut self) -> Result<(), OperatorError> {
786 while let Some(chunk) = self.child.next()? {
787 for row in chunk.selected_indices() {
788 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
789
790 let states = self.groups.entry(key).or_insert_with(|| {
792 self.aggregates
793 .iter()
794 .map(|agg| {
795 AggregateState::new(
796 agg.function,
797 agg.distinct,
798 agg.percentile,
799 agg.separator.as_deref(),
800 )
801 })
802 .collect()
803 });
804
805 for (i, agg) in self.aggregates.iter().enumerate() {
807 if agg.column2.is_some() {
809 let y_val = agg
810 .column
811 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
812 let x_val = agg
813 .column2
814 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
815 states[i].update_bivariate(y_val, x_val);
816 continue;
817 }
818
819 let value = match (agg.function, agg.distinct) {
820 (AggregateFunction::Count, false) => None,
822 (AggregateFunction::Count, true) => agg
824 .column
825 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
826 _ => agg
827 .column
828 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
829 };
830
831 match (agg.function, agg.distinct) {
833 (AggregateFunction::Count, false) => states[i].update(None),
834 (AggregateFunction::Count, true) => {
835 if value.is_some() && !matches!(value, Some(Value::Null)) {
837 states[i].update(value);
838 }
839 }
840 (AggregateFunction::CountNonNull, _) => {
841 if value.is_some() && !matches!(value, Some(Value::Null)) {
842 states[i].update(value);
843 }
844 }
845 _ => {
846 if value.is_some() && !matches!(value, Some(Value::Null)) {
847 states[i].update(value);
848 }
849 }
850 }
851 }
852 }
853 }
854
855 self.aggregation_complete = true;
856
857 let results: Vec<_> = self.groups.drain(..).collect();
859 self.results = Some(results.into_iter());
860
861 Ok(())
862 }
863}
864
865impl Operator for HashAggregateOperator {
866 fn next(&mut self) -> OperatorResult {
867 if !self.aggregation_complete {
869 self.aggregate()?;
870 }
871
872 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
874 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
876
877 for agg in &self.aggregates {
878 let state = AggregateState::new(
879 agg.function,
880 agg.distinct,
881 agg.percentile,
882 agg.separator.as_deref(),
883 );
884 let value = state.finalize();
885 if let Some(col) = builder.column_mut(self.group_columns.len()) {
886 col.push_value(value);
887 }
888 }
889 builder.advance_row();
890
891 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
893 }
894
895 let Some(results) = &mut self.results else {
896 return Ok(None);
897 };
898
899 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
900
901 for (key, states) in results.by_ref() {
902 let key_values = key.to_values();
904 for (i, value) in key_values.into_iter().enumerate() {
905 if let Some(col) = builder.column_mut(i) {
906 col.push_value(value);
907 }
908 }
909
910 for (i, state) in states.iter().enumerate() {
912 let col_idx = self.group_columns.len() + i;
913 if let Some(col) = builder.column_mut(col_idx) {
914 col.push_value(state.finalize());
915 }
916 }
917
918 builder.advance_row();
919
920 if builder.is_full() {
921 return Ok(Some(builder.finish()));
922 }
923 }
924
925 if builder.row_count() > 0 {
926 Ok(Some(builder.finish()))
927 } else {
928 Ok(None)
929 }
930 }
931
932 fn reset(&mut self) {
933 self.child.reset();
934 self.groups.clear();
935 self.aggregation_complete = false;
936 self.results = None;
937 }
938
939 fn name(&self) -> &'static str {
940 "HashAggregate"
941 }
942}
943
944pub struct SimpleAggregateOperator {
948 child: Box<dyn Operator>,
950 aggregates: Vec<AggregateExpr>,
952 output_schema: Vec<LogicalType>,
954 states: Vec<AggregateState>,
956 done: bool,
958}
959
960impl SimpleAggregateOperator {
961 pub fn new(
963 child: Box<dyn Operator>,
964 aggregates: Vec<AggregateExpr>,
965 output_schema: Vec<LogicalType>,
966 ) -> Self {
967 let states = aggregates
968 .iter()
969 .map(|agg| {
970 AggregateState::new(
971 agg.function,
972 agg.distinct,
973 agg.percentile,
974 agg.separator.as_deref(),
975 )
976 })
977 .collect();
978
979 Self {
980 child,
981 aggregates,
982 output_schema,
983 states,
984 done: false,
985 }
986 }
987}
988
989impl Operator for SimpleAggregateOperator {
990 fn next(&mut self) -> OperatorResult {
991 if self.done {
992 return Ok(None);
993 }
994
995 while let Some(chunk) = self.child.next()? {
997 for row in chunk.selected_indices() {
998 for (i, agg) in self.aggregates.iter().enumerate() {
999 if agg.column2.is_some() {
1001 let y_val = agg
1002 .column
1003 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1004 let x_val = agg
1005 .column2
1006 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1007 self.states[i].update_bivariate(y_val, x_val);
1008 continue;
1009 }
1010
1011 let value = match (agg.function, agg.distinct) {
1012 (AggregateFunction::Count, false) => None,
1014 (AggregateFunction::Count, true) => agg
1016 .column
1017 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1018 _ => agg
1019 .column
1020 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1021 };
1022
1023 match (agg.function, agg.distinct) {
1024 (AggregateFunction::Count, false) => self.states[i].update(None),
1025 (AggregateFunction::Count, true) => {
1026 if value.is_some() && !matches!(value, Some(Value::Null)) {
1028 self.states[i].update(value);
1029 }
1030 }
1031 (AggregateFunction::CountNonNull, _) => {
1032 if value.is_some() && !matches!(value, Some(Value::Null)) {
1033 self.states[i].update(value);
1034 }
1035 }
1036 _ => {
1037 if value.is_some() && !matches!(value, Some(Value::Null)) {
1038 self.states[i].update(value);
1039 }
1040 }
1041 }
1042 }
1043 }
1044 }
1045
1046 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1048
1049 for (i, state) in self.states.iter().enumerate() {
1050 if let Some(col) = builder.column_mut(i) {
1051 col.push_value(state.finalize());
1052 }
1053 }
1054 builder.advance_row();
1055
1056 self.done = true;
1057 Ok(Some(builder.finish()))
1058 }
1059
1060 fn reset(&mut self) {
1061 self.child.reset();
1062 self.states = self
1063 .aggregates
1064 .iter()
1065 .map(|agg| {
1066 AggregateState::new(
1067 agg.function,
1068 agg.distinct,
1069 agg.percentile,
1070 agg.separator.as_deref(),
1071 )
1072 })
1073 .collect();
1074 self.done = false;
1075 }
1076
1077 fn name(&self) -> &'static str {
1078 "SimpleAggregate"
1079 }
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084 use super::*;
1085 use crate::execution::chunk::DataChunkBuilder;
1086
1087 struct MockOperator {
1088 chunks: Vec<DataChunk>,
1089 position: usize,
1090 }
1091
1092 impl MockOperator {
1093 fn new(chunks: Vec<DataChunk>) -> Self {
1094 Self {
1095 chunks,
1096 position: 0,
1097 }
1098 }
1099 }
1100
1101 impl Operator for MockOperator {
1102 fn next(&mut self) -> OperatorResult {
1103 if self.position < self.chunks.len() {
1104 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1105 self.position += 1;
1106 Ok(Some(chunk))
1107 } else {
1108 Ok(None)
1109 }
1110 }
1111
1112 fn reset(&mut self) {
1113 self.position = 0;
1114 }
1115
1116 fn name(&self) -> &'static str {
1117 "Mock"
1118 }
1119 }
1120
1121 fn create_test_chunk() -> DataChunk {
1122 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1124
1125 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1126 for (group, value) in data {
1127 builder.column_mut(0).unwrap().push_int64(group);
1128 builder.column_mut(1).unwrap().push_int64(value);
1129 builder.advance_row();
1130 }
1131
1132 builder.finish()
1133 }
1134
1135 #[test]
1136 fn test_simple_count() {
1137 let mock = MockOperator::new(vec![create_test_chunk()]);
1138
1139 let mut agg = SimpleAggregateOperator::new(
1140 Box::new(mock),
1141 vec![AggregateExpr::count_star()],
1142 vec![LogicalType::Int64],
1143 );
1144
1145 let result = agg.next().unwrap().unwrap();
1146 assert_eq!(result.row_count(), 1);
1147 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1148
1149 assert!(agg.next().unwrap().is_none());
1151 }
1152
1153 #[test]
1154 fn test_simple_sum() {
1155 let mock = MockOperator::new(vec![create_test_chunk()]);
1156
1157 let mut agg = SimpleAggregateOperator::new(
1158 Box::new(mock),
1159 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1161 );
1162
1163 let result = agg.next().unwrap().unwrap();
1164 assert_eq!(result.row_count(), 1);
1165 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1167 }
1168
1169 #[test]
1170 fn test_simple_avg() {
1171 let mock = MockOperator::new(vec![create_test_chunk()]);
1172
1173 let mut agg = SimpleAggregateOperator::new(
1174 Box::new(mock),
1175 vec![AggregateExpr::avg(1)],
1176 vec![LogicalType::Float64],
1177 );
1178
1179 let result = agg.next().unwrap().unwrap();
1180 assert_eq!(result.row_count(), 1);
1181 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1183 assert!((avg - 30.0).abs() < 0.001);
1184 }
1185
1186 #[test]
1187 fn test_simple_min_max() {
1188 let mock = MockOperator::new(vec![create_test_chunk()]);
1189
1190 let mut agg = SimpleAggregateOperator::new(
1191 Box::new(mock),
1192 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1193 vec![LogicalType::Int64, LogicalType::Int64],
1194 );
1195
1196 let result = agg.next().unwrap().unwrap();
1197 assert_eq!(result.row_count(), 1);
1198 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1201
1202 #[test]
1203 fn test_sum_with_string_values() {
1204 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1206 builder.column_mut(0).unwrap().push_string("30");
1207 builder.advance_row();
1208 builder.column_mut(0).unwrap().push_string("25");
1209 builder.advance_row();
1210 builder.column_mut(0).unwrap().push_string("35");
1211 builder.advance_row();
1212 let chunk = builder.finish();
1213
1214 let mock = MockOperator::new(vec![chunk]);
1215 let mut agg = SimpleAggregateOperator::new(
1216 Box::new(mock),
1217 vec![AggregateExpr::sum(0)],
1218 vec![LogicalType::Float64],
1219 );
1220
1221 let result = agg.next().unwrap().unwrap();
1222 assert_eq!(result.row_count(), 1);
1223 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1225 assert!(
1226 (sum_val - 90.0).abs() < 0.001,
1227 "Expected 90.0, got {}",
1228 sum_val
1229 );
1230 }
1231
1232 #[test]
1233 fn test_grouped_aggregation() {
1234 let mock = MockOperator::new(vec![create_test_chunk()]);
1235
1236 let mut agg = HashAggregateOperator::new(
1238 Box::new(mock),
1239 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1242 );
1243
1244 let mut results: Vec<(i64, i64)> = Vec::new();
1245 while let Some(chunk) = agg.next().unwrap() {
1246 for row in chunk.selected_indices() {
1247 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1248 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1249 results.push((group, sum));
1250 }
1251 }
1252
1253 results.sort_by_key(|(g, _)| *g);
1254 assert_eq!(results.len(), 2);
1255 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1258
1259 #[test]
1260 fn test_grouped_count() {
1261 let mock = MockOperator::new(vec![create_test_chunk()]);
1262
1263 let mut agg = HashAggregateOperator::new(
1265 Box::new(mock),
1266 vec![0],
1267 vec![AggregateExpr::count_star()],
1268 vec![LogicalType::Int64, LogicalType::Int64],
1269 );
1270
1271 let mut results: Vec<(i64, i64)> = Vec::new();
1272 while let Some(chunk) = agg.next().unwrap() {
1273 for row in chunk.selected_indices() {
1274 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1275 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1276 results.push((group, count));
1277 }
1278 }
1279
1280 results.sort_by_key(|(g, _)| *g);
1281 assert_eq!(results.len(), 2);
1282 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1285
1286 #[test]
1287 fn test_multiple_aggregates() {
1288 let mock = MockOperator::new(vec![create_test_chunk()]);
1289
1290 let mut agg = HashAggregateOperator::new(
1292 Box::new(mock),
1293 vec![0],
1294 vec![
1295 AggregateExpr::count_star(),
1296 AggregateExpr::sum(1),
1297 AggregateExpr::avg(1),
1298 ],
1299 vec![
1300 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1305 );
1306
1307 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1308 while let Some(chunk) = agg.next().unwrap() {
1309 for row in chunk.selected_indices() {
1310 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1311 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1312 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1313 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1314 results.push((group, count, sum, avg));
1315 }
1316 }
1317
1318 results.sort_by_key(|(g, _, _, _)| *g);
1319 assert_eq!(results.len(), 2);
1320
1321 assert_eq!(results[0].0, 1);
1323 assert_eq!(results[0].1, 2);
1324 assert_eq!(results[0].2, 30);
1325 assert!((results[0].3 - 15.0).abs() < 0.001);
1326
1327 assert_eq!(results[1].0, 2);
1329 assert_eq!(results[1].1, 3);
1330 assert_eq!(results[1].2, 120);
1331 assert!((results[1].3 - 40.0).abs() < 0.001);
1332 }
1333
1334 fn create_test_chunk_with_duplicates() -> DataChunk {
1335 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1340
1341 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1342 for (group, value) in data {
1343 builder.column_mut(0).unwrap().push_int64(group);
1344 builder.column_mut(1).unwrap().push_int64(value);
1345 builder.advance_row();
1346 }
1347
1348 builder.finish()
1349 }
1350
1351 #[test]
1352 fn test_count_distinct() {
1353 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1354
1355 let mut agg = SimpleAggregateOperator::new(
1357 Box::new(mock),
1358 vec![AggregateExpr::count(1).with_distinct()],
1359 vec![LogicalType::Int64],
1360 );
1361
1362 let result = agg.next().unwrap().unwrap();
1363 assert_eq!(result.row_count(), 1);
1364 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1366 }
1367
1368 #[test]
1369 fn test_grouped_count_distinct() {
1370 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1371
1372 let mut agg = HashAggregateOperator::new(
1374 Box::new(mock),
1375 vec![0],
1376 vec![AggregateExpr::count(1).with_distinct()],
1377 vec![LogicalType::Int64, LogicalType::Int64],
1378 );
1379
1380 let mut results: Vec<(i64, i64)> = Vec::new();
1381 while let Some(chunk) = agg.next().unwrap() {
1382 for row in chunk.selected_indices() {
1383 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1384 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1385 results.push((group, count));
1386 }
1387 }
1388
1389 results.sort_by_key(|(g, _)| *g);
1390 assert_eq!(results.len(), 2);
1391 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1394
1395 #[test]
1396 fn test_sum_distinct() {
1397 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1398
1399 let mut agg = SimpleAggregateOperator::new(
1401 Box::new(mock),
1402 vec![AggregateExpr::sum(1).with_distinct()],
1403 vec![LogicalType::Int64],
1404 );
1405
1406 let result = agg.next().unwrap().unwrap();
1407 assert_eq!(result.row_count(), 1);
1408 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1410 }
1411
1412 #[test]
1413 fn test_avg_distinct() {
1414 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1415
1416 let mut agg = SimpleAggregateOperator::new(
1418 Box::new(mock),
1419 vec![AggregateExpr::avg(1).with_distinct()],
1420 vec![LogicalType::Float64],
1421 );
1422
1423 let result = agg.next().unwrap().unwrap();
1424 assert_eq!(result.row_count(), 1);
1425 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1427 assert!((avg - 20.0).abs() < 0.001);
1428 }
1429
1430 fn create_statistical_test_chunk() -> DataChunk {
1431 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1434
1435 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1436 builder.column_mut(0).unwrap().push_int64(value);
1437 builder.advance_row();
1438 }
1439
1440 builder.finish()
1441 }
1442
1443 #[test]
1444 fn test_stdev_sample() {
1445 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1446
1447 let mut agg = SimpleAggregateOperator::new(
1448 Box::new(mock),
1449 vec![AggregateExpr::stdev(0)],
1450 vec![LogicalType::Float64],
1451 );
1452
1453 let result = agg.next().unwrap().unwrap();
1454 assert_eq!(result.row_count(), 1);
1455 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1458 assert!((stdev - 2.138).abs() < 0.01);
1459 }
1460
1461 #[test]
1462 fn test_stdev_population() {
1463 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1464
1465 let mut agg = SimpleAggregateOperator::new(
1466 Box::new(mock),
1467 vec![AggregateExpr::stdev_pop(0)],
1468 vec![LogicalType::Float64],
1469 );
1470
1471 let result = agg.next().unwrap().unwrap();
1472 assert_eq!(result.row_count(), 1);
1473 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1476 assert!((stdev - 2.0).abs() < 0.01);
1477 }
1478
1479 #[test]
1480 fn test_percentile_disc() {
1481 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1482
1483 let mut agg = SimpleAggregateOperator::new(
1485 Box::new(mock),
1486 vec![AggregateExpr::percentile_disc(0, 0.5)],
1487 vec![LogicalType::Float64],
1488 );
1489
1490 let result = agg.next().unwrap().unwrap();
1491 assert_eq!(result.row_count(), 1);
1492 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1494 assert!((percentile - 4.0).abs() < 0.01);
1495 }
1496
1497 #[test]
1498 fn test_percentile_cont() {
1499 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1500
1501 let mut agg = SimpleAggregateOperator::new(
1503 Box::new(mock),
1504 vec![AggregateExpr::percentile_cont(0, 0.5)],
1505 vec![LogicalType::Float64],
1506 );
1507
1508 let result = agg.next().unwrap().unwrap();
1509 assert_eq!(result.row_count(), 1);
1510 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1513 assert!((percentile - 4.5).abs() < 0.01);
1514 }
1515
1516 #[test]
1517 fn test_percentile_extremes() {
1518 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1520
1521 let mut agg = SimpleAggregateOperator::new(
1522 Box::new(mock),
1523 vec![
1524 AggregateExpr::percentile_disc(0, 0.0),
1525 AggregateExpr::percentile_disc(0, 1.0),
1526 ],
1527 vec![LogicalType::Float64, LogicalType::Float64],
1528 );
1529
1530 let result = agg.next().unwrap().unwrap();
1531 assert_eq!(result.row_count(), 1);
1532 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1534 assert!((p0 - 2.0).abs() < 0.01);
1535 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1537 assert!((p100 - 9.0).abs() < 0.01);
1538 }
1539
1540 #[test]
1541 fn test_stdev_single_value() {
1542 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1544 builder.column_mut(0).unwrap().push_int64(42);
1545 builder.advance_row();
1546 let chunk = builder.finish();
1547
1548 let mock = MockOperator::new(vec![chunk]);
1549
1550 let mut agg = SimpleAggregateOperator::new(
1551 Box::new(mock),
1552 vec![AggregateExpr::stdev(0)],
1553 vec![LogicalType::Float64],
1554 );
1555
1556 let result = agg.next().unwrap().unwrap();
1557 assert_eq!(result.row_count(), 1);
1558 assert!(matches!(
1560 result.column(0).unwrap().get_value(0),
1561 Some(Value::Null)
1562 ));
1563 }
1564
1565 #[test]
1566 fn test_first_and_last() {
1567 let mock = MockOperator::new(vec![create_test_chunk()]);
1568
1569 let mut agg = SimpleAggregateOperator::new(
1570 Box::new(mock),
1571 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1572 vec![LogicalType::Int64, LogicalType::Int64],
1573 );
1574
1575 let result = agg.next().unwrap().unwrap();
1576 assert_eq!(result.row_count(), 1);
1577 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1579 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1580 }
1581
1582 #[test]
1583 fn test_collect() {
1584 let mock = MockOperator::new(vec![create_test_chunk()]);
1585
1586 let mut agg = SimpleAggregateOperator::new(
1587 Box::new(mock),
1588 vec![AggregateExpr::collect(1)],
1589 vec![LogicalType::Any],
1590 );
1591
1592 let result = agg.next().unwrap().unwrap();
1593 let val = result.column(0).unwrap().get_value(0).unwrap();
1594 if let Value::List(items) = val {
1595 assert_eq!(items.len(), 5);
1596 } else {
1597 panic!("Expected List value");
1598 }
1599 }
1600
1601 #[test]
1602 fn test_collect_distinct() {
1603 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1604
1605 let mut agg = SimpleAggregateOperator::new(
1606 Box::new(mock),
1607 vec![AggregateExpr::collect(1).with_distinct()],
1608 vec![LogicalType::Any],
1609 );
1610
1611 let result = agg.next().unwrap().unwrap();
1612 let val = result.column(0).unwrap().get_value(0).unwrap();
1613 if let Value::List(items) = val {
1614 assert_eq!(items.len(), 3);
1616 } else {
1617 panic!("Expected List value");
1618 }
1619 }
1620
1621 #[test]
1622 fn test_group_concat() {
1623 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1624 for s in ["hello", "world", "foo"] {
1625 builder.column_mut(0).unwrap().push_string(s);
1626 builder.advance_row();
1627 }
1628 let chunk = builder.finish();
1629 let mock = MockOperator::new(vec![chunk]);
1630
1631 let agg_expr = AggregateExpr {
1632 function: AggregateFunction::GroupConcat,
1633 column: Some(0),
1634 column2: None,
1635 distinct: false,
1636 alias: None,
1637 percentile: None,
1638 separator: None,
1639 };
1640
1641 let mut agg =
1642 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1643
1644 let result = agg.next().unwrap().unwrap();
1645 let val = result.column(0).unwrap().get_value(0).unwrap();
1646 assert_eq!(val, Value::String("hello world foo".into()));
1647 }
1648
1649 #[test]
1650 fn test_sample() {
1651 let mock = MockOperator::new(vec![create_test_chunk()]);
1652
1653 let agg_expr = AggregateExpr {
1654 function: AggregateFunction::Sample,
1655 column: Some(1),
1656 column2: None,
1657 distinct: false,
1658 alias: None,
1659 percentile: None,
1660 separator: None,
1661 };
1662
1663 let mut agg =
1664 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1665
1666 let result = agg.next().unwrap().unwrap();
1667 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1669 }
1670
1671 #[test]
1672 fn test_variance_sample() {
1673 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1674
1675 let agg_expr = AggregateExpr {
1676 function: AggregateFunction::Variance,
1677 column: Some(0),
1678 column2: None,
1679 distinct: false,
1680 alias: None,
1681 percentile: None,
1682 separator: None,
1683 };
1684
1685 let mut agg = SimpleAggregateOperator::new(
1686 Box::new(mock),
1687 vec![agg_expr],
1688 vec![LogicalType::Float64],
1689 );
1690
1691 let result = agg.next().unwrap().unwrap();
1692 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1694 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1695 }
1696
1697 #[test]
1698 fn test_variance_population() {
1699 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1700
1701 let agg_expr = AggregateExpr {
1702 function: AggregateFunction::VariancePop,
1703 column: Some(0),
1704 column2: None,
1705 distinct: false,
1706 alias: None,
1707 percentile: None,
1708 separator: None,
1709 };
1710
1711 let mut agg = SimpleAggregateOperator::new(
1712 Box::new(mock),
1713 vec![agg_expr],
1714 vec![LogicalType::Float64],
1715 );
1716
1717 let result = agg.next().unwrap().unwrap();
1718 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1720 assert!((variance - 4.0).abs() < 0.01);
1721 }
1722
1723 #[test]
1724 fn test_variance_single_value() {
1725 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1726 builder.column_mut(0).unwrap().push_int64(42);
1727 builder.advance_row();
1728 let chunk = builder.finish();
1729 let mock = MockOperator::new(vec![chunk]);
1730
1731 let agg_expr = AggregateExpr {
1732 function: AggregateFunction::Variance,
1733 column: Some(0),
1734 column2: None,
1735 distinct: false,
1736 alias: None,
1737 percentile: None,
1738 separator: None,
1739 };
1740
1741 let mut agg = SimpleAggregateOperator::new(
1742 Box::new(mock),
1743 vec![agg_expr],
1744 vec![LogicalType::Float64],
1745 );
1746
1747 let result = agg.next().unwrap().unwrap();
1748 assert!(matches!(
1750 result.column(0).unwrap().get_value(0),
1751 Some(Value::Null)
1752 ));
1753 }
1754
1755 #[test]
1756 fn test_empty_aggregation() {
1757 let mock = MockOperator::new(vec![]);
1760
1761 let mut agg = SimpleAggregateOperator::new(
1762 Box::new(mock),
1763 vec![
1764 AggregateExpr::count_star(),
1765 AggregateExpr::sum(0),
1766 AggregateExpr::avg(0),
1767 AggregateExpr::min(0),
1768 AggregateExpr::max(0),
1769 ],
1770 vec![
1771 LogicalType::Int64,
1772 LogicalType::Int64,
1773 LogicalType::Float64,
1774 LogicalType::Int64,
1775 LogicalType::Int64,
1776 ],
1777 );
1778
1779 let result = agg.next().unwrap().unwrap();
1780 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1782 result.column(1).unwrap().get_value(0),
1783 Some(Value::Null)
1784 )); assert!(matches!(
1786 result.column(2).unwrap().get_value(0),
1787 Some(Value::Null)
1788 )); assert!(matches!(
1790 result.column(3).unwrap().get_value(0),
1791 Some(Value::Null)
1792 )); assert!(matches!(
1794 result.column(4).unwrap().get_value(0),
1795 Some(Value::Null)
1796 )); }
1798
1799 #[test]
1800 fn test_stdev_pop_single_value() {
1801 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1803 builder.column_mut(0).unwrap().push_int64(42);
1804 builder.advance_row();
1805 let chunk = builder.finish();
1806
1807 let mock = MockOperator::new(vec![chunk]);
1808
1809 let mut agg = SimpleAggregateOperator::new(
1810 Box::new(mock),
1811 vec![AggregateExpr::stdev_pop(0)],
1812 vec![LogicalType::Float64],
1813 );
1814
1815 let result = agg.next().unwrap().unwrap();
1816 assert_eq!(result.row_count(), 1);
1817 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1819 assert!((stdev - 0.0).abs() < 0.01);
1820 }
1821}