Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20/// State for a single aggregation computation.
21#[derive(Debug, Clone)]
22enum AggregateState {
23    /// Count state.
24    Count(i64),
25    /// Count distinct state (count, seen values).
26    CountDistinct(i64, HashSet<HashableValue>),
27    /// Sum state (integer).
28    SumInt(i64),
29    /// Sum distinct state (integer, seen values).
30    SumIntDistinct(i64, HashSet<HashableValue>),
31    /// Sum state (float).
32    SumFloat(f64),
33    /// Sum distinct state (float, seen values).
34    SumFloatDistinct(f64, HashSet<HashableValue>),
35    /// Average state (sum, count).
36    Avg(f64, i64),
37    /// Average distinct state (sum, count, seen values).
38    AvgDistinct(f64, i64, HashSet<HashableValue>),
39    /// Min state.
40    Min(Option<Value>),
41    /// Max state.
42    Max(Option<Value>),
43    /// First state.
44    First(Option<Value>),
45    /// Last state.
46    Last(Option<Value>),
47    /// Collect state.
48    Collect(Vec<Value>),
49    /// Collect distinct state (values, seen).
50    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
52    StdDev { count: i64, mean: f64, m2: f64 },
53    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDevPop { count: i64, mean: f64, m2: f64 },
55    /// Discrete percentile state (values, percentile).
56    PercentileDisc { values: Vec<f64>, percentile: f64 },
57    /// Continuous percentile state (values, percentile).
58    PercentileCont { values: Vec<f64>, percentile: f64 },
59}
60
61impl AggregateState {
62    /// Creates initial state for an aggregation function.
63    fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
64        match (function, distinct) {
65            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
66                AggregateState::Count(0)
67            }
68            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
69                AggregateState::CountDistinct(0, HashSet::new())
70            }
71            (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
72            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
73            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
74            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
75            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
76            (AggregateFunction::Max, _) => AggregateState::Max(None),
77            (AggregateFunction::First, _) => AggregateState::First(None),
78            (AggregateFunction::Last, _) => AggregateState::Last(None),
79            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
80            (AggregateFunction::Collect, true) => {
81                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
82            }
83            // Statistical functions (Welford's algorithm for online computation)
84            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
85                count: 0,
86                mean: 0.0,
87                m2: 0.0,
88            },
89            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
90                count: 0,
91                mean: 0.0,
92                m2: 0.0,
93            },
94            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
95                values: Vec::new(),
96                percentile: percentile.unwrap_or(0.5),
97            },
98            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
99                values: Vec::new(),
100                percentile: percentile.unwrap_or(0.5),
101            },
102        }
103    }
104
105    /// Updates the state with a new value.
106    fn update(&mut self, value: Option<Value>) {
107        match self {
108            AggregateState::Count(count) => {
109                *count += 1;
110            }
111            AggregateState::CountDistinct(count, seen) => {
112                if let Some(ref v) = value {
113                    let hashable = HashableValue::from(v);
114                    if seen.insert(hashable) {
115                        *count += 1;
116                    }
117                }
118            }
119            AggregateState::SumInt(sum) => {
120                if let Some(Value::Int64(v)) = value {
121                    *sum += v;
122                } else if let Some(Value::Float64(v)) = value {
123                    // Convert to float sum
124                    *self = AggregateState::SumFloat(*sum as f64 + v);
125                } else if let Some(ref v) = value {
126                    // RDF stores numeric literals as strings - try to parse
127                    if let Some(num) = value_to_f64(v) {
128                        *self = AggregateState::SumFloat(*sum as f64 + num);
129                    }
130                }
131            }
132            AggregateState::SumIntDistinct(sum, seen) => {
133                if let Some(ref v) = value {
134                    let hashable = HashableValue::from(v);
135                    if seen.insert(hashable) {
136                        if let Value::Int64(i) = v {
137                            *sum += i;
138                        } else if let Value::Float64(f) = v {
139                            // Convert to float distinct — move the seen set instead of cloning
140                            let moved_seen = std::mem::take(seen);
141                            *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
142                        } else if let Some(num) = value_to_f64(v) {
143                            // RDF string-encoded numerics
144                            let moved_seen = std::mem::take(seen);
145                            *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
146                        }
147                    }
148                }
149            }
150            AggregateState::SumFloat(sum) => {
151                if let Some(ref v) = value {
152                    // Use value_to_f64 which now handles strings
153                    if let Some(num) = value_to_f64(v) {
154                        *sum += num;
155                    }
156                }
157            }
158            AggregateState::SumFloatDistinct(sum, seen) => {
159                if let Some(ref v) = value {
160                    let hashable = HashableValue::from(v);
161                    if seen.insert(hashable)
162                        && let Some(num) = value_to_f64(v)
163                    {
164                        *sum += num;
165                    }
166                }
167            }
168            AggregateState::Avg(sum, count) => {
169                if let Some(ref v) = value
170                    && let Some(num) = value_to_f64(v)
171                {
172                    *sum += num;
173                    *count += 1;
174                }
175            }
176            AggregateState::AvgDistinct(sum, count, seen) => {
177                if let Some(ref v) = value {
178                    let hashable = HashableValue::from(v);
179                    if seen.insert(hashable)
180                        && let Some(num) = value_to_f64(v)
181                    {
182                        *sum += num;
183                        *count += 1;
184                    }
185                }
186            }
187            AggregateState::Min(min) => {
188                if let Some(v) = value {
189                    match min {
190                        None => *min = Some(v),
191                        Some(current) => {
192                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
193                                *min = Some(v);
194                            }
195                        }
196                    }
197                }
198            }
199            AggregateState::Max(max) => {
200                if let Some(v) = value {
201                    match max {
202                        None => *max = Some(v),
203                        Some(current) => {
204                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
205                                *max = Some(v);
206                            }
207                        }
208                    }
209                }
210            }
211            AggregateState::First(first) => {
212                if first.is_none() {
213                    *first = value;
214                }
215            }
216            AggregateState::Last(last) => {
217                if value.is_some() {
218                    *last = value;
219                }
220            }
221            AggregateState::Collect(list) => {
222                if let Some(v) = value {
223                    list.push(v);
224                }
225            }
226            AggregateState::CollectDistinct(list, seen) => {
227                if let Some(v) = value {
228                    let hashable = HashableValue::from(&v);
229                    if seen.insert(hashable) {
230                        list.push(v);
231                    }
232                }
233            }
234            // Statistical functions using Welford's online algorithm
235            AggregateState::StdDev { count, mean, m2 }
236            | AggregateState::StdDevPop { count, mean, m2 } => {
237                if let Some(ref v) = value
238                    && let Some(x) = value_to_f64(v)
239                {
240                    *count += 1;
241                    let delta = x - *mean;
242                    *mean += delta / *count as f64;
243                    let delta2 = x - *mean;
244                    *m2 += delta * delta2;
245                }
246            }
247            AggregateState::PercentileDisc { values, .. }
248            | AggregateState::PercentileCont { values, .. } => {
249                if let Some(ref v) = value
250                    && let Some(x) = value_to_f64(v)
251                {
252                    values.push(x);
253                }
254            }
255        }
256    }
257
258    /// Finalizes the state and returns the result value.
259    fn finalize(&self) -> Value {
260        match self {
261            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
262                Value::Int64(*count)
263            }
264            AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
265                Value::Int64(*sum)
266            }
267            AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
268                Value::Float64(*sum)
269            }
270            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
271                if *count == 0 {
272                    Value::Null
273                } else {
274                    Value::Float64(*sum / *count as f64)
275                }
276            }
277            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
278            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
279            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
280            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
281            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
282                Value::List(list.clone().into())
283            }
284            // Sample standard deviation: sqrt(M2 / (n - 1))
285            AggregateState::StdDev { count, m2, .. } => {
286                if *count < 2 {
287                    Value::Null
288                } else {
289                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
290                }
291            }
292            // Population standard deviation: sqrt(M2 / n)
293            AggregateState::StdDevPop { count, m2, .. } => {
294                if *count == 0 {
295                    Value::Null
296                } else {
297                    Value::Float64((*m2 / *count as f64).sqrt())
298                }
299            }
300            // Discrete percentile: return actual value at percentile position
301            AggregateState::PercentileDisc { values, percentile } => {
302                if values.is_empty() {
303                    Value::Null
304                } else {
305                    let mut sorted = values.clone();
306                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
307                    // Index calculation per SQL standard: floor(p * (n - 1))
308                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
309                    Value::Float64(sorted[index])
310                }
311            }
312            // Continuous percentile: interpolate between values
313            AggregateState::PercentileCont { values, percentile } => {
314                if values.is_empty() {
315                    Value::Null
316                } else {
317                    let mut sorted = values.clone();
318                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
319                    // Linear interpolation per SQL standard
320                    let rank = percentile * (sorted.len() - 1) as f64;
321                    let lower_idx = rank.floor() as usize;
322                    let upper_idx = rank.ceil() as usize;
323                    if lower_idx == upper_idx {
324                        Value::Float64(sorted[lower_idx])
325                    } else {
326                        let fraction = rank - lower_idx as f64;
327                        let result =
328                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
329                        Value::Float64(result)
330                    }
331                }
332            }
333        }
334    }
335}
336
337use super::value_utils::{compare_values, value_to_f64};
338
339/// A group key for hash-based aggregation.
340#[derive(Debug, Clone, PartialEq, Eq, Hash)]
341pub struct GroupKey(Vec<GroupKeyPart>);
342
343#[derive(Debug, Clone, PartialEq, Eq, Hash)]
344enum GroupKeyPart {
345    Null,
346    Bool(bool),
347    Int64(i64),
348    String(String),
349}
350
351impl GroupKey {
352    /// Creates a group key from column values.
353    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
354        let parts: Vec<GroupKeyPart> = group_columns
355            .iter()
356            .map(|&col_idx| {
357                chunk
358                    .column(col_idx)
359                    .and_then(|col| col.get_value(row))
360                    .map_or(GroupKeyPart::Null, |v| match v {
361                        Value::Null => GroupKeyPart::Null,
362                        Value::Bool(b) => GroupKeyPart::Bool(b),
363                        Value::Int64(i) => GroupKeyPart::Int64(i),
364                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
365                        Value::String(s) => GroupKeyPart::String(s.to_string()),
366                        _ => GroupKeyPart::String(format!("{v:?}")),
367                    })
368            })
369            .collect();
370        GroupKey(parts)
371    }
372
373    /// Converts the group key back to values.
374    fn to_values(&self) -> Vec<Value> {
375        self.0
376            .iter()
377            .map(|part| match part {
378                GroupKeyPart::Null => Value::Null,
379                GroupKeyPart::Bool(b) => Value::Bool(*b),
380                GroupKeyPart::Int64(i) => Value::Int64(*i),
381                GroupKeyPart::String(s) => Value::String(s.clone().into()),
382            })
383            .collect()
384    }
385}
386
387/// Hash-based aggregate operator.
388///
389/// Groups input by key columns and computes aggregations for each group.
390pub struct HashAggregateOperator {
391    /// Child operator to read from.
392    child: Box<dyn Operator>,
393    /// Columns to group by.
394    group_columns: Vec<usize>,
395    /// Aggregation expressions.
396    aggregates: Vec<AggregateExpr>,
397    /// Output schema.
398    output_schema: Vec<LogicalType>,
399    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
400    groups: IndexMap<GroupKey, Vec<AggregateState>>,
401    /// Whether aggregation is complete.
402    aggregation_complete: bool,
403    /// Results iterator.
404    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
405}
406
407impl HashAggregateOperator {
408    /// Creates a new hash aggregate operator.
409    ///
410    /// # Arguments
411    /// * `child` - Child operator to read from.
412    /// * `group_columns` - Column indices to group by.
413    /// * `aggregates` - Aggregation expressions.
414    /// * `output_schema` - Schema of the output (group columns + aggregate results).
415    pub fn new(
416        child: Box<dyn Operator>,
417        group_columns: Vec<usize>,
418        aggregates: Vec<AggregateExpr>,
419        output_schema: Vec<LogicalType>,
420    ) -> Self {
421        Self {
422            child,
423            group_columns,
424            aggregates,
425            output_schema,
426            groups: IndexMap::new(),
427            aggregation_complete: false,
428            results: None,
429        }
430    }
431
432    /// Performs the aggregation.
433    fn aggregate(&mut self) -> Result<(), OperatorError> {
434        while let Some(chunk) = self.child.next()? {
435            for row in chunk.selected_indices() {
436                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
437
438                // Get or create aggregate states for this group
439                let states = self.groups.entry(key).or_insert_with(|| {
440                    self.aggregates
441                        .iter()
442                        .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
443                        .collect()
444                });
445
446                // Update each aggregate
447                for (i, agg) in self.aggregates.iter().enumerate() {
448                    let value = match (agg.function, agg.distinct) {
449                        // COUNT(*) without DISTINCT doesn't need a value
450                        (AggregateFunction::Count, false) => None,
451                        // COUNT DISTINCT needs the actual value to track unique values
452                        (AggregateFunction::Count, true) => agg
453                            .column
454                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
455                        _ => agg
456                            .column
457                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
458                    };
459
460                    // For COUNT without DISTINCT, always update. For others, skip nulls.
461                    match (agg.function, agg.distinct) {
462                        (AggregateFunction::Count, false) => states[i].update(None),
463                        (AggregateFunction::Count, true) => {
464                            // COUNT DISTINCT needs the value to track unique values
465                            if value.is_some() && !matches!(value, Some(Value::Null)) {
466                                states[i].update(value);
467                            }
468                        }
469                        (AggregateFunction::CountNonNull, _) => {
470                            if value.is_some() && !matches!(value, Some(Value::Null)) {
471                                states[i].update(value);
472                            }
473                        }
474                        _ => {
475                            if value.is_some() && !matches!(value, Some(Value::Null)) {
476                                states[i].update(value);
477                            }
478                        }
479                    }
480                }
481            }
482        }
483
484        self.aggregation_complete = true;
485
486        // Convert to results iterator (IndexMap::drain takes a range)
487        let results: Vec<_> = self.groups.drain(..).collect();
488        self.results = Some(results.into_iter());
489
490        Ok(())
491    }
492}
493
494impl Operator for HashAggregateOperator {
495    fn next(&mut self) -> OperatorResult {
496        // Perform aggregation if not done
497        if !self.aggregation_complete {
498            self.aggregate()?;
499        }
500
501        // Special case: no groups (global aggregation with no data)
502        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
503            // For global aggregation (no GROUP BY), return one row with initial values
504            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
505
506            for agg in &self.aggregates {
507                let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
508                let value = state.finalize();
509                if let Some(col) = builder.column_mut(self.group_columns.len()) {
510                    col.push_value(value);
511                }
512            }
513            builder.advance_row();
514
515            self.results = Some(Vec::new().into_iter()); // Mark as done
516            return Ok(Some(builder.finish()));
517        }
518
519        let Some(results) = &mut self.results else {
520            return Ok(None);
521        };
522
523        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
524
525        for (key, states) in results.by_ref() {
526            // Output group key columns
527            let key_values = key.to_values();
528            for (i, value) in key_values.into_iter().enumerate() {
529                if let Some(col) = builder.column_mut(i) {
530                    col.push_value(value);
531                }
532            }
533
534            // Output aggregate results
535            for (i, state) in states.iter().enumerate() {
536                let col_idx = self.group_columns.len() + i;
537                if let Some(col) = builder.column_mut(col_idx) {
538                    col.push_value(state.finalize());
539                }
540            }
541
542            builder.advance_row();
543
544            if builder.is_full() {
545                return Ok(Some(builder.finish()));
546            }
547        }
548
549        if builder.row_count() > 0 {
550            Ok(Some(builder.finish()))
551        } else {
552            Ok(None)
553        }
554    }
555
556    fn reset(&mut self) {
557        self.child.reset();
558        self.groups.clear();
559        self.aggregation_complete = false;
560        self.results = None;
561    }
562
563    fn name(&self) -> &'static str {
564        "HashAggregate"
565    }
566}
567
568/// Simple (non-grouping) aggregate operator for global aggregations.
569///
570/// Used when there's no GROUP BY clause - aggregates all input into a single row.
571pub struct SimpleAggregateOperator {
572    /// Child operator.
573    child: Box<dyn Operator>,
574    /// Aggregation expressions.
575    aggregates: Vec<AggregateExpr>,
576    /// Output schema.
577    output_schema: Vec<LogicalType>,
578    /// Aggregate states.
579    states: Vec<AggregateState>,
580    /// Whether aggregation is complete.
581    done: bool,
582}
583
584impl SimpleAggregateOperator {
585    /// Creates a new simple aggregate operator.
586    pub fn new(
587        child: Box<dyn Operator>,
588        aggregates: Vec<AggregateExpr>,
589        output_schema: Vec<LogicalType>,
590    ) -> Self {
591        let states = aggregates
592            .iter()
593            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
594            .collect();
595
596        Self {
597            child,
598            aggregates,
599            output_schema,
600            states,
601            done: false,
602        }
603    }
604}
605
606impl Operator for SimpleAggregateOperator {
607    fn next(&mut self) -> OperatorResult {
608        if self.done {
609            return Ok(None);
610        }
611
612        // Process all input
613        while let Some(chunk) = self.child.next()? {
614            for row in chunk.selected_indices() {
615                for (i, agg) in self.aggregates.iter().enumerate() {
616                    let value = match (agg.function, agg.distinct) {
617                        // COUNT(*) without DISTINCT doesn't need a value
618                        (AggregateFunction::Count, false) => None,
619                        // COUNT DISTINCT needs the actual value to track unique values
620                        (AggregateFunction::Count, true) => agg
621                            .column
622                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
623                        _ => agg
624                            .column
625                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
626                    };
627
628                    match (agg.function, agg.distinct) {
629                        (AggregateFunction::Count, false) => self.states[i].update(None),
630                        (AggregateFunction::Count, true) => {
631                            // COUNT DISTINCT needs the value to track unique values
632                            if value.is_some() && !matches!(value, Some(Value::Null)) {
633                                self.states[i].update(value);
634                            }
635                        }
636                        (AggregateFunction::CountNonNull, _) => {
637                            if value.is_some() && !matches!(value, Some(Value::Null)) {
638                                self.states[i].update(value);
639                            }
640                        }
641                        _ => {
642                            if value.is_some() && !matches!(value, Some(Value::Null)) {
643                                self.states[i].update(value);
644                            }
645                        }
646                    }
647                }
648            }
649        }
650
651        // Output single result row
652        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
653
654        for (i, state) in self.states.iter().enumerate() {
655            if let Some(col) = builder.column_mut(i) {
656                col.push_value(state.finalize());
657            }
658        }
659        builder.advance_row();
660
661        self.done = true;
662        Ok(Some(builder.finish()))
663    }
664
665    fn reset(&mut self) {
666        self.child.reset();
667        self.states = self
668            .aggregates
669            .iter()
670            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
671            .collect();
672        self.done = false;
673    }
674
675    fn name(&self) -> &'static str {
676        "SimpleAggregate"
677    }
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683    use crate::execution::chunk::DataChunkBuilder;
684
685    struct MockOperator {
686        chunks: Vec<DataChunk>,
687        position: usize,
688    }
689
690    impl MockOperator {
691        fn new(chunks: Vec<DataChunk>) -> Self {
692            Self {
693                chunks,
694                position: 0,
695            }
696        }
697    }
698
699    impl Operator for MockOperator {
700        fn next(&mut self) -> OperatorResult {
701            if self.position < self.chunks.len() {
702                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
703                self.position += 1;
704                Ok(Some(chunk))
705            } else {
706                Ok(None)
707            }
708        }
709
710        fn reset(&mut self) {
711            self.position = 0;
712        }
713
714        fn name(&self) -> &'static str {
715            "Mock"
716        }
717    }
718
719    fn create_test_chunk() -> DataChunk {
720        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
721        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
722
723        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
724        for (group, value) in data {
725            builder.column_mut(0).unwrap().push_int64(group);
726            builder.column_mut(1).unwrap().push_int64(value);
727            builder.advance_row();
728        }
729
730        builder.finish()
731    }
732
733    #[test]
734    fn test_simple_count() {
735        let mock = MockOperator::new(vec![create_test_chunk()]);
736
737        let mut agg = SimpleAggregateOperator::new(
738            Box::new(mock),
739            vec![AggregateExpr::count_star()],
740            vec![LogicalType::Int64],
741        );
742
743        let result = agg.next().unwrap().unwrap();
744        assert_eq!(result.row_count(), 1);
745        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
746
747        // Should be done
748        assert!(agg.next().unwrap().is_none());
749    }
750
751    #[test]
752    fn test_simple_sum() {
753        let mock = MockOperator::new(vec![create_test_chunk()]);
754
755        let mut agg = SimpleAggregateOperator::new(
756            Box::new(mock),
757            vec![AggregateExpr::sum(1)], // Sum of column 1
758            vec![LogicalType::Int64],
759        );
760
761        let result = agg.next().unwrap().unwrap();
762        assert_eq!(result.row_count(), 1);
763        // Sum: 10 + 20 + 30 + 40 + 50 = 150
764        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
765    }
766
767    #[test]
768    fn test_simple_avg() {
769        let mock = MockOperator::new(vec![create_test_chunk()]);
770
771        let mut agg = SimpleAggregateOperator::new(
772            Box::new(mock),
773            vec![AggregateExpr::avg(1)],
774            vec![LogicalType::Float64],
775        );
776
777        let result = agg.next().unwrap().unwrap();
778        assert_eq!(result.row_count(), 1);
779        // Avg: 150 / 5 = 30.0
780        let avg = result.column(0).unwrap().get_float64(0).unwrap();
781        assert!((avg - 30.0).abs() < 0.001);
782    }
783
784    #[test]
785    fn test_simple_min_max() {
786        let mock = MockOperator::new(vec![create_test_chunk()]);
787
788        let mut agg = SimpleAggregateOperator::new(
789            Box::new(mock),
790            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
791            vec![LogicalType::Int64, LogicalType::Int64],
792        );
793
794        let result = agg.next().unwrap().unwrap();
795        assert_eq!(result.row_count(), 1);
796        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
797        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
798    }
799
800    #[test]
801    fn test_sum_with_string_values() {
802        // Test SUM with string values (like RDF stores numeric literals)
803        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
804        builder.column_mut(0).unwrap().push_string("30");
805        builder.advance_row();
806        builder.column_mut(0).unwrap().push_string("25");
807        builder.advance_row();
808        builder.column_mut(0).unwrap().push_string("35");
809        builder.advance_row();
810        let chunk = builder.finish();
811
812        let mock = MockOperator::new(vec![chunk]);
813        let mut agg = SimpleAggregateOperator::new(
814            Box::new(mock),
815            vec![AggregateExpr::sum(0)],
816            vec![LogicalType::Float64],
817        );
818
819        let result = agg.next().unwrap().unwrap();
820        assert_eq!(result.row_count(), 1);
821        // Should parse strings and sum: 30 + 25 + 35 = 90
822        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
823        assert!(
824            (sum_val - 90.0).abs() < 0.001,
825            "Expected 90.0, got {}",
826            sum_val
827        );
828    }
829
830    #[test]
831    fn test_grouped_aggregation() {
832        let mock = MockOperator::new(vec![create_test_chunk()]);
833
834        // GROUP BY column 0, SUM(column 1)
835        let mut agg = HashAggregateOperator::new(
836            Box::new(mock),
837            vec![0],                     // Group by column 0
838            vec![AggregateExpr::sum(1)], // Sum of column 1
839            vec![LogicalType::Int64, LogicalType::Int64],
840        );
841
842        let mut results: Vec<(i64, i64)> = Vec::new();
843        while let Some(chunk) = agg.next().unwrap() {
844            for row in chunk.selected_indices() {
845                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
846                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
847                results.push((group, sum));
848            }
849        }
850
851        results.sort_by_key(|(g, _)| *g);
852        assert_eq!(results.len(), 2);
853        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
854        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
855    }
856
857    #[test]
858    fn test_grouped_count() {
859        let mock = MockOperator::new(vec![create_test_chunk()]);
860
861        // GROUP BY column 0, COUNT(*)
862        let mut agg = HashAggregateOperator::new(
863            Box::new(mock),
864            vec![0],
865            vec![AggregateExpr::count_star()],
866            vec![LogicalType::Int64, LogicalType::Int64],
867        );
868
869        let mut results: Vec<(i64, i64)> = Vec::new();
870        while let Some(chunk) = agg.next().unwrap() {
871            for row in chunk.selected_indices() {
872                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
873                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
874                results.push((group, count));
875            }
876        }
877
878        results.sort_by_key(|(g, _)| *g);
879        assert_eq!(results.len(), 2);
880        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
881        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
882    }
883
884    #[test]
885    fn test_multiple_aggregates() {
886        let mock = MockOperator::new(vec![create_test_chunk()]);
887
888        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
889        let mut agg = HashAggregateOperator::new(
890            Box::new(mock),
891            vec![0],
892            vec![
893                AggregateExpr::count_star(),
894                AggregateExpr::sum(1),
895                AggregateExpr::avg(1),
896            ],
897            vec![
898                LogicalType::Int64,   // Group key
899                LogicalType::Int64,   // COUNT
900                LogicalType::Int64,   // SUM
901                LogicalType::Float64, // AVG
902            ],
903        );
904
905        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
906        while let Some(chunk) = agg.next().unwrap() {
907            for row in chunk.selected_indices() {
908                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
909                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
910                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
911                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
912                results.push((group, count, sum, avg));
913            }
914        }
915
916        results.sort_by_key(|(g, _, _, _)| *g);
917        assert_eq!(results.len(), 2);
918
919        // Group 1: COUNT=2, SUM=30, AVG=15.0
920        assert_eq!(results[0].0, 1);
921        assert_eq!(results[0].1, 2);
922        assert_eq!(results[0].2, 30);
923        assert!((results[0].3 - 15.0).abs() < 0.001);
924
925        // Group 2: COUNT=3, SUM=120, AVG=40.0
926        assert_eq!(results[1].0, 2);
927        assert_eq!(results[1].1, 3);
928        assert_eq!(results[1].2, 120);
929        assert!((results[1].3 - 40.0).abs() < 0.001);
930    }
931
932    fn create_test_chunk_with_duplicates() -> DataChunk {
933        // Create data with duplicate values in column 1
934        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
935        // GROUP 1: values [10, 10, 20] -> distinct count = 2
936        // GROUP 2: values [30, 30, 30] -> distinct count = 1
937        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
938
939        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
940        for (group, value) in data {
941            builder.column_mut(0).unwrap().push_int64(group);
942            builder.column_mut(1).unwrap().push_int64(value);
943            builder.advance_row();
944        }
945
946        builder.finish()
947    }
948
949    #[test]
950    fn test_count_distinct() {
951        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
952
953        // COUNT(DISTINCT column 1)
954        let mut agg = SimpleAggregateOperator::new(
955            Box::new(mock),
956            vec![AggregateExpr::count(1).with_distinct()],
957            vec![LogicalType::Int64],
958        );
959
960        let result = agg.next().unwrap().unwrap();
961        assert_eq!(result.row_count(), 1);
962        // Total distinct values: 10, 20, 30 = 3 distinct values
963        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
964    }
965
966    #[test]
967    fn test_grouped_count_distinct() {
968        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
969
970        // GROUP BY column 0, COUNT(DISTINCT column 1)
971        let mut agg = HashAggregateOperator::new(
972            Box::new(mock),
973            vec![0],
974            vec![AggregateExpr::count(1).with_distinct()],
975            vec![LogicalType::Int64, LogicalType::Int64],
976        );
977
978        let mut results: Vec<(i64, i64)> = Vec::new();
979        while let Some(chunk) = agg.next().unwrap() {
980            for row in chunk.selected_indices() {
981                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
982                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
983                results.push((group, count));
984            }
985        }
986
987        results.sort_by_key(|(g, _)| *g);
988        assert_eq!(results.len(), 2);
989        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
990        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
991    }
992
993    #[test]
994    fn test_sum_distinct() {
995        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
996
997        // SUM(DISTINCT column 1)
998        let mut agg = SimpleAggregateOperator::new(
999            Box::new(mock),
1000            vec![AggregateExpr::sum(1).with_distinct()],
1001            vec![LogicalType::Int64],
1002        );
1003
1004        let result = agg.next().unwrap().unwrap();
1005        assert_eq!(result.row_count(), 1);
1006        // Sum of distinct values: 10 + 20 + 30 = 60
1007        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1008    }
1009
1010    #[test]
1011    fn test_avg_distinct() {
1012        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1013
1014        // AVG(DISTINCT column 1)
1015        let mut agg = SimpleAggregateOperator::new(
1016            Box::new(mock),
1017            vec![AggregateExpr::avg(1).with_distinct()],
1018            vec![LogicalType::Float64],
1019        );
1020
1021        let result = agg.next().unwrap().unwrap();
1022        assert_eq!(result.row_count(), 1);
1023        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1024        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1025        assert!((avg - 20.0).abs() < 0.001);
1026    }
1027
1028    fn create_statistical_test_chunk() -> DataChunk {
1029        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1030        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1031        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1032
1033        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1034            builder.column_mut(0).unwrap().push_int64(value);
1035            builder.advance_row();
1036        }
1037
1038        builder.finish()
1039    }
1040
1041    #[test]
1042    fn test_stdev_sample() {
1043        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1044
1045        let mut agg = SimpleAggregateOperator::new(
1046            Box::new(mock),
1047            vec![AggregateExpr::stdev(0)],
1048            vec![LogicalType::Float64],
1049        );
1050
1051        let result = agg.next().unwrap().unwrap();
1052        assert_eq!(result.row_count(), 1);
1053        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1054        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1055        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1056        assert!((stdev - 2.138).abs() < 0.01);
1057    }
1058
1059    #[test]
1060    fn test_stdev_population() {
1061        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1062
1063        let mut agg = SimpleAggregateOperator::new(
1064            Box::new(mock),
1065            vec![AggregateExpr::stdev_pop(0)],
1066            vec![LogicalType::Float64],
1067        );
1068
1069        let result = agg.next().unwrap().unwrap();
1070        assert_eq!(result.row_count(), 1);
1071        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1072        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1073        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1074        assert!((stdev - 2.0).abs() < 0.01);
1075    }
1076
1077    #[test]
1078    fn test_percentile_disc() {
1079        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1080
1081        // Median (50th percentile discrete)
1082        let mut agg = SimpleAggregateOperator::new(
1083            Box::new(mock),
1084            vec![AggregateExpr::percentile_disc(0, 0.5)],
1085            vec![LogicalType::Float64],
1086        );
1087
1088        let result = agg.next().unwrap().unwrap();
1089        assert_eq!(result.row_count(), 1);
1090        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1091        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1092        assert!((percentile - 4.0).abs() < 0.01);
1093    }
1094
1095    #[test]
1096    fn test_percentile_cont() {
1097        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1098
1099        // Median (50th percentile continuous)
1100        let mut agg = SimpleAggregateOperator::new(
1101            Box::new(mock),
1102            vec![AggregateExpr::percentile_cont(0, 0.5)],
1103            vec![LogicalType::Float64],
1104        );
1105
1106        let result = agg.next().unwrap().unwrap();
1107        assert_eq!(result.row_count(), 1);
1108        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1109        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1110        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1111        assert!((percentile - 4.5).abs() < 0.01);
1112    }
1113
1114    #[test]
1115    fn test_percentile_extremes() {
1116        // Test 0th and 100th percentiles
1117        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1118
1119        let mut agg = SimpleAggregateOperator::new(
1120            Box::new(mock),
1121            vec![
1122                AggregateExpr::percentile_disc(0, 0.0),
1123                AggregateExpr::percentile_disc(0, 1.0),
1124            ],
1125            vec![LogicalType::Float64, LogicalType::Float64],
1126        );
1127
1128        let result = agg.next().unwrap().unwrap();
1129        assert_eq!(result.row_count(), 1);
1130        // 0th percentile = minimum = 2
1131        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1132        assert!((p0 - 2.0).abs() < 0.01);
1133        // 100th percentile = maximum = 9
1134        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1135        assert!((p100 - 9.0).abs() < 0.01);
1136    }
1137
1138    #[test]
1139    fn test_stdev_single_value() {
1140        // Single value should return null for sample stdev
1141        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1142        builder.column_mut(0).unwrap().push_int64(42);
1143        builder.advance_row();
1144        let chunk = builder.finish();
1145
1146        let mock = MockOperator::new(vec![chunk]);
1147
1148        let mut agg = SimpleAggregateOperator::new(
1149            Box::new(mock),
1150            vec![AggregateExpr::stdev(0)],
1151            vec![LogicalType::Float64],
1152        );
1153
1154        let result = agg.next().unwrap().unwrap();
1155        assert_eq!(result.row_count(), 1);
1156        // Sample stdev of single value is undefined (null)
1157        assert!(matches!(
1158            result.column(0).unwrap().get_value(0),
1159            Some(Value::Null)
1160        ));
1161    }
1162
1163    #[test]
1164    fn test_stdev_pop_single_value() {
1165        // Single value should return 0 for population stdev
1166        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1167        builder.column_mut(0).unwrap().push_int64(42);
1168        builder.advance_row();
1169        let chunk = builder.finish();
1170
1171        let mock = MockOperator::new(vec![chunk]);
1172
1173        let mut agg = SimpleAggregateOperator::new(
1174            Box::new(mock),
1175            vec![AggregateExpr::stdev_pop(0)],
1176            vec![LogicalType::Float64],
1177        );
1178
1179        let result = agg.next().unwrap().unwrap();
1180        assert_eq!(result.row_count(), 1);
1181        // Population stdev of single value is 0
1182        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1183        assert!((stdev - 0.0).abs() < 0.01);
1184    }
1185}