1use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25 Count(i64),
27 CountDistinct(i64, HashSet<HashableValue>),
29 SumInt(i64, i64),
31 SumIntDistinct(i64, i64, HashSet<HashableValue>),
33 SumFloat(f64, f64, i64),
35 SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
37 Avg(f64, i64),
39 AvgDistinct(f64, i64, HashSet<HashableValue>),
41 Min(Option<Value>),
43 Max(Option<Value>),
45 First(Option<Value>),
47 Last(Option<Value>),
49 Collect(Vec<Value>),
51 CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53 StdDev { count: i64, mean: f64, m2: f64 },
55 StdDevPop { count: i64, mean: f64, m2: f64 },
57 PercentileDisc { values: Vec<f64>, percentile: f64 },
59 PercentileCont { values: Vec<f64>, percentile: f64 },
61 GroupConcat(Vec<String>, String),
63 GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65 Sample(Option<Value>),
67 Variance { count: i64, mean: f64, m2: f64 },
69 VariancePop { count: i64, mean: f64, m2: f64 },
71 Bivariate {
73 kind: AggregateFunction,
75 count: i64,
76 mean_x: f64,
77 mean_y: f64,
78 m2_x: f64,
79 m2_y: f64,
80 c_xy: f64,
81 },
82}
83
84impl AggregateState {
85 pub(crate) fn new(
87 function: AggregateFunction,
88 distinct: bool,
89 percentile: Option<f64>,
90 separator: Option<&str>,
91 ) -> Self {
92 match (function, distinct) {
93 (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94 AggregateState::Count(0)
95 }
96 (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97 AggregateState::CountDistinct(0, HashSet::new())
98 }
99 (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100 (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101 (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102 (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103 (AggregateFunction::Min, _) => AggregateState::Min(None), (AggregateFunction::Max, _) => AggregateState::Max(None),
105 (AggregateFunction::First, _) => AggregateState::First(None),
106 (AggregateFunction::Last, _) => AggregateState::Last(None),
107 (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108 (AggregateFunction::Collect, true) => {
109 AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110 }
111 (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113 count: 0,
114 mean: 0.0,
115 m2: 0.0,
116 },
117 (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118 count: 0,
119 mean: 0.0,
120 m2: 0.0,
121 },
122 (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123 values: Vec::new(),
124 percentile: percentile.unwrap_or(0.5),
125 },
126 (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127 values: Vec::new(),
128 percentile: percentile.unwrap_or(0.5),
129 },
130 (AggregateFunction::GroupConcat, false) => {
131 AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132 }
133 (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134 Vec::new(),
135 separator.unwrap_or(" ").to_string(),
136 HashSet::new(),
137 ),
138 (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139 (
141 AggregateFunction::CovarSamp
142 | AggregateFunction::CovarPop
143 | AggregateFunction::Corr
144 | AggregateFunction::RegrSlope
145 | AggregateFunction::RegrIntercept
146 | AggregateFunction::RegrR2
147 | AggregateFunction::RegrCount
148 | AggregateFunction::RegrSxx
149 | AggregateFunction::RegrSyy
150 | AggregateFunction::RegrSxy
151 | AggregateFunction::RegrAvgx
152 | AggregateFunction::RegrAvgy,
153 _,
154 ) => AggregateState::Bivariate {
155 kind: function,
156 count: 0,
157 mean_x: 0.0,
158 mean_y: 0.0,
159 m2_x: 0.0,
160 m2_y: 0.0,
161 c_xy: 0.0,
162 },
163 (AggregateFunction::Variance, _) => AggregateState::Variance {
164 count: 0,
165 mean: 0.0,
166 m2: 0.0,
167 },
168 (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169 count: 0,
170 mean: 0.0,
171 m2: 0.0,
172 },
173 }
174 }
175
176 pub(crate) fn update(&mut self, value: Option<Value>) {
178 match self {
179 AggregateState::Count(count) => {
180 *count += 1;
181 }
182 AggregateState::CountDistinct(count, seen) => {
183 if let Some(ref v) = value {
184 let hashable = HashableValue::from(v);
185 if seen.insert(hashable) {
186 *count += 1;
187 }
188 }
189 }
190 AggregateState::SumInt(sum, count) => {
191 if let Some(Value::Int64(v)) = value {
192 *sum += v;
193 *count += 1;
194 } else if let Some(Value::Float64(v)) = value {
195 *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
197 } else if let Some(ref v) = value {
198 if let Some(num) = value_to_f64(v) {
200 *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
201 }
202 }
203 }
204 AggregateState::SumIntDistinct(sum, count, seen) => {
205 if let Some(ref v) = value {
206 let hashable = HashableValue::from(v);
207 if seen.insert(hashable) {
208 if let Value::Int64(i) = v {
209 *sum += i;
210 *count += 1;
211 } else if let Value::Float64(f) = v {
212 let moved_seen = std::mem::take(seen);
214 *self = AggregateState::SumFloatDistinct(
215 *sum as f64 + f,
216 0.0,
217 *count + 1,
218 moved_seen,
219 );
220 } else if let Some(num) = value_to_f64(v) {
221 let moved_seen = std::mem::take(seen);
223 *self = AggregateState::SumFloatDistinct(
224 *sum as f64 + num,
225 0.0,
226 *count + 1,
227 moved_seen,
228 );
229 }
230 }
231 }
232 }
233 AggregateState::SumFloat(sum, comp, count) => {
234 if let Some(ref v) = value {
235 if let Some(num) = value_to_f64(v) {
237 let y = num - *comp;
239 let t = *sum + y;
240 *comp = (t - *sum) - y;
241 *sum = t;
242 *count += 1;
243 }
244 }
245 }
246 AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
247 if let Some(ref v) = value {
248 let hashable = HashableValue::from(v);
249 if seen.insert(hashable)
250 && let Some(num) = value_to_f64(v)
251 {
252 let y = num - *comp;
253 let t = *sum + y;
254 *comp = (t - *sum) - y;
255 *sum = t;
256 *count += 1;
257 }
258 }
259 }
260 AggregateState::Avg(sum, count) => {
261 if let Some(ref v) = value
262 && let Some(num) = value_to_f64(v)
263 {
264 *sum += num;
265 *count += 1;
266 }
267 }
268 AggregateState::AvgDistinct(sum, count, seen) => {
269 if let Some(ref v) = value {
270 let hashable = HashableValue::from(v);
271 if seen.insert(hashable)
272 && let Some(num) = value_to_f64(v)
273 {
274 *sum += num;
275 *count += 1;
276 }
277 }
278 }
279 AggregateState::Min(min) => {
280 if let Some(v) = value {
281 match min {
282 None => *min = Some(v),
283 Some(current) => {
284 if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
285 *min = Some(v);
286 }
287 }
288 }
289 }
290 }
291 AggregateState::Max(max) => {
292 if let Some(v) = value {
293 match max {
294 None => *max = Some(v),
295 Some(current) => {
296 if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
297 *max = Some(v);
298 }
299 }
300 }
301 }
302 }
303 AggregateState::First(first) => {
304 if first.is_none() {
305 *first = value;
306 }
307 }
308 AggregateState::Last(last) => {
309 if value.is_some() {
310 *last = value;
311 }
312 }
313 AggregateState::Collect(list) => {
314 if let Some(v) = value {
315 list.push(v);
316 }
317 }
318 AggregateState::CollectDistinct(list, seen) => {
319 if let Some(v) = value {
320 let hashable = HashableValue::from(&v);
321 if seen.insert(hashable) {
322 list.push(v);
323 }
324 }
325 }
326 AggregateState::StdDev { count, mean, m2 }
328 | AggregateState::StdDevPop { count, mean, m2 }
329 | AggregateState::Variance { count, mean, m2 }
330 | AggregateState::VariancePop { count, mean, m2 } => {
331 if let Some(ref v) = value
332 && let Some(x) = value_to_f64(v)
333 {
334 *count += 1;
335 let delta = x - *mean;
336 *mean += delta / *count as f64;
337 let delta2 = x - *mean;
338 *m2 += delta * delta2;
339 }
340 }
341 AggregateState::PercentileDisc { values, .. }
342 | AggregateState::PercentileCont { values, .. } => {
343 if let Some(ref v) = value
344 && let Some(x) = value_to_f64(v)
345 {
346 values.push(x);
347 }
348 }
349 AggregateState::GroupConcat(list, _sep) => {
350 if let Some(v) = value {
351 list.push(agg_value_to_string(&v));
352 }
353 }
354 AggregateState::GroupConcatDistinct(list, _sep, seen) => {
355 if let Some(v) = value {
356 let hashable = HashableValue::from(&v);
357 if seen.insert(hashable) {
358 list.push(agg_value_to_string(&v));
359 }
360 }
361 }
362 AggregateState::Sample(sample) => {
363 if sample.is_none() {
364 *sample = value;
365 }
366 }
367 AggregateState::Bivariate { .. } => {
368 }
371 }
372 }
373
374 fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
379 if let AggregateState::Bivariate {
380 count,
381 mean_x,
382 mean_y,
383 m2_x,
384 m2_y,
385 c_xy,
386 ..
387 } = self
388 {
389 if let (Some(y), Some(x)) = (&y_val, &x_val)
391 && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
392 {
393 *count += 1;
394 let n = *count as f64;
395 let dx = x_f - *mean_x;
396 let dy = y_f - *mean_y;
397 *mean_x += dx / n;
398 *mean_y += dy / n;
399 let dx2 = x_f - *mean_x; let dy2 = y_f - *mean_y; *m2_x += dx * dx2;
402 *m2_y += dy * dy2;
403 *c_xy += dx * dy2;
404 }
405 }
406 }
407
408 pub(crate) fn finalize(&self) -> Value {
410 match self {
411 AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
412 Value::Int64(*count)
413 }
414 AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
415 if *count == 0 {
416 Value::Null
417 } else {
418 Value::Int64(*sum)
419 }
420 }
421 AggregateState::SumFloat(sum, _, count)
422 | AggregateState::SumFloatDistinct(sum, _, count, _) => {
423 if *count == 0 {
424 Value::Null
425 } else {
426 Value::Float64(*sum)
427 }
428 }
429 AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
430 if *count == 0 {
431 Value::Null
432 } else {
433 Value::Float64(*sum / *count as f64)
434 }
435 }
436 AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
437 AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
438 AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
439 AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
440 AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
441 Value::List(list.clone().into())
442 }
443 AggregateState::StdDev { count, m2, .. } => {
445 if *count < 2 {
446 Value::Null
447 } else {
448 Value::Float64((*m2 / (*count - 1) as f64).sqrt())
449 }
450 }
451 AggregateState::StdDevPop { count, m2, .. } => {
453 if *count == 0 {
454 Value::Null
455 } else {
456 Value::Float64((*m2 / *count as f64).sqrt())
457 }
458 }
459 AggregateState::Variance { count, m2, .. } => {
461 if *count < 2 {
462 Value::Null
463 } else {
464 Value::Float64(*m2 / (*count - 1) as f64)
465 }
466 }
467 AggregateState::VariancePop { count, m2, .. } => {
469 if *count == 0 {
470 Value::Null
471 } else {
472 Value::Float64(*m2 / *count as f64)
473 }
474 }
475 AggregateState::PercentileDisc { values, percentile } => {
477 if values.is_empty() {
478 Value::Null
479 } else {
480 let mut sorted = values.clone();
481 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
482 let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
484 Value::Float64(sorted[index])
485 }
486 }
487 AggregateState::PercentileCont { values, percentile } => {
489 if values.is_empty() {
490 Value::Null
491 } else {
492 let mut sorted = values.clone();
493 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
494 let rank = percentile * (sorted.len() - 1) as f64;
496 let lower_idx = rank.floor() as usize;
497 let upper_idx = rank.ceil() as usize;
498 if lower_idx == upper_idx {
499 Value::Float64(sorted[lower_idx])
500 } else {
501 let fraction = rank - lower_idx as f64;
502 let result =
503 sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
504 Value::Float64(result)
505 }
506 }
507 }
508 AggregateState::GroupConcat(list, sep)
510 | AggregateState::GroupConcatDistinct(list, sep, _) => {
511 Value::String(list.join(sep).into())
512 }
513 AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
515 AggregateState::Bivariate {
517 kind,
518 count,
519 mean_x,
520 mean_y,
521 m2_x,
522 m2_y,
523 c_xy,
524 } => {
525 let n = *count;
526 match kind {
527 AggregateFunction::CovarSamp => {
528 if n < 2 {
529 Value::Null
530 } else {
531 Value::Float64(*c_xy / (n - 1) as f64)
532 }
533 }
534 AggregateFunction::CovarPop => {
535 if n == 0 {
536 Value::Null
537 } else {
538 Value::Float64(*c_xy / n as f64)
539 }
540 }
541 AggregateFunction::Corr => {
542 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
543 Value::Null
544 } else {
545 Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
546 }
547 }
548 AggregateFunction::RegrSlope => {
549 if n == 0 || *m2_x == 0.0 {
550 Value::Null
551 } else {
552 Value::Float64(*c_xy / *m2_x)
553 }
554 }
555 AggregateFunction::RegrIntercept => {
556 if n == 0 || *m2_x == 0.0 {
557 Value::Null
558 } else {
559 let slope = *c_xy / *m2_x;
560 Value::Float64(*mean_y - slope * *mean_x)
561 }
562 }
563 AggregateFunction::RegrR2 => {
564 if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
565 Value::Null
566 } else {
567 Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
568 }
569 }
570 AggregateFunction::RegrCount => Value::Int64(n),
571 AggregateFunction::RegrSxx => {
572 if n == 0 {
573 Value::Null
574 } else {
575 Value::Float64(*m2_x)
576 }
577 }
578 AggregateFunction::RegrSyy => {
579 if n == 0 {
580 Value::Null
581 } else {
582 Value::Float64(*m2_y)
583 }
584 }
585 AggregateFunction::RegrSxy => {
586 if n == 0 {
587 Value::Null
588 } else {
589 Value::Float64(*c_xy)
590 }
591 }
592 AggregateFunction::RegrAvgx => {
593 if n == 0 {
594 Value::Null
595 } else {
596 Value::Float64(*mean_x)
597 }
598 }
599 AggregateFunction::RegrAvgy => {
600 if n == 0 {
601 Value::Null
602 } else {
603 Value::Float64(*mean_y)
604 }
605 }
606 _ => Value::Null, }
608 }
609 }
610 }
611}
612
613use super::value_utils::{compare_values, value_to_f64};
614
615fn agg_value_to_string(val: &Value) -> String {
617 match val {
618 Value::String(s) => s.to_string(),
619 Value::Int64(i) => i.to_string(),
620 Value::Float64(f) => f.to_string(),
621 Value::Bool(b) => b.to_string(),
622 Value::Null => String::new(),
623 other => format!("{other:?}"),
624 }
625}
626
627#[derive(Debug, Clone, PartialEq, Eq, Hash)]
629pub struct GroupKey(Vec<GroupKeyPart>);
630
631#[derive(Debug, Clone, PartialEq, Eq, Hash)]
632enum GroupKeyPart {
633 Null,
634 Bool(bool),
635 Int64(i64),
636 String(ArcStr),
637 Bytes(Arc<[u8]>),
638 Date(grafeo_common::types::Date),
639 Time(grafeo_common::types::Time),
640 Timestamp(grafeo_common::types::Timestamp),
641 Duration(grafeo_common::types::Duration),
642 ZonedDatetime(grafeo_common::types::ZonedDatetime),
643 List(Vec<GroupKeyPart>),
644 Map(Vec<(ArcStr, GroupKeyPart)>),
645}
646
647impl GroupKeyPart {
648 fn from_value(v: Value) -> Self {
649 match v {
650 Value::Null => Self::Null,
651 Value::Bool(b) => Self::Bool(b),
652 Value::Int64(i) => Self::Int64(i),
653 Value::Float64(f) => Self::Int64(f.to_bits() as i64),
654 Value::String(s) => Self::String(s.clone()),
655 Value::Bytes(b) => Self::Bytes(b),
656 Value::Date(d) => Self::Date(d),
657 Value::Time(t) => Self::Time(t),
658 Value::Timestamp(ts) => Self::Timestamp(ts),
659 Value::Duration(d) => Self::Duration(d),
660 Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
661 Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
662 Value::Map(map) => {
663 let entries: Vec<(ArcStr, GroupKeyPart)> = map
665 .iter()
666 .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
667 .collect();
668 Self::Map(entries)
669 }
670 other => Self::String(ArcStr::from(format!("{other:?}"))),
672 }
673 }
674
675 fn to_value(&self) -> Value {
676 match self {
677 Self::Null => Value::Null,
678 Self::Bool(b) => Value::Bool(*b),
679 Self::Int64(i) => Value::Int64(*i),
680 Self::String(s) => Value::String(s.clone()),
681 Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
682 Self::Date(d) => Value::Date(*d),
683 Self::Time(t) => Value::Time(*t),
684 Self::Timestamp(ts) => Value::Timestamp(*ts),
685 Self::Duration(d) => Value::Duration(*d),
686 Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
687 Self::List(parts) => {
688 let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
689 Value::List(Arc::from(values.into_boxed_slice()))
690 }
691 Self::Map(entries) => {
692 let map: std::collections::BTreeMap<PropertyKey, Value> = entries
693 .iter()
694 .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
695 .collect();
696 Value::Map(Arc::new(map))
697 }
698 }
699 }
700}
701
702impl GroupKey {
703 fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
705 let parts: Vec<GroupKeyPart> = group_columns
706 .iter()
707 .map(|&col_idx| {
708 chunk
709 .column(col_idx)
710 .and_then(|col| col.get_value(row))
711 .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
712 })
713 .collect();
714 GroupKey(parts)
715 }
716
717 fn to_values(&self) -> Vec<Value> {
719 self.0.iter().map(GroupKeyPart::to_value).collect()
720 }
721}
722
723pub struct HashAggregateOperator {
727 child: Box<dyn Operator>,
729 group_columns: Vec<usize>,
731 aggregates: Vec<AggregateExpr>,
733 output_schema: Vec<LogicalType>,
735 groups: IndexMap<GroupKey, Vec<AggregateState>>,
737 aggregation_complete: bool,
739 results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
741}
742
743impl HashAggregateOperator {
744 pub fn new(
752 child: Box<dyn Operator>,
753 group_columns: Vec<usize>,
754 aggregates: Vec<AggregateExpr>,
755 output_schema: Vec<LogicalType>,
756 ) -> Self {
757 Self {
758 child,
759 group_columns,
760 aggregates,
761 output_schema,
762 groups: IndexMap::new(),
763 aggregation_complete: false,
764 results: None,
765 }
766 }
767
768 fn aggregate(&mut self) -> Result<(), OperatorError> {
770 while let Some(chunk) = self.child.next()? {
771 for row in chunk.selected_indices() {
772 let key = GroupKey::from_row(&chunk, row, &self.group_columns);
773
774 let states = self.groups.entry(key).or_insert_with(|| {
776 self.aggregates
777 .iter()
778 .map(|agg| {
779 AggregateState::new(
780 agg.function,
781 agg.distinct,
782 agg.percentile,
783 agg.separator.as_deref(),
784 )
785 })
786 .collect()
787 });
788
789 for (i, agg) in self.aggregates.iter().enumerate() {
791 if agg.column2.is_some() {
793 let y_val = agg
794 .column
795 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
796 let x_val = agg
797 .column2
798 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
799 states[i].update_bivariate(y_val, x_val);
800 continue;
801 }
802
803 let value = match (agg.function, agg.distinct) {
804 (AggregateFunction::Count, false) => None,
806 (AggregateFunction::Count, true) => agg
808 .column
809 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
810 _ => agg
811 .column
812 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
813 };
814
815 match (agg.function, agg.distinct) {
817 (AggregateFunction::Count, false) => states[i].update(None),
818 (AggregateFunction::Count, true) => {
819 if value.is_some() && !matches!(value, Some(Value::Null)) {
821 states[i].update(value);
822 }
823 }
824 (AggregateFunction::CountNonNull, _) => {
825 if value.is_some() && !matches!(value, Some(Value::Null)) {
826 states[i].update(value);
827 }
828 }
829 _ => {
830 if value.is_some() && !matches!(value, Some(Value::Null)) {
831 states[i].update(value);
832 }
833 }
834 }
835 }
836 }
837 }
838
839 self.aggregation_complete = true;
840
841 let results: Vec<_> = self.groups.drain(..).collect();
843 self.results = Some(results.into_iter());
844
845 Ok(())
846 }
847}
848
849impl Operator for HashAggregateOperator {
850 fn next(&mut self) -> OperatorResult {
851 if !self.aggregation_complete {
853 self.aggregate()?;
854 }
855
856 if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
858 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
860
861 for agg in &self.aggregates {
862 let state = AggregateState::new(
863 agg.function,
864 agg.distinct,
865 agg.percentile,
866 agg.separator.as_deref(),
867 );
868 let value = state.finalize();
869 if let Some(col) = builder.column_mut(self.group_columns.len()) {
870 col.push_value(value);
871 }
872 }
873 builder.advance_row();
874
875 self.results = Some(Vec::new().into_iter()); return Ok(Some(builder.finish()));
877 }
878
879 let Some(results) = &mut self.results else {
880 return Ok(None);
881 };
882
883 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
884
885 for (key, states) in results.by_ref() {
886 let key_values = key.to_values();
888 for (i, value) in key_values.into_iter().enumerate() {
889 if let Some(col) = builder.column_mut(i) {
890 col.push_value(value);
891 }
892 }
893
894 for (i, state) in states.iter().enumerate() {
896 let col_idx = self.group_columns.len() + i;
897 if let Some(col) = builder.column_mut(col_idx) {
898 col.push_value(state.finalize());
899 }
900 }
901
902 builder.advance_row();
903
904 if builder.is_full() {
905 return Ok(Some(builder.finish()));
906 }
907 }
908
909 if builder.row_count() > 0 {
910 Ok(Some(builder.finish()))
911 } else {
912 Ok(None)
913 }
914 }
915
916 fn reset(&mut self) {
917 self.child.reset();
918 self.groups.clear();
919 self.aggregation_complete = false;
920 self.results = None;
921 }
922
923 fn name(&self) -> &'static str {
924 "HashAggregate"
925 }
926}
927
928pub struct SimpleAggregateOperator {
932 child: Box<dyn Operator>,
934 aggregates: Vec<AggregateExpr>,
936 output_schema: Vec<LogicalType>,
938 states: Vec<AggregateState>,
940 done: bool,
942}
943
944impl SimpleAggregateOperator {
945 pub fn new(
947 child: Box<dyn Operator>,
948 aggregates: Vec<AggregateExpr>,
949 output_schema: Vec<LogicalType>,
950 ) -> Self {
951 let states = aggregates
952 .iter()
953 .map(|agg| {
954 AggregateState::new(
955 agg.function,
956 agg.distinct,
957 agg.percentile,
958 agg.separator.as_deref(),
959 )
960 })
961 .collect();
962
963 Self {
964 child,
965 aggregates,
966 output_schema,
967 states,
968 done: false,
969 }
970 }
971}
972
973impl Operator for SimpleAggregateOperator {
974 fn next(&mut self) -> OperatorResult {
975 if self.done {
976 return Ok(None);
977 }
978
979 while let Some(chunk) = self.child.next()? {
981 for row in chunk.selected_indices() {
982 for (i, agg) in self.aggregates.iter().enumerate() {
983 if agg.column2.is_some() {
985 let y_val = agg
986 .column
987 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
988 let x_val = agg
989 .column2
990 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
991 self.states[i].update_bivariate(y_val, x_val);
992 continue;
993 }
994
995 let value = match (agg.function, agg.distinct) {
996 (AggregateFunction::Count, false) => None,
998 (AggregateFunction::Count, true) => agg
1000 .column
1001 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1002 _ => agg
1003 .column
1004 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1005 };
1006
1007 match (agg.function, agg.distinct) {
1008 (AggregateFunction::Count, false) => self.states[i].update(None),
1009 (AggregateFunction::Count, true) => {
1010 if value.is_some() && !matches!(value, Some(Value::Null)) {
1012 self.states[i].update(value);
1013 }
1014 }
1015 (AggregateFunction::CountNonNull, _) => {
1016 if value.is_some() && !matches!(value, Some(Value::Null)) {
1017 self.states[i].update(value);
1018 }
1019 }
1020 _ => {
1021 if value.is_some() && !matches!(value, Some(Value::Null)) {
1022 self.states[i].update(value);
1023 }
1024 }
1025 }
1026 }
1027 }
1028 }
1029
1030 let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1032
1033 for (i, state) in self.states.iter().enumerate() {
1034 if let Some(col) = builder.column_mut(i) {
1035 col.push_value(state.finalize());
1036 }
1037 }
1038 builder.advance_row();
1039
1040 self.done = true;
1041 Ok(Some(builder.finish()))
1042 }
1043
1044 fn reset(&mut self) {
1045 self.child.reset();
1046 self.states = self
1047 .aggregates
1048 .iter()
1049 .map(|agg| {
1050 AggregateState::new(
1051 agg.function,
1052 agg.distinct,
1053 agg.percentile,
1054 agg.separator.as_deref(),
1055 )
1056 })
1057 .collect();
1058 self.done = false;
1059 }
1060
1061 fn name(&self) -> &'static str {
1062 "SimpleAggregate"
1063 }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068 use super::*;
1069 use crate::execution::chunk::DataChunkBuilder;
1070
1071 struct MockOperator {
1072 chunks: Vec<DataChunk>,
1073 position: usize,
1074 }
1075
1076 impl MockOperator {
1077 fn new(chunks: Vec<DataChunk>) -> Self {
1078 Self {
1079 chunks,
1080 position: 0,
1081 }
1082 }
1083 }
1084
1085 impl Operator for MockOperator {
1086 fn next(&mut self) -> OperatorResult {
1087 if self.position < self.chunks.len() {
1088 let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1089 self.position += 1;
1090 Ok(Some(chunk))
1091 } else {
1092 Ok(None)
1093 }
1094 }
1095
1096 fn reset(&mut self) {
1097 self.position = 0;
1098 }
1099
1100 fn name(&self) -> &'static str {
1101 "Mock"
1102 }
1103 }
1104
1105 fn create_test_chunk() -> DataChunk {
1106 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1108
1109 let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1110 for (group, value) in data {
1111 builder.column_mut(0).unwrap().push_int64(group);
1112 builder.column_mut(1).unwrap().push_int64(value);
1113 builder.advance_row();
1114 }
1115
1116 builder.finish()
1117 }
1118
1119 #[test]
1120 fn test_simple_count() {
1121 let mock = MockOperator::new(vec![create_test_chunk()]);
1122
1123 let mut agg = SimpleAggregateOperator::new(
1124 Box::new(mock),
1125 vec![AggregateExpr::count_star()],
1126 vec![LogicalType::Int64],
1127 );
1128
1129 let result = agg.next().unwrap().unwrap();
1130 assert_eq!(result.row_count(), 1);
1131 assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1132
1133 assert!(agg.next().unwrap().is_none());
1135 }
1136
1137 #[test]
1138 fn test_simple_sum() {
1139 let mock = MockOperator::new(vec![create_test_chunk()]);
1140
1141 let mut agg = SimpleAggregateOperator::new(
1142 Box::new(mock),
1143 vec![AggregateExpr::sum(1)], vec![LogicalType::Int64],
1145 );
1146
1147 let result = agg.next().unwrap().unwrap();
1148 assert_eq!(result.row_count(), 1);
1149 assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1151 }
1152
1153 #[test]
1154 fn test_simple_avg() {
1155 let mock = MockOperator::new(vec![create_test_chunk()]);
1156
1157 let mut agg = SimpleAggregateOperator::new(
1158 Box::new(mock),
1159 vec![AggregateExpr::avg(1)],
1160 vec![LogicalType::Float64],
1161 );
1162
1163 let result = agg.next().unwrap().unwrap();
1164 assert_eq!(result.row_count(), 1);
1165 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1167 assert!((avg - 30.0).abs() < 0.001);
1168 }
1169
1170 #[test]
1171 fn test_simple_min_max() {
1172 let mock = MockOperator::new(vec![create_test_chunk()]);
1173
1174 let mut agg = SimpleAggregateOperator::new(
1175 Box::new(mock),
1176 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1177 vec![LogicalType::Int64, LogicalType::Int64],
1178 );
1179
1180 let result = agg.next().unwrap().unwrap();
1181 assert_eq!(result.row_count(), 1);
1182 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); }
1185
1186 #[test]
1187 fn test_sum_with_string_values() {
1188 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1190 builder.column_mut(0).unwrap().push_string("30");
1191 builder.advance_row();
1192 builder.column_mut(0).unwrap().push_string("25");
1193 builder.advance_row();
1194 builder.column_mut(0).unwrap().push_string("35");
1195 builder.advance_row();
1196 let chunk = builder.finish();
1197
1198 let mock = MockOperator::new(vec![chunk]);
1199 let mut agg = SimpleAggregateOperator::new(
1200 Box::new(mock),
1201 vec![AggregateExpr::sum(0)],
1202 vec![LogicalType::Float64],
1203 );
1204
1205 let result = agg.next().unwrap().unwrap();
1206 assert_eq!(result.row_count(), 1);
1207 let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1209 assert!(
1210 (sum_val - 90.0).abs() < 0.001,
1211 "Expected 90.0, got {}",
1212 sum_val
1213 );
1214 }
1215
1216 #[test]
1217 fn test_grouped_aggregation() {
1218 let mock = MockOperator::new(vec![create_test_chunk()]);
1219
1220 let mut agg = HashAggregateOperator::new(
1222 Box::new(mock),
1223 vec![0], vec![AggregateExpr::sum(1)], vec![LogicalType::Int64, LogicalType::Int64],
1226 );
1227
1228 let mut results: Vec<(i64, i64)> = Vec::new();
1229 while let Some(chunk) = agg.next().unwrap() {
1230 for row in chunk.selected_indices() {
1231 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1232 let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1233 results.push((group, sum));
1234 }
1235 }
1236
1237 results.sort_by_key(|(g, _)| *g);
1238 assert_eq!(results.len(), 2);
1239 assert_eq!(results[0], (1, 30)); assert_eq!(results[1], (2, 120)); }
1242
1243 #[test]
1244 fn test_grouped_count() {
1245 let mock = MockOperator::new(vec![create_test_chunk()]);
1246
1247 let mut agg = HashAggregateOperator::new(
1249 Box::new(mock),
1250 vec![0],
1251 vec![AggregateExpr::count_star()],
1252 vec![LogicalType::Int64, LogicalType::Int64],
1253 );
1254
1255 let mut results: Vec<(i64, i64)> = Vec::new();
1256 while let Some(chunk) = agg.next().unwrap() {
1257 for row in chunk.selected_indices() {
1258 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1259 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1260 results.push((group, count));
1261 }
1262 }
1263
1264 results.sort_by_key(|(g, _)| *g);
1265 assert_eq!(results.len(), 2);
1266 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 3)); }
1269
1270 #[test]
1271 fn test_multiple_aggregates() {
1272 let mock = MockOperator::new(vec![create_test_chunk()]);
1273
1274 let mut agg = HashAggregateOperator::new(
1276 Box::new(mock),
1277 vec![0],
1278 vec![
1279 AggregateExpr::count_star(),
1280 AggregateExpr::sum(1),
1281 AggregateExpr::avg(1),
1282 ],
1283 vec![
1284 LogicalType::Int64, LogicalType::Int64, LogicalType::Int64, LogicalType::Float64, ],
1289 );
1290
1291 let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1292 while let Some(chunk) = agg.next().unwrap() {
1293 for row in chunk.selected_indices() {
1294 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1295 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1296 let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1297 let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1298 results.push((group, count, sum, avg));
1299 }
1300 }
1301
1302 results.sort_by_key(|(g, _, _, _)| *g);
1303 assert_eq!(results.len(), 2);
1304
1305 assert_eq!(results[0].0, 1);
1307 assert_eq!(results[0].1, 2);
1308 assert_eq!(results[0].2, 30);
1309 assert!((results[0].3 - 15.0).abs() < 0.001);
1310
1311 assert_eq!(results[1].0, 2);
1313 assert_eq!(results[1].1, 3);
1314 assert_eq!(results[1].2, 120);
1315 assert!((results[1].3 - 40.0).abs() < 0.001);
1316 }
1317
1318 fn create_test_chunk_with_duplicates() -> DataChunk {
1319 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1324
1325 let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1326 for (group, value) in data {
1327 builder.column_mut(0).unwrap().push_int64(group);
1328 builder.column_mut(1).unwrap().push_int64(value);
1329 builder.advance_row();
1330 }
1331
1332 builder.finish()
1333 }
1334
1335 #[test]
1336 fn test_count_distinct() {
1337 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1338
1339 let mut agg = SimpleAggregateOperator::new(
1341 Box::new(mock),
1342 vec![AggregateExpr::count(1).with_distinct()],
1343 vec![LogicalType::Int64],
1344 );
1345
1346 let result = agg.next().unwrap().unwrap();
1347 assert_eq!(result.row_count(), 1);
1348 assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1350 }
1351
1352 #[test]
1353 fn test_grouped_count_distinct() {
1354 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1355
1356 let mut agg = HashAggregateOperator::new(
1358 Box::new(mock),
1359 vec![0],
1360 vec![AggregateExpr::count(1).with_distinct()],
1361 vec![LogicalType::Int64, LogicalType::Int64],
1362 );
1363
1364 let mut results: Vec<(i64, i64)> = Vec::new();
1365 while let Some(chunk) = agg.next().unwrap() {
1366 for row in chunk.selected_indices() {
1367 let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1368 let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1369 results.push((group, count));
1370 }
1371 }
1372
1373 results.sort_by_key(|(g, _)| *g);
1374 assert_eq!(results.len(), 2);
1375 assert_eq!(results[0], (1, 2)); assert_eq!(results[1], (2, 1)); }
1378
1379 #[test]
1380 fn test_sum_distinct() {
1381 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1382
1383 let mut agg = SimpleAggregateOperator::new(
1385 Box::new(mock),
1386 vec![AggregateExpr::sum(1).with_distinct()],
1387 vec![LogicalType::Int64],
1388 );
1389
1390 let result = agg.next().unwrap().unwrap();
1391 assert_eq!(result.row_count(), 1);
1392 assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1394 }
1395
1396 #[test]
1397 fn test_avg_distinct() {
1398 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1399
1400 let mut agg = SimpleAggregateOperator::new(
1402 Box::new(mock),
1403 vec![AggregateExpr::avg(1).with_distinct()],
1404 vec![LogicalType::Float64],
1405 );
1406
1407 let result = agg.next().unwrap().unwrap();
1408 assert_eq!(result.row_count(), 1);
1409 let avg = result.column(0).unwrap().get_float64(0).unwrap();
1411 assert!((avg - 20.0).abs() < 0.001);
1412 }
1413
1414 fn create_statistical_test_chunk() -> DataChunk {
1415 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1418
1419 for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1420 builder.column_mut(0).unwrap().push_int64(value);
1421 builder.advance_row();
1422 }
1423
1424 builder.finish()
1425 }
1426
1427 #[test]
1428 fn test_stdev_sample() {
1429 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1430
1431 let mut agg = SimpleAggregateOperator::new(
1432 Box::new(mock),
1433 vec![AggregateExpr::stdev(0)],
1434 vec![LogicalType::Float64],
1435 );
1436
1437 let result = agg.next().unwrap().unwrap();
1438 assert_eq!(result.row_count(), 1);
1439 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1442 assert!((stdev - 2.138).abs() < 0.01);
1443 }
1444
1445 #[test]
1446 fn test_stdev_population() {
1447 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1448
1449 let mut agg = SimpleAggregateOperator::new(
1450 Box::new(mock),
1451 vec![AggregateExpr::stdev_pop(0)],
1452 vec![LogicalType::Float64],
1453 );
1454
1455 let result = agg.next().unwrap().unwrap();
1456 assert_eq!(result.row_count(), 1);
1457 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1460 assert!((stdev - 2.0).abs() < 0.01);
1461 }
1462
1463 #[test]
1464 fn test_percentile_disc() {
1465 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1466
1467 let mut agg = SimpleAggregateOperator::new(
1469 Box::new(mock),
1470 vec![AggregateExpr::percentile_disc(0, 0.5)],
1471 vec![LogicalType::Float64],
1472 );
1473
1474 let result = agg.next().unwrap().unwrap();
1475 assert_eq!(result.row_count(), 1);
1476 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1478 assert!((percentile - 4.0).abs() < 0.01);
1479 }
1480
1481 #[test]
1482 fn test_percentile_cont() {
1483 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1484
1485 let mut agg = SimpleAggregateOperator::new(
1487 Box::new(mock),
1488 vec![AggregateExpr::percentile_cont(0, 0.5)],
1489 vec![LogicalType::Float64],
1490 );
1491
1492 let result = agg.next().unwrap().unwrap();
1493 assert_eq!(result.row_count(), 1);
1494 let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1497 assert!((percentile - 4.5).abs() < 0.01);
1498 }
1499
1500 #[test]
1501 fn test_percentile_extremes() {
1502 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1504
1505 let mut agg = SimpleAggregateOperator::new(
1506 Box::new(mock),
1507 vec![
1508 AggregateExpr::percentile_disc(0, 0.0),
1509 AggregateExpr::percentile_disc(0, 1.0),
1510 ],
1511 vec![LogicalType::Float64, LogicalType::Float64],
1512 );
1513
1514 let result = agg.next().unwrap().unwrap();
1515 assert_eq!(result.row_count(), 1);
1516 let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1518 assert!((p0 - 2.0).abs() < 0.01);
1519 let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1521 assert!((p100 - 9.0).abs() < 0.01);
1522 }
1523
1524 #[test]
1525 fn test_stdev_single_value() {
1526 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1528 builder.column_mut(0).unwrap().push_int64(42);
1529 builder.advance_row();
1530 let chunk = builder.finish();
1531
1532 let mock = MockOperator::new(vec![chunk]);
1533
1534 let mut agg = SimpleAggregateOperator::new(
1535 Box::new(mock),
1536 vec![AggregateExpr::stdev(0)],
1537 vec![LogicalType::Float64],
1538 );
1539
1540 let result = agg.next().unwrap().unwrap();
1541 assert_eq!(result.row_count(), 1);
1542 assert!(matches!(
1544 result.column(0).unwrap().get_value(0),
1545 Some(Value::Null)
1546 ));
1547 }
1548
1549 #[test]
1550 fn test_first_and_last() {
1551 let mock = MockOperator::new(vec![create_test_chunk()]);
1552
1553 let mut agg = SimpleAggregateOperator::new(
1554 Box::new(mock),
1555 vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1556 vec![LogicalType::Int64, LogicalType::Int64],
1557 );
1558
1559 let result = agg.next().unwrap().unwrap();
1560 assert_eq!(result.row_count(), 1);
1561 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1563 assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1564 }
1565
1566 #[test]
1567 fn test_collect() {
1568 let mock = MockOperator::new(vec![create_test_chunk()]);
1569
1570 let mut agg = SimpleAggregateOperator::new(
1571 Box::new(mock),
1572 vec![AggregateExpr::collect(1)],
1573 vec![LogicalType::Any],
1574 );
1575
1576 let result = agg.next().unwrap().unwrap();
1577 let val = result.column(0).unwrap().get_value(0).unwrap();
1578 if let Value::List(items) = val {
1579 assert_eq!(items.len(), 5);
1580 } else {
1581 panic!("Expected List value");
1582 }
1583 }
1584
1585 #[test]
1586 fn test_collect_distinct() {
1587 let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1588
1589 let mut agg = SimpleAggregateOperator::new(
1590 Box::new(mock),
1591 vec![AggregateExpr::collect(1).with_distinct()],
1592 vec![LogicalType::Any],
1593 );
1594
1595 let result = agg.next().unwrap().unwrap();
1596 let val = result.column(0).unwrap().get_value(0).unwrap();
1597 if let Value::List(items) = val {
1598 assert_eq!(items.len(), 3);
1600 } else {
1601 panic!("Expected List value");
1602 }
1603 }
1604
1605 #[test]
1606 fn test_group_concat() {
1607 let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1608 for s in ["hello", "world", "foo"] {
1609 builder.column_mut(0).unwrap().push_string(s);
1610 builder.advance_row();
1611 }
1612 let chunk = builder.finish();
1613 let mock = MockOperator::new(vec![chunk]);
1614
1615 let agg_expr = AggregateExpr {
1616 function: AggregateFunction::GroupConcat,
1617 column: Some(0),
1618 column2: None,
1619 distinct: false,
1620 alias: None,
1621 percentile: None,
1622 separator: None,
1623 };
1624
1625 let mut agg =
1626 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1627
1628 let result = agg.next().unwrap().unwrap();
1629 let val = result.column(0).unwrap().get_value(0).unwrap();
1630 assert_eq!(val, Value::String("hello world foo".into()));
1631 }
1632
1633 #[test]
1634 fn test_sample() {
1635 let mock = MockOperator::new(vec![create_test_chunk()]);
1636
1637 let agg_expr = AggregateExpr {
1638 function: AggregateFunction::Sample,
1639 column: Some(1),
1640 column2: None,
1641 distinct: false,
1642 alias: None,
1643 percentile: None,
1644 separator: None,
1645 };
1646
1647 let mut agg =
1648 SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1649
1650 let result = agg.next().unwrap().unwrap();
1651 assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1653 }
1654
1655 #[test]
1656 fn test_variance_sample() {
1657 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1658
1659 let agg_expr = AggregateExpr {
1660 function: AggregateFunction::Variance,
1661 column: Some(0),
1662 column2: None,
1663 distinct: false,
1664 alias: None,
1665 percentile: None,
1666 separator: None,
1667 };
1668
1669 let mut agg = SimpleAggregateOperator::new(
1670 Box::new(mock),
1671 vec![agg_expr],
1672 vec![LogicalType::Float64],
1673 );
1674
1675 let result = agg.next().unwrap().unwrap();
1676 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1678 assert!((variance - 32.0 / 7.0).abs() < 0.01);
1679 }
1680
1681 #[test]
1682 fn test_variance_population() {
1683 let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1684
1685 let agg_expr = AggregateExpr {
1686 function: AggregateFunction::VariancePop,
1687 column: Some(0),
1688 column2: None,
1689 distinct: false,
1690 alias: None,
1691 percentile: None,
1692 separator: None,
1693 };
1694
1695 let mut agg = SimpleAggregateOperator::new(
1696 Box::new(mock),
1697 vec![agg_expr],
1698 vec![LogicalType::Float64],
1699 );
1700
1701 let result = agg.next().unwrap().unwrap();
1702 let variance = result.column(0).unwrap().get_float64(0).unwrap();
1704 assert!((variance - 4.0).abs() < 0.01);
1705 }
1706
1707 #[test]
1708 fn test_variance_single_value() {
1709 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1710 builder.column_mut(0).unwrap().push_int64(42);
1711 builder.advance_row();
1712 let chunk = builder.finish();
1713 let mock = MockOperator::new(vec![chunk]);
1714
1715 let agg_expr = AggregateExpr {
1716 function: AggregateFunction::Variance,
1717 column: Some(0),
1718 column2: None,
1719 distinct: false,
1720 alias: None,
1721 percentile: None,
1722 separator: None,
1723 };
1724
1725 let mut agg = SimpleAggregateOperator::new(
1726 Box::new(mock),
1727 vec![agg_expr],
1728 vec![LogicalType::Float64],
1729 );
1730
1731 let result = agg.next().unwrap().unwrap();
1732 assert!(matches!(
1734 result.column(0).unwrap().get_value(0),
1735 Some(Value::Null)
1736 ));
1737 }
1738
1739 #[test]
1740 fn test_empty_aggregation() {
1741 let mock = MockOperator::new(vec![]);
1744
1745 let mut agg = SimpleAggregateOperator::new(
1746 Box::new(mock),
1747 vec![
1748 AggregateExpr::count_star(),
1749 AggregateExpr::sum(0),
1750 AggregateExpr::avg(0),
1751 AggregateExpr::min(0),
1752 AggregateExpr::max(0),
1753 ],
1754 vec![
1755 LogicalType::Int64,
1756 LogicalType::Int64,
1757 LogicalType::Float64,
1758 LogicalType::Int64,
1759 LogicalType::Int64,
1760 ],
1761 );
1762
1763 let result = agg.next().unwrap().unwrap();
1764 assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); assert!(matches!(
1766 result.column(1).unwrap().get_value(0),
1767 Some(Value::Null)
1768 )); assert!(matches!(
1770 result.column(2).unwrap().get_value(0),
1771 Some(Value::Null)
1772 )); assert!(matches!(
1774 result.column(3).unwrap().get_value(0),
1775 Some(Value::Null)
1776 )); assert!(matches!(
1778 result.column(4).unwrap().get_value(0),
1779 Some(Value::Null)
1780 )); }
1782
1783 #[test]
1784 fn test_stdev_pop_single_value() {
1785 let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1787 builder.column_mut(0).unwrap().push_int64(42);
1788 builder.advance_row();
1789 let chunk = builder.finish();
1790
1791 let mock = MockOperator::new(vec![chunk]);
1792
1793 let mut agg = SimpleAggregateOperator::new(
1794 Box::new(mock),
1795 vec![AggregateExpr::stdev_pop(0)],
1796 vec![LogicalType::Float64],
1797 );
1798
1799 let result = agg.next().unwrap().unwrap();
1800 assert_eq!(result.row_count(), 1);
1801 let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1803 assert!((stdev - 0.0).abs() < 0.01);
1804 }
1805}