Skip to main content

graphos_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - `HashAggregateOperator`: Hash-based grouping with aggregation functions
5//! - Various aggregation functions: COUNT, SUM, AVG, MIN, MAX, etc.
6
7use std::collections::{HashMap, HashSet};
8
9use graphos_common::types::{LogicalType, Value};
10
11/// A wrapper for Value that can be hashed (for DISTINCT tracking).
12#[derive(Debug, Clone, PartialEq, Eq, Hash)]
13enum HashableValue {
14    Null,
15    Bool(bool),
16    Int64(i64),
17    Float64Bits(u64),
18    String(String),
19    Other(String),
20}
21
22impl From<&Value> for HashableValue {
23    fn from(v: &Value) -> Self {
24        match v {
25            Value::Null => HashableValue::Null,
26            Value::Bool(b) => HashableValue::Bool(*b),
27            Value::Int64(i) => HashableValue::Int64(*i),
28            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
29            Value::String(s) => HashableValue::String(s.to_string()),
30            other => HashableValue::Other(format!("{other:?}")),
31        }
32    }
33}
34
35impl From<Value> for HashableValue {
36    fn from(v: Value) -> Self {
37        Self::from(&v)
38    }
39}
40
41use super::{Operator, OperatorError, OperatorResult};
42use crate::execution::DataChunk;
43use crate::execution::chunk::DataChunkBuilder;
44
45/// Aggregation function types.
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum AggregateFunction {
48    /// Count of rows (COUNT(*)).
49    Count,
50    /// Count of non-null values (COUNT(column)).
51    CountNonNull,
52    /// Sum of values.
53    Sum,
54    /// Average of values.
55    Avg,
56    /// Minimum value.
57    Min,
58    /// Maximum value.
59    Max,
60    /// First value in the group.
61    First,
62    /// Last value in the group.
63    Last,
64    /// Collect values into a list.
65    Collect,
66}
67
68/// An aggregation expression.
69#[derive(Debug, Clone)]
70pub struct AggregateExpr {
71    /// The aggregation function.
72    pub function: AggregateFunction,
73    /// Column index to aggregate (None for COUNT(*)).
74    pub column: Option<usize>,
75    /// Whether to aggregate distinct values only.
76    pub distinct: bool,
77    /// Output alias (for naming the result column).
78    pub alias: Option<String>,
79}
80
81impl AggregateExpr {
82    /// Creates a COUNT(*) expression.
83    pub fn count_star() -> Self {
84        Self {
85            function: AggregateFunction::Count,
86            column: None,
87            distinct: false,
88            alias: None,
89        }
90    }
91
92    /// Creates a COUNT(column) expression.
93    pub fn count(column: usize) -> Self {
94        Self {
95            function: AggregateFunction::CountNonNull,
96            column: Some(column),
97            distinct: false,
98            alias: None,
99        }
100    }
101
102    /// Creates a SUM(column) expression.
103    pub fn sum(column: usize) -> Self {
104        Self {
105            function: AggregateFunction::Sum,
106            column: Some(column),
107            distinct: false,
108            alias: None,
109        }
110    }
111
112    /// Creates an AVG(column) expression.
113    pub fn avg(column: usize) -> Self {
114        Self {
115            function: AggregateFunction::Avg,
116            column: Some(column),
117            distinct: false,
118            alias: None,
119        }
120    }
121
122    /// Creates a MIN(column) expression.
123    pub fn min(column: usize) -> Self {
124        Self {
125            function: AggregateFunction::Min,
126            column: Some(column),
127            distinct: false,
128            alias: None,
129        }
130    }
131
132    /// Creates a MAX(column) expression.
133    pub fn max(column: usize) -> Self {
134        Self {
135            function: AggregateFunction::Max,
136            column: Some(column),
137            distinct: false,
138            alias: None,
139        }
140    }
141
142    /// Creates a FIRST(column) expression.
143    pub fn first(column: usize) -> Self {
144        Self {
145            function: AggregateFunction::First,
146            column: Some(column),
147            distinct: false,
148            alias: None,
149        }
150    }
151
152    /// Creates a LAST(column) expression.
153    pub fn last(column: usize) -> Self {
154        Self {
155            function: AggregateFunction::Last,
156            column: Some(column),
157            distinct: false,
158            alias: None,
159        }
160    }
161
162    /// Creates a COLLECT(column) expression.
163    pub fn collect(column: usize) -> Self {
164        Self {
165            function: AggregateFunction::Collect,
166            column: Some(column),
167            distinct: false,
168            alias: None,
169        }
170    }
171
172    /// Sets the distinct flag.
173    pub fn with_distinct(mut self) -> Self {
174        self.distinct = true;
175        self
176    }
177
178    /// Sets the output alias.
179    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
180        self.alias = Some(alias.into());
181        self
182    }
183}
184
185/// State for a single aggregation computation.
186#[derive(Debug, Clone)]
187enum AggregateState {
188    /// Count state.
189    Count(i64),
190    /// Count distinct state (count, seen values).
191    CountDistinct(i64, HashSet<HashableValue>),
192    /// Sum state (integer).
193    SumInt(i64),
194    /// Sum distinct state (integer, seen values).
195    SumIntDistinct(i64, HashSet<HashableValue>),
196    /// Sum state (float).
197    SumFloat(f64),
198    /// Sum distinct state (float, seen values).
199    SumFloatDistinct(f64, HashSet<HashableValue>),
200    /// Average state (sum, count).
201    Avg(f64, i64),
202    /// Average distinct state (sum, count, seen values).
203    AvgDistinct(f64, i64, HashSet<HashableValue>),
204    /// Min state.
205    Min(Option<Value>),
206    /// Max state.
207    Max(Option<Value>),
208    /// First state.
209    First(Option<Value>),
210    /// Last state.
211    Last(Option<Value>),
212    /// Collect state.
213    Collect(Vec<Value>),
214    /// Collect distinct state (values, seen).
215    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
216}
217
218impl AggregateState {
219    /// Creates initial state for an aggregation function.
220    fn new(function: AggregateFunction, distinct: bool) -> Self {
221        match (function, distinct) {
222            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
223                AggregateState::Count(0)
224            }
225            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
226                AggregateState::CountDistinct(0, HashSet::new())
227            }
228            (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
229            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
230            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
231            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
232            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
233            (AggregateFunction::Max, _) => AggregateState::Max(None),
234            (AggregateFunction::First, _) => AggregateState::First(None),
235            (AggregateFunction::Last, _) => AggregateState::Last(None),
236            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
237            (AggregateFunction::Collect, true) => {
238                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
239            }
240        }
241    }
242
243    /// Updates the state with a new value.
244    fn update(&mut self, value: Option<Value>) {
245        match self {
246            AggregateState::Count(count) => {
247                *count += 1;
248            }
249            AggregateState::CountDistinct(count, seen) => {
250                if let Some(ref v) = value {
251                    let hashable = HashableValue::from(v);
252                    if seen.insert(hashable) {
253                        *count += 1;
254                    }
255                }
256            }
257            AggregateState::SumInt(sum) => {
258                if let Some(Value::Int64(v)) = value {
259                    *sum += v;
260                } else if let Some(Value::Float64(v)) = value {
261                    // Convert to float sum
262                    *self = AggregateState::SumFloat(*sum as f64 + v);
263                }
264            }
265            AggregateState::SumIntDistinct(sum, seen) => {
266                if let Some(ref v) = value {
267                    let hashable = HashableValue::from(v);
268                    if seen.insert(hashable) {
269                        if let Value::Int64(i) = v {
270                            *sum += i;
271                        } else if let Value::Float64(f) = v {
272                            // Convert to float distinct
273                            let seen_clone = seen.clone();
274                            *self = AggregateState::SumFloatDistinct(*sum as f64 + f, seen_clone);
275                        }
276                    }
277                }
278            }
279            AggregateState::SumFloat(sum) => {
280                if let Some(Value::Int64(v)) = value {
281                    *sum += v as f64;
282                } else if let Some(Value::Float64(v)) = value {
283                    *sum += v;
284                }
285            }
286            AggregateState::SumFloatDistinct(sum, seen) => {
287                if let Some(ref v) = value {
288                    let hashable = HashableValue::from(v);
289                    if seen.insert(hashable) {
290                        if let Some(num) = value_to_f64(v) {
291                            *sum += num;
292                        }
293                    }
294                }
295            }
296            AggregateState::Avg(sum, count) => {
297                if let Some(ref v) = value {
298                    if let Some(num) = value_to_f64(v) {
299                        *sum += num;
300                        *count += 1;
301                    }
302                }
303            }
304            AggregateState::AvgDistinct(sum, count, seen) => {
305                if let Some(ref v) = value {
306                    let hashable = HashableValue::from(v);
307                    if seen.insert(hashable) {
308                        if let Some(num) = value_to_f64(v) {
309                            *sum += num;
310                            *count += 1;
311                        }
312                    }
313                }
314            }
315            AggregateState::Min(min) => {
316                if let Some(v) = value {
317                    match min {
318                        None => *min = Some(v),
319                        Some(current) => {
320                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
321                                *min = Some(v);
322                            }
323                        }
324                    }
325                }
326            }
327            AggregateState::Max(max) => {
328                if let Some(v) = value {
329                    match max {
330                        None => *max = Some(v),
331                        Some(current) => {
332                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
333                                *max = Some(v);
334                            }
335                        }
336                    }
337                }
338            }
339            AggregateState::First(first) => {
340                if first.is_none() {
341                    *first = value;
342                }
343            }
344            AggregateState::Last(last) => {
345                if value.is_some() {
346                    *last = value;
347                }
348            }
349            AggregateState::Collect(list) => {
350                if let Some(v) = value {
351                    list.push(v);
352                }
353            }
354            AggregateState::CollectDistinct(list, seen) => {
355                if let Some(v) = value {
356                    let hashable = HashableValue::from(&v);
357                    if seen.insert(hashable) {
358                        list.push(v);
359                    }
360                }
361            }
362        }
363    }
364
365    /// Finalizes the state and returns the result value.
366    fn finalize(&self) -> Value {
367        match self {
368            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
369                Value::Int64(*count)
370            }
371            AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
372                Value::Int64(*sum)
373            }
374            AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
375                Value::Float64(*sum)
376            }
377            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
378                if *count == 0 {
379                    Value::Null
380                } else {
381                    Value::Float64(*sum / *count as f64)
382                }
383            }
384            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
385            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
386            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
387            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
388            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
389                Value::List(list.clone().into())
390            }
391        }
392    }
393}
394
395/// Convert a value to f64 for numeric aggregations.
396fn value_to_f64(value: &Value) -> Option<f64> {
397    match value {
398        Value::Int64(i) => Some(*i as f64),
399        Value::Float64(f) => Some(*f),
400        _ => None,
401    }
402}
403
404/// Compare two values.
405fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
406    match (a, b) {
407        (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
408        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
409        (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
410        (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
411        (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
412        (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
413        _ => None,
414    }
415}
416
417/// A group key for hash-based aggregation.
418#[derive(Debug, Clone, PartialEq, Eq, Hash)]
419pub struct GroupKey(Vec<GroupKeyPart>);
420
421#[derive(Debug, Clone, PartialEq, Eq, Hash)]
422enum GroupKeyPart {
423    Null,
424    Bool(bool),
425    Int64(i64),
426    String(String),
427}
428
429impl GroupKey {
430    /// Creates a group key from column values.
431    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
432        let parts: Vec<GroupKeyPart> = group_columns
433            .iter()
434            .map(|&col_idx| {
435                chunk
436                    .column(col_idx)
437                    .and_then(|col| col.get_value(row))
438                    .map(|v| match v {
439                        Value::Null => GroupKeyPart::Null,
440                        Value::Bool(b) => GroupKeyPart::Bool(b),
441                        Value::Int64(i) => GroupKeyPart::Int64(i),
442                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
443                        Value::String(s) => GroupKeyPart::String(s.to_string()),
444                        _ => GroupKeyPart::String(format!("{v:?}")),
445                    })
446                    .unwrap_or(GroupKeyPart::Null)
447            })
448            .collect();
449        GroupKey(parts)
450    }
451
452    /// Converts the group key back to values.
453    fn to_values(&self) -> Vec<Value> {
454        self.0
455            .iter()
456            .map(|part| match part {
457                GroupKeyPart::Null => Value::Null,
458                GroupKeyPart::Bool(b) => Value::Bool(*b),
459                GroupKeyPart::Int64(i) => Value::Int64(*i),
460                GroupKeyPart::String(s) => Value::String(s.clone().into()),
461            })
462            .collect()
463    }
464}
465
466/// Hash-based aggregate operator.
467///
468/// Groups input by key columns and computes aggregations for each group.
469pub struct HashAggregateOperator {
470    /// Child operator to read from.
471    child: Box<dyn Operator>,
472    /// Columns to group by.
473    group_columns: Vec<usize>,
474    /// Aggregation expressions.
475    aggregates: Vec<AggregateExpr>,
476    /// Output schema.
477    output_schema: Vec<LogicalType>,
478    /// Hash table: group key -> (group values, aggregate states).
479    groups: HashMap<GroupKey, Vec<AggregateState>>,
480    /// Whether aggregation is complete.
481    aggregation_complete: bool,
482    /// Results iterator.
483    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
484}
485
486impl HashAggregateOperator {
487    /// Creates a new hash aggregate operator.
488    ///
489    /// # Arguments
490    /// * `child` - Child operator to read from.
491    /// * `group_columns` - Column indices to group by.
492    /// * `aggregates` - Aggregation expressions.
493    /// * `output_schema` - Schema of the output (group columns + aggregate results).
494    pub fn new(
495        child: Box<dyn Operator>,
496        group_columns: Vec<usize>,
497        aggregates: Vec<AggregateExpr>,
498        output_schema: Vec<LogicalType>,
499    ) -> Self {
500        Self {
501            child,
502            group_columns,
503            aggregates,
504            output_schema,
505            groups: HashMap::new(),
506            aggregation_complete: false,
507            results: None,
508        }
509    }
510
511    /// Performs the aggregation.
512    fn aggregate(&mut self) -> Result<(), OperatorError> {
513        while let Some(chunk) = self.child.next()? {
514            for row in chunk.selected_indices() {
515                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
516
517                // Get or create aggregate states for this group
518                let states = self.groups.entry(key).or_insert_with(|| {
519                    self.aggregates
520                        .iter()
521                        .map(|agg| AggregateState::new(agg.function, agg.distinct))
522                        .collect()
523                });
524
525                // Update each aggregate
526                for (i, agg) in self.aggregates.iter().enumerate() {
527                    let value = match agg.function {
528                        AggregateFunction::Count => None, // COUNT(*) doesn't need a value
529                        _ => agg
530                            .column
531                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
532                    };
533
534                    // For COUNT, always update. For others, skip nulls unless counting.
535                    match agg.function {
536                        AggregateFunction::Count => states[i].update(None),
537                        AggregateFunction::CountNonNull => {
538                            if value.is_some() && !matches!(value, Some(Value::Null)) {
539                                states[i].update(value);
540                            }
541                        }
542                        _ => {
543                            if value.is_some() && !matches!(value, Some(Value::Null)) {
544                                states[i].update(value);
545                            }
546                        }
547                    }
548                }
549            }
550        }
551
552        self.aggregation_complete = true;
553
554        // Convert to results iterator
555        let results: Vec<_> = self.groups.drain().collect();
556        self.results = Some(results.into_iter());
557
558        Ok(())
559    }
560}
561
562impl Operator for HashAggregateOperator {
563    fn next(&mut self) -> OperatorResult {
564        // Perform aggregation if not done
565        if !self.aggregation_complete {
566            self.aggregate()?;
567        }
568
569        // Special case: no groups (global aggregation with no data)
570        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
571            // For global aggregation (no GROUP BY), return one row with initial values
572            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
573
574            for agg in &self.aggregates {
575                let state = AggregateState::new(agg.function, agg.distinct);
576                let value = state.finalize();
577                if let Some(col) = builder.column_mut(self.group_columns.len()) {
578                    col.push_value(value);
579                }
580            }
581            builder.advance_row();
582
583            self.results = Some(Vec::new().into_iter()); // Mark as done
584            return Ok(Some(builder.finish()));
585        }
586
587        let results = match &mut self.results {
588            Some(r) => r,
589            None => return Ok(None),
590        };
591
592        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
593
594        for (key, states) in results.by_ref() {
595            // Output group key columns
596            let key_values = key.to_values();
597            for (i, value) in key_values.into_iter().enumerate() {
598                if let Some(col) = builder.column_mut(i) {
599                    col.push_value(value);
600                }
601            }
602
603            // Output aggregate results
604            for (i, state) in states.iter().enumerate() {
605                let col_idx = self.group_columns.len() + i;
606                if let Some(col) = builder.column_mut(col_idx) {
607                    col.push_value(state.finalize());
608                }
609            }
610
611            builder.advance_row();
612
613            if builder.is_full() {
614                return Ok(Some(builder.finish()));
615            }
616        }
617
618        if builder.row_count() > 0 {
619            Ok(Some(builder.finish()))
620        } else {
621            Ok(None)
622        }
623    }
624
625    fn reset(&mut self) {
626        self.child.reset();
627        self.groups.clear();
628        self.aggregation_complete = false;
629        self.results = None;
630    }
631
632    fn name(&self) -> &'static str {
633        "HashAggregate"
634    }
635}
636
637/// Simple (non-grouping) aggregate operator for global aggregations.
638///
639/// Used when there's no GROUP BY clause - aggregates all input into a single row.
640pub struct SimpleAggregateOperator {
641    /// Child operator.
642    child: Box<dyn Operator>,
643    /// Aggregation expressions.
644    aggregates: Vec<AggregateExpr>,
645    /// Output schema.
646    output_schema: Vec<LogicalType>,
647    /// Aggregate states.
648    states: Vec<AggregateState>,
649    /// Whether aggregation is complete.
650    done: bool,
651}
652
653impl SimpleAggregateOperator {
654    /// Creates a new simple aggregate operator.
655    pub fn new(
656        child: Box<dyn Operator>,
657        aggregates: Vec<AggregateExpr>,
658        output_schema: Vec<LogicalType>,
659    ) -> Self {
660        let states = aggregates
661            .iter()
662            .map(|agg| AggregateState::new(agg.function, agg.distinct))
663            .collect();
664
665        Self {
666            child,
667            aggregates,
668            output_schema,
669            states,
670            done: false,
671        }
672    }
673}
674
675impl Operator for SimpleAggregateOperator {
676    fn next(&mut self) -> OperatorResult {
677        if self.done {
678            return Ok(None);
679        }
680
681        // Process all input
682        while let Some(chunk) = self.child.next()? {
683            for row in chunk.selected_indices() {
684                for (i, agg) in self.aggregates.iter().enumerate() {
685                    let value = match agg.function {
686                        AggregateFunction::Count => None,
687                        _ => agg
688                            .column
689                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
690                    };
691
692                    match agg.function {
693                        AggregateFunction::Count => self.states[i].update(None),
694                        AggregateFunction::CountNonNull => {
695                            if value.is_some() && !matches!(value, Some(Value::Null)) {
696                                self.states[i].update(value);
697                            }
698                        }
699                        _ => {
700                            if value.is_some() && !matches!(value, Some(Value::Null)) {
701                                self.states[i].update(value);
702                            }
703                        }
704                    }
705                }
706            }
707        }
708
709        // Output single result row
710        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
711
712        for (i, state) in self.states.iter().enumerate() {
713            if let Some(col) = builder.column_mut(i) {
714                col.push_value(state.finalize());
715            }
716        }
717        builder.advance_row();
718
719        self.done = true;
720        Ok(Some(builder.finish()))
721    }
722
723    fn reset(&mut self) {
724        self.child.reset();
725        self.states = self
726            .aggregates
727            .iter()
728            .map(|agg| AggregateState::new(agg.function, agg.distinct))
729            .collect();
730        self.done = false;
731    }
732
733    fn name(&self) -> &'static str {
734        "SimpleAggregate"
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use super::*;
741    use crate::execution::chunk::DataChunkBuilder;
742
743    struct MockOperator {
744        chunks: Vec<DataChunk>,
745        position: usize,
746    }
747
748    impl MockOperator {
749        fn new(chunks: Vec<DataChunk>) -> Self {
750            Self {
751                chunks,
752                position: 0,
753            }
754        }
755    }
756
757    impl Operator for MockOperator {
758        fn next(&mut self) -> OperatorResult {
759            if self.position < self.chunks.len() {
760                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
761                self.position += 1;
762                Ok(Some(chunk))
763            } else {
764                Ok(None)
765            }
766        }
767
768        fn reset(&mut self) {
769            self.position = 0;
770        }
771
772        fn name(&self) -> &'static str {
773            "Mock"
774        }
775    }
776
777    fn create_test_chunk() -> DataChunk {
778        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
779        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
780
781        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
782        for (group, value) in data {
783            builder.column_mut(0).unwrap().push_int64(group);
784            builder.column_mut(1).unwrap().push_int64(value);
785            builder.advance_row();
786        }
787
788        builder.finish()
789    }
790
791    #[test]
792    fn test_simple_count() {
793        let mock = MockOperator::new(vec![create_test_chunk()]);
794
795        let mut agg = SimpleAggregateOperator::new(
796            Box::new(mock),
797            vec![AggregateExpr::count_star()],
798            vec![LogicalType::Int64],
799        );
800
801        let result = agg.next().unwrap().unwrap();
802        assert_eq!(result.row_count(), 1);
803        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
804
805        // Should be done
806        assert!(agg.next().unwrap().is_none());
807    }
808
809    #[test]
810    fn test_simple_sum() {
811        let mock = MockOperator::new(vec![create_test_chunk()]);
812
813        let mut agg = SimpleAggregateOperator::new(
814            Box::new(mock),
815            vec![AggregateExpr::sum(1)], // Sum of column 1
816            vec![LogicalType::Int64],
817        );
818
819        let result = agg.next().unwrap().unwrap();
820        assert_eq!(result.row_count(), 1);
821        // Sum: 10 + 20 + 30 + 40 + 50 = 150
822        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
823    }
824
825    #[test]
826    fn test_simple_avg() {
827        let mock = MockOperator::new(vec![create_test_chunk()]);
828
829        let mut agg = SimpleAggregateOperator::new(
830            Box::new(mock),
831            vec![AggregateExpr::avg(1)],
832            vec![LogicalType::Float64],
833        );
834
835        let result = agg.next().unwrap().unwrap();
836        assert_eq!(result.row_count(), 1);
837        // Avg: 150 / 5 = 30.0
838        let avg = result.column(0).unwrap().get_float64(0).unwrap();
839        assert!((avg - 30.0).abs() < 0.001);
840    }
841
842    #[test]
843    fn test_simple_min_max() {
844        let mock = MockOperator::new(vec![create_test_chunk()]);
845
846        let mut agg = SimpleAggregateOperator::new(
847            Box::new(mock),
848            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
849            vec![LogicalType::Int64, LogicalType::Int64],
850        );
851
852        let result = agg.next().unwrap().unwrap();
853        assert_eq!(result.row_count(), 1);
854        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
855        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
856    }
857
858    #[test]
859    fn test_grouped_aggregation() {
860        let mock = MockOperator::new(vec![create_test_chunk()]);
861
862        // GROUP BY column 0, SUM(column 1)
863        let mut agg = HashAggregateOperator::new(
864            Box::new(mock),
865            vec![0],                     // Group by column 0
866            vec![AggregateExpr::sum(1)], // Sum of column 1
867            vec![LogicalType::Int64, LogicalType::Int64],
868        );
869
870        let mut results: Vec<(i64, i64)> = Vec::new();
871        while let Some(chunk) = agg.next().unwrap() {
872            for row in chunk.selected_indices() {
873                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
874                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
875                results.push((group, sum));
876            }
877        }
878
879        results.sort_by_key(|(g, _)| *g);
880        assert_eq!(results.len(), 2);
881        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
882        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
883    }
884
885    #[test]
886    fn test_grouped_count() {
887        let mock = MockOperator::new(vec![create_test_chunk()]);
888
889        // GROUP BY column 0, COUNT(*)
890        let mut agg = HashAggregateOperator::new(
891            Box::new(mock),
892            vec![0],
893            vec![AggregateExpr::count_star()],
894            vec![LogicalType::Int64, LogicalType::Int64],
895        );
896
897        let mut results: Vec<(i64, i64)> = Vec::new();
898        while let Some(chunk) = agg.next().unwrap() {
899            for row in chunk.selected_indices() {
900                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
901                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
902                results.push((group, count));
903            }
904        }
905
906        results.sort_by_key(|(g, _)| *g);
907        assert_eq!(results.len(), 2);
908        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
909        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
910    }
911
912    #[test]
913    fn test_multiple_aggregates() {
914        let mock = MockOperator::new(vec![create_test_chunk()]);
915
916        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
917        let mut agg = HashAggregateOperator::new(
918            Box::new(mock),
919            vec![0],
920            vec![
921                AggregateExpr::count_star(),
922                AggregateExpr::sum(1),
923                AggregateExpr::avg(1),
924            ],
925            vec![
926                LogicalType::Int64,   // Group key
927                LogicalType::Int64,   // COUNT
928                LogicalType::Int64,   // SUM
929                LogicalType::Float64, // AVG
930            ],
931        );
932
933        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
934        while let Some(chunk) = agg.next().unwrap() {
935            for row in chunk.selected_indices() {
936                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
937                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
938                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
939                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
940                results.push((group, count, sum, avg));
941            }
942        }
943
944        results.sort_by_key(|(g, _, _, _)| *g);
945        assert_eq!(results.len(), 2);
946
947        // Group 1: COUNT=2, SUM=30, AVG=15.0
948        assert_eq!(results[0].0, 1);
949        assert_eq!(results[0].1, 2);
950        assert_eq!(results[0].2, 30);
951        assert!((results[0].3 - 15.0).abs() < 0.001);
952
953        // Group 2: COUNT=3, SUM=120, AVG=40.0
954        assert_eq!(results[1].0, 2);
955        assert_eq!(results[1].1, 3);
956        assert_eq!(results[1].2, 120);
957        assert!((results[1].3 - 40.0).abs() < 0.001);
958    }
959
960    fn create_test_chunk_with_duplicates() -> DataChunk {
961        // Create data with duplicate values in column 1
962        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
963        // GROUP 1: values [10, 10, 20] -> distinct count = 2
964        // GROUP 2: values [30, 30, 30] -> distinct count = 1
965        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
966
967        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
968        for (group, value) in data {
969            builder.column_mut(0).unwrap().push_int64(group);
970            builder.column_mut(1).unwrap().push_int64(value);
971            builder.advance_row();
972        }
973
974        builder.finish()
975    }
976
977    #[test]
978    fn test_count_distinct() {
979        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
980
981        // COUNT(DISTINCT column 1)
982        let mut agg = SimpleAggregateOperator::new(
983            Box::new(mock),
984            vec![AggregateExpr::count(1).with_distinct()],
985            vec![LogicalType::Int64],
986        );
987
988        let result = agg.next().unwrap().unwrap();
989        assert_eq!(result.row_count(), 1);
990        // Total distinct values: 10, 20, 30 = 3 distinct values
991        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
992    }
993
994    #[test]
995    fn test_grouped_count_distinct() {
996        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
997
998        // GROUP BY column 0, COUNT(DISTINCT column 1)
999        let mut agg = HashAggregateOperator::new(
1000            Box::new(mock),
1001            vec![0],
1002            vec![AggregateExpr::count(1).with_distinct()],
1003            vec![LogicalType::Int64, LogicalType::Int64],
1004        );
1005
1006        let mut results: Vec<(i64, i64)> = Vec::new();
1007        while let Some(chunk) = agg.next().unwrap() {
1008            for row in chunk.selected_indices() {
1009                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1010                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1011                results.push((group, count));
1012            }
1013        }
1014
1015        results.sort_by_key(|(g, _)| *g);
1016        assert_eq!(results.len(), 2);
1017        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1018        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1019    }
1020
1021    #[test]
1022    fn test_sum_distinct() {
1023        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1024
1025        // SUM(DISTINCT column 1)
1026        let mut agg = SimpleAggregateOperator::new(
1027            Box::new(mock),
1028            vec![AggregateExpr::sum(1).with_distinct()],
1029            vec![LogicalType::Int64],
1030        );
1031
1032        let result = agg.next().unwrap().unwrap();
1033        assert_eq!(result.row_count(), 1);
1034        // Sum of distinct values: 10 + 20 + 30 = 60
1035        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1036    }
1037
1038    #[test]
1039    fn test_avg_distinct() {
1040        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1041
1042        // AVG(DISTINCT column 1)
1043        let mut agg = SimpleAggregateOperator::new(
1044            Box::new(mock),
1045            vec![AggregateExpr::avg(1).with_distinct()],
1046            vec![LogicalType::Float64],
1047        );
1048
1049        let result = agg.next().unwrap().unwrap();
1050        assert_eq!(result.row_count(), 1);
1051        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1052        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1053        assert!((avg - 20.0).abs() < 0.001);
1054    }
1055}