1use crate::ast::AggregateFunc;
4use crate::executor::{Relation, RelationEntry, SharedTables};
5use alloc::collections::BTreeMap;
6use alloc::rc::Rc;
7use alloc::string::String;
8use alloc::vec::Vec;
9use cynos_core::{Row, Value, DUMMY_ROW_ID};
10use libm::{exp, log, sqrt};
11
12pub struct AggregateExecutor {
14 group_by: Vec<usize>,
16 aggregates: Vec<(AggregateFunc, Option<usize>)>,
18}
19
20impl AggregateExecutor {
21 pub fn new(group_by: Vec<usize>, aggregates: Vec<(AggregateFunc, Option<usize>)>) -> Self {
23 Self {
24 group_by,
25 aggregates,
26 }
27 }
28
29 pub fn no_group(aggregates: Vec<(AggregateFunc, Option<usize>)>) -> Self {
31 Self::new(Vec::new(), aggregates)
32 }
33
34 pub fn execute(&self, input: Relation) -> Relation {
36 let tables = input.tables().to_vec();
37 let shared_tables: SharedTables = tables.clone().into();
38 let result_column_count = self.group_by.len() + self.aggregates.len();
40
41 if self.group_by.is_empty() {
42 let version_sum: u64 = input.iter().map(|e| e.row.version()).sum();
45 let values = self.compute_aggregates(input.iter());
46 let entry = RelationEntry::new_combined(
47 Rc::new(Row::dummy_with_version(version_sum, values)),
48 shared_tables,
49 );
50 return Relation {
51 entries: alloc::vec![entry],
52 tables,
53 table_column_counts: alloc::vec![result_column_count],
54 };
55 }
56
57 let mut groups: BTreeMap<String, Vec<&RelationEntry>> = BTreeMap::new();
59
60 for entry in input.iter() {
61 let key = self.make_group_key(entry);
62 groups.entry(key).or_default().push(entry);
63 }
64
65 let entries: Vec<RelationEntry> = groups
66 .into_iter()
67 .map(|(_, group_entries)| {
68 let mut values = Vec::new();
69
70 let version_sum: u64 = group_entries.iter().map(|e| e.row.version()).sum();
72
73 if let Some(first) = group_entries.first() {
75 for &idx in &self.group_by {
76 values.push(first.get_field(idx).cloned().unwrap_or(Value::Null));
77 }
78 }
79
80 let agg_values = self.compute_aggregates(group_entries.iter().copied());
82 values.extend(agg_values);
83
84 RelationEntry::new_combined(
85 Rc::new(Row::dummy_with_version(version_sum, values)),
86 shared_tables.clone(),
87 )
88 })
89 .collect();
90
91 Relation {
92 entries,
93 tables,
94 table_column_counts: alloc::vec![result_column_count],
95 }
96 }
97
98 fn make_group_key(&self, entry: &RelationEntry) -> String {
99 self.group_by
100 .iter()
101 .map(|&idx| {
102 entry
103 .get_field(idx)
104 .map(value_to_string)
105 .unwrap_or_else(|| String::from("null"))
106 })
107 .collect::<Vec<_>>()
108 .join("|")
109 }
110
111 fn compute_aggregates<'a>(
112 &self,
113 entries: impl Iterator<Item = &'a RelationEntry>,
114 ) -> Vec<Value> {
115 let entries: Vec<_> = entries.collect();
116
117 self.aggregates
118 .iter()
119 .map(|(func, col_idx)| self.compute_single_aggregate(*func, *col_idx, &entries))
120 .collect()
121 }
122
123 fn compute_single_aggregate(
124 &self,
125 func: AggregateFunc,
126 col_idx: Option<usize>,
127 entries: &[&RelationEntry],
128 ) -> Value {
129 match func {
130 AggregateFunc::Count => {
131 if let Some(idx) = col_idx {
132 let count = entries
134 .iter()
135 .filter(|e| {
136 e.get_field(idx)
137 .map(|v| !v.is_null())
138 .unwrap_or(false)
139 })
140 .count();
141 Value::Int64(count as i64)
142 } else {
143 Value::Int64(entries.len() as i64)
145 }
146 }
147 AggregateFunc::Sum => {
148 let idx = col_idx.unwrap_or(0);
149 let sum = entries
150 .iter()
151 .filter_map(|e| e.get_field(idx))
152 .filter(|v| !v.is_null())
153 .fold(0.0f64, |acc, v| {
154 acc + match v {
155 Value::Int32(i) => *i as f64,
156 Value::Int64(i) => *i as f64,
157 Value::Float64(f) => *f,
158 _ => 0.0,
159 }
160 });
161
162 if entries.iter().all(|e| {
163 e.get_field(idx)
164 .map(|v| v.is_null() || matches!(v, Value::Int32(_) | Value::Int64(_)))
165 .unwrap_or(true)
166 }) {
167 Value::Int64(sum as i64)
168 } else {
169 Value::Float64(sum)
170 }
171 }
172 AggregateFunc::Avg => {
173 let idx = col_idx.unwrap_or(0);
174 let values: Vec<f64> = entries
175 .iter()
176 .filter_map(|e| e.get_field(idx))
177 .filter(|v| !v.is_null())
178 .filter_map(|v| match v {
179 Value::Int32(i) => Some(*i as f64),
180 Value::Int64(i) => Some(*i as f64),
181 Value::Float64(f) => Some(*f),
182 _ => None,
183 })
184 .collect();
185
186 if values.is_empty() {
187 Value::Null
188 } else {
189 let sum: f64 = values.iter().sum();
190 Value::Float64(sum / values.len() as f64)
191 }
192 }
193 AggregateFunc::Min => {
194 let idx = col_idx.unwrap_or(0);
195 entries
196 .iter()
197 .filter_map(|e| e.get_field(idx))
198 .filter(|v| !v.is_null())
199 .min()
200 .cloned()
201 .unwrap_or(Value::Null)
202 }
203 AggregateFunc::Max => {
204 let idx = col_idx.unwrap_or(0);
205 entries
206 .iter()
207 .filter_map(|e| e.get_field(idx))
208 .filter(|v| !v.is_null())
209 .max()
210 .cloned()
211 .unwrap_or(Value::Null)
212 }
213 AggregateFunc::Distinct => {
214 let idx = col_idx.unwrap_or(0);
215 let mut seen: BTreeMap<String, Value> = BTreeMap::new();
216 for entry in entries {
217 if let Some(v) = entry.get_field(idx) {
218 let key = value_to_string(v);
219 seen.entry(key).or_insert_with(|| v.clone());
220 }
221 }
222 Value::Int64(seen.len() as i64)
224 }
225 AggregateFunc::StdDev => {
226 let idx = col_idx.unwrap_or(0);
227 let values: Vec<f64> = entries
228 .iter()
229 .filter_map(|e| e.get_field(idx))
230 .filter(|v| !v.is_null())
231 .filter_map(|v| match v {
232 Value::Int32(i) => Some(*i as f64),
233 Value::Int64(i) => Some(*i as f64),
234 Value::Float64(f) => Some(*f),
235 _ => None,
236 })
237 .collect();
238
239 if values.is_empty() {
240 Value::Null
241 } else {
242 let mean: f64 = values.iter().sum::<f64>() / values.len() as f64;
243 let variance: f64 = values
244 .iter()
245 .map(|v| (v - mean) * (v - mean))
246 .sum::<f64>()
247 / values.len() as f64;
248 Value::Float64(sqrt(variance))
249 }
250 }
251 AggregateFunc::GeoMean => {
252 let idx = col_idx.unwrap_or(0);
253 let values: Vec<f64> = entries
254 .iter()
255 .filter_map(|e| e.get_field(idx))
256 .filter(|v| !v.is_null())
257 .filter_map(|v| match v {
258 Value::Int32(i) => Some(*i as f64),
259 Value::Int64(i) => Some(*i as f64),
260 Value::Float64(f) => Some(*f),
261 _ => None,
262 })
263 .filter(|&v| v > 0.0)
264 .collect();
265
266 if values.is_empty() {
267 Value::Null
268 } else {
269 let log_sum: f64 = values.iter().map(|v| log(*v)).sum();
270 let geomean = exp(log_sum / values.len() as f64);
271 Value::Float64(geomean)
272 }
273 }
274 }
275 }
276}
277
278fn value_to_string(value: &Value) -> String {
279 match value {
280 Value::Null => String::from("null"),
281 Value::Boolean(b) => alloc::format!("{}", b),
282 Value::Int32(i) => alloc::format!("{}", i),
283 Value::Int64(i) => alloc::format!("{}", i),
284 Value::Float64(f) => alloc::format!("{}", f),
285 Value::String(s) => s.clone(),
286 Value::DateTime(d) => alloc::format!("{}", d),
287 Value::Bytes(b) => alloc::format!("{:?}", b),
288 Value::Jsonb(j) => alloc::format!("{:?}", j.0),
289 }
290}
291
292#[cfg(test)]
293mod tests {
294 use super::*;
295 use alloc::vec;
296
297 #[test]
298 fn test_count_star() {
299 let rows = vec![
300 Row::new(0, vec![Value::Int64(1)]),
301 Row::new(1, vec![Value::Int64(2)]),
302 Row::new(2, vec![Value::Int64(3)]),
303 ];
304 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
305
306 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Count, None)]);
307 let result = executor.execute(input);
308
309 assert_eq!(result.len(), 1);
310 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(3)));
311 }
312
313 #[test]
314 fn test_count_column() {
315 let rows = vec![
316 Row::new(0, vec![Value::Int64(1)]),
317 Row::new(1, vec![Value::Null]),
318 Row::new(2, vec![Value::Int64(3)]),
319 ];
320 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
321
322 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Count, Some(0))]);
323 let result = executor.execute(input);
324
325 assert_eq!(result.len(), 1);
326 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(2)));
328 }
329
330 #[test]
331 fn test_sum() {
332 let rows = vec![
333 Row::new(0, vec![Value::Int64(10)]),
334 Row::new(1, vec![Value::Int64(20)]),
335 Row::new(2, vec![Value::Int64(30)]),
336 ];
337 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
338
339 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Sum, Some(0))]);
340 let result = executor.execute(input);
341
342 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(60)));
343 }
344
345 #[test]
346 fn test_avg() {
347 let rows = vec![
348 Row::new(0, vec![Value::Int64(10)]),
349 Row::new(1, vec![Value::Int64(20)]),
350 Row::new(2, vec![Value::Int64(30)]),
351 ];
352 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
353
354 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Avg, Some(0))]);
355 let result = executor.execute(input);
356
357 assert_eq!(result.entries[0].get_field(0), Some(&Value::Float64(20.0)));
358 }
359
360 #[test]
361 fn test_min_max() {
362 let rows = vec![
363 Row::new(0, vec![Value::Int64(30)]),
364 Row::new(1, vec![Value::Int64(10)]),
365 Row::new(2, vec![Value::Int64(20)]),
366 ];
367 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
368
369 let executor = AggregateExecutor::no_group(vec![
370 (AggregateFunc::Min, Some(0)),
371 (AggregateFunc::Max, Some(0)),
372 ]);
373 let result = executor.execute(input);
374
375 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(10)));
376 assert_eq!(result.entries[0].get_field(1), Some(&Value::Int64(30)));
377 }
378
379 #[test]
380 fn test_group_by() {
381 let rows = vec![
382 Row::new(0, vec![Value::String("A".into()), Value::Int64(10)]),
383 Row::new(1, vec![Value::String("A".into()), Value::Int64(20)]),
384 Row::new(2, vec![Value::String("B".into()), Value::Int64(30)]),
385 ];
386 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
387
388 let executor = AggregateExecutor::new(
389 vec![0], vec![(AggregateFunc::Sum, Some(1))],
391 );
392 let result = executor.execute(input);
393
394 assert_eq!(result.len(), 2);
395 let mut sums: Vec<i64> = result
398 .entries
399 .iter()
400 .filter_map(|e| match e.get_field(1) {
401 Some(Value::Int64(v)) => Some(*v),
402 _ => None,
403 })
404 .collect();
405 sums.sort();
406 assert_eq!(sums, vec![30, 30]); }
408
409 #[test]
410 fn test_empty_relation() {
411 let input = Relation::from_rows_owned(Vec::new(), vec!["t".into()]);
412
413 let executor = AggregateExecutor::no_group(vec![
414 (AggregateFunc::Count, None),
415 (AggregateFunc::Sum, Some(0)),
416 (AggregateFunc::Avg, Some(0)),
417 ]);
418 let result = executor.execute(input);
419
420 assert_eq!(result.len(), 1);
421 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(0))); assert_eq!(result.entries[0].get_field(1), Some(&Value::Int64(0))); assert_eq!(result.entries[0].get_field(2), Some(&Value::Null)); }
425
426 #[test]
427 fn test_stddev() {
428 let rows = vec![
429 Row::new(0, vec![Value::Float64(2.0)]),
430 Row::new(1, vec![Value::Float64(4.0)]),
431 Row::new(2, vec![Value::Float64(4.0)]),
432 Row::new(3, vec![Value::Float64(4.0)]),
433 Row::new(4, vec![Value::Float64(5.0)]),
434 Row::new(5, vec![Value::Float64(5.0)]),
435 Row::new(6, vec![Value::Float64(7.0)]),
436 Row::new(7, vec![Value::Float64(9.0)]),
437 ];
438 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
439
440 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::StdDev, Some(0))]);
441 let result = executor.execute(input);
442
443 if let Some(Value::Float64(stddev)) = result.entries[0].get_field(0) {
445 assert!((stddev - 2.0).abs() < 0.001);
446 } else {
447 panic!("Expected Float64 value");
448 }
449 }
450
451 #[test]
452 fn test_distinct() {
453 let rows = vec![
454 Row::new(0, vec![Value::Int64(1)]),
455 Row::new(1, vec![Value::Int64(2)]),
456 Row::new(2, vec![Value::Int64(1)]), Row::new(3, vec![Value::Int64(3)]),
458 Row::new(4, vec![Value::Int64(2)]), ];
460 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
461
462 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Distinct, Some(0))]);
463 let result = executor.execute(input);
464
465 assert_eq!(result.len(), 1);
466 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(3)));
468 }
469
470 #[test]
471 fn test_distinct_with_nulls() {
472 let rows = vec![
473 Row::new(0, vec![Value::Int64(1)]),
474 Row::new(1, vec![Value::Null]),
475 Row::new(2, vec![Value::Int64(1)]),
476 Row::new(3, vec![Value::Null]),
477 ];
478 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
479
480 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Distinct, Some(0))]);
481 let result = executor.execute(input);
482
483 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(2)));
485 }
486
487 #[test]
488 fn test_geomean() {
489 let rows = vec![
491 Row::new(0, vec![Value::Float64(2.0)]),
492 Row::new(1, vec![Value::Float64(8.0)]),
493 ];
494 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
495
496 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::GeoMean, Some(0))]);
497 let result = executor.execute(input);
498
499 if let Some(Value::Float64(geomean)) = result.entries[0].get_field(0) {
500 assert!((geomean - 4.0).abs() < 0.001);
501 } else {
502 panic!("Expected Float64 value");
503 }
504 }
505
506 #[test]
507 fn test_geomean_single_value() {
508 let rows = vec![Row::new(0, vec![Value::Float64(5.0)])];
509 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
510
511 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::GeoMean, Some(0))]);
512 let result = executor.execute(input);
513
514 if let Some(Value::Float64(geomean)) = result.entries[0].get_field(0) {
515 assert!((geomean - 5.0).abs() < 0.001);
516 } else {
517 panic!("Expected Float64 value");
518 }
519 }
520
521 #[test]
522 fn test_geomean_with_zero_and_negative() {
523 let rows = vec![
525 Row::new(0, vec![Value::Float64(2.0)]),
526 Row::new(1, vec![Value::Float64(0.0)]), Row::new(2, vec![Value::Float64(-1.0)]), Row::new(3, vec![Value::Float64(8.0)]),
529 ];
530 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
531
532 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::GeoMean, Some(0))]);
533 let result = executor.execute(input);
534
535 if let Some(Value::Float64(geomean)) = result.entries[0].get_field(0) {
537 assert!((geomean - 4.0).abs() < 0.001);
538 } else {
539 panic!("Expected Float64 value");
540 }
541 }
542
543 #[test]
544 fn test_geomean_all_non_positive() {
545 let rows = vec![
546 Row::new(0, vec![Value::Float64(0.0)]),
547 Row::new(1, vec![Value::Float64(-1.0)]),
548 ];
549 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
550
551 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::GeoMean, Some(0))]);
552 let result = executor.execute(input);
553
554 assert_eq!(result.entries[0].get_field(0), Some(&Value::Null));
556 }
557
558 #[test]
559 fn test_sum_with_nulls() {
560 let rows = vec![
561 Row::new(0, vec![Value::Int64(10)]),
562 Row::new(1, vec![Value::Null]),
563 Row::new(2, vec![Value::Int64(20)]),
564 Row::new(3, vec![Value::Null]),
565 ];
566 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
567
568 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Sum, Some(0))]);
569 let result = executor.execute(input);
570
571 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(30)));
573 }
574
575 #[test]
576 fn test_sum_mixed_types() {
577 let rows = vec![
578 Row::new(0, vec![Value::Int32(10)]),
579 Row::new(1, vec![Value::Int64(20)]),
580 Row::new(2, vec![Value::Float64(30.5)]),
581 ];
582 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
583
584 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Sum, Some(0))]);
585 let result = executor.execute(input);
586
587 if let Some(Value::Float64(sum)) = result.entries[0].get_field(0) {
589 assert!((sum - 60.5).abs() < 0.001);
590 } else {
591 panic!("Expected Float64 value for mixed types");
592 }
593 }
594
595 #[test]
596 fn test_min_max_with_nulls() {
597 let rows = vec![
598 Row::new(0, vec![Value::Int64(30)]),
599 Row::new(1, vec![Value::Null]),
600 Row::new(2, vec![Value::Int64(10)]),
601 Row::new(3, vec![Value::Null]),
602 Row::new(4, vec![Value::Int64(20)]),
603 ];
604 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
605
606 let executor = AggregateExecutor::no_group(vec![
607 (AggregateFunc::Min, Some(0)),
608 (AggregateFunc::Max, Some(0)),
609 ]);
610 let result = executor.execute(input);
611
612 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(10)));
614 assert_eq!(result.entries[0].get_field(1), Some(&Value::Int64(30)));
615 }
616
617 #[test]
618 fn test_min_max_all_nulls() {
619 let rows = vec![
620 Row::new(0, vec![Value::Null]),
621 Row::new(1, vec![Value::Null]),
622 ];
623 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
624
625 let executor = AggregateExecutor::no_group(vec![
626 (AggregateFunc::Min, Some(0)),
627 (AggregateFunc::Max, Some(0)),
628 ]);
629 let result = executor.execute(input);
630
631 assert_eq!(result.entries[0].get_field(0), Some(&Value::Null));
633 assert_eq!(result.entries[0].get_field(1), Some(&Value::Null));
634 }
635
636 #[test]
637 fn test_stddev_single_value() {
638 let rows = vec![Row::new(0, vec![Value::Float64(5.0)])];
639 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
640
641 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::StdDev, Some(0))]);
642 let result = executor.execute(input);
643
644 if let Some(Value::Float64(stddev)) = result.entries[0].get_field(0) {
646 assert!((stddev - 0.0).abs() < 0.001);
647 } else {
648 panic!("Expected Float64 value");
649 }
650 }
651
652 #[test]
653 fn test_stddev_empty() {
654 let input = Relation::from_rows_owned(Vec::new(), vec!["t".into()]);
655
656 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::StdDev, Some(0))]);
657 let result = executor.execute(input);
658
659 assert_eq!(result.entries[0].get_field(0), Some(&Value::Null));
661 }
662
663 #[test]
664 fn test_stddev_with_nulls() {
665 let rows = vec![
666 Row::new(0, vec![Value::Float64(2.0)]),
667 Row::new(1, vec![Value::Null]),
668 Row::new(2, vec![Value::Float64(4.0)]),
669 Row::new(3, vec![Value::Null]),
670 Row::new(4, vec![Value::Float64(6.0)]),
671 ];
672 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
673
674 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::StdDev, Some(0))]);
675 let result = executor.execute(input);
676
677 if let Some(Value::Float64(stddev)) = result.entries[0].get_field(0) {
680 assert!((stddev - 1.633).abs() < 0.01);
681 } else {
682 panic!("Expected Float64 value");
683 }
684 }
685
686 #[test]
687 fn test_avg_with_nulls() {
688 let rows = vec![
689 Row::new(0, vec![Value::Int64(10)]),
690 Row::new(1, vec![Value::Null]),
691 Row::new(2, vec![Value::Int64(20)]),
692 Row::new(3, vec![Value::Null]),
693 ];
694 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
695
696 let executor = AggregateExecutor::no_group(vec![(AggregateFunc::Avg, Some(0))]);
697 let result = executor.execute(input);
698
699 assert_eq!(result.entries[0].get_field(0), Some(&Value::Float64(15.0)));
701 }
702
703 #[test]
704 fn test_multiple_aggregates() {
705 let rows = vec![
706 Row::new(0, vec![Value::Int64(10)]),
707 Row::new(1, vec![Value::Int64(20)]),
708 Row::new(2, vec![Value::Int64(30)]),
709 Row::new(3, vec![Value::Int64(40)]),
710 ];
711 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
712
713 let executor = AggregateExecutor::no_group(vec![
714 (AggregateFunc::Count, None),
715 (AggregateFunc::Sum, Some(0)),
716 (AggregateFunc::Avg, Some(0)),
717 (AggregateFunc::Min, Some(0)),
718 (AggregateFunc::Max, Some(0)),
719 ]);
720 let result = executor.execute(input);
721
722 assert_eq!(result.entries[0].get_field(0), Some(&Value::Int64(4))); assert_eq!(result.entries[0].get_field(1), Some(&Value::Int64(100))); assert_eq!(result.entries[0].get_field(2), Some(&Value::Float64(25.0))); assert_eq!(result.entries[0].get_field(3), Some(&Value::Int64(10))); assert_eq!(result.entries[0].get_field(4), Some(&Value::Int64(40))); }
728
729 #[test]
730 fn test_group_by_with_multiple_aggregates() {
731 let rows = vec![
732 Row::new(0, vec![Value::String("A".into()), Value::Int64(10)]),
733 Row::new(1, vec![Value::String("A".into()), Value::Int64(20)]),
734 Row::new(2, vec![Value::String("A".into()), Value::Int64(30)]),
735 Row::new(3, vec![Value::String("B".into()), Value::Int64(100)]),
736 ];
737 let input = Relation::from_rows_owned(rows, vec!["t".into()]);
738
739 let executor = AggregateExecutor::new(
740 vec![0],
741 vec![
742 (AggregateFunc::Count, None),
743 (AggregateFunc::Sum, Some(1)),
744 (AggregateFunc::Avg, Some(1)),
745 ],
746 );
747 let result = executor.execute(input);
748
749 assert_eq!(result.len(), 2);
750
751 for entry in &result.entries {
753 let group_key = entry.get_field(0);
754 match group_key {
755 Some(Value::String(s)) if s == "A" => {
756 assert_eq!(entry.get_field(1), Some(&Value::Int64(3))); assert_eq!(entry.get_field(2), Some(&Value::Int64(60))); assert_eq!(entry.get_field(3), Some(&Value::Float64(20.0))); }
760 Some(Value::String(s)) if s == "B" => {
761 assert_eq!(entry.get_field(1), Some(&Value::Int64(1))); assert_eq!(entry.get_field(2), Some(&Value::Int64(100))); assert_eq!(entry.get_field(3), Some(&Value::Float64(100.0))); }
765 _ => panic!("Unexpected group key"),
766 }
767 }
768 }
769}