Skip to main content

alopex_sql/executor/query/
aggregate.rs

1use std::cmp::Ordering;
2use std::collections::{HashMap, HashSet};
3
4use crate::catalog::ColumnMetadata;
5use crate::executor::evaluator::EvalContext;
6use crate::executor::memory::{MemoryPolicy, MemoryTracker};
7use crate::executor::{ExecutorError, Result};
8use crate::planner::aggregate_expr::{AggregateExpr, AggregateFunction};
9use crate::planner::typed_expr::TypedExpr;
10use crate::storage::{RowCodec, SqlValue};
11
12use super::{Row, RowIterator};
13
14/// Byte-encoded group key for hash-based aggregation.
15pub type GroupKeyBytes = Vec<u8>;
16
17fn encode_group_value(value: &SqlValue, buf: &mut Vec<u8>) -> Result<()> {
18    buf.push(value.type_tag());
19    match value {
20        SqlValue::Null => Ok(()),
21        SqlValue::Integer(v) => {
22            buf.extend_from_slice(&v.to_le_bytes());
23            Ok(())
24        }
25        SqlValue::BigInt(v) => {
26            buf.extend_from_slice(&v.to_le_bytes());
27            Ok(())
28        }
29        SqlValue::Float(v) => {
30            buf.extend_from_slice(&v.to_bits().to_le_bytes());
31            Ok(())
32        }
33        SqlValue::Double(v) => {
34            buf.extend_from_slice(&v.to_bits().to_le_bytes());
35            Ok(())
36        }
37        SqlValue::Text(s) => {
38            let len = u32::try_from(s.len()).map_err(|_| ExecutorError::InvalidOperation {
39                operation: "aggregate".into(),
40                reason: "text length exceeds u32::MAX".into(),
41            })?;
42            buf.extend_from_slice(&len.to_le_bytes());
43            buf.extend_from_slice(s.as_bytes());
44            Ok(())
45        }
46        SqlValue::Blob(bytes) => {
47            let len = u32::try_from(bytes.len()).map_err(|_| ExecutorError::InvalidOperation {
48                operation: "aggregate".into(),
49                reason: "blob length exceeds u32::MAX".into(),
50            })?;
51            buf.extend_from_slice(&len.to_le_bytes());
52            buf.extend_from_slice(bytes);
53            Ok(())
54        }
55        SqlValue::Boolean(b) => {
56            buf.push(u8::from(*b));
57            Ok(())
58        }
59        SqlValue::Timestamp(v) => {
60            buf.extend_from_slice(&v.to_le_bytes());
61            Ok(())
62        }
63        SqlValue::Vector(values) => {
64            let len = u32::try_from(values.len()).map_err(|_| ExecutorError::InvalidOperation {
65                operation: "aggregate".into(),
66                reason: "vector length exceeds u32::MAX".into(),
67            })?;
68            buf.extend_from_slice(&len.to_le_bytes());
69            for f in values {
70                buf.extend_from_slice(&f.to_bits().to_le_bytes());
71            }
72            Ok(())
73        }
74    }
75}
76
77/// Encode group key values into a deterministic byte sequence.
78pub fn encode_group_key(values: &[SqlValue]) -> Result<GroupKeyBytes> {
79    let mut buf = Vec::new();
80    for value in values {
81        encode_group_value(value, &mut buf)?;
82    }
83    Ok(buf)
84}
85
86/// Accumulator interface for aggregate function execution.
87pub trait Accumulator: Send {
88    /// Update the accumulator with a new value (None for COUNT(*) rows).
89    fn update(&mut self, value: Option<SqlValue>) -> Result<()>;
90    /// Finalize the accumulator and return the resulting SqlValue.
91    fn finalize(&self) -> Result<SqlValue>;
92    /// Clone the accumulator as a trait object.
93    fn clone_box(&self) -> Box<dyn Accumulator>;
94}
95
96impl Clone for Box<dyn Accumulator> {
97    fn clone(&self) -> Self {
98        self.clone_box()
99    }
100}
101
102/// Accumulator for COUNT / COUNT(DISTINCT).
103#[derive(Debug, Clone)]
104pub struct CountAccumulator {
105    count: usize,
106    distinct_values: Option<HashSet<Vec<u8>>>,
107}
108
109impl CountAccumulator {
110    /// Create a new count accumulator.
111    pub fn new(distinct: bool) -> Self {
112        Self {
113            count: 0,
114            distinct_values: if distinct { Some(HashSet::new()) } else { None },
115        }
116    }
117}
118
119impl Accumulator for CountAccumulator {
120    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
121        match (&mut self.distinct_values, value) {
122            (Some(distinct), Some(value)) => {
123                if value.is_null() {
124                    return Ok(());
125                }
126                let encoded = RowCodec::encode(std::slice::from_ref(&value));
127                if distinct.insert(encoded) {
128                    self.count += 1;
129                }
130            }
131            (Some(_), None) => {
132                self.count += 1;
133            }
134            (None, Some(value)) => {
135                if !value.is_null() {
136                    self.count += 1;
137                }
138            }
139            (None, None) => {
140                self.count += 1;
141            }
142        }
143        Ok(())
144    }
145
146    fn finalize(&self) -> Result<SqlValue> {
147        Ok(SqlValue::BigInt(self.count as i64))
148    }
149
150    fn clone_box(&self) -> Box<dyn Accumulator> {
151        Box::new(self.clone())
152    }
153}
154
155/// Accumulator for SUM.
156#[derive(Debug, Clone)]
157pub struct SumAccumulator {
158    sum: Option<f64>,
159}
160
161impl SumAccumulator {
162    /// Create a new sum accumulator.
163    pub fn new() -> Self {
164        Self { sum: None }
165    }
166}
167
168impl Default for SumAccumulator {
169    fn default() -> Self {
170        Self::new()
171    }
172}
173
174impl Accumulator for SumAccumulator {
175    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
176        let Some(value) = value else {
177            return Ok(());
178        };
179        if value.is_null() {
180            return Ok(());
181        }
182        let numeric = numeric_to_f64(&value)?;
183        self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
184        Ok(())
185    }
186
187    fn finalize(&self) -> Result<SqlValue> {
188        Ok(self.sum.map_or(SqlValue::Null, SqlValue::Double))
189    }
190
191    fn clone_box(&self) -> Box<dyn Accumulator> {
192        Box::new(self.clone())
193    }
194}
195
196/// Accumulator for TOTAL (SUM that returns 0.0 on empty/all-NULL input).
197#[derive(Debug, Clone)]
198pub struct TotalAccumulator {
199    sum: Option<f64>,
200}
201
202impl TotalAccumulator {
203    /// Create a new total accumulator.
204    pub fn new() -> Self {
205        Self { sum: None }
206    }
207}
208
209impl Default for TotalAccumulator {
210    fn default() -> Self {
211        Self::new()
212    }
213}
214
215impl Accumulator for TotalAccumulator {
216    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
217        let Some(value) = value else {
218            return Ok(());
219        };
220        if value.is_null() {
221            return Ok(());
222        }
223        let numeric = numeric_to_f64(&value)?;
224        self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
225        Ok(())
226    }
227
228    fn finalize(&self) -> Result<SqlValue> {
229        Ok(SqlValue::Double(self.sum.unwrap_or(0.0)))
230    }
231
232    fn clone_box(&self) -> Box<dyn Accumulator> {
233        Box::new(self.clone())
234    }
235}
236
237/// Accumulator for AVG.
238#[derive(Debug, Clone)]
239pub struct AvgAccumulator {
240    sum: Option<f64>,
241    count: usize,
242}
243
244impl AvgAccumulator {
245    /// Create a new average accumulator.
246    pub fn new() -> Self {
247        Self {
248            sum: None,
249            count: 0,
250        }
251    }
252}
253
254impl Default for AvgAccumulator {
255    fn default() -> Self {
256        Self::new()
257    }
258}
259
260impl Accumulator for AvgAccumulator {
261    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
262        let Some(value) = value else {
263            return Ok(());
264        };
265        if value.is_null() {
266            return Ok(());
267        }
268        let numeric = numeric_to_f64(&value)?;
269        self.sum = Some(self.sum.unwrap_or(0.0) + numeric);
270        self.count += 1;
271        Ok(())
272    }
273
274    fn finalize(&self) -> Result<SqlValue> {
275        if self.count == 0 {
276            return Ok(SqlValue::Null);
277        }
278        let sum = self.sum.unwrap_or(0.0);
279        Ok(SqlValue::Double(sum / self.count as f64))
280    }
281
282    fn clone_box(&self) -> Box<dyn Accumulator> {
283        Box::new(self.clone())
284    }
285}
286
287fn numeric_to_f64(value: &SqlValue) -> Result<f64> {
288    match value {
289        SqlValue::Integer(v) => Ok(*v as f64),
290        SqlValue::BigInt(v) => Ok(*v as f64),
291        SqlValue::Float(v) => Ok(*v as f64),
292        SqlValue::Double(v) => Ok(*v),
293        _ => Err(ExecutorError::Evaluation(
294            crate::executor::EvaluationError::TypeMismatch {
295                expected: "numeric".into(),
296                actual: value.type_name().into(),
297            },
298        )),
299    }
300}
301
302/// Accumulator for MIN / MAX.
303#[derive(Debug, Clone)]
304pub struct MinMaxAccumulator {
305    value: Option<SqlValue>,
306    is_min: bool,
307}
308
309impl MinMaxAccumulator {
310    /// Create a new min/max accumulator.
311    pub fn new(is_min: bool) -> Self {
312        Self {
313            value: None,
314            is_min,
315        }
316    }
317}
318
319impl Accumulator for MinMaxAccumulator {
320    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
321        let Some(value) = value else {
322            return Ok(());
323        };
324        if value.is_null() {
325            return Ok(());
326        }
327
328        match &self.value {
329            None => {
330                self.value = Some(value);
331            }
332            Some(current) => {
333                if std::mem::discriminant(current) != std::mem::discriminant(&value) {
334                    return Err(ExecutorError::Evaluation(
335                        crate::executor::EvaluationError::TypeMismatch {
336                            expected: current.type_name().into(),
337                            actual: value.type_name().into(),
338                        },
339                    ));
340                }
341                let ordering = value.partial_cmp(current).ok_or_else(|| {
342                    ExecutorError::Evaluation(crate::executor::EvaluationError::TypeMismatch {
343                        expected: current.type_name().into(),
344                        actual: value.type_name().into(),
345                    })
346                })?;
347                let should_replace = matches!(
348                    (self.is_min, ordering),
349                    (true, Ordering::Less) | (false, Ordering::Greater)
350                );
351                if should_replace {
352                    self.value = Some(value);
353                }
354            }
355        }
356        Ok(())
357    }
358
359    fn finalize(&self) -> Result<SqlValue> {
360        Ok(self.value.clone().unwrap_or(SqlValue::Null))
361    }
362
363    fn clone_box(&self) -> Box<dyn Accumulator> {
364        Box::new(self.clone())
365    }
366}
367
368/// Accumulator for GROUP_CONCAT.
369#[derive(Debug, Clone)]
370pub struct GroupConcatAccumulator {
371    values: Vec<String>,
372    separator: String,
373}
374
375impl GroupConcatAccumulator {
376    /// Create a new GROUP_CONCAT accumulator with the given separator.
377    pub fn new(separator: String) -> Self {
378        Self {
379            values: Vec::new(),
380            separator,
381        }
382    }
383}
384
385impl Accumulator for GroupConcatAccumulator {
386    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
387        let Some(value) = value else {
388            return Ok(());
389        };
390        match value {
391            SqlValue::Null => Ok(()),
392            SqlValue::Text(text) => {
393                self.values.push(text);
394                Ok(())
395            }
396            other => Err(ExecutorError::Evaluation(
397                crate::executor::EvaluationError::TypeMismatch {
398                    expected: "Text".into(),
399                    actual: other.type_name().into(),
400                },
401            )),
402        }
403    }
404
405    fn finalize(&self) -> Result<SqlValue> {
406        if self.values.is_empty() {
407            return Ok(SqlValue::Null);
408        }
409        Ok(SqlValue::Text(self.values.join(&self.separator)))
410    }
411
412    fn clone_box(&self) -> Box<dyn Accumulator> {
413        Box::new(self.clone())
414    }
415}
416
417/// Accumulator for STRING_AGG.
418#[derive(Debug, Clone)]
419pub struct StringAggAccumulator {
420    values: Vec<String>,
421    separator: String,
422}
423
424impl StringAggAccumulator {
425    /// Create a new string_agg accumulator.
426    pub fn new(separator: String) -> Self {
427        Self {
428            values: Vec::new(),
429            separator,
430        }
431    }
432}
433
434impl Accumulator for StringAggAccumulator {
435    fn update(&mut self, value: Option<SqlValue>) -> Result<()> {
436        let Some(value) = value else {
437            return Ok(());
438        };
439        match value {
440            SqlValue::Null => Ok(()),
441            SqlValue::Text(s) => {
442                self.values.push(s);
443                Ok(())
444            }
445            other => Err(ExecutorError::Evaluation(
446                crate::executor::EvaluationError::TypeMismatch {
447                    expected: "Text".into(),
448                    actual: other.type_name().into(),
449                },
450            )),
451        }
452    }
453
454    fn finalize(&self) -> Result<SqlValue> {
455        if self.values.is_empty() {
456            return Ok(SqlValue::Null);
457        }
458        Ok(SqlValue::Text(self.values.join(&self.separator)))
459    }
460
461    fn clone_box(&self) -> Box<dyn Accumulator> {
462        Box::new(self.clone())
463    }
464}
465
466/// Create a new accumulator instance for the aggregate function.
467pub fn create_accumulator(function: &AggregateFunction, distinct: bool) -> Box<dyn Accumulator> {
468    match function {
469        AggregateFunction::Count => Box::new(CountAccumulator::new(distinct)),
470        AggregateFunction::Sum => Box::new(SumAccumulator::new()),
471        AggregateFunction::Total => Box::new(TotalAccumulator::new()),
472        AggregateFunction::Avg => Box::new(AvgAccumulator::new()),
473        AggregateFunction::Min => Box::new(MinMaxAccumulator::new(true)),
474        AggregateFunction::Max => Box::new(MinMaxAccumulator::new(false)),
475        AggregateFunction::GroupConcat { separator } => {
476            let sep = separator.clone().unwrap_or_else(|| ",".to_string());
477            Box::new(GroupConcatAccumulator::new(sep))
478        }
479        AggregateFunction::StringAgg { separator } => {
480            let sep = separator.clone().unwrap_or_else(|| ",".to_string());
481            Box::new(StringAggAccumulator::new(sep))
482        }
483    }
484}
485
486const DEFAULT_GROUP_LIMIT: usize = 1_000_000;
487const AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES: u64 = 32;
488
489struct AggregateGroup {
490    key_values: Vec<SqlValue>,
491    accumulators: Vec<Box<dyn Accumulator>>,
492}
493
494/// Iterator that performs hash-based aggregation over input rows.
495pub struct AggregateIterator<'a> {
496    input: Box<dyn RowIterator + 'a>,
497    group_keys: Vec<TypedExpr>,
498    aggregates: Vec<AggregateExpr>,
499    having: Option<TypedExpr>,
500    hash_table: Option<HashMap<GroupKeyBytes, AggregateGroup>>,
501    result_rows: Vec<Row>,
502    index: usize,
503    schema: Vec<ColumnMetadata>,
504    group_limit: usize,
505    memory_tracker: Option<MemoryTracker>,
506}
507
508impl<'a> AggregateIterator<'a> {
509    /// Create a new aggregate iterator with the default group limit.
510    pub fn new(
511        input: Box<dyn RowIterator + 'a>,
512        group_keys: Vec<TypedExpr>,
513        aggregates: Vec<AggregateExpr>,
514        having: Option<TypedExpr>,
515        schema: Vec<ColumnMetadata>,
516    ) -> Self {
517        Self {
518            input,
519            group_keys,
520            aggregates,
521            having,
522            hash_table: None,
523            result_rows: Vec::new(),
524            index: 0,
525            schema,
526            group_limit: DEFAULT_GROUP_LIMIT,
527            memory_tracker: None,
528        }
529    }
530
531    /// Override the maximum number of groups allowed during aggregation.
532    pub fn with_group_limit(mut self, limit: usize) -> Self {
533        self.group_limit = limit;
534        self
535    }
536
537    /// Attach a memory policy for enforcing in-flight aggregation limits.
538    pub fn with_memory_policy(mut self, policy: Option<MemoryPolicy>) -> Self {
539        self.memory_tracker = policy.map(MemoryTracker::new);
540        self
541    }
542
543    fn build_hash_table(&mut self) -> Result<()> {
544        let mut table: HashMap<GroupKeyBytes, AggregateGroup> = HashMap::new();
545        let mut next_row_id = 0u64;
546
547        while let Some(result) = self.input.next_row() {
548            let row = result?;
549            let ctx = EvalContext::new(&row.values);
550
551            let mut key_values = Vec::with_capacity(self.group_keys.len());
552            for expr in &self.group_keys {
553                key_values.push(crate::executor::evaluator::evaluate(expr, &ctx)?);
554            }
555            let key_bytes = encode_group_key(&key_values)?;
556
557            if !table.contains_key(&key_bytes) {
558                if table.len() + 1 > self.group_limit {
559                    return Err(ExecutorError::ResourceExhausted {
560                        message: format!(
561                            "GROUP BY result exceeds memory limit (max groups: {})",
562                            self.group_limit
563                        ),
564                    });
565                }
566                if let Some(tracker) = &mut self.memory_tracker {
567                    tracker.add_values(&key_values)?;
568                    tracker.add_bytes(
569                        self.aggregates.len() as u64 * AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES,
570                    )?;
571                }
572                let accumulators = self
573                    .aggregates
574                    .iter()
575                    .map(|agg| create_accumulator(&agg.function, agg.distinct))
576                    .collect::<Vec<_>>();
577                table.insert(
578                    key_bytes.clone(),
579                    AggregateGroup {
580                        key_values: key_values.clone(),
581                        accumulators,
582                    },
583                );
584            }
585
586            if let Some(group) = table.get_mut(&key_bytes) {
587                for (idx, agg) in self.aggregates.iter().enumerate() {
588                    let value = match &agg.arg {
589                        None => None,
590                        Some(expr) => Some(crate::executor::evaluator::evaluate(expr, &ctx)?),
591                    };
592                    if let Some(tracker) = &mut self.memory_tracker
593                        && matches!(
594                            agg.function,
595                            AggregateFunction::GroupConcat { .. }
596                                | AggregateFunction::StringAgg { .. }
597                        )
598                        && let Some(value_ref) = value.as_ref()
599                    {
600                        tracker.add_value(value_ref)?;
601                    }
602                    group.accumulators[idx].update(value)?;
603                }
604            }
605        }
606
607        if table.is_empty() && self.group_keys.is_empty() {
608            if let Some(tracker) = &mut self.memory_tracker {
609                tracker.add_bytes(
610                    self.aggregates.len() as u64 * AGGREGATE_ACCUMULATOR_OVERHEAD_BYTES,
611                )?;
612            }
613            let accumulators = self
614                .aggregates
615                .iter()
616                .map(|agg| create_accumulator(&agg.function, agg.distinct))
617                .collect::<Vec<_>>();
618            table.insert(
619                Vec::new(),
620                AggregateGroup {
621                    key_values: Vec::new(),
622                    accumulators,
623                },
624            );
625        }
626
627        let mut rows = Vec::with_capacity(table.len());
628        for group in table.values() {
629            let mut values = Vec::with_capacity(self.group_keys.len() + self.aggregates.len());
630            values.extend(group.key_values.iter().cloned());
631            for acc in &group.accumulators {
632                values.push(acc.finalize()?);
633            }
634            let row = Row::new(next_row_id, values);
635            next_row_id += 1;
636            if let Some(tracker) = &mut self.memory_tracker {
637                tracker.add_row(&row.values)?;
638            }
639
640            if let Some(having) = &self.having {
641                let ctx = EvalContext::new(&row.values);
642                match crate::executor::evaluator::evaluate(having, &ctx)? {
643                    SqlValue::Boolean(true) => rows.push(row),
644                    SqlValue::Boolean(false) | SqlValue::Null => {}
645                    other => {
646                        return Err(ExecutorError::Evaluation(
647                            crate::executor::EvaluationError::TypeMismatch {
648                                expected: "Boolean".into(),
649                                actual: other.type_name().into(),
650                            },
651                        ));
652                    }
653                }
654            } else {
655                rows.push(row);
656            }
657        }
658
659        self.hash_table = Some(table);
660        self.result_rows = rows;
661        Ok(())
662    }
663}
664
665impl<'a> RowIterator for AggregateIterator<'a> {
666    fn next_row(&mut self) -> Option<Result<Row>> {
667        if self.hash_table.is_none()
668            && let Err(err) = self.build_hash_table()
669        {
670            return Some(Err(err));
671        }
672
673        if self.index >= self.result_rows.len() {
674            return None;
675        }
676        let row = self.result_rows[self.index].clone();
677        self.index += 1;
678        Some(Ok(row))
679    }
680
681    fn schema(&self) -> &[ColumnMetadata] {
682        &self.schema
683    }
684}
685
686/// Iterator that performs streaming aggregation over sorted input.
687pub struct StreamingAggregateIterator<'a> {
688    input: Box<dyn RowIterator + 'a>,
689    group_keys: Vec<TypedExpr>,
690    aggregates: Vec<AggregateExpr>,
691    having: Option<TypedExpr>,
692    schema: Vec<ColumnMetadata>,
693    current_key: Option<Vec<SqlValue>>,
694    accumulators: Vec<Box<dyn Accumulator>>,
695    pending_row: Option<Row>,
696    finished: bool,
697    next_row_id: u64,
698    saw_row: bool,
699}
700
701impl<'a> StreamingAggregateIterator<'a> {
702    pub fn new(
703        input: Box<dyn RowIterator + 'a>,
704        group_keys: Vec<TypedExpr>,
705        aggregates: Vec<AggregateExpr>,
706        having: Option<TypedExpr>,
707        schema: Vec<ColumnMetadata>,
708    ) -> Self {
709        Self {
710            input,
711            group_keys,
712            aggregates,
713            having,
714            schema,
715            current_key: None,
716            accumulators: Vec::new(),
717            pending_row: None,
718            finished: false,
719            next_row_id: 0,
720            saw_row: false,
721        }
722    }
723
724    fn init_accumulators(&self) -> Vec<Box<dyn Accumulator>> {
725        self.aggregates
726            .iter()
727            .map(|agg| create_accumulator(&agg.function, agg.distinct))
728            .collect()
729    }
730
731    fn update_accumulators(&mut self, ctx: &EvalContext<'_>) -> Result<()> {
732        for (idx, agg) in self.aggregates.iter().enumerate() {
733            let value = match &agg.arg {
734                None => None,
735                Some(expr) => Some(crate::executor::evaluator::evaluate(expr, ctx)?),
736            };
737            self.accumulators[idx].update(value)?;
738        }
739        Ok(())
740    }
741
742    fn finalize_group(&mut self, key_values: &[SqlValue]) -> Result<Option<Row>> {
743        let mut values = Vec::with_capacity(self.group_keys.len() + self.aggregates.len());
744        values.extend(key_values.iter().cloned());
745        for acc in &self.accumulators {
746            values.push(acc.finalize()?);
747        }
748        let row = Row::new(self.next_row_id, values);
749        self.next_row_id = self.next_row_id.saturating_add(1);
750
751        if let Some(having) = &self.having {
752            let ctx = EvalContext::new(&row.values);
753            match crate::executor::evaluator::evaluate(having, &ctx)? {
754                SqlValue::Boolean(true) => Ok(Some(row)),
755                SqlValue::Boolean(false) | SqlValue::Null => Ok(None),
756                other => Err(ExecutorError::Evaluation(
757                    crate::executor::EvaluationError::TypeMismatch {
758                        expected: "Boolean".into(),
759                        actual: other.type_name().into(),
760                    },
761                )),
762            }
763        } else {
764            Ok(Some(row))
765        }
766    }
767}
768
769impl<'a> RowIterator for StreamingAggregateIterator<'a> {
770    fn next_row(&mut self) -> Option<Result<Row>> {
771        if let Some(row) = self.pending_row.take() {
772            return Some(Ok(row));
773        }
774        if self.finished {
775            return None;
776        }
777
778        loop {
779            match self.input.next_row() {
780                Some(Ok(row)) => {
781                    self.saw_row = true;
782                    let ctx = EvalContext::new(&row.values);
783                    let mut key_values = Vec::with_capacity(self.group_keys.len());
784                    for expr in &self.group_keys {
785                        match crate::executor::evaluator::evaluate(expr, &ctx) {
786                            Ok(value) => key_values.push(value),
787                            Err(err) => return Some(Err(err)),
788                        }
789                    }
790
791                    match &self.current_key {
792                        None => {
793                            self.current_key = Some(key_values);
794                            self.accumulators = self.init_accumulators();
795                            if let Err(err) = self.update_accumulators(&ctx) {
796                                return Some(Err(err));
797                            }
798                        }
799                        Some(current_key) if *current_key == key_values => {
800                            if let Err(err) = self.update_accumulators(&ctx) {
801                                return Some(Err(err));
802                            }
803                        }
804                        Some(_) => {
805                            let current_key = self.current_key.clone().unwrap_or_default();
806                            let output = match self.finalize_group(&current_key) {
807                                Ok(value) => value,
808                                Err(err) => return Some(Err(err)),
809                            };
810                            self.current_key = Some(key_values);
811                            self.accumulators = self.init_accumulators();
812                            if let Err(err) = self.update_accumulators(&ctx) {
813                                return Some(Err(err));
814                            }
815                            if let Some(row) = output {
816                                return Some(Ok(row));
817                            }
818                        }
819                    }
820                }
821                Some(Err(err)) => return Some(Err(err)),
822                None => {
823                    self.finished = true;
824                    if let Some(current_key) = self.current_key.take() {
825                        return match self.finalize_group(&current_key) {
826                            Ok(Some(row)) => Some(Ok(row)),
827                            Ok(None) => None,
828                            Err(err) => Some(Err(err)),
829                        };
830                    }
831
832                    if self.group_keys.is_empty() && !self.saw_row {
833                        self.accumulators = self.init_accumulators();
834                        return match self.finalize_group(&[]) {
835                            Ok(Some(row)) => Some(Ok(row)),
836                            Ok(None) => None,
837                            Err(err) => Some(Err(err)),
838                        };
839                    }
840
841                    return None;
842                }
843            }
844        }
845    }
846
847    fn schema(&self) -> &[ColumnMetadata] {
848        &self.schema
849    }
850}
851
852/// Build output schema for aggregate results.
853pub fn build_aggregate_schema(
854    group_keys: &[TypedExpr],
855    aggregates: &[AggregateExpr],
856) -> Vec<ColumnMetadata> {
857    let mut schema = Vec::new();
858    for (idx, key) in group_keys.iter().enumerate() {
859        let name = match &key.kind {
860            crate::planner::typed_expr::TypedExprKind::ColumnRef { column, .. } => column.clone(),
861            _ => format!("group_{idx}"),
862        };
863        schema.push(ColumnMetadata::new(name, key.resolved_type.clone()));
864    }
865    for (idx, agg) in aggregates.iter().enumerate() {
866        let name = match &agg.function {
867            AggregateFunction::Count => format!("count_{idx}"),
868            AggregateFunction::Sum => format!("sum_{idx}"),
869            AggregateFunction::Total => format!("total_{idx}"),
870            AggregateFunction::Avg => format!("avg_{idx}"),
871            AggregateFunction::Min => format!("min_{idx}"),
872            AggregateFunction::Max => format!("max_{idx}"),
873            AggregateFunction::GroupConcat { .. } => format!("group_concat_{idx}"),
874            AggregateFunction::StringAgg { .. } => format!("string_agg_{idx}"),
875        };
876        schema.push(ColumnMetadata::new(name, agg.result_type.clone()));
877    }
878    schema
879}
880
881#[cfg(test)]
882mod tests {
883    use super::*;
884
885    #[test]
886    fn count_accumulator_counts_rows_and_skips_nulls() {
887        let mut acc = CountAccumulator::new(false);
888        acc.update(None).unwrap();
889        acc.update(Some(SqlValue::Null)).unwrap();
890        acc.update(Some(SqlValue::Integer(1))).unwrap();
891        assert_eq!(acc.finalize().unwrap(), SqlValue::BigInt(2));
892    }
893
894    #[test]
895    fn count_accumulator_distinct_deduplicates() {
896        let mut acc = CountAccumulator::new(true);
897        acc.update(Some(SqlValue::Integer(1))).unwrap();
898        acc.update(Some(SqlValue::Integer(1))).unwrap();
899        acc.update(Some(SqlValue::Integer(2))).unwrap();
900        assert_eq!(acc.finalize().unwrap(), SqlValue::BigInt(2));
901    }
902
903    #[test]
904    fn sum_accumulator_aggregates_numeric_values() {
905        let mut acc = SumAccumulator::new();
906        acc.update(Some(SqlValue::Integer(2))).unwrap();
907        acc.update(Some(SqlValue::Double(3.5))).unwrap();
908        acc.update(Some(SqlValue::Null)).unwrap();
909        assert_eq!(acc.finalize().unwrap(), SqlValue::Double(5.5));
910    }
911
912    #[test]
913    fn total_accumulator_returns_zero_for_empty() {
914        let acc = TotalAccumulator::new();
915        assert_eq!(acc.finalize().unwrap(), SqlValue::Double(0.0));
916    }
917
918    #[test]
919    fn total_accumulator_aggregates_numeric_values() {
920        let mut acc = TotalAccumulator::new();
921        acc.update(Some(SqlValue::Integer(2))).unwrap();
922        acc.update(Some(SqlValue::Null)).unwrap();
923        acc.update(Some(SqlValue::Double(1.5))).unwrap();
924        assert_eq!(acc.finalize().unwrap(), SqlValue::Double(3.5));
925    }
926
927    #[test]
928    fn avg_accumulator_handles_empty_and_nulls() {
929        let mut acc = AvgAccumulator::new();
930        assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
931        acc.update(Some(SqlValue::Null)).unwrap();
932        acc.update(Some(SqlValue::BigInt(4))).unwrap();
933        acc.update(Some(SqlValue::Integer(2))).unwrap();
934        assert_eq!(acc.finalize().unwrap(), SqlValue::Double(3.0));
935    }
936
937    #[test]
938    fn min_max_accumulator_tracks_extremes() {
939        let mut min_acc = MinMaxAccumulator::new(true);
940        let mut max_acc = MinMaxAccumulator::new(false);
941        for value in [3, 1, 2] {
942            min_acc.update(Some(SqlValue::Integer(value))).unwrap();
943            max_acc.update(Some(SqlValue::Integer(value))).unwrap();
944        }
945        assert_eq!(min_acc.finalize().unwrap(), SqlValue::Integer(1));
946        assert_eq!(max_acc.finalize().unwrap(), SqlValue::Integer(3));
947    }
948
949    #[test]
950    fn min_max_accumulator_rejects_type_mismatch() {
951        let mut acc = MinMaxAccumulator::new(true);
952        acc.update(Some(SqlValue::Integer(1))).unwrap();
953        let err = acc.update(Some(SqlValue::Text("bad".into()))).unwrap_err();
954        match err {
955            ExecutorError::Evaluation(crate::executor::EvaluationError::TypeMismatch {
956                ..
957            }) => {}
958            other => panic!("unexpected error {:?}", other),
959        }
960    }
961
962    #[test]
963    fn group_concat_accumulator_joins_values() {
964        let mut acc = GroupConcatAccumulator::new("|".into());
965        acc.update(Some(SqlValue::Text("a".into()))).unwrap();
966        acc.update(Some(SqlValue::Null)).unwrap();
967        acc.update(Some(SqlValue::Text("b".into()))).unwrap();
968        assert_eq!(acc.finalize().unwrap(), SqlValue::Text("a|b".into()));
969    }
970
971    #[test]
972    fn group_concat_accumulator_empty_returns_null() {
973        let acc = GroupConcatAccumulator::new(",".into());
974        assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
975    }
976
977    #[test]
978    fn string_agg_accumulator_joins_values() {
979        let mut acc = StringAggAccumulator::new("::".into());
980        acc.update(Some(SqlValue::Text("a".into()))).unwrap();
981        acc.update(Some(SqlValue::Null)).unwrap();
982        acc.update(Some(SqlValue::Text("b".into()))).unwrap();
983        assert_eq!(acc.finalize().unwrap(), SqlValue::Text("a::b".into()));
984    }
985
986    #[test]
987    fn string_agg_accumulator_empty_returns_null() {
988        let acc = StringAggAccumulator::new(",".into());
989        assert_eq!(acc.finalize().unwrap(), SqlValue::Null);
990    }
991
992    #[test]
993    fn encode_group_key_is_deterministic() {
994        let values = vec![
995            SqlValue::Integer(1),
996            SqlValue::Text("a".into()),
997            SqlValue::Null,
998        ];
999        let first = encode_group_key(&values).unwrap();
1000        let second = encode_group_key(&values).unwrap();
1001        assert_eq!(first, second);
1002    }
1003}