Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22/// State for a single aggregation computation.
23#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25    /// Count state.
26    Count(i64),
27    /// Count distinct state (count, seen values).
28    CountDistinct(i64, HashSet<HashableValue>),
29    /// Sum state (integer sum, count of values added).
30    SumInt(i64, i64),
31    /// Sum distinct state (integer sum, count, seen values).
32    SumIntDistinct(i64, i64, HashSet<HashableValue>),
33    /// Sum state (float sum, count of values added).
34    SumFloat(f64, i64),
35    /// Sum distinct state (float sum, count, seen values).
36    SumFloatDistinct(f64, i64, HashSet<HashableValue>),
37    /// Average state (sum, count).
38    Avg(f64, i64),
39    /// Average distinct state (sum, count, seen values).
40    AvgDistinct(f64, i64, HashSet<HashableValue>),
41    /// Min state.
42    Min(Option<Value>),
43    /// Max state.
44    Max(Option<Value>),
45    /// First state.
46    First(Option<Value>),
47    /// Last state.
48    Last(Option<Value>),
49    /// Collect state.
50    Collect(Vec<Value>),
51    /// Collect distinct state (values, seen).
52    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDev { count: i64, mean: f64, m2: f64 },
55    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
56    StdDevPop { count: i64, mean: f64, m2: f64 },
57    /// Discrete percentile state (values, percentile).
58    PercentileDisc { values: Vec<f64>, percentile: f64 },
59    /// Continuous percentile state (values, percentile).
60    PercentileCont { values: Vec<f64>, percentile: f64 },
61    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
62    GroupConcat(Vec<String>, String),
63    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
64    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65    /// SAMPLE state (first non-null value encountered).
66    Sample(Option<Value>),
67    /// Sample variance state using Welford's algorithm (count, mean, M2).
68    Variance { count: i64, mean: f64, m2: f64 },
69    /// Population variance state using Welford's algorithm (count, mean, M2).
70    VariancePop { count: i64, mean: f64, m2: f64 },
71    /// Two-variable online statistics (Welford generalization for covariance/regression).
72    Bivariate {
73        /// Which binary set function this state will finalize to.
74        kind: AggregateFunction,
75        count: i64,
76        mean_x: f64,
77        mean_y: f64,
78        m2_x: f64,
79        m2_y: f64,
80        c_xy: f64,
81    },
82}
83
84impl AggregateState {
85    /// Creates initial state for an aggregation function.
86    pub(crate) fn new(
87        function: AggregateFunction,
88        distinct: bool,
89        percentile: Option<f64>,
90        separator: Option<&str>,
91    ) -> Self {
92        match (function, distinct) {
93            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94                AggregateState::Count(0)
95            }
96            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97                AggregateState::CountDistinct(0, HashSet::new())
98            }
99            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
104            (AggregateFunction::Max, _) => AggregateState::Max(None),
105            (AggregateFunction::First, _) => AggregateState::First(None),
106            (AggregateFunction::Last, _) => AggregateState::Last(None),
107            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108            (AggregateFunction::Collect, true) => {
109                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110            }
111            // Statistical functions (Welford's algorithm for online computation)
112            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113                count: 0,
114                mean: 0.0,
115                m2: 0.0,
116            },
117            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118                count: 0,
119                mean: 0.0,
120                m2: 0.0,
121            },
122            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123                values: Vec::new(),
124                percentile: percentile.unwrap_or(0.5),
125            },
126            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127                values: Vec::new(),
128                percentile: percentile.unwrap_or(0.5),
129            },
130            (AggregateFunction::GroupConcat, false) => {
131                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132            }
133            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134                Vec::new(),
135                separator.unwrap_or(" ").to_string(),
136                HashSet::new(),
137            ),
138            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139            // Binary set functions (all share the same Bivariate state)
140            (
141                AggregateFunction::CovarSamp
142                | AggregateFunction::CovarPop
143                | AggregateFunction::Corr
144                | AggregateFunction::RegrSlope
145                | AggregateFunction::RegrIntercept
146                | AggregateFunction::RegrR2
147                | AggregateFunction::RegrCount
148                | AggregateFunction::RegrSxx
149                | AggregateFunction::RegrSyy
150                | AggregateFunction::RegrSxy
151                | AggregateFunction::RegrAvgx
152                | AggregateFunction::RegrAvgy,
153                _,
154            ) => AggregateState::Bivariate {
155                kind: function,
156                count: 0,
157                mean_x: 0.0,
158                mean_y: 0.0,
159                m2_x: 0.0,
160                m2_y: 0.0,
161                c_xy: 0.0,
162            },
163            (AggregateFunction::Variance, _) => AggregateState::Variance {
164                count: 0,
165                mean: 0.0,
166                m2: 0.0,
167            },
168            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169                count: 0,
170                mean: 0.0,
171                m2: 0.0,
172            },
173        }
174    }
175
176    /// Updates the state with a new value.
177    pub(crate) fn update(&mut self, value: Option<Value>) {
178        match self {
179            AggregateState::Count(count) => {
180                *count += 1;
181            }
182            AggregateState::CountDistinct(count, seen) => {
183                if let Some(ref v) = value {
184                    let hashable = HashableValue::from(v);
185                    if seen.insert(hashable) {
186                        *count += 1;
187                    }
188                }
189            }
190            AggregateState::SumInt(sum, count) => {
191                if let Some(Value::Int64(v)) = value {
192                    *sum += v;
193                    *count += 1;
194                } else if let Some(Value::Float64(v)) = value {
195                    // Convert to float sum, carrying count forward
196                    *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
197                } else if let Some(ref v) = value {
198                    // RDF stores numeric literals as strings - try to parse
199                    if let Some(num) = value_to_f64(v) {
200                        *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
201                    }
202                }
203            }
204            AggregateState::SumIntDistinct(sum, count, seen) => {
205                if let Some(ref v) = value {
206                    let hashable = HashableValue::from(v);
207                    if seen.insert(hashable) {
208                        if let Value::Int64(i) = v {
209                            *sum += i;
210                            *count += 1;
211                        } else if let Value::Float64(f) = v {
212                            // Convert to float distinct: move the seen set instead of cloning
213                            let moved_seen = std::mem::take(seen);
214                            *self = AggregateState::SumFloatDistinct(
215                                *sum as f64 + f,
216                                *count + 1,
217                                moved_seen,
218                            );
219                        } else if let Some(num) = value_to_f64(v) {
220                            // RDF string-encoded numerics
221                            let moved_seen = std::mem::take(seen);
222                            *self = AggregateState::SumFloatDistinct(
223                                *sum as f64 + num,
224                                *count + 1,
225                                moved_seen,
226                            );
227                        }
228                    }
229                }
230            }
231            AggregateState::SumFloat(sum, count) => {
232                if let Some(ref v) = value {
233                    // Use value_to_f64 which now handles strings
234                    if let Some(num) = value_to_f64(v) {
235                        *sum += num;
236                        *count += 1;
237                    }
238                }
239            }
240            AggregateState::SumFloatDistinct(sum, count, seen) => {
241                if let Some(ref v) = value {
242                    let hashable = HashableValue::from(v);
243                    if seen.insert(hashable)
244                        && let Some(num) = value_to_f64(v)
245                    {
246                        *sum += num;
247                        *count += 1;
248                    }
249                }
250            }
251            AggregateState::Avg(sum, count) => {
252                if let Some(ref v) = value
253                    && let Some(num) = value_to_f64(v)
254                {
255                    *sum += num;
256                    *count += 1;
257                }
258            }
259            AggregateState::AvgDistinct(sum, count, seen) => {
260                if let Some(ref v) = value {
261                    let hashable = HashableValue::from(v);
262                    if seen.insert(hashable)
263                        && let Some(num) = value_to_f64(v)
264                    {
265                        *sum += num;
266                        *count += 1;
267                    }
268                }
269            }
270            AggregateState::Min(min) => {
271                if let Some(v) = value {
272                    match min {
273                        None => *min = Some(v),
274                        Some(current) => {
275                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
276                                *min = Some(v);
277                            }
278                        }
279                    }
280                }
281            }
282            AggregateState::Max(max) => {
283                if let Some(v) = value {
284                    match max {
285                        None => *max = Some(v),
286                        Some(current) => {
287                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
288                                *max = Some(v);
289                            }
290                        }
291                    }
292                }
293            }
294            AggregateState::First(first) => {
295                if first.is_none() {
296                    *first = value;
297                }
298            }
299            AggregateState::Last(last) => {
300                if value.is_some() {
301                    *last = value;
302                }
303            }
304            AggregateState::Collect(list) => {
305                if let Some(v) = value {
306                    list.push(v);
307                }
308            }
309            AggregateState::CollectDistinct(list, seen) => {
310                if let Some(v) = value {
311                    let hashable = HashableValue::from(&v);
312                    if seen.insert(hashable) {
313                        list.push(v);
314                    }
315                }
316            }
317            // Statistical functions using Welford's online algorithm
318            AggregateState::StdDev { count, mean, m2 }
319            | AggregateState::StdDevPop { count, mean, m2 }
320            | AggregateState::Variance { count, mean, m2 }
321            | AggregateState::VariancePop { count, mean, m2 } => {
322                if let Some(ref v) = value
323                    && let Some(x) = value_to_f64(v)
324                {
325                    *count += 1;
326                    let delta = x - *mean;
327                    *mean += delta / *count as f64;
328                    let delta2 = x - *mean;
329                    *m2 += delta * delta2;
330                }
331            }
332            AggregateState::PercentileDisc { values, .. }
333            | AggregateState::PercentileCont { values, .. } => {
334                if let Some(ref v) = value
335                    && let Some(x) = value_to_f64(v)
336                {
337                    values.push(x);
338                }
339            }
340            AggregateState::GroupConcat(list, _sep) => {
341                if let Some(v) = value {
342                    list.push(agg_value_to_string(&v));
343                }
344            }
345            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
346                if let Some(v) = value {
347                    let hashable = HashableValue::from(&v);
348                    if seen.insert(hashable) {
349                        list.push(agg_value_to_string(&v));
350                    }
351                }
352            }
353            AggregateState::Sample(sample) => {
354                if sample.is_none() {
355                    *sample = value;
356                }
357            }
358            AggregateState::Bivariate { .. } => {
359                // Bivariate functions require two values; use update_bivariate() instead.
360                // Single-value update is a no-op for bivariate state.
361            }
362        }
363    }
364
365    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
366    ///
367    /// Uses the two-variable Welford online algorithm for numerically stable computation
368    /// of covariance and related statistics. Skips the update if either value is null.
369    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
370        if let AggregateState::Bivariate {
371            count,
372            mean_x,
373            mean_y,
374            m2_x,
375            m2_y,
376            c_xy,
377            ..
378        } = self
379        {
380            // Skip if either value is null (SQL semantics: exclude non-pairs)
381            if let (Some(y), Some(x)) = (&y_val, &x_val)
382                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
383            {
384                *count += 1;
385                let n = *count as f64;
386                let dx = x_f - *mean_x;
387                let dy = y_f - *mean_y;
388                *mean_x += dx / n;
389                *mean_y += dy / n;
390                let dx2 = x_f - *mean_x; // post-update delta
391                let dy2 = y_f - *mean_y; // post-update delta
392                *m2_x += dx * dx2;
393                *m2_y += dy * dy2;
394                *c_xy += dx * dy2;
395            }
396        }
397    }
398
399    /// Finalizes the state and returns the result value.
400    pub(crate) fn finalize(&self) -> Value {
401        match self {
402            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
403                Value::Int64(*count)
404            }
405            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
406                if *count == 0 {
407                    Value::Null
408                } else {
409                    Value::Int64(*sum)
410                }
411            }
412            AggregateState::SumFloat(sum, count)
413            | AggregateState::SumFloatDistinct(sum, count, _) => {
414                if *count == 0 {
415                    Value::Null
416                } else {
417                    Value::Float64(*sum)
418                }
419            }
420            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
421                if *count == 0 {
422                    Value::Null
423                } else {
424                    Value::Float64(*sum / *count as f64)
425                }
426            }
427            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
428            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
429            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
430            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
431            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
432                Value::List(list.clone().into())
433            }
434            // Sample standard deviation: sqrt(M2 / (n - 1))
435            AggregateState::StdDev { count, m2, .. } => {
436                if *count < 2 {
437                    Value::Null
438                } else {
439                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
440                }
441            }
442            // Population standard deviation: sqrt(M2 / n)
443            AggregateState::StdDevPop { count, m2, .. } => {
444                if *count == 0 {
445                    Value::Null
446                } else {
447                    Value::Float64((*m2 / *count as f64).sqrt())
448                }
449            }
450            // Sample variance: M2 / (n - 1)
451            AggregateState::Variance { count, m2, .. } => {
452                if *count < 2 {
453                    Value::Null
454                } else {
455                    Value::Float64(*m2 / (*count - 1) as f64)
456                }
457            }
458            // Population variance: M2 / n
459            AggregateState::VariancePop { count, m2, .. } => {
460                if *count == 0 {
461                    Value::Null
462                } else {
463                    Value::Float64(*m2 / *count as f64)
464                }
465            }
466            // Discrete percentile: return actual value at percentile position
467            AggregateState::PercentileDisc { values, percentile } => {
468                if values.is_empty() {
469                    Value::Null
470                } else {
471                    let mut sorted = values.clone();
472                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
473                    // Index calculation per SQL standard: floor(p * (n - 1))
474                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
475                    Value::Float64(sorted[index])
476                }
477            }
478            // Continuous percentile: interpolate between values
479            AggregateState::PercentileCont { values, percentile } => {
480                if values.is_empty() {
481                    Value::Null
482                } else {
483                    let mut sorted = values.clone();
484                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
485                    // Linear interpolation per SQL standard
486                    let rank = percentile * (sorted.len() - 1) as f64;
487                    let lower_idx = rank.floor() as usize;
488                    let upper_idx = rank.ceil() as usize;
489                    if lower_idx == upper_idx {
490                        Value::Float64(sorted[lower_idx])
491                    } else {
492                        let fraction = rank - lower_idx as f64;
493                        let result =
494                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
495                        Value::Float64(result)
496                    }
497                }
498            }
499            // GROUP_CONCAT: join strings with space separator (SPARQL default)
500            AggregateState::GroupConcat(list, sep)
501            | AggregateState::GroupConcatDistinct(list, sep, _) => {
502                Value::String(list.join(sep).into())
503            }
504            // SAMPLE: return the first non-null value seen
505            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
506            // Binary set functions: dispatch on kind
507            AggregateState::Bivariate {
508                kind,
509                count,
510                mean_x,
511                mean_y,
512                m2_x,
513                m2_y,
514                c_xy,
515            } => {
516                let n = *count;
517                match kind {
518                    AggregateFunction::CovarSamp => {
519                        if n < 2 {
520                            Value::Null
521                        } else {
522                            Value::Float64(*c_xy / (n - 1) as f64)
523                        }
524                    }
525                    AggregateFunction::CovarPop => {
526                        if n == 0 {
527                            Value::Null
528                        } else {
529                            Value::Float64(*c_xy / n as f64)
530                        }
531                    }
532                    AggregateFunction::Corr => {
533                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
534                            Value::Null
535                        } else {
536                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
537                        }
538                    }
539                    AggregateFunction::RegrSlope => {
540                        if n == 0 || *m2_x == 0.0 {
541                            Value::Null
542                        } else {
543                            Value::Float64(*c_xy / *m2_x)
544                        }
545                    }
546                    AggregateFunction::RegrIntercept => {
547                        if n == 0 || *m2_x == 0.0 {
548                            Value::Null
549                        } else {
550                            let slope = *c_xy / *m2_x;
551                            Value::Float64(*mean_y - slope * *mean_x)
552                        }
553                    }
554                    AggregateFunction::RegrR2 => {
555                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
556                            Value::Null
557                        } else {
558                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
559                        }
560                    }
561                    AggregateFunction::RegrCount => Value::Int64(n),
562                    AggregateFunction::RegrSxx => {
563                        if n == 0 {
564                            Value::Null
565                        } else {
566                            Value::Float64(*m2_x)
567                        }
568                    }
569                    AggregateFunction::RegrSyy => {
570                        if n == 0 {
571                            Value::Null
572                        } else {
573                            Value::Float64(*m2_y)
574                        }
575                    }
576                    AggregateFunction::RegrSxy => {
577                        if n == 0 {
578                            Value::Null
579                        } else {
580                            Value::Float64(*c_xy)
581                        }
582                    }
583                    AggregateFunction::RegrAvgx => {
584                        if n == 0 {
585                            Value::Null
586                        } else {
587                            Value::Float64(*mean_x)
588                        }
589                    }
590                    AggregateFunction::RegrAvgy => {
591                        if n == 0 {
592                            Value::Null
593                        } else {
594                            Value::Float64(*mean_y)
595                        }
596                    }
597                    _ => Value::Null, // non-bivariate functions never reach here
598                }
599            }
600        }
601    }
602}
603
604use super::value_utils::{compare_values, value_to_f64};
605
606/// Converts a Value to its string representation for GROUP_CONCAT.
607fn agg_value_to_string(val: &Value) -> String {
608    match val {
609        Value::String(s) => s.to_string(),
610        Value::Int64(i) => i.to_string(),
611        Value::Float64(f) => f.to_string(),
612        Value::Bool(b) => b.to_string(),
613        Value::Null => String::new(),
614        other => format!("{other:?}"),
615    }
616}
617
618/// A group key for hash-based aggregation.
619#[derive(Debug, Clone, PartialEq, Eq, Hash)]
620pub struct GroupKey(Vec<GroupKeyPart>);
621
622#[derive(Debug, Clone, PartialEq, Eq, Hash)]
623enum GroupKeyPart {
624    Null,
625    Bool(bool),
626    Int64(i64),
627    String(ArcStr),
628    List(Vec<GroupKeyPart>),
629}
630
631impl GroupKeyPart {
632    fn from_value(v: Value) -> Self {
633        match v {
634            Value::Null => Self::Null,
635            Value::Bool(b) => Self::Bool(b),
636            Value::Int64(i) => Self::Int64(i),
637            Value::Float64(f) => Self::Int64(f.to_bits() as i64),
638            Value::String(s) => Self::String(s.clone()),
639            Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
640            other => Self::String(ArcStr::from(format!("{other:?}"))),
641        }
642    }
643
644    fn to_value(&self) -> Value {
645        match self {
646            Self::Null => Value::Null,
647            Self::Bool(b) => Value::Bool(*b),
648            Self::Int64(i) => Value::Int64(*i),
649            Self::String(s) => Value::String(s.clone()),
650            Self::List(parts) => {
651                let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
652                Value::List(Arc::from(values.into_boxed_slice()))
653            }
654        }
655    }
656}
657
658impl GroupKey {
659    /// Creates a group key from column values.
660    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
661        let parts: Vec<GroupKeyPart> = group_columns
662            .iter()
663            .map(|&col_idx| {
664                chunk
665                    .column(col_idx)
666                    .and_then(|col| col.get_value(row))
667                    .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
668            })
669            .collect();
670        GroupKey(parts)
671    }
672
673    /// Converts the group key back to values.
674    fn to_values(&self) -> Vec<Value> {
675        self.0.iter().map(GroupKeyPart::to_value).collect()
676    }
677}
678
679/// Hash-based aggregate operator.
680///
681/// Groups input by key columns and computes aggregations for each group.
682pub struct HashAggregateOperator {
683    /// Child operator to read from.
684    child: Box<dyn Operator>,
685    /// Columns to group by.
686    group_columns: Vec<usize>,
687    /// Aggregation expressions.
688    aggregates: Vec<AggregateExpr>,
689    /// Output schema.
690    output_schema: Vec<LogicalType>,
691    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
692    groups: IndexMap<GroupKey, Vec<AggregateState>>,
693    /// Whether aggregation is complete.
694    aggregation_complete: bool,
695    /// Results iterator.
696    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
697}
698
699impl HashAggregateOperator {
700    /// Creates a new hash aggregate operator.
701    ///
702    /// # Arguments
703    /// * `child` - Child operator to read from.
704    /// * `group_columns` - Column indices to group by.
705    /// * `aggregates` - Aggregation expressions.
706    /// * `output_schema` - Schema of the output (group columns + aggregate results).
707    pub fn new(
708        child: Box<dyn Operator>,
709        group_columns: Vec<usize>,
710        aggregates: Vec<AggregateExpr>,
711        output_schema: Vec<LogicalType>,
712    ) -> Self {
713        Self {
714            child,
715            group_columns,
716            aggregates,
717            output_schema,
718            groups: IndexMap::new(),
719            aggregation_complete: false,
720            results: None,
721        }
722    }
723
724    /// Performs the aggregation.
725    fn aggregate(&mut self) -> Result<(), OperatorError> {
726        while let Some(chunk) = self.child.next()? {
727            for row in chunk.selected_indices() {
728                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
729
730                // Get or create aggregate states for this group
731                let states = self.groups.entry(key).or_insert_with(|| {
732                    self.aggregates
733                        .iter()
734                        .map(|agg| {
735                            AggregateState::new(
736                                agg.function,
737                                agg.distinct,
738                                agg.percentile,
739                                agg.separator.as_deref(),
740                            )
741                        })
742                        .collect()
743                });
744
745                // Update each aggregate
746                for (i, agg) in self.aggregates.iter().enumerate() {
747                    // Binary set functions: read two column values
748                    if agg.column2.is_some() {
749                        let y_val = agg
750                            .column
751                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
752                        let x_val = agg
753                            .column2
754                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
755                        states[i].update_bivariate(y_val, x_val);
756                        continue;
757                    }
758
759                    let value = match (agg.function, agg.distinct) {
760                        // COUNT(*) without DISTINCT doesn't need a value
761                        (AggregateFunction::Count, false) => None,
762                        // COUNT DISTINCT needs the actual value to track unique values
763                        (AggregateFunction::Count, true) => agg
764                            .column
765                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
766                        _ => agg
767                            .column
768                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
769                    };
770
771                    // For COUNT without DISTINCT, always update. For others, skip nulls.
772                    match (agg.function, agg.distinct) {
773                        (AggregateFunction::Count, false) => states[i].update(None),
774                        (AggregateFunction::Count, true) => {
775                            // COUNT DISTINCT needs the value to track unique values
776                            if value.is_some() && !matches!(value, Some(Value::Null)) {
777                                states[i].update(value);
778                            }
779                        }
780                        (AggregateFunction::CountNonNull, _) => {
781                            if value.is_some() && !matches!(value, Some(Value::Null)) {
782                                states[i].update(value);
783                            }
784                        }
785                        _ => {
786                            if value.is_some() && !matches!(value, Some(Value::Null)) {
787                                states[i].update(value);
788                            }
789                        }
790                    }
791                }
792            }
793        }
794
795        self.aggregation_complete = true;
796
797        // Convert to results iterator (IndexMap::drain takes a range)
798        let results: Vec<_> = self.groups.drain(..).collect();
799        self.results = Some(results.into_iter());
800
801        Ok(())
802    }
803}
804
805impl Operator for HashAggregateOperator {
806    fn next(&mut self) -> OperatorResult {
807        // Perform aggregation if not done
808        if !self.aggregation_complete {
809            self.aggregate()?;
810        }
811
812        // Special case: no groups (global aggregation with no data)
813        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
814            // For global aggregation (no GROUP BY), return one row with initial values
815            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
816
817            for agg in &self.aggregates {
818                let state = AggregateState::new(
819                    agg.function,
820                    agg.distinct,
821                    agg.percentile,
822                    agg.separator.as_deref(),
823                );
824                let value = state.finalize();
825                if let Some(col) = builder.column_mut(self.group_columns.len()) {
826                    col.push_value(value);
827                }
828            }
829            builder.advance_row();
830
831            self.results = Some(Vec::new().into_iter()); // Mark as done
832            return Ok(Some(builder.finish()));
833        }
834
835        let Some(results) = &mut self.results else {
836            return Ok(None);
837        };
838
839        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
840
841        for (key, states) in results.by_ref() {
842            // Output group key columns
843            let key_values = key.to_values();
844            for (i, value) in key_values.into_iter().enumerate() {
845                if let Some(col) = builder.column_mut(i) {
846                    col.push_value(value);
847                }
848            }
849
850            // Output aggregate results
851            for (i, state) in states.iter().enumerate() {
852                let col_idx = self.group_columns.len() + i;
853                if let Some(col) = builder.column_mut(col_idx) {
854                    col.push_value(state.finalize());
855                }
856            }
857
858            builder.advance_row();
859
860            if builder.is_full() {
861                return Ok(Some(builder.finish()));
862            }
863        }
864
865        if builder.row_count() > 0 {
866            Ok(Some(builder.finish()))
867        } else {
868            Ok(None)
869        }
870    }
871
872    fn reset(&mut self) {
873        self.child.reset();
874        self.groups.clear();
875        self.aggregation_complete = false;
876        self.results = None;
877    }
878
879    fn name(&self) -> &'static str {
880        "HashAggregate"
881    }
882}
883
884/// Simple (non-grouping) aggregate operator for global aggregations.
885///
886/// Used when there's no GROUP BY clause - aggregates all input into a single row.
887pub struct SimpleAggregateOperator {
888    /// Child operator.
889    child: Box<dyn Operator>,
890    /// Aggregation expressions.
891    aggregates: Vec<AggregateExpr>,
892    /// Output schema.
893    output_schema: Vec<LogicalType>,
894    /// Aggregate states.
895    states: Vec<AggregateState>,
896    /// Whether aggregation is complete.
897    done: bool,
898}
899
900impl SimpleAggregateOperator {
901    /// Creates a new simple aggregate operator.
902    pub fn new(
903        child: Box<dyn Operator>,
904        aggregates: Vec<AggregateExpr>,
905        output_schema: Vec<LogicalType>,
906    ) -> Self {
907        let states = aggregates
908            .iter()
909            .map(|agg| {
910                AggregateState::new(
911                    agg.function,
912                    agg.distinct,
913                    agg.percentile,
914                    agg.separator.as_deref(),
915                )
916            })
917            .collect();
918
919        Self {
920            child,
921            aggregates,
922            output_schema,
923            states,
924            done: false,
925        }
926    }
927}
928
929impl Operator for SimpleAggregateOperator {
930    fn next(&mut self) -> OperatorResult {
931        if self.done {
932            return Ok(None);
933        }
934
935        // Process all input
936        while let Some(chunk) = self.child.next()? {
937            for row in chunk.selected_indices() {
938                for (i, agg) in self.aggregates.iter().enumerate() {
939                    // Binary set functions: read two column values
940                    if agg.column2.is_some() {
941                        let y_val = agg
942                            .column
943                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
944                        let x_val = agg
945                            .column2
946                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
947                        self.states[i].update_bivariate(y_val, x_val);
948                        continue;
949                    }
950
951                    let value = match (agg.function, agg.distinct) {
952                        // COUNT(*) without DISTINCT doesn't need a value
953                        (AggregateFunction::Count, false) => None,
954                        // COUNT DISTINCT needs the actual value to track unique values
955                        (AggregateFunction::Count, true) => agg
956                            .column
957                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
958                        _ => agg
959                            .column
960                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
961                    };
962
963                    match (agg.function, agg.distinct) {
964                        (AggregateFunction::Count, false) => self.states[i].update(None),
965                        (AggregateFunction::Count, true) => {
966                            // COUNT DISTINCT needs the value to track unique values
967                            if value.is_some() && !matches!(value, Some(Value::Null)) {
968                                self.states[i].update(value);
969                            }
970                        }
971                        (AggregateFunction::CountNonNull, _) => {
972                            if value.is_some() && !matches!(value, Some(Value::Null)) {
973                                self.states[i].update(value);
974                            }
975                        }
976                        _ => {
977                            if value.is_some() && !matches!(value, Some(Value::Null)) {
978                                self.states[i].update(value);
979                            }
980                        }
981                    }
982                }
983            }
984        }
985
986        // Output single result row
987        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
988
989        for (i, state) in self.states.iter().enumerate() {
990            if let Some(col) = builder.column_mut(i) {
991                col.push_value(state.finalize());
992            }
993        }
994        builder.advance_row();
995
996        self.done = true;
997        Ok(Some(builder.finish()))
998    }
999
1000    fn reset(&mut self) {
1001        self.child.reset();
1002        self.states = self
1003            .aggregates
1004            .iter()
1005            .map(|agg| {
1006                AggregateState::new(
1007                    agg.function,
1008                    agg.distinct,
1009                    agg.percentile,
1010                    agg.separator.as_deref(),
1011                )
1012            })
1013            .collect();
1014        self.done = false;
1015    }
1016
1017    fn name(&self) -> &'static str {
1018        "SimpleAggregate"
1019    }
1020}
1021
1022#[cfg(test)]
1023mod tests {
1024    use super::*;
1025    use crate::execution::chunk::DataChunkBuilder;
1026
1027    struct MockOperator {
1028        chunks: Vec<DataChunk>,
1029        position: usize,
1030    }
1031
1032    impl MockOperator {
1033        fn new(chunks: Vec<DataChunk>) -> Self {
1034            Self {
1035                chunks,
1036                position: 0,
1037            }
1038        }
1039    }
1040
1041    impl Operator for MockOperator {
1042        fn next(&mut self) -> OperatorResult {
1043            if self.position < self.chunks.len() {
1044                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1045                self.position += 1;
1046                Ok(Some(chunk))
1047            } else {
1048                Ok(None)
1049            }
1050        }
1051
1052        fn reset(&mut self) {
1053            self.position = 0;
1054        }
1055
1056        fn name(&self) -> &'static str {
1057            "Mock"
1058        }
1059    }
1060
1061    fn create_test_chunk() -> DataChunk {
1062        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1063        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1064
1065        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1066        for (group, value) in data {
1067            builder.column_mut(0).unwrap().push_int64(group);
1068            builder.column_mut(1).unwrap().push_int64(value);
1069            builder.advance_row();
1070        }
1071
1072        builder.finish()
1073    }
1074
1075    #[test]
1076    fn test_simple_count() {
1077        let mock = MockOperator::new(vec![create_test_chunk()]);
1078
1079        let mut agg = SimpleAggregateOperator::new(
1080            Box::new(mock),
1081            vec![AggregateExpr::count_star()],
1082            vec![LogicalType::Int64],
1083        );
1084
1085        let result = agg.next().unwrap().unwrap();
1086        assert_eq!(result.row_count(), 1);
1087        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1088
1089        // Should be done
1090        assert!(agg.next().unwrap().is_none());
1091    }
1092
1093    #[test]
1094    fn test_simple_sum() {
1095        let mock = MockOperator::new(vec![create_test_chunk()]);
1096
1097        let mut agg = SimpleAggregateOperator::new(
1098            Box::new(mock),
1099            vec![AggregateExpr::sum(1)], // Sum of column 1
1100            vec![LogicalType::Int64],
1101        );
1102
1103        let result = agg.next().unwrap().unwrap();
1104        assert_eq!(result.row_count(), 1);
1105        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1106        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1107    }
1108
1109    #[test]
1110    fn test_simple_avg() {
1111        let mock = MockOperator::new(vec![create_test_chunk()]);
1112
1113        let mut agg = SimpleAggregateOperator::new(
1114            Box::new(mock),
1115            vec![AggregateExpr::avg(1)],
1116            vec![LogicalType::Float64],
1117        );
1118
1119        let result = agg.next().unwrap().unwrap();
1120        assert_eq!(result.row_count(), 1);
1121        // Avg: 150 / 5 = 30.0
1122        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1123        assert!((avg - 30.0).abs() < 0.001);
1124    }
1125
1126    #[test]
1127    fn test_simple_min_max() {
1128        let mock = MockOperator::new(vec![create_test_chunk()]);
1129
1130        let mut agg = SimpleAggregateOperator::new(
1131            Box::new(mock),
1132            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1133            vec![LogicalType::Int64, LogicalType::Int64],
1134        );
1135
1136        let result = agg.next().unwrap().unwrap();
1137        assert_eq!(result.row_count(), 1);
1138        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1139        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1140    }
1141
1142    #[test]
1143    fn test_sum_with_string_values() {
1144        // Test SUM with string values (like RDF stores numeric literals)
1145        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1146        builder.column_mut(0).unwrap().push_string("30");
1147        builder.advance_row();
1148        builder.column_mut(0).unwrap().push_string("25");
1149        builder.advance_row();
1150        builder.column_mut(0).unwrap().push_string("35");
1151        builder.advance_row();
1152        let chunk = builder.finish();
1153
1154        let mock = MockOperator::new(vec![chunk]);
1155        let mut agg = SimpleAggregateOperator::new(
1156            Box::new(mock),
1157            vec![AggregateExpr::sum(0)],
1158            vec![LogicalType::Float64],
1159        );
1160
1161        let result = agg.next().unwrap().unwrap();
1162        assert_eq!(result.row_count(), 1);
1163        // Should parse strings and sum: 30 + 25 + 35 = 90
1164        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1165        assert!(
1166            (sum_val - 90.0).abs() < 0.001,
1167            "Expected 90.0, got {}",
1168            sum_val
1169        );
1170    }
1171
1172    #[test]
1173    fn test_grouped_aggregation() {
1174        let mock = MockOperator::new(vec![create_test_chunk()]);
1175
1176        // GROUP BY column 0, SUM(column 1)
1177        let mut agg = HashAggregateOperator::new(
1178            Box::new(mock),
1179            vec![0],                     // Group by column 0
1180            vec![AggregateExpr::sum(1)], // Sum of column 1
1181            vec![LogicalType::Int64, LogicalType::Int64],
1182        );
1183
1184        let mut results: Vec<(i64, i64)> = Vec::new();
1185        while let Some(chunk) = agg.next().unwrap() {
1186            for row in chunk.selected_indices() {
1187                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1188                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1189                results.push((group, sum));
1190            }
1191        }
1192
1193        results.sort_by_key(|(g, _)| *g);
1194        assert_eq!(results.len(), 2);
1195        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1196        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1197    }
1198
1199    #[test]
1200    fn test_grouped_count() {
1201        let mock = MockOperator::new(vec![create_test_chunk()]);
1202
1203        // GROUP BY column 0, COUNT(*)
1204        let mut agg = HashAggregateOperator::new(
1205            Box::new(mock),
1206            vec![0],
1207            vec![AggregateExpr::count_star()],
1208            vec![LogicalType::Int64, LogicalType::Int64],
1209        );
1210
1211        let mut results: Vec<(i64, i64)> = Vec::new();
1212        while let Some(chunk) = agg.next().unwrap() {
1213            for row in chunk.selected_indices() {
1214                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1215                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1216                results.push((group, count));
1217            }
1218        }
1219
1220        results.sort_by_key(|(g, _)| *g);
1221        assert_eq!(results.len(), 2);
1222        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1223        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1224    }
1225
1226    #[test]
1227    fn test_multiple_aggregates() {
1228        let mock = MockOperator::new(vec![create_test_chunk()]);
1229
1230        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1231        let mut agg = HashAggregateOperator::new(
1232            Box::new(mock),
1233            vec![0],
1234            vec![
1235                AggregateExpr::count_star(),
1236                AggregateExpr::sum(1),
1237                AggregateExpr::avg(1),
1238            ],
1239            vec![
1240                LogicalType::Int64,   // Group key
1241                LogicalType::Int64,   // COUNT
1242                LogicalType::Int64,   // SUM
1243                LogicalType::Float64, // AVG
1244            ],
1245        );
1246
1247        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1248        while let Some(chunk) = agg.next().unwrap() {
1249            for row in chunk.selected_indices() {
1250                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1251                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1252                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1253                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1254                results.push((group, count, sum, avg));
1255            }
1256        }
1257
1258        results.sort_by_key(|(g, _, _, _)| *g);
1259        assert_eq!(results.len(), 2);
1260
1261        // Group 1: COUNT=2, SUM=30, AVG=15.0
1262        assert_eq!(results[0].0, 1);
1263        assert_eq!(results[0].1, 2);
1264        assert_eq!(results[0].2, 30);
1265        assert!((results[0].3 - 15.0).abs() < 0.001);
1266
1267        // Group 2: COUNT=3, SUM=120, AVG=40.0
1268        assert_eq!(results[1].0, 2);
1269        assert_eq!(results[1].1, 3);
1270        assert_eq!(results[1].2, 120);
1271        assert!((results[1].3 - 40.0).abs() < 0.001);
1272    }
1273
1274    fn create_test_chunk_with_duplicates() -> DataChunk {
1275        // Create data with duplicate values in column 1
1276        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1277        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1278        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1279        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1280
1281        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1282        for (group, value) in data {
1283            builder.column_mut(0).unwrap().push_int64(group);
1284            builder.column_mut(1).unwrap().push_int64(value);
1285            builder.advance_row();
1286        }
1287
1288        builder.finish()
1289    }
1290
1291    #[test]
1292    fn test_count_distinct() {
1293        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1294
1295        // COUNT(DISTINCT column 1)
1296        let mut agg = SimpleAggregateOperator::new(
1297            Box::new(mock),
1298            vec![AggregateExpr::count(1).with_distinct()],
1299            vec![LogicalType::Int64],
1300        );
1301
1302        let result = agg.next().unwrap().unwrap();
1303        assert_eq!(result.row_count(), 1);
1304        // Total distinct values: 10, 20, 30 = 3 distinct values
1305        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1306    }
1307
1308    #[test]
1309    fn test_grouped_count_distinct() {
1310        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1311
1312        // GROUP BY column 0, COUNT(DISTINCT column 1)
1313        let mut agg = HashAggregateOperator::new(
1314            Box::new(mock),
1315            vec![0],
1316            vec![AggregateExpr::count(1).with_distinct()],
1317            vec![LogicalType::Int64, LogicalType::Int64],
1318        );
1319
1320        let mut results: Vec<(i64, i64)> = Vec::new();
1321        while let Some(chunk) = agg.next().unwrap() {
1322            for row in chunk.selected_indices() {
1323                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1324                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1325                results.push((group, count));
1326            }
1327        }
1328
1329        results.sort_by_key(|(g, _)| *g);
1330        assert_eq!(results.len(), 2);
1331        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1332        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1333    }
1334
1335    #[test]
1336    fn test_sum_distinct() {
1337        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1338
1339        // SUM(DISTINCT column 1)
1340        let mut agg = SimpleAggregateOperator::new(
1341            Box::new(mock),
1342            vec![AggregateExpr::sum(1).with_distinct()],
1343            vec![LogicalType::Int64],
1344        );
1345
1346        let result = agg.next().unwrap().unwrap();
1347        assert_eq!(result.row_count(), 1);
1348        // Sum of distinct values: 10 + 20 + 30 = 60
1349        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1350    }
1351
1352    #[test]
1353    fn test_avg_distinct() {
1354        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1355
1356        // AVG(DISTINCT column 1)
1357        let mut agg = SimpleAggregateOperator::new(
1358            Box::new(mock),
1359            vec![AggregateExpr::avg(1).with_distinct()],
1360            vec![LogicalType::Float64],
1361        );
1362
1363        let result = agg.next().unwrap().unwrap();
1364        assert_eq!(result.row_count(), 1);
1365        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1366        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1367        assert!((avg - 20.0).abs() < 0.001);
1368    }
1369
1370    fn create_statistical_test_chunk() -> DataChunk {
1371        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1372        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1373        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1374
1375        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1376            builder.column_mut(0).unwrap().push_int64(value);
1377            builder.advance_row();
1378        }
1379
1380        builder.finish()
1381    }
1382
1383    #[test]
1384    fn test_stdev_sample() {
1385        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1386
1387        let mut agg = SimpleAggregateOperator::new(
1388            Box::new(mock),
1389            vec![AggregateExpr::stdev(0)],
1390            vec![LogicalType::Float64],
1391        );
1392
1393        let result = agg.next().unwrap().unwrap();
1394        assert_eq!(result.row_count(), 1);
1395        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1396        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1397        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1398        assert!((stdev - 2.138).abs() < 0.01);
1399    }
1400
1401    #[test]
1402    fn test_stdev_population() {
1403        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1404
1405        let mut agg = SimpleAggregateOperator::new(
1406            Box::new(mock),
1407            vec![AggregateExpr::stdev_pop(0)],
1408            vec![LogicalType::Float64],
1409        );
1410
1411        let result = agg.next().unwrap().unwrap();
1412        assert_eq!(result.row_count(), 1);
1413        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1414        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1415        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1416        assert!((stdev - 2.0).abs() < 0.01);
1417    }
1418
1419    #[test]
1420    fn test_percentile_disc() {
1421        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1422
1423        // Median (50th percentile discrete)
1424        let mut agg = SimpleAggregateOperator::new(
1425            Box::new(mock),
1426            vec![AggregateExpr::percentile_disc(0, 0.5)],
1427            vec![LogicalType::Float64],
1428        );
1429
1430        let result = agg.next().unwrap().unwrap();
1431        assert_eq!(result.row_count(), 1);
1432        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1433        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1434        assert!((percentile - 4.0).abs() < 0.01);
1435    }
1436
1437    #[test]
1438    fn test_percentile_cont() {
1439        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1440
1441        // Median (50th percentile continuous)
1442        let mut agg = SimpleAggregateOperator::new(
1443            Box::new(mock),
1444            vec![AggregateExpr::percentile_cont(0, 0.5)],
1445            vec![LogicalType::Float64],
1446        );
1447
1448        let result = agg.next().unwrap().unwrap();
1449        assert_eq!(result.row_count(), 1);
1450        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1451        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1452        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1453        assert!((percentile - 4.5).abs() < 0.01);
1454    }
1455
1456    #[test]
1457    fn test_percentile_extremes() {
1458        // Test 0th and 100th percentiles
1459        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1460
1461        let mut agg = SimpleAggregateOperator::new(
1462            Box::new(mock),
1463            vec![
1464                AggregateExpr::percentile_disc(0, 0.0),
1465                AggregateExpr::percentile_disc(0, 1.0),
1466            ],
1467            vec![LogicalType::Float64, LogicalType::Float64],
1468        );
1469
1470        let result = agg.next().unwrap().unwrap();
1471        assert_eq!(result.row_count(), 1);
1472        // 0th percentile = minimum = 2
1473        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1474        assert!((p0 - 2.0).abs() < 0.01);
1475        // 100th percentile = maximum = 9
1476        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1477        assert!((p100 - 9.0).abs() < 0.01);
1478    }
1479
1480    #[test]
1481    fn test_stdev_single_value() {
1482        // Single value should return null for sample stdev
1483        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1484        builder.column_mut(0).unwrap().push_int64(42);
1485        builder.advance_row();
1486        let chunk = builder.finish();
1487
1488        let mock = MockOperator::new(vec![chunk]);
1489
1490        let mut agg = SimpleAggregateOperator::new(
1491            Box::new(mock),
1492            vec![AggregateExpr::stdev(0)],
1493            vec![LogicalType::Float64],
1494        );
1495
1496        let result = agg.next().unwrap().unwrap();
1497        assert_eq!(result.row_count(), 1);
1498        // Sample stdev of single value is undefined (null)
1499        assert!(matches!(
1500            result.column(0).unwrap().get_value(0),
1501            Some(Value::Null)
1502        ));
1503    }
1504
1505    #[test]
1506    fn test_first_and_last() {
1507        let mock = MockOperator::new(vec![create_test_chunk()]);
1508
1509        let mut agg = SimpleAggregateOperator::new(
1510            Box::new(mock),
1511            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1512            vec![LogicalType::Int64, LogicalType::Int64],
1513        );
1514
1515        let result = agg.next().unwrap().unwrap();
1516        assert_eq!(result.row_count(), 1);
1517        // First: 10, Last: 50
1518        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1519        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1520    }
1521
1522    #[test]
1523    fn test_collect() {
1524        let mock = MockOperator::new(vec![create_test_chunk()]);
1525
1526        let mut agg = SimpleAggregateOperator::new(
1527            Box::new(mock),
1528            vec![AggregateExpr::collect(1)],
1529            vec![LogicalType::Any],
1530        );
1531
1532        let result = agg.next().unwrap().unwrap();
1533        let val = result.column(0).unwrap().get_value(0).unwrap();
1534        if let Value::List(items) = val {
1535            assert_eq!(items.len(), 5);
1536        } else {
1537            panic!("Expected List value");
1538        }
1539    }
1540
1541    #[test]
1542    fn test_collect_distinct() {
1543        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1544
1545        let mut agg = SimpleAggregateOperator::new(
1546            Box::new(mock),
1547            vec![AggregateExpr::collect(1).with_distinct()],
1548            vec![LogicalType::Any],
1549        );
1550
1551        let result = agg.next().unwrap().unwrap();
1552        let val = result.column(0).unwrap().get_value(0).unwrap();
1553        if let Value::List(items) = val {
1554            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1555            assert_eq!(items.len(), 3);
1556        } else {
1557            panic!("Expected List value");
1558        }
1559    }
1560
1561    #[test]
1562    fn test_group_concat() {
1563        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1564        for s in ["hello", "world", "foo"] {
1565            builder.column_mut(0).unwrap().push_string(s);
1566            builder.advance_row();
1567        }
1568        let chunk = builder.finish();
1569        let mock = MockOperator::new(vec![chunk]);
1570
1571        let agg_expr = AggregateExpr {
1572            function: AggregateFunction::GroupConcat,
1573            column: Some(0),
1574            column2: None,
1575            distinct: false,
1576            alias: None,
1577            percentile: None,
1578            separator: None,
1579        };
1580
1581        let mut agg =
1582            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1583
1584        let result = agg.next().unwrap().unwrap();
1585        let val = result.column(0).unwrap().get_value(0).unwrap();
1586        assert_eq!(val, Value::String("hello world foo".into()));
1587    }
1588
1589    #[test]
1590    fn test_sample() {
1591        let mock = MockOperator::new(vec![create_test_chunk()]);
1592
1593        let agg_expr = AggregateExpr {
1594            function: AggregateFunction::Sample,
1595            column: Some(1),
1596            column2: None,
1597            distinct: false,
1598            alias: None,
1599            percentile: None,
1600            separator: None,
1601        };
1602
1603        let mut agg =
1604            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1605
1606        let result = agg.next().unwrap().unwrap();
1607        // Sample should return the first non-null value (10)
1608        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1609    }
1610
1611    #[test]
1612    fn test_variance_sample() {
1613        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1614
1615        let agg_expr = AggregateExpr {
1616            function: AggregateFunction::Variance,
1617            column: Some(0),
1618            column2: None,
1619            distinct: false,
1620            alias: None,
1621            percentile: None,
1622            separator: None,
1623        };
1624
1625        let mut agg = SimpleAggregateOperator::new(
1626            Box::new(mock),
1627            vec![agg_expr],
1628            vec![LogicalType::Float64],
1629        );
1630
1631        let result = agg.next().unwrap().unwrap();
1632        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1633        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1634        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1635    }
1636
1637    #[test]
1638    fn test_variance_population() {
1639        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1640
1641        let agg_expr = AggregateExpr {
1642            function: AggregateFunction::VariancePop,
1643            column: Some(0),
1644            column2: None,
1645            distinct: false,
1646            alias: None,
1647            percentile: None,
1648            separator: None,
1649        };
1650
1651        let mut agg = SimpleAggregateOperator::new(
1652            Box::new(mock),
1653            vec![agg_expr],
1654            vec![LogicalType::Float64],
1655        );
1656
1657        let result = agg.next().unwrap().unwrap();
1658        // Population variance: M2/n = 32/8 = 4.0
1659        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1660        assert!((variance - 4.0).abs() < 0.01);
1661    }
1662
1663    #[test]
1664    fn test_variance_single_value() {
1665        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1666        builder.column_mut(0).unwrap().push_int64(42);
1667        builder.advance_row();
1668        let chunk = builder.finish();
1669        let mock = MockOperator::new(vec![chunk]);
1670
1671        let agg_expr = AggregateExpr {
1672            function: AggregateFunction::Variance,
1673            column: Some(0),
1674            column2: None,
1675            distinct: false,
1676            alias: None,
1677            percentile: None,
1678            separator: None,
1679        };
1680
1681        let mut agg = SimpleAggregateOperator::new(
1682            Box::new(mock),
1683            vec![agg_expr],
1684            vec![LogicalType::Float64],
1685        );
1686
1687        let result = agg.next().unwrap().unwrap();
1688        // Sample variance of single value is undefined (null)
1689        assert!(matches!(
1690            result.column(0).unwrap().get_value(0),
1691            Some(Value::Null)
1692        ));
1693    }
1694
1695    #[test]
1696    fn test_empty_aggregation() {
1697        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1698        // (ISO/IEC 39075 Section 20.9)
1699        let mock = MockOperator::new(vec![]);
1700
1701        let mut agg = SimpleAggregateOperator::new(
1702            Box::new(mock),
1703            vec![
1704                AggregateExpr::count_star(),
1705                AggregateExpr::sum(0),
1706                AggregateExpr::avg(0),
1707                AggregateExpr::min(0),
1708                AggregateExpr::max(0),
1709            ],
1710            vec![
1711                LogicalType::Int64,
1712                LogicalType::Int64,
1713                LogicalType::Float64,
1714                LogicalType::Int64,
1715                LogicalType::Int64,
1716            ],
1717        );
1718
1719        let result = agg.next().unwrap().unwrap();
1720        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1721        assert!(matches!(
1722            result.column(1).unwrap().get_value(0),
1723            Some(Value::Null)
1724        )); // SUM
1725        assert!(matches!(
1726            result.column(2).unwrap().get_value(0),
1727            Some(Value::Null)
1728        )); // AVG
1729        assert!(matches!(
1730            result.column(3).unwrap().get_value(0),
1731            Some(Value::Null)
1732        )); // MIN
1733        assert!(matches!(
1734            result.column(4).unwrap().get_value(0),
1735            Some(Value::Null)
1736        )); // MAX
1737    }
1738
1739    #[test]
1740    fn test_stdev_pop_single_value() {
1741        // Single value should return 0 for population stdev
1742        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1743        builder.column_mut(0).unwrap().push_int64(42);
1744        builder.advance_row();
1745        let chunk = builder.finish();
1746
1747        let mock = MockOperator::new(vec![chunk]);
1748
1749        let mut agg = SimpleAggregateOperator::new(
1750            Box::new(mock),
1751            vec![AggregateExpr::stdev_pop(0)],
1752            vec![LogicalType::Float64],
1753        );
1754
1755        let result = agg.next().unwrap().unwrap();
1756        assert_eq!(result.row_count(), 1);
1757        // Population stdev of single value is 0
1758        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1759        assert!((stdev - 0.0).abs() < 0.01);
1760    }
1761}