Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20/// State for a single aggregation computation.
21#[derive(Debug, Clone)]
22pub(crate) enum AggregateState {
23    /// Count state.
24    Count(i64),
25    /// Count distinct state (count, seen values).
26    CountDistinct(i64, HashSet<HashableValue>),
27    /// Sum state (integer).
28    SumInt(i64),
29    /// Sum distinct state (integer, seen values).
30    SumIntDistinct(i64, HashSet<HashableValue>),
31    /// Sum state (float).
32    SumFloat(f64),
33    /// Sum distinct state (float, seen values).
34    SumFloatDistinct(f64, HashSet<HashableValue>),
35    /// Average state (sum, count).
36    Avg(f64, i64),
37    /// Average distinct state (sum, count, seen values).
38    AvgDistinct(f64, i64, HashSet<HashableValue>),
39    /// Min state.
40    Min(Option<Value>),
41    /// Max state.
42    Max(Option<Value>),
43    /// First state.
44    First(Option<Value>),
45    /// Last state.
46    Last(Option<Value>),
47    /// Collect state.
48    Collect(Vec<Value>),
49    /// Collect distinct state (values, seen).
50    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
52    StdDev { count: i64, mean: f64, m2: f64 },
53    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDevPop { count: i64, mean: f64, m2: f64 },
55    /// Discrete percentile state (values, percentile).
56    PercentileDisc { values: Vec<f64>, percentile: f64 },
57    /// Continuous percentile state (values, percentile).
58    PercentileCont { values: Vec<f64>, percentile: f64 },
59    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
60    GroupConcat(Vec<String>, String),
61    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
62    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
63    /// SAMPLE state (first non-null value encountered).
64    Sample(Option<Value>),
65    /// Sample variance state using Welford's algorithm (count, mean, M2).
66    Variance { count: i64, mean: f64, m2: f64 },
67    /// Population variance state using Welford's algorithm (count, mean, M2).
68    VariancePop { count: i64, mean: f64, m2: f64 },
69    /// Two-variable online statistics (Welford generalization for covariance/regression).
70    Bivariate {
71        /// Which binary set function this state will finalize to.
72        kind: AggregateFunction,
73        count: i64,
74        mean_x: f64,
75        mean_y: f64,
76        m2_x: f64,
77        m2_y: f64,
78        c_xy: f64,
79    },
80}
81
82impl AggregateState {
83    /// Creates initial state for an aggregation function.
84    pub(crate) fn new(
85        function: AggregateFunction,
86        distinct: bool,
87        percentile: Option<f64>,
88        separator: Option<&str>,
89    ) -> Self {
90        match (function, distinct) {
91            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
92                AggregateState::Count(0)
93            }
94            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
95                AggregateState::CountDistinct(0, HashSet::new())
96            }
97            (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
98            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
99            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
100            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
101            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
102            (AggregateFunction::Max, _) => AggregateState::Max(None),
103            (AggregateFunction::First, _) => AggregateState::First(None),
104            (AggregateFunction::Last, _) => AggregateState::Last(None),
105            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
106            (AggregateFunction::Collect, true) => {
107                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
108            }
109            // Statistical functions (Welford's algorithm for online computation)
110            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
111                count: 0,
112                mean: 0.0,
113                m2: 0.0,
114            },
115            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
116                count: 0,
117                mean: 0.0,
118                m2: 0.0,
119            },
120            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
121                values: Vec::new(),
122                percentile: percentile.unwrap_or(0.5),
123            },
124            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
125                values: Vec::new(),
126                percentile: percentile.unwrap_or(0.5),
127            },
128            (AggregateFunction::GroupConcat, false) => {
129                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
130            }
131            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
132                Vec::new(),
133                separator.unwrap_or(" ").to_string(),
134                HashSet::new(),
135            ),
136            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
137            // Binary set functions (all share the same Bivariate state)
138            (
139                AggregateFunction::CovarSamp
140                | AggregateFunction::CovarPop
141                | AggregateFunction::Corr
142                | AggregateFunction::RegrSlope
143                | AggregateFunction::RegrIntercept
144                | AggregateFunction::RegrR2
145                | AggregateFunction::RegrCount
146                | AggregateFunction::RegrSxx
147                | AggregateFunction::RegrSyy
148                | AggregateFunction::RegrSxy
149                | AggregateFunction::RegrAvgx
150                | AggregateFunction::RegrAvgy,
151                _,
152            ) => AggregateState::Bivariate {
153                kind: function,
154                count: 0,
155                mean_x: 0.0,
156                mean_y: 0.0,
157                m2_x: 0.0,
158                m2_y: 0.0,
159                c_xy: 0.0,
160            },
161            (AggregateFunction::Variance, _) => AggregateState::Variance {
162                count: 0,
163                mean: 0.0,
164                m2: 0.0,
165            },
166            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
167                count: 0,
168                mean: 0.0,
169                m2: 0.0,
170            },
171        }
172    }
173
174    /// Updates the state with a new value.
175    pub(crate) fn update(&mut self, value: Option<Value>) {
176        match self {
177            AggregateState::Count(count) => {
178                *count += 1;
179            }
180            AggregateState::CountDistinct(count, seen) => {
181                if let Some(ref v) = value {
182                    let hashable = HashableValue::from(v);
183                    if seen.insert(hashable) {
184                        *count += 1;
185                    }
186                }
187            }
188            AggregateState::SumInt(sum) => {
189                if let Some(Value::Int64(v)) = value {
190                    *sum += v;
191                } else if let Some(Value::Float64(v)) = value {
192                    // Convert to float sum
193                    *self = AggregateState::SumFloat(*sum as f64 + v);
194                } else if let Some(ref v) = value {
195                    // RDF stores numeric literals as strings - try to parse
196                    if let Some(num) = value_to_f64(v) {
197                        *self = AggregateState::SumFloat(*sum as f64 + num);
198                    }
199                }
200            }
201            AggregateState::SumIntDistinct(sum, seen) => {
202                if let Some(ref v) = value {
203                    let hashable = HashableValue::from(v);
204                    if seen.insert(hashable) {
205                        if let Value::Int64(i) = v {
206                            *sum += i;
207                        } else if let Value::Float64(f) = v {
208                            // Convert to float distinct: move the seen set instead of cloning
209                            let moved_seen = std::mem::take(seen);
210                            *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
211                        } else if let Some(num) = value_to_f64(v) {
212                            // RDF string-encoded numerics
213                            let moved_seen = std::mem::take(seen);
214                            *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
215                        }
216                    }
217                }
218            }
219            AggregateState::SumFloat(sum) => {
220                if let Some(ref v) = value {
221                    // Use value_to_f64 which now handles strings
222                    if let Some(num) = value_to_f64(v) {
223                        *sum += num;
224                    }
225                }
226            }
227            AggregateState::SumFloatDistinct(sum, seen) => {
228                if let Some(ref v) = value {
229                    let hashable = HashableValue::from(v);
230                    if seen.insert(hashable)
231                        && let Some(num) = value_to_f64(v)
232                    {
233                        *sum += num;
234                    }
235                }
236            }
237            AggregateState::Avg(sum, count) => {
238                if let Some(ref v) = value
239                    && let Some(num) = value_to_f64(v)
240                {
241                    *sum += num;
242                    *count += 1;
243                }
244            }
245            AggregateState::AvgDistinct(sum, count, seen) => {
246                if let Some(ref v) = value {
247                    let hashable = HashableValue::from(v);
248                    if seen.insert(hashable)
249                        && let Some(num) = value_to_f64(v)
250                    {
251                        *sum += num;
252                        *count += 1;
253                    }
254                }
255            }
256            AggregateState::Min(min) => {
257                if let Some(v) = value {
258                    match min {
259                        None => *min = Some(v),
260                        Some(current) => {
261                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
262                                *min = Some(v);
263                            }
264                        }
265                    }
266                }
267            }
268            AggregateState::Max(max) => {
269                if let Some(v) = value {
270                    match max {
271                        None => *max = Some(v),
272                        Some(current) => {
273                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
274                                *max = Some(v);
275                            }
276                        }
277                    }
278                }
279            }
280            AggregateState::First(first) => {
281                if first.is_none() {
282                    *first = value;
283                }
284            }
285            AggregateState::Last(last) => {
286                if value.is_some() {
287                    *last = value;
288                }
289            }
290            AggregateState::Collect(list) => {
291                if let Some(v) = value {
292                    list.push(v);
293                }
294            }
295            AggregateState::CollectDistinct(list, seen) => {
296                if let Some(v) = value {
297                    let hashable = HashableValue::from(&v);
298                    if seen.insert(hashable) {
299                        list.push(v);
300                    }
301                }
302            }
303            // Statistical functions using Welford's online algorithm
304            AggregateState::StdDev { count, mean, m2 }
305            | AggregateState::StdDevPop { count, mean, m2 }
306            | AggregateState::Variance { count, mean, m2 }
307            | AggregateState::VariancePop { count, mean, m2 } => {
308                if let Some(ref v) = value
309                    && let Some(x) = value_to_f64(v)
310                {
311                    *count += 1;
312                    let delta = x - *mean;
313                    *mean += delta / *count as f64;
314                    let delta2 = x - *mean;
315                    *m2 += delta * delta2;
316                }
317            }
318            AggregateState::PercentileDisc { values, .. }
319            | AggregateState::PercentileCont { values, .. } => {
320                if let Some(ref v) = value
321                    && let Some(x) = value_to_f64(v)
322                {
323                    values.push(x);
324                }
325            }
326            AggregateState::GroupConcat(list, _sep) => {
327                if let Some(v) = value {
328                    list.push(agg_value_to_string(&v));
329                }
330            }
331            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
332                if let Some(v) = value {
333                    let hashable = HashableValue::from(&v);
334                    if seen.insert(hashable) {
335                        list.push(agg_value_to_string(&v));
336                    }
337                }
338            }
339            AggregateState::Sample(sample) => {
340                if sample.is_none() {
341                    *sample = value;
342                }
343            }
344            AggregateState::Bivariate { .. } => {
345                // Bivariate functions require two values; use update_bivariate() instead.
346                // Single-value update is a no-op for bivariate state.
347            }
348        }
349    }
350
351    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
352    ///
353    /// Uses the two-variable Welford online algorithm for numerically stable computation
354    /// of covariance and related statistics. Skips the update if either value is null.
355    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
356        if let AggregateState::Bivariate {
357            count,
358            mean_x,
359            mean_y,
360            m2_x,
361            m2_y,
362            c_xy,
363            ..
364        } = self
365        {
366            // Skip if either value is null (SQL semantics: exclude non-pairs)
367            if let (Some(y), Some(x)) = (&y_val, &x_val)
368                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
369            {
370                *count += 1;
371                let n = *count as f64;
372                let dx = x_f - *mean_x;
373                let dy = y_f - *mean_y;
374                *mean_x += dx / n;
375                *mean_y += dy / n;
376                let dx2 = x_f - *mean_x; // post-update delta
377                let dy2 = y_f - *mean_y; // post-update delta
378                *m2_x += dx * dx2;
379                *m2_y += dy * dy2;
380                *c_xy += dx * dy2;
381            }
382        }
383    }
384
385    /// Finalizes the state and returns the result value.
386    pub(crate) fn finalize(&self) -> Value {
387        match self {
388            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
389                Value::Int64(*count)
390            }
391            AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
392                Value::Int64(*sum)
393            }
394            AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
395                Value::Float64(*sum)
396            }
397            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
398                if *count == 0 {
399                    Value::Null
400                } else {
401                    Value::Float64(*sum / *count as f64)
402                }
403            }
404            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
405            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
406            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
407            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
408            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
409                Value::List(list.clone().into())
410            }
411            // Sample standard deviation: sqrt(M2 / (n - 1))
412            AggregateState::StdDev { count, m2, .. } => {
413                if *count < 2 {
414                    Value::Null
415                } else {
416                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
417                }
418            }
419            // Population standard deviation: sqrt(M2 / n)
420            AggregateState::StdDevPop { count, m2, .. } => {
421                if *count == 0 {
422                    Value::Null
423                } else {
424                    Value::Float64((*m2 / *count as f64).sqrt())
425                }
426            }
427            // Sample variance: M2 / (n - 1)
428            AggregateState::Variance { count, m2, .. } => {
429                if *count < 2 {
430                    Value::Null
431                } else {
432                    Value::Float64(*m2 / (*count - 1) as f64)
433                }
434            }
435            // Population variance: M2 / n
436            AggregateState::VariancePop { count, m2, .. } => {
437                if *count == 0 {
438                    Value::Null
439                } else {
440                    Value::Float64(*m2 / *count as f64)
441                }
442            }
443            // Discrete percentile: return actual value at percentile position
444            AggregateState::PercentileDisc { values, percentile } => {
445                if values.is_empty() {
446                    Value::Null
447                } else {
448                    let mut sorted = values.clone();
449                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
450                    // Index calculation per SQL standard: floor(p * (n - 1))
451                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
452                    Value::Float64(sorted[index])
453                }
454            }
455            // Continuous percentile: interpolate between values
456            AggregateState::PercentileCont { values, percentile } => {
457                if values.is_empty() {
458                    Value::Null
459                } else {
460                    let mut sorted = values.clone();
461                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
462                    // Linear interpolation per SQL standard
463                    let rank = percentile * (sorted.len() - 1) as f64;
464                    let lower_idx = rank.floor() as usize;
465                    let upper_idx = rank.ceil() as usize;
466                    if lower_idx == upper_idx {
467                        Value::Float64(sorted[lower_idx])
468                    } else {
469                        let fraction = rank - lower_idx as f64;
470                        let result =
471                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
472                        Value::Float64(result)
473                    }
474                }
475            }
476            // GROUP_CONCAT: join strings with space separator (SPARQL default)
477            AggregateState::GroupConcat(list, sep)
478            | AggregateState::GroupConcatDistinct(list, sep, _) => {
479                Value::String(list.join(sep).into())
480            }
481            // SAMPLE: return the first non-null value seen
482            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
483            // Binary set functions: dispatch on kind
484            AggregateState::Bivariate {
485                kind,
486                count,
487                mean_x,
488                mean_y,
489                m2_x,
490                m2_y,
491                c_xy,
492            } => {
493                let n = *count;
494                match kind {
495                    AggregateFunction::CovarSamp => {
496                        if n < 2 {
497                            Value::Null
498                        } else {
499                            Value::Float64(*c_xy / (n - 1) as f64)
500                        }
501                    }
502                    AggregateFunction::CovarPop => {
503                        if n == 0 {
504                            Value::Null
505                        } else {
506                            Value::Float64(*c_xy / n as f64)
507                        }
508                    }
509                    AggregateFunction::Corr => {
510                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
511                            Value::Null
512                        } else {
513                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
514                        }
515                    }
516                    AggregateFunction::RegrSlope => {
517                        if n == 0 || *m2_x == 0.0 {
518                            Value::Null
519                        } else {
520                            Value::Float64(*c_xy / *m2_x)
521                        }
522                    }
523                    AggregateFunction::RegrIntercept => {
524                        if n == 0 || *m2_x == 0.0 {
525                            Value::Null
526                        } else {
527                            let slope = *c_xy / *m2_x;
528                            Value::Float64(*mean_y - slope * *mean_x)
529                        }
530                    }
531                    AggregateFunction::RegrR2 => {
532                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
533                            Value::Null
534                        } else {
535                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
536                        }
537                    }
538                    AggregateFunction::RegrCount => Value::Int64(n),
539                    AggregateFunction::RegrSxx => {
540                        if n == 0 {
541                            Value::Null
542                        } else {
543                            Value::Float64(*m2_x)
544                        }
545                    }
546                    AggregateFunction::RegrSyy => {
547                        if n == 0 {
548                            Value::Null
549                        } else {
550                            Value::Float64(*m2_y)
551                        }
552                    }
553                    AggregateFunction::RegrSxy => {
554                        if n == 0 {
555                            Value::Null
556                        } else {
557                            Value::Float64(*c_xy)
558                        }
559                    }
560                    AggregateFunction::RegrAvgx => {
561                        if n == 0 {
562                            Value::Null
563                        } else {
564                            Value::Float64(*mean_x)
565                        }
566                    }
567                    AggregateFunction::RegrAvgy => {
568                        if n == 0 {
569                            Value::Null
570                        } else {
571                            Value::Float64(*mean_y)
572                        }
573                    }
574                    _ => Value::Null, // non-bivariate functions never reach here
575                }
576            }
577        }
578    }
579}
580
581use super::value_utils::{compare_values, value_to_f64};
582
583/// Converts a Value to its string representation for GROUP_CONCAT.
584fn agg_value_to_string(val: &Value) -> String {
585    match val {
586        Value::String(s) => s.to_string(),
587        Value::Int64(i) => i.to_string(),
588        Value::Float64(f) => f.to_string(),
589        Value::Bool(b) => b.to_string(),
590        Value::Null => String::new(),
591        other => format!("{other:?}"),
592    }
593}
594
595/// A group key for hash-based aggregation.
596#[derive(Debug, Clone, PartialEq, Eq, Hash)]
597pub struct GroupKey(Vec<GroupKeyPart>);
598
599#[derive(Debug, Clone, PartialEq, Eq, Hash)]
600enum GroupKeyPart {
601    Null,
602    Bool(bool),
603    Int64(i64),
604    String(String),
605}
606
607impl GroupKey {
608    /// Creates a group key from column values.
609    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
610        let parts: Vec<GroupKeyPart> = group_columns
611            .iter()
612            .map(|&col_idx| {
613                chunk
614                    .column(col_idx)
615                    .and_then(|col| col.get_value(row))
616                    .map_or(GroupKeyPart::Null, |v| match v {
617                        Value::Null => GroupKeyPart::Null,
618                        Value::Bool(b) => GroupKeyPart::Bool(b),
619                        Value::Int64(i) => GroupKeyPart::Int64(i),
620                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
621                        Value::String(s) => GroupKeyPart::String(s.to_string()),
622                        _ => GroupKeyPart::String(format!("{v:?}")),
623                    })
624            })
625            .collect();
626        GroupKey(parts)
627    }
628
629    /// Converts the group key back to values.
630    fn to_values(&self) -> Vec<Value> {
631        self.0
632            .iter()
633            .map(|part| match part {
634                GroupKeyPart::Null => Value::Null,
635                GroupKeyPart::Bool(b) => Value::Bool(*b),
636                GroupKeyPart::Int64(i) => Value::Int64(*i),
637                GroupKeyPart::String(s) => Value::String(s.clone().into()),
638            })
639            .collect()
640    }
641}
642
643/// Hash-based aggregate operator.
644///
645/// Groups input by key columns and computes aggregations for each group.
646pub struct HashAggregateOperator {
647    /// Child operator to read from.
648    child: Box<dyn Operator>,
649    /// Columns to group by.
650    group_columns: Vec<usize>,
651    /// Aggregation expressions.
652    aggregates: Vec<AggregateExpr>,
653    /// Output schema.
654    output_schema: Vec<LogicalType>,
655    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
656    groups: IndexMap<GroupKey, Vec<AggregateState>>,
657    /// Whether aggregation is complete.
658    aggregation_complete: bool,
659    /// Results iterator.
660    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
661}
662
663impl HashAggregateOperator {
664    /// Creates a new hash aggregate operator.
665    ///
666    /// # Arguments
667    /// * `child` - Child operator to read from.
668    /// * `group_columns` - Column indices to group by.
669    /// * `aggregates` - Aggregation expressions.
670    /// * `output_schema` - Schema of the output (group columns + aggregate results).
671    pub fn new(
672        child: Box<dyn Operator>,
673        group_columns: Vec<usize>,
674        aggregates: Vec<AggregateExpr>,
675        output_schema: Vec<LogicalType>,
676    ) -> Self {
677        Self {
678            child,
679            group_columns,
680            aggregates,
681            output_schema,
682            groups: IndexMap::new(),
683            aggregation_complete: false,
684            results: None,
685        }
686    }
687
688    /// Performs the aggregation.
689    fn aggregate(&mut self) -> Result<(), OperatorError> {
690        while let Some(chunk) = self.child.next()? {
691            for row in chunk.selected_indices() {
692                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
693
694                // Get or create aggregate states for this group
695                let states = self.groups.entry(key).or_insert_with(|| {
696                    self.aggregates
697                        .iter()
698                        .map(|agg| {
699                            AggregateState::new(
700                                agg.function,
701                                agg.distinct,
702                                agg.percentile,
703                                agg.separator.as_deref(),
704                            )
705                        })
706                        .collect()
707                });
708
709                // Update each aggregate
710                for (i, agg) in self.aggregates.iter().enumerate() {
711                    // Binary set functions: read two column values
712                    if agg.column2.is_some() {
713                        let y_val = agg
714                            .column
715                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
716                        let x_val = agg
717                            .column2
718                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
719                        states[i].update_bivariate(y_val, x_val);
720                        continue;
721                    }
722
723                    let value = match (agg.function, agg.distinct) {
724                        // COUNT(*) without DISTINCT doesn't need a value
725                        (AggregateFunction::Count, false) => None,
726                        // COUNT DISTINCT needs the actual value to track unique values
727                        (AggregateFunction::Count, true) => agg
728                            .column
729                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
730                        _ => agg
731                            .column
732                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
733                    };
734
735                    // For COUNT without DISTINCT, always update. For others, skip nulls.
736                    match (agg.function, agg.distinct) {
737                        (AggregateFunction::Count, false) => states[i].update(None),
738                        (AggregateFunction::Count, true) => {
739                            // COUNT DISTINCT needs the value to track unique values
740                            if value.is_some() && !matches!(value, Some(Value::Null)) {
741                                states[i].update(value);
742                            }
743                        }
744                        (AggregateFunction::CountNonNull, _) => {
745                            if value.is_some() && !matches!(value, Some(Value::Null)) {
746                                states[i].update(value);
747                            }
748                        }
749                        _ => {
750                            if value.is_some() && !matches!(value, Some(Value::Null)) {
751                                states[i].update(value);
752                            }
753                        }
754                    }
755                }
756            }
757        }
758
759        self.aggregation_complete = true;
760
761        // Convert to results iterator (IndexMap::drain takes a range)
762        let results: Vec<_> = self.groups.drain(..).collect();
763        self.results = Some(results.into_iter());
764
765        Ok(())
766    }
767}
768
769impl Operator for HashAggregateOperator {
770    fn next(&mut self) -> OperatorResult {
771        // Perform aggregation if not done
772        if !self.aggregation_complete {
773            self.aggregate()?;
774        }
775
776        // Special case: no groups (global aggregation with no data)
777        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
778            // For global aggregation (no GROUP BY), return one row with initial values
779            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
780
781            for agg in &self.aggregates {
782                let state = AggregateState::new(
783                    agg.function,
784                    agg.distinct,
785                    agg.percentile,
786                    agg.separator.as_deref(),
787                );
788                let value = state.finalize();
789                if let Some(col) = builder.column_mut(self.group_columns.len()) {
790                    col.push_value(value);
791                }
792            }
793            builder.advance_row();
794
795            self.results = Some(Vec::new().into_iter()); // Mark as done
796            return Ok(Some(builder.finish()));
797        }
798
799        let Some(results) = &mut self.results else {
800            return Ok(None);
801        };
802
803        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
804
805        for (key, states) in results.by_ref() {
806            // Output group key columns
807            let key_values = key.to_values();
808            for (i, value) in key_values.into_iter().enumerate() {
809                if let Some(col) = builder.column_mut(i) {
810                    col.push_value(value);
811                }
812            }
813
814            // Output aggregate results
815            for (i, state) in states.iter().enumerate() {
816                let col_idx = self.group_columns.len() + i;
817                if let Some(col) = builder.column_mut(col_idx) {
818                    col.push_value(state.finalize());
819                }
820            }
821
822            builder.advance_row();
823
824            if builder.is_full() {
825                return Ok(Some(builder.finish()));
826            }
827        }
828
829        if builder.row_count() > 0 {
830            Ok(Some(builder.finish()))
831        } else {
832            Ok(None)
833        }
834    }
835
836    fn reset(&mut self) {
837        self.child.reset();
838        self.groups.clear();
839        self.aggregation_complete = false;
840        self.results = None;
841    }
842
843    fn name(&self) -> &'static str {
844        "HashAggregate"
845    }
846}
847
848/// Simple (non-grouping) aggregate operator for global aggregations.
849///
850/// Used when there's no GROUP BY clause - aggregates all input into a single row.
851pub struct SimpleAggregateOperator {
852    /// Child operator.
853    child: Box<dyn Operator>,
854    /// Aggregation expressions.
855    aggregates: Vec<AggregateExpr>,
856    /// Output schema.
857    output_schema: Vec<LogicalType>,
858    /// Aggregate states.
859    states: Vec<AggregateState>,
860    /// Whether aggregation is complete.
861    done: bool,
862}
863
864impl SimpleAggregateOperator {
865    /// Creates a new simple aggregate operator.
866    pub fn new(
867        child: Box<dyn Operator>,
868        aggregates: Vec<AggregateExpr>,
869        output_schema: Vec<LogicalType>,
870    ) -> Self {
871        let states = aggregates
872            .iter()
873            .map(|agg| {
874                AggregateState::new(
875                    agg.function,
876                    agg.distinct,
877                    agg.percentile,
878                    agg.separator.as_deref(),
879                )
880            })
881            .collect();
882
883        Self {
884            child,
885            aggregates,
886            output_schema,
887            states,
888            done: false,
889        }
890    }
891}
892
893impl Operator for SimpleAggregateOperator {
894    fn next(&mut self) -> OperatorResult {
895        if self.done {
896            return Ok(None);
897        }
898
899        // Process all input
900        while let Some(chunk) = self.child.next()? {
901            for row in chunk.selected_indices() {
902                for (i, agg) in self.aggregates.iter().enumerate() {
903                    // Binary set functions: read two column values
904                    if agg.column2.is_some() {
905                        let y_val = agg
906                            .column
907                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
908                        let x_val = agg
909                            .column2
910                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
911                        self.states[i].update_bivariate(y_val, x_val);
912                        continue;
913                    }
914
915                    let value = match (agg.function, agg.distinct) {
916                        // COUNT(*) without DISTINCT doesn't need a value
917                        (AggregateFunction::Count, false) => None,
918                        // COUNT DISTINCT needs the actual value to track unique values
919                        (AggregateFunction::Count, true) => agg
920                            .column
921                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
922                        _ => agg
923                            .column
924                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
925                    };
926
927                    match (agg.function, agg.distinct) {
928                        (AggregateFunction::Count, false) => self.states[i].update(None),
929                        (AggregateFunction::Count, true) => {
930                            // COUNT DISTINCT needs the value to track unique values
931                            if value.is_some() && !matches!(value, Some(Value::Null)) {
932                                self.states[i].update(value);
933                            }
934                        }
935                        (AggregateFunction::CountNonNull, _) => {
936                            if value.is_some() && !matches!(value, Some(Value::Null)) {
937                                self.states[i].update(value);
938                            }
939                        }
940                        _ => {
941                            if value.is_some() && !matches!(value, Some(Value::Null)) {
942                                self.states[i].update(value);
943                            }
944                        }
945                    }
946                }
947            }
948        }
949
950        // Output single result row
951        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
952
953        for (i, state) in self.states.iter().enumerate() {
954            if let Some(col) = builder.column_mut(i) {
955                col.push_value(state.finalize());
956            }
957        }
958        builder.advance_row();
959
960        self.done = true;
961        Ok(Some(builder.finish()))
962    }
963
964    fn reset(&mut self) {
965        self.child.reset();
966        self.states = self
967            .aggregates
968            .iter()
969            .map(|agg| {
970                AggregateState::new(
971                    agg.function,
972                    agg.distinct,
973                    agg.percentile,
974                    agg.separator.as_deref(),
975                )
976            })
977            .collect();
978        self.done = false;
979    }
980
981    fn name(&self) -> &'static str {
982        "SimpleAggregate"
983    }
984}
985
986#[cfg(test)]
987mod tests {
988    use super::*;
989    use crate::execution::chunk::DataChunkBuilder;
990
991    struct MockOperator {
992        chunks: Vec<DataChunk>,
993        position: usize,
994    }
995
996    impl MockOperator {
997        fn new(chunks: Vec<DataChunk>) -> Self {
998            Self {
999                chunks,
1000                position: 0,
1001            }
1002        }
1003    }
1004
1005    impl Operator for MockOperator {
1006        fn next(&mut self) -> OperatorResult {
1007            if self.position < self.chunks.len() {
1008                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1009                self.position += 1;
1010                Ok(Some(chunk))
1011            } else {
1012                Ok(None)
1013            }
1014        }
1015
1016        fn reset(&mut self) {
1017            self.position = 0;
1018        }
1019
1020        fn name(&self) -> &'static str {
1021            "Mock"
1022        }
1023    }
1024
1025    fn create_test_chunk() -> DataChunk {
1026        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1027        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1028
1029        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1030        for (group, value) in data {
1031            builder.column_mut(0).unwrap().push_int64(group);
1032            builder.column_mut(1).unwrap().push_int64(value);
1033            builder.advance_row();
1034        }
1035
1036        builder.finish()
1037    }
1038
1039    #[test]
1040    fn test_simple_count() {
1041        let mock = MockOperator::new(vec![create_test_chunk()]);
1042
1043        let mut agg = SimpleAggregateOperator::new(
1044            Box::new(mock),
1045            vec![AggregateExpr::count_star()],
1046            vec![LogicalType::Int64],
1047        );
1048
1049        let result = agg.next().unwrap().unwrap();
1050        assert_eq!(result.row_count(), 1);
1051        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1052
1053        // Should be done
1054        assert!(agg.next().unwrap().is_none());
1055    }
1056
1057    #[test]
1058    fn test_simple_sum() {
1059        let mock = MockOperator::new(vec![create_test_chunk()]);
1060
1061        let mut agg = SimpleAggregateOperator::new(
1062            Box::new(mock),
1063            vec![AggregateExpr::sum(1)], // Sum of column 1
1064            vec![LogicalType::Int64],
1065        );
1066
1067        let result = agg.next().unwrap().unwrap();
1068        assert_eq!(result.row_count(), 1);
1069        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1070        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1071    }
1072
1073    #[test]
1074    fn test_simple_avg() {
1075        let mock = MockOperator::new(vec![create_test_chunk()]);
1076
1077        let mut agg = SimpleAggregateOperator::new(
1078            Box::new(mock),
1079            vec![AggregateExpr::avg(1)],
1080            vec![LogicalType::Float64],
1081        );
1082
1083        let result = agg.next().unwrap().unwrap();
1084        assert_eq!(result.row_count(), 1);
1085        // Avg: 150 / 5 = 30.0
1086        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1087        assert!((avg - 30.0).abs() < 0.001);
1088    }
1089
1090    #[test]
1091    fn test_simple_min_max() {
1092        let mock = MockOperator::new(vec![create_test_chunk()]);
1093
1094        let mut agg = SimpleAggregateOperator::new(
1095            Box::new(mock),
1096            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1097            vec![LogicalType::Int64, LogicalType::Int64],
1098        );
1099
1100        let result = agg.next().unwrap().unwrap();
1101        assert_eq!(result.row_count(), 1);
1102        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1103        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1104    }
1105
1106    #[test]
1107    fn test_sum_with_string_values() {
1108        // Test SUM with string values (like RDF stores numeric literals)
1109        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1110        builder.column_mut(0).unwrap().push_string("30");
1111        builder.advance_row();
1112        builder.column_mut(0).unwrap().push_string("25");
1113        builder.advance_row();
1114        builder.column_mut(0).unwrap().push_string("35");
1115        builder.advance_row();
1116        let chunk = builder.finish();
1117
1118        let mock = MockOperator::new(vec![chunk]);
1119        let mut agg = SimpleAggregateOperator::new(
1120            Box::new(mock),
1121            vec![AggregateExpr::sum(0)],
1122            vec![LogicalType::Float64],
1123        );
1124
1125        let result = agg.next().unwrap().unwrap();
1126        assert_eq!(result.row_count(), 1);
1127        // Should parse strings and sum: 30 + 25 + 35 = 90
1128        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1129        assert!(
1130            (sum_val - 90.0).abs() < 0.001,
1131            "Expected 90.0, got {}",
1132            sum_val
1133        );
1134    }
1135
1136    #[test]
1137    fn test_grouped_aggregation() {
1138        let mock = MockOperator::new(vec![create_test_chunk()]);
1139
1140        // GROUP BY column 0, SUM(column 1)
1141        let mut agg = HashAggregateOperator::new(
1142            Box::new(mock),
1143            vec![0],                     // Group by column 0
1144            vec![AggregateExpr::sum(1)], // Sum of column 1
1145            vec![LogicalType::Int64, LogicalType::Int64],
1146        );
1147
1148        let mut results: Vec<(i64, i64)> = Vec::new();
1149        while let Some(chunk) = agg.next().unwrap() {
1150            for row in chunk.selected_indices() {
1151                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1152                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1153                results.push((group, sum));
1154            }
1155        }
1156
1157        results.sort_by_key(|(g, _)| *g);
1158        assert_eq!(results.len(), 2);
1159        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1160        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1161    }
1162
1163    #[test]
1164    fn test_grouped_count() {
1165        let mock = MockOperator::new(vec![create_test_chunk()]);
1166
1167        // GROUP BY column 0, COUNT(*)
1168        let mut agg = HashAggregateOperator::new(
1169            Box::new(mock),
1170            vec![0],
1171            vec![AggregateExpr::count_star()],
1172            vec![LogicalType::Int64, LogicalType::Int64],
1173        );
1174
1175        let mut results: Vec<(i64, i64)> = Vec::new();
1176        while let Some(chunk) = agg.next().unwrap() {
1177            for row in chunk.selected_indices() {
1178                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1179                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1180                results.push((group, count));
1181            }
1182        }
1183
1184        results.sort_by_key(|(g, _)| *g);
1185        assert_eq!(results.len(), 2);
1186        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1187        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1188    }
1189
1190    #[test]
1191    fn test_multiple_aggregates() {
1192        let mock = MockOperator::new(vec![create_test_chunk()]);
1193
1194        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1195        let mut agg = HashAggregateOperator::new(
1196            Box::new(mock),
1197            vec![0],
1198            vec![
1199                AggregateExpr::count_star(),
1200                AggregateExpr::sum(1),
1201                AggregateExpr::avg(1),
1202            ],
1203            vec![
1204                LogicalType::Int64,   // Group key
1205                LogicalType::Int64,   // COUNT
1206                LogicalType::Int64,   // SUM
1207                LogicalType::Float64, // AVG
1208            ],
1209        );
1210
1211        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1212        while let Some(chunk) = agg.next().unwrap() {
1213            for row in chunk.selected_indices() {
1214                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1215                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1216                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1217                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1218                results.push((group, count, sum, avg));
1219            }
1220        }
1221
1222        results.sort_by_key(|(g, _, _, _)| *g);
1223        assert_eq!(results.len(), 2);
1224
1225        // Group 1: COUNT=2, SUM=30, AVG=15.0
1226        assert_eq!(results[0].0, 1);
1227        assert_eq!(results[0].1, 2);
1228        assert_eq!(results[0].2, 30);
1229        assert!((results[0].3 - 15.0).abs() < 0.001);
1230
1231        // Group 2: COUNT=3, SUM=120, AVG=40.0
1232        assert_eq!(results[1].0, 2);
1233        assert_eq!(results[1].1, 3);
1234        assert_eq!(results[1].2, 120);
1235        assert!((results[1].3 - 40.0).abs() < 0.001);
1236    }
1237
1238    fn create_test_chunk_with_duplicates() -> DataChunk {
1239        // Create data with duplicate values in column 1
1240        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1241        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1242        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1243        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1244
1245        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1246        for (group, value) in data {
1247            builder.column_mut(0).unwrap().push_int64(group);
1248            builder.column_mut(1).unwrap().push_int64(value);
1249            builder.advance_row();
1250        }
1251
1252        builder.finish()
1253    }
1254
1255    #[test]
1256    fn test_count_distinct() {
1257        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1258
1259        // COUNT(DISTINCT column 1)
1260        let mut agg = SimpleAggregateOperator::new(
1261            Box::new(mock),
1262            vec![AggregateExpr::count(1).with_distinct()],
1263            vec![LogicalType::Int64],
1264        );
1265
1266        let result = agg.next().unwrap().unwrap();
1267        assert_eq!(result.row_count(), 1);
1268        // Total distinct values: 10, 20, 30 = 3 distinct values
1269        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1270    }
1271
1272    #[test]
1273    fn test_grouped_count_distinct() {
1274        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1275
1276        // GROUP BY column 0, COUNT(DISTINCT column 1)
1277        let mut agg = HashAggregateOperator::new(
1278            Box::new(mock),
1279            vec![0],
1280            vec![AggregateExpr::count(1).with_distinct()],
1281            vec![LogicalType::Int64, LogicalType::Int64],
1282        );
1283
1284        let mut results: Vec<(i64, i64)> = Vec::new();
1285        while let Some(chunk) = agg.next().unwrap() {
1286            for row in chunk.selected_indices() {
1287                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1288                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1289                results.push((group, count));
1290            }
1291        }
1292
1293        results.sort_by_key(|(g, _)| *g);
1294        assert_eq!(results.len(), 2);
1295        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1296        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1297    }
1298
1299    #[test]
1300    fn test_sum_distinct() {
1301        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1302
1303        // SUM(DISTINCT column 1)
1304        let mut agg = SimpleAggregateOperator::new(
1305            Box::new(mock),
1306            vec![AggregateExpr::sum(1).with_distinct()],
1307            vec![LogicalType::Int64],
1308        );
1309
1310        let result = agg.next().unwrap().unwrap();
1311        assert_eq!(result.row_count(), 1);
1312        // Sum of distinct values: 10 + 20 + 30 = 60
1313        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1314    }
1315
1316    #[test]
1317    fn test_avg_distinct() {
1318        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1319
1320        // AVG(DISTINCT column 1)
1321        let mut agg = SimpleAggregateOperator::new(
1322            Box::new(mock),
1323            vec![AggregateExpr::avg(1).with_distinct()],
1324            vec![LogicalType::Float64],
1325        );
1326
1327        let result = agg.next().unwrap().unwrap();
1328        assert_eq!(result.row_count(), 1);
1329        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1330        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1331        assert!((avg - 20.0).abs() < 0.001);
1332    }
1333
1334    fn create_statistical_test_chunk() -> DataChunk {
1335        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1336        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1337        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1338
1339        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1340            builder.column_mut(0).unwrap().push_int64(value);
1341            builder.advance_row();
1342        }
1343
1344        builder.finish()
1345    }
1346
1347    #[test]
1348    fn test_stdev_sample() {
1349        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1350
1351        let mut agg = SimpleAggregateOperator::new(
1352            Box::new(mock),
1353            vec![AggregateExpr::stdev(0)],
1354            vec![LogicalType::Float64],
1355        );
1356
1357        let result = agg.next().unwrap().unwrap();
1358        assert_eq!(result.row_count(), 1);
1359        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1360        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1361        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1362        assert!((stdev - 2.138).abs() < 0.01);
1363    }
1364
1365    #[test]
1366    fn test_stdev_population() {
1367        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1368
1369        let mut agg = SimpleAggregateOperator::new(
1370            Box::new(mock),
1371            vec![AggregateExpr::stdev_pop(0)],
1372            vec![LogicalType::Float64],
1373        );
1374
1375        let result = agg.next().unwrap().unwrap();
1376        assert_eq!(result.row_count(), 1);
1377        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1378        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1379        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1380        assert!((stdev - 2.0).abs() < 0.01);
1381    }
1382
1383    #[test]
1384    fn test_percentile_disc() {
1385        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1386
1387        // Median (50th percentile discrete)
1388        let mut agg = SimpleAggregateOperator::new(
1389            Box::new(mock),
1390            vec![AggregateExpr::percentile_disc(0, 0.5)],
1391            vec![LogicalType::Float64],
1392        );
1393
1394        let result = agg.next().unwrap().unwrap();
1395        assert_eq!(result.row_count(), 1);
1396        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1397        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1398        assert!((percentile - 4.0).abs() < 0.01);
1399    }
1400
1401    #[test]
1402    fn test_percentile_cont() {
1403        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1404
1405        // Median (50th percentile continuous)
1406        let mut agg = SimpleAggregateOperator::new(
1407            Box::new(mock),
1408            vec![AggregateExpr::percentile_cont(0, 0.5)],
1409            vec![LogicalType::Float64],
1410        );
1411
1412        let result = agg.next().unwrap().unwrap();
1413        assert_eq!(result.row_count(), 1);
1414        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1415        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1416        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1417        assert!((percentile - 4.5).abs() < 0.01);
1418    }
1419
1420    #[test]
1421    fn test_percentile_extremes() {
1422        // Test 0th and 100th percentiles
1423        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1424
1425        let mut agg = SimpleAggregateOperator::new(
1426            Box::new(mock),
1427            vec![
1428                AggregateExpr::percentile_disc(0, 0.0),
1429                AggregateExpr::percentile_disc(0, 1.0),
1430            ],
1431            vec![LogicalType::Float64, LogicalType::Float64],
1432        );
1433
1434        let result = agg.next().unwrap().unwrap();
1435        assert_eq!(result.row_count(), 1);
1436        // 0th percentile = minimum = 2
1437        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1438        assert!((p0 - 2.0).abs() < 0.01);
1439        // 100th percentile = maximum = 9
1440        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1441        assert!((p100 - 9.0).abs() < 0.01);
1442    }
1443
1444    #[test]
1445    fn test_stdev_single_value() {
1446        // Single value should return null for sample stdev
1447        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1448        builder.column_mut(0).unwrap().push_int64(42);
1449        builder.advance_row();
1450        let chunk = builder.finish();
1451
1452        let mock = MockOperator::new(vec![chunk]);
1453
1454        let mut agg = SimpleAggregateOperator::new(
1455            Box::new(mock),
1456            vec![AggregateExpr::stdev(0)],
1457            vec![LogicalType::Float64],
1458        );
1459
1460        let result = agg.next().unwrap().unwrap();
1461        assert_eq!(result.row_count(), 1);
1462        // Sample stdev of single value is undefined (null)
1463        assert!(matches!(
1464            result.column(0).unwrap().get_value(0),
1465            Some(Value::Null)
1466        ));
1467    }
1468
1469    #[test]
1470    fn test_first_and_last() {
1471        let mock = MockOperator::new(vec![create_test_chunk()]);
1472
1473        let mut agg = SimpleAggregateOperator::new(
1474            Box::new(mock),
1475            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1476            vec![LogicalType::Int64, LogicalType::Int64],
1477        );
1478
1479        let result = agg.next().unwrap().unwrap();
1480        assert_eq!(result.row_count(), 1);
1481        // First: 10, Last: 50
1482        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1483        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1484    }
1485
1486    #[test]
1487    fn test_collect() {
1488        let mock = MockOperator::new(vec![create_test_chunk()]);
1489
1490        let mut agg = SimpleAggregateOperator::new(
1491            Box::new(mock),
1492            vec![AggregateExpr::collect(1)],
1493            vec![LogicalType::Any],
1494        );
1495
1496        let result = agg.next().unwrap().unwrap();
1497        let val = result.column(0).unwrap().get_value(0).unwrap();
1498        if let Value::List(items) = val {
1499            assert_eq!(items.len(), 5);
1500        } else {
1501            panic!("Expected List value");
1502        }
1503    }
1504
1505    #[test]
1506    fn test_collect_distinct() {
1507        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1508
1509        let mut agg = SimpleAggregateOperator::new(
1510            Box::new(mock),
1511            vec![AggregateExpr::collect(1).with_distinct()],
1512            vec![LogicalType::Any],
1513        );
1514
1515        let result = agg.next().unwrap().unwrap();
1516        let val = result.column(0).unwrap().get_value(0).unwrap();
1517        if let Value::List(items) = val {
1518            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1519            assert_eq!(items.len(), 3);
1520        } else {
1521            panic!("Expected List value");
1522        }
1523    }
1524
1525    #[test]
1526    fn test_group_concat() {
1527        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1528        for s in ["hello", "world", "foo"] {
1529            builder.column_mut(0).unwrap().push_string(s);
1530            builder.advance_row();
1531        }
1532        let chunk = builder.finish();
1533        let mock = MockOperator::new(vec![chunk]);
1534
1535        let agg_expr = AggregateExpr {
1536            function: AggregateFunction::GroupConcat,
1537            column: Some(0),
1538            column2: None,
1539            distinct: false,
1540            alias: None,
1541            percentile: None,
1542            separator: None,
1543        };
1544
1545        let mut agg =
1546            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1547
1548        let result = agg.next().unwrap().unwrap();
1549        let val = result.column(0).unwrap().get_value(0).unwrap();
1550        assert_eq!(val, Value::String("hello world foo".into()));
1551    }
1552
1553    #[test]
1554    fn test_sample() {
1555        let mock = MockOperator::new(vec![create_test_chunk()]);
1556
1557        let agg_expr = AggregateExpr {
1558            function: AggregateFunction::Sample,
1559            column: Some(1),
1560            column2: None,
1561            distinct: false,
1562            alias: None,
1563            percentile: None,
1564            separator: None,
1565        };
1566
1567        let mut agg =
1568            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1569
1570        let result = agg.next().unwrap().unwrap();
1571        // Sample should return the first non-null value (10)
1572        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1573    }
1574
1575    #[test]
1576    fn test_variance_sample() {
1577        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1578
1579        let agg_expr = AggregateExpr {
1580            function: AggregateFunction::Variance,
1581            column: Some(0),
1582            column2: None,
1583            distinct: false,
1584            alias: None,
1585            percentile: None,
1586            separator: None,
1587        };
1588
1589        let mut agg = SimpleAggregateOperator::new(
1590            Box::new(mock),
1591            vec![agg_expr],
1592            vec![LogicalType::Float64],
1593        );
1594
1595        let result = agg.next().unwrap().unwrap();
1596        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1597        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1598        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1599    }
1600
1601    #[test]
1602    fn test_variance_population() {
1603        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1604
1605        let agg_expr = AggregateExpr {
1606            function: AggregateFunction::VariancePop,
1607            column: Some(0),
1608            column2: None,
1609            distinct: false,
1610            alias: None,
1611            percentile: None,
1612            separator: None,
1613        };
1614
1615        let mut agg = SimpleAggregateOperator::new(
1616            Box::new(mock),
1617            vec![agg_expr],
1618            vec![LogicalType::Float64],
1619        );
1620
1621        let result = agg.next().unwrap().unwrap();
1622        // Population variance: M2/n = 32/8 = 4.0
1623        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1624        assert!((variance - 4.0).abs() < 0.01);
1625    }
1626
1627    #[test]
1628    fn test_variance_single_value() {
1629        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1630        builder.column_mut(0).unwrap().push_int64(42);
1631        builder.advance_row();
1632        let chunk = builder.finish();
1633        let mock = MockOperator::new(vec![chunk]);
1634
1635        let agg_expr = AggregateExpr {
1636            function: AggregateFunction::Variance,
1637            column: Some(0),
1638            column2: None,
1639            distinct: false,
1640            alias: None,
1641            percentile: None,
1642            separator: None,
1643        };
1644
1645        let mut agg = SimpleAggregateOperator::new(
1646            Box::new(mock),
1647            vec![agg_expr],
1648            vec![LogicalType::Float64],
1649        );
1650
1651        let result = agg.next().unwrap().unwrap();
1652        // Sample variance of single value is undefined (null)
1653        assert!(matches!(
1654            result.column(0).unwrap().get_value(0),
1655            Some(Value::Null)
1656        ));
1657    }
1658
1659    #[test]
1660    fn test_empty_aggregation() {
1661        // No input rows: COUNT should be 0, SUM 0, AVG null, MIN/MAX null
1662        let mock = MockOperator::new(vec![]);
1663
1664        let mut agg = SimpleAggregateOperator::new(
1665            Box::new(mock),
1666            vec![
1667                AggregateExpr::count_star(),
1668                AggregateExpr::sum(0),
1669                AggregateExpr::avg(0),
1670                AggregateExpr::min(0),
1671                AggregateExpr::max(0),
1672            ],
1673            vec![
1674                LogicalType::Int64,
1675                LogicalType::Int64,
1676                LogicalType::Float64,
1677                LogicalType::Int64,
1678                LogicalType::Int64,
1679            ],
1680        );
1681
1682        let result = agg.next().unwrap().unwrap();
1683        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1684        assert_eq!(result.column(1).unwrap().get_int64(0), Some(0)); // SUM
1685        assert!(matches!(
1686            result.column(2).unwrap().get_value(0),
1687            Some(Value::Null)
1688        )); // AVG
1689        assert!(matches!(
1690            result.column(3).unwrap().get_value(0),
1691            Some(Value::Null)
1692        )); // MIN
1693        assert!(matches!(
1694            result.column(4).unwrap().get_value(0),
1695            Some(Value::Null)
1696        )); // MAX
1697    }
1698
1699    #[test]
1700    fn test_stdev_pop_single_value() {
1701        // Single value should return 0 for population stdev
1702        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1703        builder.column_mut(0).unwrap().push_int64(42);
1704        builder.advance_row();
1705        let chunk = builder.finish();
1706
1707        let mock = MockOperator::new(vec![chunk]);
1708
1709        let mut agg = SimpleAggregateOperator::new(
1710            Box::new(mock),
1711            vec![AggregateExpr::stdev_pop(0)],
1712            vec![LogicalType::Float64],
1713        );
1714
1715        let result = agg.next().unwrap().unwrap();
1716        assert_eq!(result.row_count(), 1);
1717        // Population stdev of single value is 0
1718        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1719        assert!((stdev - 0.0).abs() < 0.01);
1720    }
1721}