Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use arcstr::ArcStr;
14use grafeo_common::types::{LogicalType, Value};
15
16use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
17use super::{Operator, OperatorError, OperatorResult};
18use crate::execution::DataChunk;
19use crate::execution::chunk::DataChunkBuilder;
20
21/// State for a single aggregation computation.
22#[derive(Debug, Clone)]
23pub(crate) enum AggregateState {
24    /// Count state.
25    Count(i64),
26    /// Count distinct state (count, seen values).
27    CountDistinct(i64, HashSet<HashableValue>),
28    /// Sum state (integer sum, count of values added).
29    SumInt(i64, i64),
30    /// Sum distinct state (integer sum, count, seen values).
31    SumIntDistinct(i64, i64, HashSet<HashableValue>),
32    /// Sum state (float sum, count of values added).
33    SumFloat(f64, i64),
34    /// Sum distinct state (float sum, count, seen values).
35    SumFloatDistinct(f64, i64, HashSet<HashableValue>),
36    /// Average state (sum, count).
37    Avg(f64, i64),
38    /// Average distinct state (sum, count, seen values).
39    AvgDistinct(f64, i64, HashSet<HashableValue>),
40    /// Min state.
41    Min(Option<Value>),
42    /// Max state.
43    Max(Option<Value>),
44    /// First state.
45    First(Option<Value>),
46    /// Last state.
47    Last(Option<Value>),
48    /// Collect state.
49    Collect(Vec<Value>),
50    /// Collect distinct state (values, seen).
51    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
52    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
53    StdDev { count: i64, mean: f64, m2: f64 },
54    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
55    StdDevPop { count: i64, mean: f64, m2: f64 },
56    /// Discrete percentile state (values, percentile).
57    PercentileDisc { values: Vec<f64>, percentile: f64 },
58    /// Continuous percentile state (values, percentile).
59    PercentileCont { values: Vec<f64>, percentile: f64 },
60    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
61    GroupConcat(Vec<String>, String),
62    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
63    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
64    /// SAMPLE state (first non-null value encountered).
65    Sample(Option<Value>),
66    /// Sample variance state using Welford's algorithm (count, mean, M2).
67    Variance { count: i64, mean: f64, m2: f64 },
68    /// Population variance state using Welford's algorithm (count, mean, M2).
69    VariancePop { count: i64, mean: f64, m2: f64 },
70    /// Two-variable online statistics (Welford generalization for covariance/regression).
71    Bivariate {
72        /// Which binary set function this state will finalize to.
73        kind: AggregateFunction,
74        count: i64,
75        mean_x: f64,
76        mean_y: f64,
77        m2_x: f64,
78        m2_y: f64,
79        c_xy: f64,
80    },
81}
82
83impl AggregateState {
84    /// Creates initial state for an aggregation function.
85    pub(crate) fn new(
86        function: AggregateFunction,
87        distinct: bool,
88        percentile: Option<f64>,
89        separator: Option<&str>,
90    ) -> Self {
91        match (function, distinct) {
92            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
93                AggregateState::Count(0)
94            }
95            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
96                AggregateState::CountDistinct(0, HashSet::new())
97            }
98            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
99            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
100            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
101            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
102            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
103            (AggregateFunction::Max, _) => AggregateState::Max(None),
104            (AggregateFunction::First, _) => AggregateState::First(None),
105            (AggregateFunction::Last, _) => AggregateState::Last(None),
106            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
107            (AggregateFunction::Collect, true) => {
108                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
109            }
110            // Statistical functions (Welford's algorithm for online computation)
111            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
112                count: 0,
113                mean: 0.0,
114                m2: 0.0,
115            },
116            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
117                count: 0,
118                mean: 0.0,
119                m2: 0.0,
120            },
121            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
122                values: Vec::new(),
123                percentile: percentile.unwrap_or(0.5),
124            },
125            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
126                values: Vec::new(),
127                percentile: percentile.unwrap_or(0.5),
128            },
129            (AggregateFunction::GroupConcat, false) => {
130                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
131            }
132            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
133                Vec::new(),
134                separator.unwrap_or(" ").to_string(),
135                HashSet::new(),
136            ),
137            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
138            // Binary set functions (all share the same Bivariate state)
139            (
140                AggregateFunction::CovarSamp
141                | AggregateFunction::CovarPop
142                | AggregateFunction::Corr
143                | AggregateFunction::RegrSlope
144                | AggregateFunction::RegrIntercept
145                | AggregateFunction::RegrR2
146                | AggregateFunction::RegrCount
147                | AggregateFunction::RegrSxx
148                | AggregateFunction::RegrSyy
149                | AggregateFunction::RegrSxy
150                | AggregateFunction::RegrAvgx
151                | AggregateFunction::RegrAvgy,
152                _,
153            ) => AggregateState::Bivariate {
154                kind: function,
155                count: 0,
156                mean_x: 0.0,
157                mean_y: 0.0,
158                m2_x: 0.0,
159                m2_y: 0.0,
160                c_xy: 0.0,
161            },
162            (AggregateFunction::Variance, _) => AggregateState::Variance {
163                count: 0,
164                mean: 0.0,
165                m2: 0.0,
166            },
167            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
168                count: 0,
169                mean: 0.0,
170                m2: 0.0,
171            },
172        }
173    }
174
175    /// Updates the state with a new value.
176    pub(crate) fn update(&mut self, value: Option<Value>) {
177        match self {
178            AggregateState::Count(count) => {
179                *count += 1;
180            }
181            AggregateState::CountDistinct(count, seen) => {
182                if let Some(ref v) = value {
183                    let hashable = HashableValue::from(v);
184                    if seen.insert(hashable) {
185                        *count += 1;
186                    }
187                }
188            }
189            AggregateState::SumInt(sum, count) => {
190                if let Some(Value::Int64(v)) = value {
191                    *sum += v;
192                    *count += 1;
193                } else if let Some(Value::Float64(v)) = value {
194                    // Convert to float sum, carrying count forward
195                    *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
196                } else if let Some(ref v) = value {
197                    // RDF stores numeric literals as strings - try to parse
198                    if let Some(num) = value_to_f64(v) {
199                        *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
200                    }
201                }
202            }
203            AggregateState::SumIntDistinct(sum, count, seen) => {
204                if let Some(ref v) = value {
205                    let hashable = HashableValue::from(v);
206                    if seen.insert(hashable) {
207                        if let Value::Int64(i) = v {
208                            *sum += i;
209                            *count += 1;
210                        } else if let Value::Float64(f) = v {
211                            // Convert to float distinct: move the seen set instead of cloning
212                            let moved_seen = std::mem::take(seen);
213                            *self = AggregateState::SumFloatDistinct(
214                                *sum as f64 + f,
215                                *count + 1,
216                                moved_seen,
217                            );
218                        } else if let Some(num) = value_to_f64(v) {
219                            // RDF string-encoded numerics
220                            let moved_seen = std::mem::take(seen);
221                            *self = AggregateState::SumFloatDistinct(
222                                *sum as f64 + num,
223                                *count + 1,
224                                moved_seen,
225                            );
226                        }
227                    }
228                }
229            }
230            AggregateState::SumFloat(sum, count) => {
231                if let Some(ref v) = value {
232                    // Use value_to_f64 which now handles strings
233                    if let Some(num) = value_to_f64(v) {
234                        *sum += num;
235                        *count += 1;
236                    }
237                }
238            }
239            AggregateState::SumFloatDistinct(sum, count, seen) => {
240                if let Some(ref v) = value {
241                    let hashable = HashableValue::from(v);
242                    if seen.insert(hashable)
243                        && let Some(num) = value_to_f64(v)
244                    {
245                        *sum += num;
246                        *count += 1;
247                    }
248                }
249            }
250            AggregateState::Avg(sum, count) => {
251                if let Some(ref v) = value
252                    && let Some(num) = value_to_f64(v)
253                {
254                    *sum += num;
255                    *count += 1;
256                }
257            }
258            AggregateState::AvgDistinct(sum, count, seen) => {
259                if let Some(ref v) = value {
260                    let hashable = HashableValue::from(v);
261                    if seen.insert(hashable)
262                        && let Some(num) = value_to_f64(v)
263                    {
264                        *sum += num;
265                        *count += 1;
266                    }
267                }
268            }
269            AggregateState::Min(min) => {
270                if let Some(v) = value {
271                    match min {
272                        None => *min = Some(v),
273                        Some(current) => {
274                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
275                                *min = Some(v);
276                            }
277                        }
278                    }
279                }
280            }
281            AggregateState::Max(max) => {
282                if let Some(v) = value {
283                    match max {
284                        None => *max = Some(v),
285                        Some(current) => {
286                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
287                                *max = Some(v);
288                            }
289                        }
290                    }
291                }
292            }
293            AggregateState::First(first) => {
294                if first.is_none() {
295                    *first = value;
296                }
297            }
298            AggregateState::Last(last) => {
299                if value.is_some() {
300                    *last = value;
301                }
302            }
303            AggregateState::Collect(list) => {
304                if let Some(v) = value {
305                    list.push(v);
306                }
307            }
308            AggregateState::CollectDistinct(list, seen) => {
309                if let Some(v) = value {
310                    let hashable = HashableValue::from(&v);
311                    if seen.insert(hashable) {
312                        list.push(v);
313                    }
314                }
315            }
316            // Statistical functions using Welford's online algorithm
317            AggregateState::StdDev { count, mean, m2 }
318            | AggregateState::StdDevPop { count, mean, m2 }
319            | AggregateState::Variance { count, mean, m2 }
320            | AggregateState::VariancePop { count, mean, m2 } => {
321                if let Some(ref v) = value
322                    && let Some(x) = value_to_f64(v)
323                {
324                    *count += 1;
325                    let delta = x - *mean;
326                    *mean += delta / *count as f64;
327                    let delta2 = x - *mean;
328                    *m2 += delta * delta2;
329                }
330            }
331            AggregateState::PercentileDisc { values, .. }
332            | AggregateState::PercentileCont { values, .. } => {
333                if let Some(ref v) = value
334                    && let Some(x) = value_to_f64(v)
335                {
336                    values.push(x);
337                }
338            }
339            AggregateState::GroupConcat(list, _sep) => {
340                if let Some(v) = value {
341                    list.push(agg_value_to_string(&v));
342                }
343            }
344            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
345                if let Some(v) = value {
346                    let hashable = HashableValue::from(&v);
347                    if seen.insert(hashable) {
348                        list.push(agg_value_to_string(&v));
349                    }
350                }
351            }
352            AggregateState::Sample(sample) => {
353                if sample.is_none() {
354                    *sample = value;
355                }
356            }
357            AggregateState::Bivariate { .. } => {
358                // Bivariate functions require two values; use update_bivariate() instead.
359                // Single-value update is a no-op for bivariate state.
360            }
361        }
362    }
363
364    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
365    ///
366    /// Uses the two-variable Welford online algorithm for numerically stable computation
367    /// of covariance and related statistics. Skips the update if either value is null.
368    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
369        if let AggregateState::Bivariate {
370            count,
371            mean_x,
372            mean_y,
373            m2_x,
374            m2_y,
375            c_xy,
376            ..
377        } = self
378        {
379            // Skip if either value is null (SQL semantics: exclude non-pairs)
380            if let (Some(y), Some(x)) = (&y_val, &x_val)
381                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
382            {
383                *count += 1;
384                let n = *count as f64;
385                let dx = x_f - *mean_x;
386                let dy = y_f - *mean_y;
387                *mean_x += dx / n;
388                *mean_y += dy / n;
389                let dx2 = x_f - *mean_x; // post-update delta
390                let dy2 = y_f - *mean_y; // post-update delta
391                *m2_x += dx * dx2;
392                *m2_y += dy * dy2;
393                *c_xy += dx * dy2;
394            }
395        }
396    }
397
398    /// Finalizes the state and returns the result value.
399    pub(crate) fn finalize(&self) -> Value {
400        match self {
401            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
402                Value::Int64(*count)
403            }
404            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
405                if *count == 0 {
406                    Value::Null
407                } else {
408                    Value::Int64(*sum)
409                }
410            }
411            AggregateState::SumFloat(sum, count)
412            | AggregateState::SumFloatDistinct(sum, count, _) => {
413                if *count == 0 {
414                    Value::Null
415                } else {
416                    Value::Float64(*sum)
417                }
418            }
419            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
420                if *count == 0 {
421                    Value::Null
422                } else {
423                    Value::Float64(*sum / *count as f64)
424                }
425            }
426            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
427            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
428            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
429            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
430            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
431                Value::List(list.clone().into())
432            }
433            // Sample standard deviation: sqrt(M2 / (n - 1))
434            AggregateState::StdDev { count, m2, .. } => {
435                if *count < 2 {
436                    Value::Null
437                } else {
438                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
439                }
440            }
441            // Population standard deviation: sqrt(M2 / n)
442            AggregateState::StdDevPop { count, m2, .. } => {
443                if *count == 0 {
444                    Value::Null
445                } else {
446                    Value::Float64((*m2 / *count as f64).sqrt())
447                }
448            }
449            // Sample variance: M2 / (n - 1)
450            AggregateState::Variance { count, m2, .. } => {
451                if *count < 2 {
452                    Value::Null
453                } else {
454                    Value::Float64(*m2 / (*count - 1) as f64)
455                }
456            }
457            // Population variance: M2 / n
458            AggregateState::VariancePop { count, m2, .. } => {
459                if *count == 0 {
460                    Value::Null
461                } else {
462                    Value::Float64(*m2 / *count as f64)
463                }
464            }
465            // Discrete percentile: return actual value at percentile position
466            AggregateState::PercentileDisc { values, percentile } => {
467                if values.is_empty() {
468                    Value::Null
469                } else {
470                    let mut sorted = values.clone();
471                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
472                    // Index calculation per SQL standard: floor(p * (n - 1))
473                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
474                    Value::Float64(sorted[index])
475                }
476            }
477            // Continuous percentile: interpolate between values
478            AggregateState::PercentileCont { values, percentile } => {
479                if values.is_empty() {
480                    Value::Null
481                } else {
482                    let mut sorted = values.clone();
483                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
484                    // Linear interpolation per SQL standard
485                    let rank = percentile * (sorted.len() - 1) as f64;
486                    let lower_idx = rank.floor() as usize;
487                    let upper_idx = rank.ceil() as usize;
488                    if lower_idx == upper_idx {
489                        Value::Float64(sorted[lower_idx])
490                    } else {
491                        let fraction = rank - lower_idx as f64;
492                        let result =
493                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
494                        Value::Float64(result)
495                    }
496                }
497            }
498            // GROUP_CONCAT: join strings with space separator (SPARQL default)
499            AggregateState::GroupConcat(list, sep)
500            | AggregateState::GroupConcatDistinct(list, sep, _) => {
501                Value::String(list.join(sep).into())
502            }
503            // SAMPLE: return the first non-null value seen
504            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
505            // Binary set functions: dispatch on kind
506            AggregateState::Bivariate {
507                kind,
508                count,
509                mean_x,
510                mean_y,
511                m2_x,
512                m2_y,
513                c_xy,
514            } => {
515                let n = *count;
516                match kind {
517                    AggregateFunction::CovarSamp => {
518                        if n < 2 {
519                            Value::Null
520                        } else {
521                            Value::Float64(*c_xy / (n - 1) as f64)
522                        }
523                    }
524                    AggregateFunction::CovarPop => {
525                        if n == 0 {
526                            Value::Null
527                        } else {
528                            Value::Float64(*c_xy / n as f64)
529                        }
530                    }
531                    AggregateFunction::Corr => {
532                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
533                            Value::Null
534                        } else {
535                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
536                        }
537                    }
538                    AggregateFunction::RegrSlope => {
539                        if n == 0 || *m2_x == 0.0 {
540                            Value::Null
541                        } else {
542                            Value::Float64(*c_xy / *m2_x)
543                        }
544                    }
545                    AggregateFunction::RegrIntercept => {
546                        if n == 0 || *m2_x == 0.0 {
547                            Value::Null
548                        } else {
549                            let slope = *c_xy / *m2_x;
550                            Value::Float64(*mean_y - slope * *mean_x)
551                        }
552                    }
553                    AggregateFunction::RegrR2 => {
554                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
555                            Value::Null
556                        } else {
557                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
558                        }
559                    }
560                    AggregateFunction::RegrCount => Value::Int64(n),
561                    AggregateFunction::RegrSxx => {
562                        if n == 0 {
563                            Value::Null
564                        } else {
565                            Value::Float64(*m2_x)
566                        }
567                    }
568                    AggregateFunction::RegrSyy => {
569                        if n == 0 {
570                            Value::Null
571                        } else {
572                            Value::Float64(*m2_y)
573                        }
574                    }
575                    AggregateFunction::RegrSxy => {
576                        if n == 0 {
577                            Value::Null
578                        } else {
579                            Value::Float64(*c_xy)
580                        }
581                    }
582                    AggregateFunction::RegrAvgx => {
583                        if n == 0 {
584                            Value::Null
585                        } else {
586                            Value::Float64(*mean_x)
587                        }
588                    }
589                    AggregateFunction::RegrAvgy => {
590                        if n == 0 {
591                            Value::Null
592                        } else {
593                            Value::Float64(*mean_y)
594                        }
595                    }
596                    _ => Value::Null, // non-bivariate functions never reach here
597                }
598            }
599        }
600    }
601}
602
603use super::value_utils::{compare_values, value_to_f64};
604
605/// Converts a Value to its string representation for GROUP_CONCAT.
606fn agg_value_to_string(val: &Value) -> String {
607    match val {
608        Value::String(s) => s.to_string(),
609        Value::Int64(i) => i.to_string(),
610        Value::Float64(f) => f.to_string(),
611        Value::Bool(b) => b.to_string(),
612        Value::Null => String::new(),
613        other => format!("{other:?}"),
614    }
615}
616
617/// A group key for hash-based aggregation.
618#[derive(Debug, Clone, PartialEq, Eq, Hash)]
619pub struct GroupKey(Vec<GroupKeyPart>);
620
621#[derive(Debug, Clone, PartialEq, Eq, Hash)]
622enum GroupKeyPart {
623    Null,
624    Bool(bool),
625    Int64(i64),
626    String(ArcStr),
627}
628
629impl GroupKey {
630    /// Creates a group key from column values.
631    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
632        let parts: Vec<GroupKeyPart> = group_columns
633            .iter()
634            .map(|&col_idx| {
635                chunk
636                    .column(col_idx)
637                    .and_then(|col| col.get_value(row))
638                    .map_or(GroupKeyPart::Null, |v| match v {
639                        Value::Null => GroupKeyPart::Null,
640                        Value::Bool(b) => GroupKeyPart::Bool(b),
641                        Value::Int64(i) => GroupKeyPart::Int64(i),
642                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
643                        Value::String(s) => GroupKeyPart::String(s.clone()),
644                        _ => GroupKeyPart::String(ArcStr::from(format!("{v:?}"))),
645                    })
646            })
647            .collect();
648        GroupKey(parts)
649    }
650
651    /// Converts the group key back to values.
652    fn to_values(&self) -> Vec<Value> {
653        self.0
654            .iter()
655            .map(|part| match part {
656                GroupKeyPart::Null => Value::Null,
657                GroupKeyPart::Bool(b) => Value::Bool(*b),
658                GroupKeyPart::Int64(i) => Value::Int64(*i),
659                GroupKeyPart::String(s) => Value::String(s.clone()),
660            })
661            .collect()
662    }
663}
664
665/// Hash-based aggregate operator.
666///
667/// Groups input by key columns and computes aggregations for each group.
668pub struct HashAggregateOperator {
669    /// Child operator to read from.
670    child: Box<dyn Operator>,
671    /// Columns to group by.
672    group_columns: Vec<usize>,
673    /// Aggregation expressions.
674    aggregates: Vec<AggregateExpr>,
675    /// Output schema.
676    output_schema: Vec<LogicalType>,
677    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
678    groups: IndexMap<GroupKey, Vec<AggregateState>>,
679    /// Whether aggregation is complete.
680    aggregation_complete: bool,
681    /// Results iterator.
682    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
683}
684
685impl HashAggregateOperator {
686    /// Creates a new hash aggregate operator.
687    ///
688    /// # Arguments
689    /// * `child` - Child operator to read from.
690    /// * `group_columns` - Column indices to group by.
691    /// * `aggregates` - Aggregation expressions.
692    /// * `output_schema` - Schema of the output (group columns + aggregate results).
693    pub fn new(
694        child: Box<dyn Operator>,
695        group_columns: Vec<usize>,
696        aggregates: Vec<AggregateExpr>,
697        output_schema: Vec<LogicalType>,
698    ) -> Self {
699        Self {
700            child,
701            group_columns,
702            aggregates,
703            output_schema,
704            groups: IndexMap::new(),
705            aggregation_complete: false,
706            results: None,
707        }
708    }
709
710    /// Performs the aggregation.
711    fn aggregate(&mut self) -> Result<(), OperatorError> {
712        while let Some(chunk) = self.child.next()? {
713            for row in chunk.selected_indices() {
714                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
715
716                // Get or create aggregate states for this group
717                let states = self.groups.entry(key).or_insert_with(|| {
718                    self.aggregates
719                        .iter()
720                        .map(|agg| {
721                            AggregateState::new(
722                                agg.function,
723                                agg.distinct,
724                                agg.percentile,
725                                agg.separator.as_deref(),
726                            )
727                        })
728                        .collect()
729                });
730
731                // Update each aggregate
732                for (i, agg) in self.aggregates.iter().enumerate() {
733                    // Binary set functions: read two column values
734                    if agg.column2.is_some() {
735                        let y_val = agg
736                            .column
737                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
738                        let x_val = agg
739                            .column2
740                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
741                        states[i].update_bivariate(y_val, x_val);
742                        continue;
743                    }
744
745                    let value = match (agg.function, agg.distinct) {
746                        // COUNT(*) without DISTINCT doesn't need a value
747                        (AggregateFunction::Count, false) => None,
748                        // COUNT DISTINCT needs the actual value to track unique values
749                        (AggregateFunction::Count, true) => agg
750                            .column
751                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
752                        _ => agg
753                            .column
754                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
755                    };
756
757                    // For COUNT without DISTINCT, always update. For others, skip nulls.
758                    match (agg.function, agg.distinct) {
759                        (AggregateFunction::Count, false) => states[i].update(None),
760                        (AggregateFunction::Count, true) => {
761                            // COUNT DISTINCT needs the value to track unique values
762                            if value.is_some() && !matches!(value, Some(Value::Null)) {
763                                states[i].update(value);
764                            }
765                        }
766                        (AggregateFunction::CountNonNull, _) => {
767                            if value.is_some() && !matches!(value, Some(Value::Null)) {
768                                states[i].update(value);
769                            }
770                        }
771                        _ => {
772                            if value.is_some() && !matches!(value, Some(Value::Null)) {
773                                states[i].update(value);
774                            }
775                        }
776                    }
777                }
778            }
779        }
780
781        self.aggregation_complete = true;
782
783        // Convert to results iterator (IndexMap::drain takes a range)
784        let results: Vec<_> = self.groups.drain(..).collect();
785        self.results = Some(results.into_iter());
786
787        Ok(())
788    }
789}
790
791impl Operator for HashAggregateOperator {
792    fn next(&mut self) -> OperatorResult {
793        // Perform aggregation if not done
794        if !self.aggregation_complete {
795            self.aggregate()?;
796        }
797
798        // Special case: no groups (global aggregation with no data)
799        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
800            // For global aggregation (no GROUP BY), return one row with initial values
801            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
802
803            for agg in &self.aggregates {
804                let state = AggregateState::new(
805                    agg.function,
806                    agg.distinct,
807                    agg.percentile,
808                    agg.separator.as_deref(),
809                );
810                let value = state.finalize();
811                if let Some(col) = builder.column_mut(self.group_columns.len()) {
812                    col.push_value(value);
813                }
814            }
815            builder.advance_row();
816
817            self.results = Some(Vec::new().into_iter()); // Mark as done
818            return Ok(Some(builder.finish()));
819        }
820
821        let Some(results) = &mut self.results else {
822            return Ok(None);
823        };
824
825        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
826
827        for (key, states) in results.by_ref() {
828            // Output group key columns
829            let key_values = key.to_values();
830            for (i, value) in key_values.into_iter().enumerate() {
831                if let Some(col) = builder.column_mut(i) {
832                    col.push_value(value);
833                }
834            }
835
836            // Output aggregate results
837            for (i, state) in states.iter().enumerate() {
838                let col_idx = self.group_columns.len() + i;
839                if let Some(col) = builder.column_mut(col_idx) {
840                    col.push_value(state.finalize());
841                }
842            }
843
844            builder.advance_row();
845
846            if builder.is_full() {
847                return Ok(Some(builder.finish()));
848            }
849        }
850
851        if builder.row_count() > 0 {
852            Ok(Some(builder.finish()))
853        } else {
854            Ok(None)
855        }
856    }
857
858    fn reset(&mut self) {
859        self.child.reset();
860        self.groups.clear();
861        self.aggregation_complete = false;
862        self.results = None;
863    }
864
865    fn name(&self) -> &'static str {
866        "HashAggregate"
867    }
868}
869
870/// Simple (non-grouping) aggregate operator for global aggregations.
871///
872/// Used when there's no GROUP BY clause - aggregates all input into a single row.
873pub struct SimpleAggregateOperator {
874    /// Child operator.
875    child: Box<dyn Operator>,
876    /// Aggregation expressions.
877    aggregates: Vec<AggregateExpr>,
878    /// Output schema.
879    output_schema: Vec<LogicalType>,
880    /// Aggregate states.
881    states: Vec<AggregateState>,
882    /// Whether aggregation is complete.
883    done: bool,
884}
885
886impl SimpleAggregateOperator {
887    /// Creates a new simple aggregate operator.
888    pub fn new(
889        child: Box<dyn Operator>,
890        aggregates: Vec<AggregateExpr>,
891        output_schema: Vec<LogicalType>,
892    ) -> Self {
893        let states = aggregates
894            .iter()
895            .map(|agg| {
896                AggregateState::new(
897                    agg.function,
898                    agg.distinct,
899                    agg.percentile,
900                    agg.separator.as_deref(),
901                )
902            })
903            .collect();
904
905        Self {
906            child,
907            aggregates,
908            output_schema,
909            states,
910            done: false,
911        }
912    }
913}
914
915impl Operator for SimpleAggregateOperator {
916    fn next(&mut self) -> OperatorResult {
917        if self.done {
918            return Ok(None);
919        }
920
921        // Process all input
922        while let Some(chunk) = self.child.next()? {
923            for row in chunk.selected_indices() {
924                for (i, agg) in self.aggregates.iter().enumerate() {
925                    // Binary set functions: read two column values
926                    if agg.column2.is_some() {
927                        let y_val = agg
928                            .column
929                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
930                        let x_val = agg
931                            .column2
932                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
933                        self.states[i].update_bivariate(y_val, x_val);
934                        continue;
935                    }
936
937                    let value = match (agg.function, agg.distinct) {
938                        // COUNT(*) without DISTINCT doesn't need a value
939                        (AggregateFunction::Count, false) => None,
940                        // COUNT DISTINCT needs the actual value to track unique values
941                        (AggregateFunction::Count, true) => agg
942                            .column
943                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
944                        _ => agg
945                            .column
946                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
947                    };
948
949                    match (agg.function, agg.distinct) {
950                        (AggregateFunction::Count, false) => self.states[i].update(None),
951                        (AggregateFunction::Count, true) => {
952                            // COUNT DISTINCT needs the value to track unique values
953                            if value.is_some() && !matches!(value, Some(Value::Null)) {
954                                self.states[i].update(value);
955                            }
956                        }
957                        (AggregateFunction::CountNonNull, _) => {
958                            if value.is_some() && !matches!(value, Some(Value::Null)) {
959                                self.states[i].update(value);
960                            }
961                        }
962                        _ => {
963                            if value.is_some() && !matches!(value, Some(Value::Null)) {
964                                self.states[i].update(value);
965                            }
966                        }
967                    }
968                }
969            }
970        }
971
972        // Output single result row
973        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
974
975        for (i, state) in self.states.iter().enumerate() {
976            if let Some(col) = builder.column_mut(i) {
977                col.push_value(state.finalize());
978            }
979        }
980        builder.advance_row();
981
982        self.done = true;
983        Ok(Some(builder.finish()))
984    }
985
986    fn reset(&mut self) {
987        self.child.reset();
988        self.states = self
989            .aggregates
990            .iter()
991            .map(|agg| {
992                AggregateState::new(
993                    agg.function,
994                    agg.distinct,
995                    agg.percentile,
996                    agg.separator.as_deref(),
997                )
998            })
999            .collect();
1000        self.done = false;
1001    }
1002
1003    fn name(&self) -> &'static str {
1004        "SimpleAggregate"
1005    }
1006}
1007
1008#[cfg(test)]
1009mod tests {
1010    use super::*;
1011    use crate::execution::chunk::DataChunkBuilder;
1012
1013    struct MockOperator {
1014        chunks: Vec<DataChunk>,
1015        position: usize,
1016    }
1017
1018    impl MockOperator {
1019        fn new(chunks: Vec<DataChunk>) -> Self {
1020            Self {
1021                chunks,
1022                position: 0,
1023            }
1024        }
1025    }
1026
1027    impl Operator for MockOperator {
1028        fn next(&mut self) -> OperatorResult {
1029            if self.position < self.chunks.len() {
1030                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1031                self.position += 1;
1032                Ok(Some(chunk))
1033            } else {
1034                Ok(None)
1035            }
1036        }
1037
1038        fn reset(&mut self) {
1039            self.position = 0;
1040        }
1041
1042        fn name(&self) -> &'static str {
1043            "Mock"
1044        }
1045    }
1046
1047    fn create_test_chunk() -> DataChunk {
1048        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1049        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1050
1051        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1052        for (group, value) in data {
1053            builder.column_mut(0).unwrap().push_int64(group);
1054            builder.column_mut(1).unwrap().push_int64(value);
1055            builder.advance_row();
1056        }
1057
1058        builder.finish()
1059    }
1060
1061    #[test]
1062    fn test_simple_count() {
1063        let mock = MockOperator::new(vec![create_test_chunk()]);
1064
1065        let mut agg = SimpleAggregateOperator::new(
1066            Box::new(mock),
1067            vec![AggregateExpr::count_star()],
1068            vec![LogicalType::Int64],
1069        );
1070
1071        let result = agg.next().unwrap().unwrap();
1072        assert_eq!(result.row_count(), 1);
1073        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1074
1075        // Should be done
1076        assert!(agg.next().unwrap().is_none());
1077    }
1078
1079    #[test]
1080    fn test_simple_sum() {
1081        let mock = MockOperator::new(vec![create_test_chunk()]);
1082
1083        let mut agg = SimpleAggregateOperator::new(
1084            Box::new(mock),
1085            vec![AggregateExpr::sum(1)], // Sum of column 1
1086            vec![LogicalType::Int64],
1087        );
1088
1089        let result = agg.next().unwrap().unwrap();
1090        assert_eq!(result.row_count(), 1);
1091        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1092        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1093    }
1094
1095    #[test]
1096    fn test_simple_avg() {
1097        let mock = MockOperator::new(vec![create_test_chunk()]);
1098
1099        let mut agg = SimpleAggregateOperator::new(
1100            Box::new(mock),
1101            vec![AggregateExpr::avg(1)],
1102            vec![LogicalType::Float64],
1103        );
1104
1105        let result = agg.next().unwrap().unwrap();
1106        assert_eq!(result.row_count(), 1);
1107        // Avg: 150 / 5 = 30.0
1108        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1109        assert!((avg - 30.0).abs() < 0.001);
1110    }
1111
1112    #[test]
1113    fn test_simple_min_max() {
1114        let mock = MockOperator::new(vec![create_test_chunk()]);
1115
1116        let mut agg = SimpleAggregateOperator::new(
1117            Box::new(mock),
1118            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1119            vec![LogicalType::Int64, LogicalType::Int64],
1120        );
1121
1122        let result = agg.next().unwrap().unwrap();
1123        assert_eq!(result.row_count(), 1);
1124        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1125        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1126    }
1127
1128    #[test]
1129    fn test_sum_with_string_values() {
1130        // Test SUM with string values (like RDF stores numeric literals)
1131        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1132        builder.column_mut(0).unwrap().push_string("30");
1133        builder.advance_row();
1134        builder.column_mut(0).unwrap().push_string("25");
1135        builder.advance_row();
1136        builder.column_mut(0).unwrap().push_string("35");
1137        builder.advance_row();
1138        let chunk = builder.finish();
1139
1140        let mock = MockOperator::new(vec![chunk]);
1141        let mut agg = SimpleAggregateOperator::new(
1142            Box::new(mock),
1143            vec![AggregateExpr::sum(0)],
1144            vec![LogicalType::Float64],
1145        );
1146
1147        let result = agg.next().unwrap().unwrap();
1148        assert_eq!(result.row_count(), 1);
1149        // Should parse strings and sum: 30 + 25 + 35 = 90
1150        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1151        assert!(
1152            (sum_val - 90.0).abs() < 0.001,
1153            "Expected 90.0, got {}",
1154            sum_val
1155        );
1156    }
1157
1158    #[test]
1159    fn test_grouped_aggregation() {
1160        let mock = MockOperator::new(vec![create_test_chunk()]);
1161
1162        // GROUP BY column 0, SUM(column 1)
1163        let mut agg = HashAggregateOperator::new(
1164            Box::new(mock),
1165            vec![0],                     // Group by column 0
1166            vec![AggregateExpr::sum(1)], // Sum of column 1
1167            vec![LogicalType::Int64, LogicalType::Int64],
1168        );
1169
1170        let mut results: Vec<(i64, i64)> = Vec::new();
1171        while let Some(chunk) = agg.next().unwrap() {
1172            for row in chunk.selected_indices() {
1173                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1174                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1175                results.push((group, sum));
1176            }
1177        }
1178
1179        results.sort_by_key(|(g, _)| *g);
1180        assert_eq!(results.len(), 2);
1181        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1182        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1183    }
1184
1185    #[test]
1186    fn test_grouped_count() {
1187        let mock = MockOperator::new(vec![create_test_chunk()]);
1188
1189        // GROUP BY column 0, COUNT(*)
1190        let mut agg = HashAggregateOperator::new(
1191            Box::new(mock),
1192            vec![0],
1193            vec![AggregateExpr::count_star()],
1194            vec![LogicalType::Int64, LogicalType::Int64],
1195        );
1196
1197        let mut results: Vec<(i64, i64)> = Vec::new();
1198        while let Some(chunk) = agg.next().unwrap() {
1199            for row in chunk.selected_indices() {
1200                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1201                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1202                results.push((group, count));
1203            }
1204        }
1205
1206        results.sort_by_key(|(g, _)| *g);
1207        assert_eq!(results.len(), 2);
1208        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1209        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1210    }
1211
1212    #[test]
1213    fn test_multiple_aggregates() {
1214        let mock = MockOperator::new(vec![create_test_chunk()]);
1215
1216        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1217        let mut agg = HashAggregateOperator::new(
1218            Box::new(mock),
1219            vec![0],
1220            vec![
1221                AggregateExpr::count_star(),
1222                AggregateExpr::sum(1),
1223                AggregateExpr::avg(1),
1224            ],
1225            vec![
1226                LogicalType::Int64,   // Group key
1227                LogicalType::Int64,   // COUNT
1228                LogicalType::Int64,   // SUM
1229                LogicalType::Float64, // AVG
1230            ],
1231        );
1232
1233        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1234        while let Some(chunk) = agg.next().unwrap() {
1235            for row in chunk.selected_indices() {
1236                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1237                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1238                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1239                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1240                results.push((group, count, sum, avg));
1241            }
1242        }
1243
1244        results.sort_by_key(|(g, _, _, _)| *g);
1245        assert_eq!(results.len(), 2);
1246
1247        // Group 1: COUNT=2, SUM=30, AVG=15.0
1248        assert_eq!(results[0].0, 1);
1249        assert_eq!(results[0].1, 2);
1250        assert_eq!(results[0].2, 30);
1251        assert!((results[0].3 - 15.0).abs() < 0.001);
1252
1253        // Group 2: COUNT=3, SUM=120, AVG=40.0
1254        assert_eq!(results[1].0, 2);
1255        assert_eq!(results[1].1, 3);
1256        assert_eq!(results[1].2, 120);
1257        assert!((results[1].3 - 40.0).abs() < 0.001);
1258    }
1259
1260    fn create_test_chunk_with_duplicates() -> DataChunk {
1261        // Create data with duplicate values in column 1
1262        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1263        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1264        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1265        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1266
1267        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1268        for (group, value) in data {
1269            builder.column_mut(0).unwrap().push_int64(group);
1270            builder.column_mut(1).unwrap().push_int64(value);
1271            builder.advance_row();
1272        }
1273
1274        builder.finish()
1275    }
1276
1277    #[test]
1278    fn test_count_distinct() {
1279        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1280
1281        // COUNT(DISTINCT column 1)
1282        let mut agg = SimpleAggregateOperator::new(
1283            Box::new(mock),
1284            vec![AggregateExpr::count(1).with_distinct()],
1285            vec![LogicalType::Int64],
1286        );
1287
1288        let result = agg.next().unwrap().unwrap();
1289        assert_eq!(result.row_count(), 1);
1290        // Total distinct values: 10, 20, 30 = 3 distinct values
1291        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1292    }
1293
1294    #[test]
1295    fn test_grouped_count_distinct() {
1296        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1297
1298        // GROUP BY column 0, COUNT(DISTINCT column 1)
1299        let mut agg = HashAggregateOperator::new(
1300            Box::new(mock),
1301            vec![0],
1302            vec![AggregateExpr::count(1).with_distinct()],
1303            vec![LogicalType::Int64, LogicalType::Int64],
1304        );
1305
1306        let mut results: Vec<(i64, i64)> = Vec::new();
1307        while let Some(chunk) = agg.next().unwrap() {
1308            for row in chunk.selected_indices() {
1309                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1310                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1311                results.push((group, count));
1312            }
1313        }
1314
1315        results.sort_by_key(|(g, _)| *g);
1316        assert_eq!(results.len(), 2);
1317        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1318        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1319    }
1320
1321    #[test]
1322    fn test_sum_distinct() {
1323        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1324
1325        // SUM(DISTINCT column 1)
1326        let mut agg = SimpleAggregateOperator::new(
1327            Box::new(mock),
1328            vec![AggregateExpr::sum(1).with_distinct()],
1329            vec![LogicalType::Int64],
1330        );
1331
1332        let result = agg.next().unwrap().unwrap();
1333        assert_eq!(result.row_count(), 1);
1334        // Sum of distinct values: 10 + 20 + 30 = 60
1335        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1336    }
1337
1338    #[test]
1339    fn test_avg_distinct() {
1340        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1341
1342        // AVG(DISTINCT column 1)
1343        let mut agg = SimpleAggregateOperator::new(
1344            Box::new(mock),
1345            vec![AggregateExpr::avg(1).with_distinct()],
1346            vec![LogicalType::Float64],
1347        );
1348
1349        let result = agg.next().unwrap().unwrap();
1350        assert_eq!(result.row_count(), 1);
1351        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1352        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1353        assert!((avg - 20.0).abs() < 0.001);
1354    }
1355
1356    fn create_statistical_test_chunk() -> DataChunk {
1357        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1358        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1359        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1360
1361        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1362            builder.column_mut(0).unwrap().push_int64(value);
1363            builder.advance_row();
1364        }
1365
1366        builder.finish()
1367    }
1368
1369    #[test]
1370    fn test_stdev_sample() {
1371        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1372
1373        let mut agg = SimpleAggregateOperator::new(
1374            Box::new(mock),
1375            vec![AggregateExpr::stdev(0)],
1376            vec![LogicalType::Float64],
1377        );
1378
1379        let result = agg.next().unwrap().unwrap();
1380        assert_eq!(result.row_count(), 1);
1381        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1382        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1383        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1384        assert!((stdev - 2.138).abs() < 0.01);
1385    }
1386
1387    #[test]
1388    fn test_stdev_population() {
1389        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1390
1391        let mut agg = SimpleAggregateOperator::new(
1392            Box::new(mock),
1393            vec![AggregateExpr::stdev_pop(0)],
1394            vec![LogicalType::Float64],
1395        );
1396
1397        let result = agg.next().unwrap().unwrap();
1398        assert_eq!(result.row_count(), 1);
1399        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1400        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1401        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1402        assert!((stdev - 2.0).abs() < 0.01);
1403    }
1404
1405    #[test]
1406    fn test_percentile_disc() {
1407        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1408
1409        // Median (50th percentile discrete)
1410        let mut agg = SimpleAggregateOperator::new(
1411            Box::new(mock),
1412            vec![AggregateExpr::percentile_disc(0, 0.5)],
1413            vec![LogicalType::Float64],
1414        );
1415
1416        let result = agg.next().unwrap().unwrap();
1417        assert_eq!(result.row_count(), 1);
1418        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1419        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1420        assert!((percentile - 4.0).abs() < 0.01);
1421    }
1422
1423    #[test]
1424    fn test_percentile_cont() {
1425        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1426
1427        // Median (50th percentile continuous)
1428        let mut agg = SimpleAggregateOperator::new(
1429            Box::new(mock),
1430            vec![AggregateExpr::percentile_cont(0, 0.5)],
1431            vec![LogicalType::Float64],
1432        );
1433
1434        let result = agg.next().unwrap().unwrap();
1435        assert_eq!(result.row_count(), 1);
1436        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1437        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1438        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1439        assert!((percentile - 4.5).abs() < 0.01);
1440    }
1441
1442    #[test]
1443    fn test_percentile_extremes() {
1444        // Test 0th and 100th percentiles
1445        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1446
1447        let mut agg = SimpleAggregateOperator::new(
1448            Box::new(mock),
1449            vec![
1450                AggregateExpr::percentile_disc(0, 0.0),
1451                AggregateExpr::percentile_disc(0, 1.0),
1452            ],
1453            vec![LogicalType::Float64, LogicalType::Float64],
1454        );
1455
1456        let result = agg.next().unwrap().unwrap();
1457        assert_eq!(result.row_count(), 1);
1458        // 0th percentile = minimum = 2
1459        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1460        assert!((p0 - 2.0).abs() < 0.01);
1461        // 100th percentile = maximum = 9
1462        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1463        assert!((p100 - 9.0).abs() < 0.01);
1464    }
1465
1466    #[test]
1467    fn test_stdev_single_value() {
1468        // Single value should return null for sample stdev
1469        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1470        builder.column_mut(0).unwrap().push_int64(42);
1471        builder.advance_row();
1472        let chunk = builder.finish();
1473
1474        let mock = MockOperator::new(vec![chunk]);
1475
1476        let mut agg = SimpleAggregateOperator::new(
1477            Box::new(mock),
1478            vec![AggregateExpr::stdev(0)],
1479            vec![LogicalType::Float64],
1480        );
1481
1482        let result = agg.next().unwrap().unwrap();
1483        assert_eq!(result.row_count(), 1);
1484        // Sample stdev of single value is undefined (null)
1485        assert!(matches!(
1486            result.column(0).unwrap().get_value(0),
1487            Some(Value::Null)
1488        ));
1489    }
1490
1491    #[test]
1492    fn test_first_and_last() {
1493        let mock = MockOperator::new(vec![create_test_chunk()]);
1494
1495        let mut agg = SimpleAggregateOperator::new(
1496            Box::new(mock),
1497            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1498            vec![LogicalType::Int64, LogicalType::Int64],
1499        );
1500
1501        let result = agg.next().unwrap().unwrap();
1502        assert_eq!(result.row_count(), 1);
1503        // First: 10, Last: 50
1504        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1505        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1506    }
1507
1508    #[test]
1509    fn test_collect() {
1510        let mock = MockOperator::new(vec![create_test_chunk()]);
1511
1512        let mut agg = SimpleAggregateOperator::new(
1513            Box::new(mock),
1514            vec![AggregateExpr::collect(1)],
1515            vec![LogicalType::Any],
1516        );
1517
1518        let result = agg.next().unwrap().unwrap();
1519        let val = result.column(0).unwrap().get_value(0).unwrap();
1520        if let Value::List(items) = val {
1521            assert_eq!(items.len(), 5);
1522        } else {
1523            panic!("Expected List value");
1524        }
1525    }
1526
1527    #[test]
1528    fn test_collect_distinct() {
1529        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1530
1531        let mut agg = SimpleAggregateOperator::new(
1532            Box::new(mock),
1533            vec![AggregateExpr::collect(1).with_distinct()],
1534            vec![LogicalType::Any],
1535        );
1536
1537        let result = agg.next().unwrap().unwrap();
1538        let val = result.column(0).unwrap().get_value(0).unwrap();
1539        if let Value::List(items) = val {
1540            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1541            assert_eq!(items.len(), 3);
1542        } else {
1543            panic!("Expected List value");
1544        }
1545    }
1546
1547    #[test]
1548    fn test_group_concat() {
1549        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1550        for s in ["hello", "world", "foo"] {
1551            builder.column_mut(0).unwrap().push_string(s);
1552            builder.advance_row();
1553        }
1554        let chunk = builder.finish();
1555        let mock = MockOperator::new(vec![chunk]);
1556
1557        let agg_expr = AggregateExpr {
1558            function: AggregateFunction::GroupConcat,
1559            column: Some(0),
1560            column2: None,
1561            distinct: false,
1562            alias: None,
1563            percentile: None,
1564            separator: None,
1565        };
1566
1567        let mut agg =
1568            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1569
1570        let result = agg.next().unwrap().unwrap();
1571        let val = result.column(0).unwrap().get_value(0).unwrap();
1572        assert_eq!(val, Value::String("hello world foo".into()));
1573    }
1574
1575    #[test]
1576    fn test_sample() {
1577        let mock = MockOperator::new(vec![create_test_chunk()]);
1578
1579        let agg_expr = AggregateExpr {
1580            function: AggregateFunction::Sample,
1581            column: Some(1),
1582            column2: None,
1583            distinct: false,
1584            alias: None,
1585            percentile: None,
1586            separator: None,
1587        };
1588
1589        let mut agg =
1590            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1591
1592        let result = agg.next().unwrap().unwrap();
1593        // Sample should return the first non-null value (10)
1594        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1595    }
1596
1597    #[test]
1598    fn test_variance_sample() {
1599        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1600
1601        let agg_expr = AggregateExpr {
1602            function: AggregateFunction::Variance,
1603            column: Some(0),
1604            column2: None,
1605            distinct: false,
1606            alias: None,
1607            percentile: None,
1608            separator: None,
1609        };
1610
1611        let mut agg = SimpleAggregateOperator::new(
1612            Box::new(mock),
1613            vec![agg_expr],
1614            vec![LogicalType::Float64],
1615        );
1616
1617        let result = agg.next().unwrap().unwrap();
1618        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1619        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1620        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1621    }
1622
1623    #[test]
1624    fn test_variance_population() {
1625        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1626
1627        let agg_expr = AggregateExpr {
1628            function: AggregateFunction::VariancePop,
1629            column: Some(0),
1630            column2: None,
1631            distinct: false,
1632            alias: None,
1633            percentile: None,
1634            separator: None,
1635        };
1636
1637        let mut agg = SimpleAggregateOperator::new(
1638            Box::new(mock),
1639            vec![agg_expr],
1640            vec![LogicalType::Float64],
1641        );
1642
1643        let result = agg.next().unwrap().unwrap();
1644        // Population variance: M2/n = 32/8 = 4.0
1645        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1646        assert!((variance - 4.0).abs() < 0.01);
1647    }
1648
1649    #[test]
1650    fn test_variance_single_value() {
1651        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1652        builder.column_mut(0).unwrap().push_int64(42);
1653        builder.advance_row();
1654        let chunk = builder.finish();
1655        let mock = MockOperator::new(vec![chunk]);
1656
1657        let agg_expr = AggregateExpr {
1658            function: AggregateFunction::Variance,
1659            column: Some(0),
1660            column2: None,
1661            distinct: false,
1662            alias: None,
1663            percentile: None,
1664            separator: None,
1665        };
1666
1667        let mut agg = SimpleAggregateOperator::new(
1668            Box::new(mock),
1669            vec![agg_expr],
1670            vec![LogicalType::Float64],
1671        );
1672
1673        let result = agg.next().unwrap().unwrap();
1674        // Sample variance of single value is undefined (null)
1675        assert!(matches!(
1676            result.column(0).unwrap().get_value(0),
1677            Some(Value::Null)
1678        ));
1679    }
1680
1681    #[test]
1682    fn test_empty_aggregation() {
1683        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1684        // (ISO/IEC 39075 Section 20.9)
1685        let mock = MockOperator::new(vec![]);
1686
1687        let mut agg = SimpleAggregateOperator::new(
1688            Box::new(mock),
1689            vec![
1690                AggregateExpr::count_star(),
1691                AggregateExpr::sum(0),
1692                AggregateExpr::avg(0),
1693                AggregateExpr::min(0),
1694                AggregateExpr::max(0),
1695            ],
1696            vec![
1697                LogicalType::Int64,
1698                LogicalType::Int64,
1699                LogicalType::Float64,
1700                LogicalType::Int64,
1701                LogicalType::Int64,
1702            ],
1703        );
1704
1705        let result = agg.next().unwrap().unwrap();
1706        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1707        assert!(matches!(
1708            result.column(1).unwrap().get_value(0),
1709            Some(Value::Null)
1710        )); // SUM
1711        assert!(matches!(
1712            result.column(2).unwrap().get_value(0),
1713            Some(Value::Null)
1714        )); // AVG
1715        assert!(matches!(
1716            result.column(3).unwrap().get_value(0),
1717            Some(Value::Null)
1718        )); // MIN
1719        assert!(matches!(
1720            result.column(4).unwrap().get_value(0),
1721            Some(Value::Null)
1722        )); // MAX
1723    }
1724
1725    #[test]
1726    fn test_stdev_pop_single_value() {
1727        // Single value should return 0 for population stdev
1728        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1729        builder.column_mut(0).unwrap().push_int64(42);
1730        builder.advance_row();
1731        let chunk = builder.finish();
1732
1733        let mock = MockOperator::new(vec![chunk]);
1734
1735        let mut agg = SimpleAggregateOperator::new(
1736            Box::new(mock),
1737            vec![AggregateExpr::stdev_pop(0)],
1738            vec![LogicalType::Float64],
1739        );
1740
1741        let result = agg.next().unwrap().unwrap();
1742        assert_eq!(result.row_count(), 1);
1743        // Population stdev of single value is 0
1744        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1745        assert!((stdev - 0.0).abs() < 0.01);
1746    }
1747}