1use std::cmp::Ordering;
2use std::collections::{HashMap, HashSet};
3
4use crate::catalog::ColumnMetadata;
5use crate::executor::evaluator::EvalContext;
6use crate::executor::memory::{MemoryPolicy, MemoryTracker};
7use crate::executor::{ExecutorError, Result};
8use crate::planner::aggregate_expr::{AggregateExpr, AggregateFunction};
9use crate::planner::typed_expr::TypedExpr;
10use crate::storage::{RowCodec, SqlValue};
11
12use super::{Row, RowIterator};
13
14pub type GroupKeyBytes = Vec<u8>;
16
17fn encode_group_value(value: &SqlValue, buf: &mut Vec<u8>) -> Result<()> {
18 buf.push(value.type_tag());
19 match value {
20 SqlValue::Null => Ok(()),
21 SqlValue::Integer(v) => {
22 buf.extend_from_slice(&v.to_le_bytes());
23 Ok(())
24 }
25 SqlValue::BigInt(v) => {
26 buf.extend_from_slice(&v.to_le_bytes());
27 Ok(())
28 }
29 SqlValue::Float(v) => {
30 buf.extend_from_slice(&v.to_bits().to_le_bytes());
31 Ok(())
32 }
33 SqlValue::Double(v) => {
34 buf.extend_from_slice(&v.to_bits().to_le_bytes());
35 Ok(())
36 }
37 SqlValue::Text(s) => {
38 let len = u32::try_from(s.len()).map_err(|_| ExecutorError::InvalidOperation {
39 operation: "aggregate".into(),
40 reason: "text length exceeds u32::MAX".into(),
41 })?;
42 buf.extend_from_slice(&len.to_le_bytes());
43 buf.extend_from_slice(s.as_bytes());
44 Ok(())
45 }
46 SqlValue::Blob(bytes) => {
47 let len = u32::try_from(bytes.len()).map_err(|_| ExecutorError::InvalidOperation {
48 operation: "aggregate".into(),
49 reason: "blob length exceeds u32::MAX".into(),
50 })?;
51 buf.extend_from_slice(&len.to_le_bytes());
52 buf.extend_from_slice(bytes);
53 Ok(())
54 }
55 SqlValue::Boolean(b) => {
56 buf.push(u8::from(*b));
57 Ok(())
58 }
59 SqlValue::Timestamp(v) => {
60 buf.extend_from_slice(&v.to_le_bytes());
61 Ok(())
62 }
63 SqlValue::Vector(values) => {
64 let len = u32::try_from(values.len()).map_err(|_| ExecutorError::InvalidOperation {
65 operation: "aggregate".into(),
66 reason: "vector length exceeds u32::MAX".into(),
67 })?;
68 buf.extend_from_slice(&len.to_le_bytes());
69 for f in values {
70 buf.extend_from_slice(&f.to_bits().to_le_bytes());
71 }
72 Ok(())
73 }
74 }
75}
76
77pub fn encode_group_key(values: &[SqlValue]) -> Result<GroupKeyBytes> {
79 let mut buf = Vec::new();
80 for value in values {
81 encode_group_value(value, &mut buf)?;
82 }
83 Ok(buf)
84}
85
86pub trait Accumulator: Send {
88 fn update(&mut self, value: Option<SqlValue>) -> Result<()>;
90 fn finalize(&self) -> Result<SqlValue>;
92 fn clone_box(&self) -> Box<dyn Accumulator>;
94}
95
96impl Clone for Box<dyn Accumulator> {
97 fn clone(&self) -> Self {
98 self.clone_box()
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct CountAccumulator {
105 count: usize,
106 distinct_values: Option<HashSet<Vec<u8>>>,
107}
108
109impl CountAccumulator {
110 pub fn new(distinct: bool) -> Self {
112 Self {
113 count: 0,
114 distinct_values: if distinct { Some(HashSet::new()) } else { None },
115 }
116 }
117}
118
119impl Accumulator for CountAccumulator {
120 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
121 match (&mut self.distinct_values, value) {
122 (Some(distinct), Some(value)) => {
123 if value.is_null() {
124 return Ok(());
125 }
126 let encoded = RowCodec::encode(std::slice::from_ref(&value));
127 if distinct.insert(encoded) {
128 self.count += 1;
129 }
130 }
131 (Some(_), None) => {
132 self.count += 1;
133 }
134 (None, Some(value)) => {
135 if !value.is_null() {
136 self.count += 1;
137 }
138 }
139 (None, None) => {
140 self.count += 1;
141 }
142 }
143 Ok(())
144 }
145
146 fn finalize(&self) -> Result<SqlValue> {
147 Ok(SqlValue::BigInt(self.count as i64))
148 }
149
150 fn clone_box(&self) -> Box<dyn Accumulator> {
151 Box::new(self.clone())
152 }
153}
154
155#[derive(Debug, Clone)]
157pub struct SumAccumulator {
158 sum: Option<f64>,
159}
160
161impl SumAccumulator {
162 pub fn new() -> Self {
164 Self { sum: None }
165 }
166}
167
168impl Default for SumAccumulator {
169 fn default() -> Self {
170 Self::new()
171 }
172}
173
174impl Accumulator for SumAccumulator {
175 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
176 let Some(value) = value else {
177 return Ok(());
178 };
179 if value.is_null() {
180 return Ok(());
181 }
182 let numeric = numeric_to_f64(&value)?;
183 self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
184 Ok(())
185 }
186
187 fn finalize(&self) -> Result<SqlValue> {
188 Ok(self.sum.map_or(SqlValue::Null, SqlValue::Double))
189 }
190
191 fn clone_box(&self) -> Box<dyn Accumulator> {
192 Box::new(self.clone())
193 }
194}
195
196#[derive(Debug, Clone)]
198pub struct TotalAccumulator {
199 sum: Option<f64>,
200}
201
202impl TotalAccumulator {
203 pub fn new() -> Self {
205 Self { sum: None }
206 }
207}
208
209impl Default for TotalAccumulator {
210 fn default() -> Self {
211 Self::new()
212 }
213}
214
215impl Accumulator for TotalAccumulator {
216 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
217 let Some(value) = value else {
218 return Ok(());
219 };
220 if value.is_null() {
221 return Ok(());
222 }
223 let numeric = numeric_to_f64(&value)?;
224 self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
225 Ok(())
226 }
227
228 fn finalize(&self) -> Result<SqlValue> {
229 Ok(SqlValue::Double(self.sum.unwrap_or(0.0)))
230 }
231
232 fn clone_box(&self) -> Box<dyn Accumulator> {
233 Box::new(self.clone())
234 }
235}
236
237#[derive(Debug, Clone)]
239pub struct AvgAccumulator {
240 sum: Option<f64>,
241 count: usize,
242}
243
244impl AvgAccumulator {
245 pub fn new() -> Self {
247 Self {
248 sum: None,
249 count: 0,
250 }
251 }
252}
253
254impl Default for AvgAccumulator {
255 fn default() -> Self {
256 Self::new()
257 }
258}
259
260impl Accumulator for AvgAccumulator {
261 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
262 let Some(value) = value else {
263 return Ok(());
264 };
265 if value.is_null() {
266 return Ok(());
267 }
268 let numeric = numeric_to_f64(&value)?;
269 self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
270 self.count += 1;
271 Ok(())
272 }
273
274 fn finalize(&self) -> Result<SqlValue> {
275 if self.count == 0 {
276 return Ok(SqlValue::Null);
277 }
278 let sum = self.sum.unwrap_or(0.0);
279 Ok(SqlValue::Double(sum / self.count as f64))
280 }
281
282 fn clone_box(&self) -> Box<dyn Accumulator> {
283 Box::new(self.clone())
284 }
285}
286
287fn numeric_to_f64(value: &SqlValue) -> Result<f64> {
288 match value {
289 SqlValue::Integer(v) => Ok(*v as f64),
290 SqlValue::BigInt(v) => Ok(*v as f64),
291 SqlValue::Float(v) => Ok(*v as f64),
292 SqlValue::Double(v) => Ok(*v),
293 _ => Err(ExecutorError::Evaluation(
294 crate::executor::EvaluationError::TypeMismatch {
295 expected: "numeric".into(),
296 actual: value.type_name().into(),
297 },
298 )),
299 }
300}
301
302#[derive(Debug, Clone)]
304pub struct MinMaxAccumulator {
305 value: Option<SqlValue>,
306 is_min: bool,
307}
308
309impl MinMaxAccumulator {
310 pub fn new(is_min: bool) -> Self {
312 Self {
313 value: None,
314 is_min,
315 }
316 }
317}
318
319impl Accumulator for MinMaxAccumulator {
320 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
321 let Some(value) = value else {
322 return Ok(());
323 };
324 if value.is_null() {
325 return Ok(());
326 }
327
328 match &self.value {
329 None => {
330 self.value = Some(value);
331 }
332 Some(current) => {
333 if std::mem::discriminant(current) != std::mem::discriminant(&value) {
334 return Err(ExecutorError::Evaluation(
335 crate::executor::EvaluationError::TypeMismatch {
336 expected: current.type_name().into(),
337 actual: value.type_name().into(),
338 },
339 ));
340 }
341 let ordering = value.partial_cmp(current).ok_or_else(|| {
342 ExecutorError::Evaluation(crate::executor::EvaluationError::TypeMismatch {
343 expected: current.type_name().into(),
344 actual: value.type_name().into(),
345 })
346 })?;
347 let should_replace = matches!(
348 (self.is_min, ordering),
349 (true, Ordering::Less) | (false, Ordering::Greater)
350 );
351 if should_replace {
352 self.value = Some(value);
353 }
354 }
355 }
356 Ok(())
357 }
358
359 fn finalize(&self) -> Result<SqlValue> {
360 Ok(self.value.clone().unwrap_or(SqlValue::Null))
361 }
362
363 fn clone_box(&self) -> Box<dyn Accumulator> {
364 Box::new(self.clone())
365 }
366}
367
368#[derive(Debug, Clone)]
370pub struct GroupConcatAccumulator {
371 values: Vec<String>,
372 separator: String,
373}
374
375impl GroupConcatAccumulator {
376 pub fn new(separator: String) -> Self {
378 Self {
379 values: Vec::new(),
380 separator,
381 }
382 }
383}
384
385impl Accumulator for GroupConcatAccumulator {
386 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
387 let Some(value) = value else {
388 return Ok(());
389 };
390 match value {
391 SqlValue::Null => Ok(()),
392 SqlValue::Text(text) => {
393 self.values.push(text);
394 Ok(())
395 }
396 other => Err(ExecutorError::Evaluation(
397 crate::executor::EvaluationError::TypeMismatch {
398 expected: "Text".into(),
399 actual: other.type_name().into(),
400 },
401 )),
402 }
403 }
404
405 fn finalize(&self) -> Result<SqlValue> {
406 if self.values.is_empty() {
407 return Ok(SqlValue::Null);
408 }
409 Ok(SqlValue::Text(self.values.join(&self.separator)))
410 }
411
412 fn clone_box(&self) -> Box<dyn Accumulator> {
413 Box::new(self.clone())
414 }
415}
416
417#[derive(Debug, Clone)]
419pub struct StringAggAccumulator {
420 values: Vec<String>,
421 separator: String,
422}
423
424impl StringAggAccumulator {
425 pub fn new(separator: String) -> Self {
427 Self {
428 values: Vec::new(),
429 separator,
430 }
431 }
432}
433
434impl Accumulator for StringAggAccumulator {
435 fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
436 let Some(value) = value else {
437 return Ok(());
438 };
439 match value {
440 SqlValue::Null => Ok(()),
441 SqlValue::Text(s) => {
442 self.values.push(s);
443 Ok(())
444 }
445 other => Err(ExecutorError::Evaluation(
446 crate::executor::EvaluationError::TypeMismatch {
447 expected: "Text".into(),
448 actual: other.type_name().into(),
449 },
450 )),
451 }
452 }
453
454 fn finalize(&self) -> Result<SqlValue> {
455 if self.values.is_empty() {
456 return Ok(SqlValue::Null);
457 }
458 Ok(SqlValue::Text(self.values.join(&self.separator)))
459 }
460
461 fn clone_box(&self) -> Box<dyn Accumulator> {
462 Box::new(self.clone())
463 }
464}
465
466pub fn create_accumulator(function: &AggregateFunction, distinct: bool) -> Box<dyn Accumulator> {
468 match function {
469 AggregateFunction::Count => Box::new(CountAccumulator::new(distinct)),
470 AggregateFunction::Sum => Box::new(SumAccumulator::new()),
471 AggregateFunction::Total => Box::new(TotalAccumulator::new()),
472 AggregateFunction::Avg => Box::new(AvgAccumulator::new()),
473 AggregateFunction::Min => Box::new(MinMaxAccumulator::new(true)),
474 AggregateFunction::Max => Box::new(MinMaxAccumulator::new(false)),
475 AggregateFunction::GroupConcat { separator } => {
476 let sep = separator.clone().unwrap_or_else(|| ",".to_string());
477 Box::new(GroupConcatAccumulator::new(sep))
478 }
479 AggregateFunction::StringAgg { separator } => {
480 let sep = separator.clone().unwrap_or_else(|| ",".to_string());
481 Box::new(StringAggAccumulator::new(sep))
482 }
483 }
484}
485
486const DEFAULT_GROUP_LIMIT: usize = 1_000_000;
487const AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES: u64 = 32;
488
489struct AggregateGroup {
490 key_values: Vec<SqlValue>,
491 accumulators: Vec<Box<dyn Accumulator>>,
492}
493
494pub struct AggregateIterator<'a> {
496 input: Box<dyn RowIterator + 'a>,
497 group_keys: Vec<TypedExpr>,
498 aggregates: Vec<AggregateExpr>,
499 having: Option<TypedExpr>,
500 hash_table: Option<HashMap<GroupKeyBytes, AggregateGroup>>,
501 result_rows: Vec<Row>,
502 index: usize,
503 schema: Vec<ColumnMetadata>,
504 group_limit: usize,
505 memory_tracker: Option<MemoryTracker>,
506}
507
508impl<'a> AggregateIterator<'a> {
509 pub fn new(
511 input: Box<dyn RowIterator + 'a>,
512 group_keys: Vec<TypedExpr>,
513 aggregates: Vec<AggregateExpr>,
514 having: Option<TypedExpr>,
515 schema: Vec<ColumnMetadata>,
516 ) -> Self {
517 Self {
518 input,
519 group_keys,
520 aggregates,
521 having,
522 hash_table: None,
523 result_rows: Vec::new(),
524 index: 0,
525 schema,
526 group_limit: DEFAULT_GROUP_LIMIT,
527 memory_tracker: None,
528 }
529 }
530
531 pub fn with_group_limit(mut self, limit: usize) -> Self {
533 self.group_limit = limit;
534 self
535 }
536
537 pub fn with_memory_policy(mut self, policy: Option<MemoryPolicy>) -> Self {
539 self.memory_tracker = policy.map(MemoryTracker::new);
540 self
541 }
542
543 fn build_hash_table(&mut self) -> Result<()> {
544 let mut table: HashMap<GroupKeyBytes, AggregateGroup> = HashMap::new();
545 let mut next_row_id = 0u64;
546
547 while let Some(result) = self.input.next_row() {
548 let row = result?;
549 let ctx = EvalContext::new(&row.values);
550
551 let mut key_values = Vec::with_capacity(self.group_keys.len());
552 for expr in &self.group_keys {
553 key_values.push(crate::executor::evaluator::evaluate(expr, &ctx)?);
554 }
555 let key_bytes = encode_group_key(&key_values)?;
556
557 if !table.contains_key(&key_bytes) {
558 if table.len() + 1 > self.group_limit {
559 return Err(ExecutorError::ResourceExhausted {
560 message: format!(
561 "GROUP BY result exceeds memory limit (max groups: {})",
562 self.group_limit
563 ),
564 });
565 }
566 if let Some(tracker) = &mut self.memory_tracker {
567 tracker.add_values(&key_values)?;
568 tracker.add_bytes(
569 self.aggregates.len() as u64 * AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES,
570 )?;
571 }
572 let accumulators = self
573 .aggregates
574 .iter()
575 .map(|agg| create_accumulator(&agg.function, agg.distinct))
576 .collect::<Vec<_>>();
577 table.insert(
578 key_bytes.clone(),
579 AggregateGroup {
580 key_values: key_values.clone(),
581 accumulators,
582 },
583 );
584 }
585
586 if let Some(group) = table.get_mut(&key_bytes) {
587 for (idx, agg) in self.aggregates.iter().enumerate() {
588 let value = match &agg.arg {
589 None => None,
590 Some(expr) => Some(crate::executor::evaluator::evaluate(expr, &ctx)?),
591 };
592 if let Some(tracker) = &mut self.memory_tracker
593 && matches!(
594 agg.function,
595 AggregateFunction::GroupConcat { .. }
596 | AggregateFunction::StringAgg { .. }
597 )
598 && let Some(value_ref) = value.as_ref()
599 {
600 tracker.add_value(value_ref)?;
601 }
602 group.accumulators[idx].update(value)?;
603 }
604 }
605 }
606
607 if table.is_empty() && self.group_keys.is_empty() {
608 if let Some(tracker) = &mut self.memory_tracker {
609 tracker.add_bytes(
610 self.aggregates.len() as u64 * AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES,
611 )?;
612 }
613 let accumulators = self
614 .aggregates
615 .iter()
616 .map(|agg| create_accumulator(&agg.function, agg.distinct))
617 .collect::<Vec<_>>();
618 table.insert(
619 Vec::new(),
620 AggregateGroup {
621 key_values: Vec::new(),
622 accumulators,
623 },
624 );
625 }
626
627 let mut rows = Vec::with_capacity(table.len());
628 for group in table.values() {
629 let mut values = Vec::with_capacity(self.group_keys.len() + self.aggregates.len());
630 values.extend(group.key_values.iter().cloned());
631 for acc in &group.accumulators {
632 values.push(acc.finalize()?);
633 }
634 let row = Row::new(next_row_id, values);
635 next_row_id += 1;
636 if let Some(tracker) = &mut self.memory_tracker {
637 tracker.add_row(&row.values)?;
638 }
639
640 if let Some(having) = &self.having {
641 let ctx = EvalContext::new(&row.values);
642 match crate::executor::evaluator::evaluate(having, &ctx)? {
643 SqlValue::Boolean(true) => rows.push(row),
644 SqlValue::Boolean(false) | SqlValue::Null => {}
645 other => {
646 return Err(ExecutorError::Evaluation(
647 crate::executor::EvaluationError::TypeMismatch {
648 expected: "Boolean".into(),
649 actual: other.type_name().into(),
650 },
651 ));
652 }
653 }
654 } else {
655 rows.push(row);
656 }
657 }
658
659 self.hash_table = Some(table);
660 self.result_rows = rows;
661 Ok(())
662 }
663}
664
665impl<'a> RowIterator for AggregateIterator<'a> {
666 fn next_row(&mut self) -> Option<Result<Row>> {
667 if self.hash_table.is_none()
668 && let Err(err) = self.build_hash_table()
669 {
670 return Some(Err(err));
671 }
672
673 if self.index >= self.result_rows.len() {
674 return None;
675 }
676 let row = self.result_rows[self.index].clone();
677 self.index += 1;
678 Some(Ok(row))
679 }
680
681 fn schema(&self) -> &[ColumnMetadata] {
682 &self.schema
683 }
684}
685
686pub struct StreamingAggregateIterator<'a> {
688 input: Box<dyn RowIterator + 'a>,
689 group_keys: Vec<TypedExpr>,
690 aggregates: Vec<AggregateExpr>,
691 having: Option<TypedExpr>,
692 schema: Vec<ColumnMetadata>,
693 current_key: Option<Vec<SqlValue>>,
694 accumulators: Vec<Box<dyn Accumulator>>,
695 pending_row: Option<Row>,
696 finished: bool,
697 next_row_id: u64,
698 saw_row: bool,
699}
700
701impl<'a> StreamingAggregateIterator<'a> {
702 pub fn new(
703 input: Box<dyn RowIterator + 'a>,
704 group_keys: Vec<TypedExpr>,
705 aggregates: Vec<AggregateExpr>,
706 having: Option<TypedExpr>,
707 schema: Vec<ColumnMetadata>,
708 ) -> Self {
709 Self {
710 input,
711 group_keys,
712 aggregates,
713 having,
714 schema,
715 current_key: None,
716 accumulators: Vec::new(),
717 pending_row: None,
718 finished: false,
719 next_row_id: 0,
720 saw_row: false,
721 }
722 }
723
724 fn init_accumulators(&self) -> Vec<Box<dyn Accumulator>> {
725 self.aggregates
726 .iter()
727 .map(|agg| create_accumulator(&agg.function, agg.distinct))
728 .collect()
729 }
730
731 fn update_accumulators(&mut self, ctx: &EvalContext<'_>) -> Result<()> {
732 for (idx, agg) in self.aggregates.iter().enumerate() {
733 let value = match &agg.arg {
734 None => None,
735 Some(expr) => Some(crate::executor::evaluator::evaluate(expr, ctx)?),
736 };
737 self.accumulators[idx].update(value)?;
738 }
739 Ok(())
740 }
741
742 fn finalize_group(&mut self, key_values: &[SqlValue]) -> Result<Option<Row>> {
743 let mut values = Vec::with_capacity(self.group_keys.len() + self.aggregates.len());
744 values.extend(key_values.iter().cloned());
745 for acc in &self.accumulators {
746 values.push(acc.finalize()?);
747 }
748 let row = Row::new(self.next_row_id, values);
749 self.next_row_id = self.next_row_id.saturating_add(1);
750
751 if let Some(having) = &self.having {
752 let ctx = EvalContext::new(&row.values);
753 match crate::executor::evaluator::evaluate(having, &ctx)? {
754 SqlValue::Boolean(true) => Ok(Some(row)),
755 SqlValue::Boolean(false) | SqlValue::Null => Ok(None),
756 other => Err(ExecutorError::Evaluation(
757 crate::executor::EvaluationError::TypeMismatch {
758 expected: "Boolean".into(),
759 actual: other.type_name().into(),
760 },
761 )),
762 }
763 } else {
764 Ok(Some(row))
765 }
766 }
767}
768
769impl<'a> RowIterator for StreamingAggregateIterator<'a> {
770 fn next_row(&mut self) -> Option<Result<Row>> {
771 if let Some(row) = self.pending_row.take() {
772 return Some(Ok(row));
773 }
774 if self.finished {
775 return None;
776 }
777
778 loop {
779 match self.input.next_row() {
780 Some(Ok(row)) => {
781 self.saw_row = true;
782 let ctx = EvalContext::new(&row.values);
783 let mut key_values = Vec::with_capacity(self.group_keys.len());
784 for expr in &self.group_keys {
785 match crate::executor::evaluator::evaluate(expr, &ctx) {
786 Ok(value) => key_values.push(value),
787 Err(err) => return Some(Err(err)),
788 }
789 }
790
791 match &self.current_key {
792 None => {
793 self.current_key = Some(key_values);
794 self.accumulators = self.init_accumulators();
795 if let Err(err) = self.update_accumulators(&ctx) {
796 return Some(Err(err));
797 }
798 }
799 Some(current_key) if *current_key == key_values => {
800 if let Err(err) = self.update_accumulators(&ctx) {
801 return Some(Err(err));
802 }
803 }
804 Some(_) => {
805 let current_key = self.current_key.clone().unwrap_or_default();
806 let output = match self.finalize_group(¤t_key) {
807 Ok(value) => value,
808 Err(err) => return Some(Err(err)),
809 };
810 self.current_key = Some(key_values);
811 self.accumulators = self.init_accumulators();
812 if let Err(err) = self.update_accumulators(&ctx) {
813 return Some(Err(err));
814 }
815 if let Some(row) = output {
816 return Some(Ok(row));
817 }
818 }
819 }
820 }
821 Some(Err(err)) => return Some(Err(err)),
822 None => {
823 self.finished = true;
824 if let Some(current_key) = self.current_key.take() {
825 return match self.finalize_group(¤t_key) {
826 Ok(Some(row)) => Some(Ok(row)),
827 Ok(None) => None,
828 Err(err) => Some(Err(err)),
829 };
830 }
831
832 if self.group_keys.is_empty() && !self.saw_row {
833 self.accumulators = self.init_accumulators();
834 return match self.finalize_group(&[]) {
835 Ok(Some(row)) => Some(Ok(row)),
836 Ok(None) => None,
837 Err(err) => Some(Err(err)),
838 };
839 }
840
841 return None;
842 }
843 }
844 }
845 }
846
847 fn schema(&self) -> &[ColumnMetadata] {
848 &self.schema
849 }
850}
851
852pub fn build_aggregate_schema(
854 group_keys: &[TypedExpr],
855 aggregates: &[AggregateExpr],
856) -> Vec<ColumnMetadata> {
857 let mut schema = Vec::new();
858 for (idx, key) in group_keys.iter().enumerate() {
859 let name = match &key.kind {
860 crate::planner::typed_expr::TypedExprKind::ColumnRef { column, .. } => column.clone(),
861 _ => format!("group_{idx}"),
862 };
863 schema.push(ColumnMetadata::new(name, key.resolved_type.clone()));
864 }
865 for (idx, agg) in aggregates.iter().enumerate() {
866 let name = match &agg.function {
867 AggregateFunction::Count => format!("count_{idx}"),
868 AggregateFunction::Sum => format!("sum_{idx}"),
869 AggregateFunction::Total => format!("total_{idx}"),
870 AggregateFunction::Avg => format!("avg_{idx}"),
871 AggregateFunction::Min => format!("min_{idx}"),
872 AggregateFunction::Max => format!("max_{idx}"),
873 AggregateFunction::GroupConcat { .. } => format!("group_concat_{idx}"),
874 AggregateFunction::StringAgg { .. } => format!("string_agg_{idx}"),
875 };
876 schema.push(ColumnMetadata::new(name, agg.result_type.clone()));
877 }
878 schema
879}
880
881#[cfg(test)]
882mod tests {
883 use super::*;
884
885 #[test]
886 fn count_accumulator_counts_rows_and_skips_nulls() {
887 let mut acc = CountAccumulator::new(false);
888 acc.update(None).unwrap();
889 acc.update(Some(SqlValue::Null)).unwrap();
890 acc.update(Some(SqlValue::Integer(1))).unwrap();
891 assert_eq!(acc.finalize().unwrap(), SqlValue::BigInt(2));
892 }
893
894 #[test]
895 fn count_accumulator_distinct_deduplicates() {
896 let mut acc = CountAccumulator::new(true);
897 acc.update(Some(SqlValue::Integer(1))).unwrap();
898 acc.update(Some(SqlValue::Integer(1))).unwrap();
899 acc.update(Some(SqlValue::Integer(2))).unwrap();
900 assert_eq!(acc.finalize().unwrap(), SqlValue::BigInt(2));
901 }
902
903 #[test]
904 fn sum_accumulator_aggregates_numeric_values() {
905 let mut acc = SumAccumulator::new();
906 acc.update(Some(SqlValue::Integer(2))).unwrap();
907 acc.update(Some(SqlValue::Double(3.5))).unwrap();
908 acc.update(Some(SqlValue::Null)).unwrap();
909 assert_eq!(acc.finalize().unwrap(), SqlValue::Double(5.5));
910 }
911
912 #[test]
913 fn total_accumulator_returns_zero_for_empty() {
914 let acc = TotalAccumulator::new();
915 assert_eq!(acc.finalize().unwrap(), SqlValue::Double(0.0));
916 }
917
918 #[test]
919 fn total_accumulator_aggregates_numeric_values() {
920 let mut acc = TotalAccumulator::new();
921 acc.update(Some(SqlValue::Integer(2))).unwrap();
922 acc.update(Some(SqlValue::Null)).unwrap();
923 acc.update(Some(SqlValue::Double(1.5))).unwrap();
924 assert_eq!(acc.finalize().unwrap(), SqlValue::Double(3.5));
925 }
926
927 #[test]
928 fn avg_accumulator_handles_empty_and_nulls() {
929 let mut acc = AvgAccumulator::new();
930 assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
931 acc.update(Some(SqlValue::Null)).unwrap();
932 acc.update(Some(SqlValue::BigInt(4))).unwrap();
933 acc.update(Some(SqlValue::Integer(2))).unwrap();
934 assert_eq!(acc.finalize().unwrap(), SqlValue::Double(3.0));
935 }
936
937 #[test]
938 fn min_max_accumulator_tracks_extremes() {
939 let mut min_acc = MinMaxAccumulator::new(true);
940 let mut max_acc = MinMaxAccumulator::new(false);
941 for value in [3, 1, 2] {
942 min_acc.update(Some(SqlValue::Integer(value))).unwrap();
943 max_acc.update(Some(SqlValue::Integer(value))).unwrap();
944 }
945 assert_eq!(min_acc.finalize().unwrap(), SqlValue::Integer(1));
946 assert_eq!(max_acc.finalize().unwrap(), SqlValue::Integer(3));
947 }
948
949 #[test]
950 fn min_max_accumulator_rejects_type_mismatch() {
951 let mut acc = MinMaxAccumulator::new(true);
952 acc.update(Some(SqlValue::Integer(1))).unwrap();
953 let err = acc.update(Some(SqlValue::Text("bad".into()))).unwrap_err();
954 match err {
955 ExecutorError::Evaluation(crate::executor::EvaluationError::TypeMismatch {
956 ..
957 }) => {}
958 other => panic!("unexpected error {:?}", other),
959 }
960 }
961
962 #[test]
963 fn group_concat_accumulator_joins_values() {
964 let mut acc = GroupConcatAccumulator::new("|".into());
965 acc.update(Some(SqlValue::Text("a".into()))).unwrap();
966 acc.update(Some(SqlValue::Null)).unwrap();
967 acc.update(Some(SqlValue::Text("b".into()))).unwrap();
968 assert_eq!(acc.finalize().unwrap(), SqlValue::Text("a|b".into()));
969 }
970
971 #[test]
972 fn group_concat_accumulator_empty_returns_null() {
973 let acc = GroupConcatAccumulator::new(",".into());
974 assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
975 }
976
977 #[test]
978 fn string_agg_accumulator_joins_values() {
979 let mut acc = StringAggAccumulator::new("::".into());
980 acc.update(Some(SqlValue::Text("a".into()))).unwrap();
981 acc.update(Some(SqlValue::Null)).unwrap();
982 acc.update(Some(SqlValue::Text("b".into()))).unwrap();
983 assert_eq!(acc.finalize().unwrap(), SqlValue::Text("a::b".into()));
984 }
985
986 #[test]
987 fn string_agg_accumulator_empty_returns_null() {
988 let acc = StringAggAccumulator::new(",".into());
989 assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
990 }
991
992 #[test]
993 fn encode_group_key_is_deterministic() {
994 let values = vec![
995 SqlValue::Integer(1),
996 SqlValue::Text("a".into()),
997 SqlValue::Null,
998 ];
999 let first = encode_group_key(&values).unwrap();
1000 let second = encode_group_key(&values).unwrap();
1001 assert_eq!(first, second);
1002 }
1003}