Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20/// State for a single aggregation computation.
21#[derive(Debug, Clone)]
22pub(crate) enum AggregateState {
23    /// Count state.
24    Count(i64),
25    /// Count distinct state (count, seen values).
26    CountDistinct(i64, HashSet<HashableValue>),
27    /// Sum state (integer sum, count of values added).
28    SumInt(i64, i64),
29    /// Sum distinct state (integer sum, count, seen values).
30    SumIntDistinct(i64, i64, HashSet<HashableValue>),
31    /// Sum state (float sum, count of values added).
32    SumFloat(f64, i64),
33    /// Sum distinct state (float sum, count, seen values).
34    SumFloatDistinct(f64, i64, HashSet<HashableValue>),
35    /// Average state (sum, count).
36    Avg(f64, i64),
37    /// Average distinct state (sum, count, seen values).
38    AvgDistinct(f64, i64, HashSet<HashableValue>),
39    /// Min state.
40    Min(Option<Value>),
41    /// Max state.
42    Max(Option<Value>),
43    /// First state.
44    First(Option<Value>),
45    /// Last state.
46    Last(Option<Value>),
47    /// Collect state.
48    Collect(Vec<Value>),
49    /// Collect distinct state (values, seen).
50    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
52    StdDev { count: i64, mean: f64, m2: f64 },
53    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDevPop { count: i64, mean: f64, m2: f64 },
55    /// Discrete percentile state (values, percentile).
56    PercentileDisc { values: Vec<f64>, percentile: f64 },
57    /// Continuous percentile state (values, percentile).
58    PercentileCont { values: Vec<f64>, percentile: f64 },
59    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
60    GroupConcat(Vec<String>, String),
61    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
62    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
63    /// SAMPLE state (first non-null value encountered).
64    Sample(Option<Value>),
65    /// Sample variance state using Welford's algorithm (count, mean, M2).
66    Variance { count: i64, mean: f64, m2: f64 },
67    /// Population variance state using Welford's algorithm (count, mean, M2).
68    VariancePop { count: i64, mean: f64, m2: f64 },
69    /// Two-variable online statistics (Welford generalization for covariance/regression).
70    Bivariate {
71        /// Which binary set function this state will finalize to.
72        kind: AggregateFunction,
73        count: i64,
74        mean_x: f64,
75        mean_y: f64,
76        m2_x: f64,
77        m2_y: f64,
78        c_xy: f64,
79    },
80}
81
82impl AggregateState {
83    /// Creates initial state for an aggregation function.
84    pub(crate) fn new(
85        function: AggregateFunction,
86        distinct: bool,
87        percentile: Option<f64>,
88        separator: Option<&str>,
89    ) -> Self {
90        match (function, distinct) {
91            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
92                AggregateState::Count(0)
93            }
94            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
95                AggregateState::CountDistinct(0, HashSet::new())
96            }
97            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
98            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
99            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
100            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
101            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
102            (AggregateFunction::Max, _) => AggregateState::Max(None),
103            (AggregateFunction::First, _) => AggregateState::First(None),
104            (AggregateFunction::Last, _) => AggregateState::Last(None),
105            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
106            (AggregateFunction::Collect, true) => {
107                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
108            }
109            // Statistical functions (Welford's algorithm for online computation)
110            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
111                count: 0,
112                mean: 0.0,
113                m2: 0.0,
114            },
115            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
116                count: 0,
117                mean: 0.0,
118                m2: 0.0,
119            },
120            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
121                values: Vec::new(),
122                percentile: percentile.unwrap_or(0.5),
123            },
124            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
125                values: Vec::new(),
126                percentile: percentile.unwrap_or(0.5),
127            },
128            (AggregateFunction::GroupConcat, false) => {
129                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
130            }
131            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
132                Vec::new(),
133                separator.unwrap_or(" ").to_string(),
134                HashSet::new(),
135            ),
136            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
137            // Binary set functions (all share the same Bivariate state)
138            (
139                AggregateFunction::CovarSamp
140                | AggregateFunction::CovarPop
141                | AggregateFunction::Corr
142                | AggregateFunction::RegrSlope
143                | AggregateFunction::RegrIntercept
144                | AggregateFunction::RegrR2
145                | AggregateFunction::RegrCount
146                | AggregateFunction::RegrSxx
147                | AggregateFunction::RegrSyy
148                | AggregateFunction::RegrSxy
149                | AggregateFunction::RegrAvgx
150                | AggregateFunction::RegrAvgy,
151                _,
152            ) => AggregateState::Bivariate {
153                kind: function,
154                count: 0,
155                mean_x: 0.0,
156                mean_y: 0.0,
157                m2_x: 0.0,
158                m2_y: 0.0,
159                c_xy: 0.0,
160            },
161            (AggregateFunction::Variance, _) => AggregateState::Variance {
162                count: 0,
163                mean: 0.0,
164                m2: 0.0,
165            },
166            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
167                count: 0,
168                mean: 0.0,
169                m2: 0.0,
170            },
171        }
172    }
173
174    /// Updates the state with a new value.
175    pub(crate) fn update(&mut self, value: Option<Value>) {
176        match self {
177            AggregateState::Count(count) => {
178                *count += 1;
179            }
180            AggregateState::CountDistinct(count, seen) => {
181                if let Some(ref v) = value {
182                    let hashable = HashableValue::from(v);
183                    if seen.insert(hashable) {
184                        *count += 1;
185                    }
186                }
187            }
188            AggregateState::SumInt(sum, count) => {
189                if let Some(Value::Int64(v)) = value {
190                    *sum += v;
191                    *count += 1;
192                } else if let Some(Value::Float64(v)) = value {
193                    // Convert to float sum, carrying count forward
194                    *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
195                } else if let Some(ref v) = value {
196                    // RDF stores numeric literals as strings - try to parse
197                    if let Some(num) = value_to_f64(v) {
198                        *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
199                    }
200                }
201            }
202            AggregateState::SumIntDistinct(sum, count, seen) => {
203                if let Some(ref v) = value {
204                    let hashable = HashableValue::from(v);
205                    if seen.insert(hashable) {
206                        if let Value::Int64(i) = v {
207                            *sum += i;
208                            *count += 1;
209                        } else if let Value::Float64(f) = v {
210                            // Convert to float distinct: move the seen set instead of cloning
211                            let moved_seen = std::mem::take(seen);
212                            *self = AggregateState::SumFloatDistinct(
213                                *sum as f64 + f,
214                                *count + 1,
215                                moved_seen,
216                            );
217                        } else if let Some(num) = value_to_f64(v) {
218                            // RDF string-encoded numerics
219                            let moved_seen = std::mem::take(seen);
220                            *self = AggregateState::SumFloatDistinct(
221                                *sum as f64 + num,
222                                *count + 1,
223                                moved_seen,
224                            );
225                        }
226                    }
227                }
228            }
229            AggregateState::SumFloat(sum, count) => {
230                if let Some(ref v) = value {
231                    // Use value_to_f64 which now handles strings
232                    if let Some(num) = value_to_f64(v) {
233                        *sum += num;
234                        *count += 1;
235                    }
236                }
237            }
238            AggregateState::SumFloatDistinct(sum, count, seen) => {
239                if let Some(ref v) = value {
240                    let hashable = HashableValue::from(v);
241                    if seen.insert(hashable)
242                        && let Some(num) = value_to_f64(v)
243                    {
244                        *sum += num;
245                        *count += 1;
246                    }
247                }
248            }
249            AggregateState::Avg(sum, count) => {
250                if let Some(ref v) = value
251                    && let Some(num) = value_to_f64(v)
252                {
253                    *sum += num;
254                    *count += 1;
255                }
256            }
257            AggregateState::AvgDistinct(sum, count, seen) => {
258                if let Some(ref v) = value {
259                    let hashable = HashableValue::from(v);
260                    if seen.insert(hashable)
261                        && let Some(num) = value_to_f64(v)
262                    {
263                        *sum += num;
264                        *count += 1;
265                    }
266                }
267            }
268            AggregateState::Min(min) => {
269                if let Some(v) = value {
270                    match min {
271                        None => *min = Some(v),
272                        Some(current) => {
273                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
274                                *min = Some(v);
275                            }
276                        }
277                    }
278                }
279            }
280            AggregateState::Max(max) => {
281                if let Some(v) = value {
282                    match max {
283                        None => *max = Some(v),
284                        Some(current) => {
285                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
286                                *max = Some(v);
287                            }
288                        }
289                    }
290                }
291            }
292            AggregateState::First(first) => {
293                if first.is_none() {
294                    *first = value;
295                }
296            }
297            AggregateState::Last(last) => {
298                if value.is_some() {
299                    *last = value;
300                }
301            }
302            AggregateState::Collect(list) => {
303                if let Some(v) = value {
304                    list.push(v);
305                }
306            }
307            AggregateState::CollectDistinct(list, seen) => {
308                if let Some(v) = value {
309                    let hashable = HashableValue::from(&v);
310                    if seen.insert(hashable) {
311                        list.push(v);
312                    }
313                }
314            }
315            // Statistical functions using Welford's online algorithm
316            AggregateState::StdDev { count, mean, m2 }
317            | AggregateState::StdDevPop { count, mean, m2 }
318            | AggregateState::Variance { count, mean, m2 }
319            | AggregateState::VariancePop { count, mean, m2 } => {
320                if let Some(ref v) = value
321                    && let Some(x) = value_to_f64(v)
322                {
323                    *count += 1;
324                    let delta = x - *mean;
325                    *mean += delta / *count as f64;
326                    let delta2 = x - *mean;
327                    *m2 += delta * delta2;
328                }
329            }
330            AggregateState::PercentileDisc { values, .. }
331            | AggregateState::PercentileCont { values, .. } => {
332                if let Some(ref v) = value
333                    && let Some(x) = value_to_f64(v)
334                {
335                    values.push(x);
336                }
337            }
338            AggregateState::GroupConcat(list, _sep) => {
339                if let Some(v) = value {
340                    list.push(agg_value_to_string(&v));
341                }
342            }
343            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
344                if let Some(v) = value {
345                    let hashable = HashableValue::from(&v);
346                    if seen.insert(hashable) {
347                        list.push(agg_value_to_string(&v));
348                    }
349                }
350            }
351            AggregateState::Sample(sample) => {
352                if sample.is_none() {
353                    *sample = value;
354                }
355            }
356            AggregateState::Bivariate { .. } => {
357                // Bivariate functions require two values; use update_bivariate() instead.
358                // Single-value update is a no-op for bivariate state.
359            }
360        }
361    }
362
363    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
364    ///
365    /// Uses the two-variable Welford online algorithm for numerically stable computation
366    /// of covariance and related statistics. Skips the update if either value is null.
367    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
368        if let AggregateState::Bivariate {
369            count,
370            mean_x,
371            mean_y,
372            m2_x,
373            m2_y,
374            c_xy,
375            ..
376        } = self
377        {
378            // Skip if either value is null (SQL semantics: exclude non-pairs)
379            if let (Some(y), Some(x)) = (&y_val, &x_val)
380                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
381            {
382                *count += 1;
383                let n = *count as f64;
384                let dx = x_f - *mean_x;
385                let dy = y_f - *mean_y;
386                *mean_x += dx / n;
387                *mean_y += dy / n;
388                let dx2 = x_f - *mean_x; // post-update delta
389                let dy2 = y_f - *mean_y; // post-update delta
390                *m2_x += dx * dx2;
391                *m2_y += dy * dy2;
392                *c_xy += dx * dy2;
393            }
394        }
395    }
396
397    /// Finalizes the state and returns the result value.
398    pub(crate) fn finalize(&self) -> Value {
399        match self {
400            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
401                Value::Int64(*count)
402            }
403            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
404                if *count == 0 {
405                    Value::Null
406                } else {
407                    Value::Int64(*sum)
408                }
409            }
410            AggregateState::SumFloat(sum, count)
411            | AggregateState::SumFloatDistinct(sum, count, _) => {
412                if *count == 0 {
413                    Value::Null
414                } else {
415                    Value::Float64(*sum)
416                }
417            }
418            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
419                if *count == 0 {
420                    Value::Null
421                } else {
422                    Value::Float64(*sum / *count as f64)
423                }
424            }
425            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
426            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
427            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
428            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
429            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
430                Value::List(list.clone().into())
431            }
432            // Sample standard deviation: sqrt(M2 / (n - 1))
433            AggregateState::StdDev { count, m2, .. } => {
434                if *count < 2 {
435                    Value::Null
436                } else {
437                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
438                }
439            }
440            // Population standard deviation: sqrt(M2 / n)
441            AggregateState::StdDevPop { count, m2, .. } => {
442                if *count == 0 {
443                    Value::Null
444                } else {
445                    Value::Float64((*m2 / *count as f64).sqrt())
446                }
447            }
448            // Sample variance: M2 / (n - 1)
449            AggregateState::Variance { count, m2, .. } => {
450                if *count < 2 {
451                    Value::Null
452                } else {
453                    Value::Float64(*m2 / (*count - 1) as f64)
454                }
455            }
456            // Population variance: M2 / n
457            AggregateState::VariancePop { count, m2, .. } => {
458                if *count == 0 {
459                    Value::Null
460                } else {
461                    Value::Float64(*m2 / *count as f64)
462                }
463            }
464            // Discrete percentile: return actual value at percentile position
465            AggregateState::PercentileDisc { values, percentile } => {
466                if values.is_empty() {
467                    Value::Null
468                } else {
469                    let mut sorted = values.clone();
470                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
471                    // Index calculation per SQL standard: floor(p * (n - 1))
472                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
473                    Value::Float64(sorted[index])
474                }
475            }
476            // Continuous percentile: interpolate between values
477            AggregateState::PercentileCont { values, percentile } => {
478                if values.is_empty() {
479                    Value::Null
480                } else {
481                    let mut sorted = values.clone();
482                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
483                    // Linear interpolation per SQL standard
484                    let rank = percentile * (sorted.len() - 1) as f64;
485                    let lower_idx = rank.floor() as usize;
486                    let upper_idx = rank.ceil() as usize;
487                    if lower_idx == upper_idx {
488                        Value::Float64(sorted[lower_idx])
489                    } else {
490                        let fraction = rank - lower_idx as f64;
491                        let result =
492                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
493                        Value::Float64(result)
494                    }
495                }
496            }
497            // GROUP_CONCAT: join strings with space separator (SPARQL default)
498            AggregateState::GroupConcat(list, sep)
499            | AggregateState::GroupConcatDistinct(list, sep, _) => {
500                Value::String(list.join(sep).into())
501            }
502            // SAMPLE: return the first non-null value seen
503            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
504            // Binary set functions: dispatch on kind
505            AggregateState::Bivariate {
506                kind,
507                count,
508                mean_x,
509                mean_y,
510                m2_x,
511                m2_y,
512                c_xy,
513            } => {
514                let n = *count;
515                match kind {
516                    AggregateFunction::CovarSamp => {
517                        if n < 2 {
518                            Value::Null
519                        } else {
520                            Value::Float64(*c_xy / (n - 1) as f64)
521                        }
522                    }
523                    AggregateFunction::CovarPop => {
524                        if n == 0 {
525                            Value::Null
526                        } else {
527                            Value::Float64(*c_xy / n as f64)
528                        }
529                    }
530                    AggregateFunction::Corr => {
531                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
532                            Value::Null
533                        } else {
534                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
535                        }
536                    }
537                    AggregateFunction::RegrSlope => {
538                        if n == 0 || *m2_x == 0.0 {
539                            Value::Null
540                        } else {
541                            Value::Float64(*c_xy / *m2_x)
542                        }
543                    }
544                    AggregateFunction::RegrIntercept => {
545                        if n == 0 || *m2_x == 0.0 {
546                            Value::Null
547                        } else {
548                            let slope = *c_xy / *m2_x;
549                            Value::Float64(*mean_y - slope * *mean_x)
550                        }
551                    }
552                    AggregateFunction::RegrR2 => {
553                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
554                            Value::Null
555                        } else {
556                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
557                        }
558                    }
559                    AggregateFunction::RegrCount => Value::Int64(n),
560                    AggregateFunction::RegrSxx => {
561                        if n == 0 {
562                            Value::Null
563                        } else {
564                            Value::Float64(*m2_x)
565                        }
566                    }
567                    AggregateFunction::RegrSyy => {
568                        if n == 0 {
569                            Value::Null
570                        } else {
571                            Value::Float64(*m2_y)
572                        }
573                    }
574                    AggregateFunction::RegrSxy => {
575                        if n == 0 {
576                            Value::Null
577                        } else {
578                            Value::Float64(*c_xy)
579                        }
580                    }
581                    AggregateFunction::RegrAvgx => {
582                        if n == 0 {
583                            Value::Null
584                        } else {
585                            Value::Float64(*mean_x)
586                        }
587                    }
588                    AggregateFunction::RegrAvgy => {
589                        if n == 0 {
590                            Value::Null
591                        } else {
592                            Value::Float64(*mean_y)
593                        }
594                    }
595                    _ => Value::Null, // non-bivariate functions never reach here
596                }
597            }
598        }
599    }
600}
601
602use super::value_utils::{compare_values, value_to_f64};
603
604/// Converts a Value to its string representation for GROUP_CONCAT.
605fn agg_value_to_string(val: &Value) -> String {
606    match val {
607        Value::String(s) => s.to_string(),
608        Value::Int64(i) => i.to_string(),
609        Value::Float64(f) => f.to_string(),
610        Value::Bool(b) => b.to_string(),
611        Value::Null => String::new(),
612        other => format!("{other:?}"),
613    }
614}
615
616/// A group key for hash-based aggregation.
617#[derive(Debug, Clone, PartialEq, Eq, Hash)]
618pub struct GroupKey(Vec<GroupKeyPart>);
619
620#[derive(Debug, Clone, PartialEq, Eq, Hash)]
621enum GroupKeyPart {
622    Null,
623    Bool(bool),
624    Int64(i64),
625    String(String),
626}
627
628impl GroupKey {
629    /// Creates a group key from column values.
630    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
631        let parts: Vec<GroupKeyPart> = group_columns
632            .iter()
633            .map(|&col_idx| {
634                chunk
635                    .column(col_idx)
636                    .and_then(|col| col.get_value(row))
637                    .map_or(GroupKeyPart::Null, |v| match v {
638                        Value::Null => GroupKeyPart::Null,
639                        Value::Bool(b) => GroupKeyPart::Bool(b),
640                        Value::Int64(i) => GroupKeyPart::Int64(i),
641                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
642                        Value::String(s) => GroupKeyPart::String(s.to_string()),
643                        _ => GroupKeyPart::String(format!("{v:?}")),
644                    })
645            })
646            .collect();
647        GroupKey(parts)
648    }
649
650    /// Converts the group key back to values.
651    fn to_values(&self) -> Vec<Value> {
652        self.0
653            .iter()
654            .map(|part| match part {
655                GroupKeyPart::Null => Value::Null,
656                GroupKeyPart::Bool(b) => Value::Bool(*b),
657                GroupKeyPart::Int64(i) => Value::Int64(*i),
658                GroupKeyPart::String(s) => Value::String(s.clone().into()),
659            })
660            .collect()
661    }
662}
663
664/// Hash-based aggregate operator.
665///
666/// Groups input by key columns and computes aggregations for each group.
667pub struct HashAggregateOperator {
668    /// Child operator to read from.
669    child: Box<dyn Operator>,
670    /// Columns to group by.
671    group_columns: Vec<usize>,
672    /// Aggregation expressions.
673    aggregates: Vec<AggregateExpr>,
674    /// Output schema.
675    output_schema: Vec<LogicalType>,
676    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
677    groups: IndexMap<GroupKey, Vec<AggregateState>>,
678    /// Whether aggregation is complete.
679    aggregation_complete: bool,
680    /// Results iterator.
681    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
682}
683
684impl HashAggregateOperator {
685    /// Creates a new hash aggregate operator.
686    ///
687    /// # Arguments
688    /// * `child` - Child operator to read from.
689    /// * `group_columns` - Column indices to group by.
690    /// * `aggregates` - Aggregation expressions.
691    /// * `output_schema` - Schema of the output (group columns + aggregate results).
692    pub fn new(
693        child: Box<dyn Operator>,
694        group_columns: Vec<usize>,
695        aggregates: Vec<AggregateExpr>,
696        output_schema: Vec<LogicalType>,
697    ) -> Self {
698        Self {
699            child,
700            group_columns,
701            aggregates,
702            output_schema,
703            groups: IndexMap::new(),
704            aggregation_complete: false,
705            results: None,
706        }
707    }
708
709    /// Performs the aggregation.
710    fn aggregate(&mut self) -> Result<(), OperatorError> {
711        while let Some(chunk) = self.child.next()? {
712            for row in chunk.selected_indices() {
713                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
714
715                // Get or create aggregate states for this group
716                let states = self.groups.entry(key).or_insert_with(|| {
717                    self.aggregates
718                        .iter()
719                        .map(|agg| {
720                            AggregateState::new(
721                                agg.function,
722                                agg.distinct,
723                                agg.percentile,
724                                agg.separator.as_deref(),
725                            )
726                        })
727                        .collect()
728                });
729
730                // Update each aggregate
731                for (i, agg) in self.aggregates.iter().enumerate() {
732                    // Binary set functions: read two column values
733                    if agg.column2.is_some() {
734                        let y_val = agg
735                            .column
736                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
737                        let x_val = agg
738                            .column2
739                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
740                        states[i].update_bivariate(y_val, x_val);
741                        continue;
742                    }
743
744                    let value = match (agg.function, agg.distinct) {
745                        // COUNT(*) without DISTINCT doesn't need a value
746                        (AggregateFunction::Count, false) => None,
747                        // COUNT DISTINCT needs the actual value to track unique values
748                        (AggregateFunction::Count, true) => agg
749                            .column
750                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
751                        _ => agg
752                            .column
753                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
754                    };
755
756                    // For COUNT without DISTINCT, always update. For others, skip nulls.
757                    match (agg.function, agg.distinct) {
758                        (AggregateFunction::Count, false) => states[i].update(None),
759                        (AggregateFunction::Count, true) => {
760                            // COUNT DISTINCT needs the value to track unique values
761                            if value.is_some() && !matches!(value, Some(Value::Null)) {
762                                states[i].update(value);
763                            }
764                        }
765                        (AggregateFunction::CountNonNull, _) => {
766                            if value.is_some() && !matches!(value, Some(Value::Null)) {
767                                states[i].update(value);
768                            }
769                        }
770                        _ => {
771                            if value.is_some() && !matches!(value, Some(Value::Null)) {
772                                states[i].update(value);
773                            }
774                        }
775                    }
776                }
777            }
778        }
779
780        self.aggregation_complete = true;
781
782        // Convert to results iterator (IndexMap::drain takes a range)
783        let results: Vec<_> = self.groups.drain(..).collect();
784        self.results = Some(results.into_iter());
785
786        Ok(())
787    }
788}
789
790impl Operator for HashAggregateOperator {
791    fn next(&mut self) -> OperatorResult {
792        // Perform aggregation if not done
793        if !self.aggregation_complete {
794            self.aggregate()?;
795        }
796
797        // Special case: no groups (global aggregation with no data)
798        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
799            // For global aggregation (no GROUP BY), return one row with initial values
800            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
801
802            for agg in &self.aggregates {
803                let state = AggregateState::new(
804                    agg.function,
805                    agg.distinct,
806                    agg.percentile,
807                    agg.separator.as_deref(),
808                );
809                let value = state.finalize();
810                if let Some(col) = builder.column_mut(self.group_columns.len()) {
811                    col.push_value(value);
812                }
813            }
814            builder.advance_row();
815
816            self.results = Some(Vec::new().into_iter()); // Mark as done
817            return Ok(Some(builder.finish()));
818        }
819
820        let Some(results) = &mut self.results else {
821            return Ok(None);
822        };
823
824        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
825
826        for (key, states) in results.by_ref() {
827            // Output group key columns
828            let key_values = key.to_values();
829            for (i, value) in key_values.into_iter().enumerate() {
830                if let Some(col) = builder.column_mut(i) {
831                    col.push_value(value);
832                }
833            }
834
835            // Output aggregate results
836            for (i, state) in states.iter().enumerate() {
837                let col_idx = self.group_columns.len() + i;
838                if let Some(col) = builder.column_mut(col_idx) {
839                    col.push_value(state.finalize());
840                }
841            }
842
843            builder.advance_row();
844
845            if builder.is_full() {
846                return Ok(Some(builder.finish()));
847            }
848        }
849
850        if builder.row_count() > 0 {
851            Ok(Some(builder.finish()))
852        } else {
853            Ok(None)
854        }
855    }
856
857    fn reset(&mut self) {
858        self.child.reset();
859        self.groups.clear();
860        self.aggregation_complete = false;
861        self.results = None;
862    }
863
864    fn name(&self) -> &'static str {
865        "HashAggregate"
866    }
867}
868
869/// Simple (non-grouping) aggregate operator for global aggregations.
870///
871/// Used when there's no GROUP BY clause - aggregates all input into a single row.
872pub struct SimpleAggregateOperator {
873    /// Child operator.
874    child: Box<dyn Operator>,
875    /// Aggregation expressions.
876    aggregates: Vec<AggregateExpr>,
877    /// Output schema.
878    output_schema: Vec<LogicalType>,
879    /// Aggregate states.
880    states: Vec<AggregateState>,
881    /// Whether aggregation is complete.
882    done: bool,
883}
884
885impl SimpleAggregateOperator {
886    /// Creates a new simple aggregate operator.
887    pub fn new(
888        child: Box<dyn Operator>,
889        aggregates: Vec<AggregateExpr>,
890        output_schema: Vec<LogicalType>,
891    ) -> Self {
892        let states = aggregates
893            .iter()
894            .map(|agg| {
895                AggregateState::new(
896                    agg.function,
897                    agg.distinct,
898                    agg.percentile,
899                    agg.separator.as_deref(),
900                )
901            })
902            .collect();
903
904        Self {
905            child,
906            aggregates,
907            output_schema,
908            states,
909            done: false,
910        }
911    }
912}
913
914impl Operator for SimpleAggregateOperator {
915    fn next(&mut self) -> OperatorResult {
916        if self.done {
917            return Ok(None);
918        }
919
920        // Process all input
921        while let Some(chunk) = self.child.next()? {
922            for row in chunk.selected_indices() {
923                for (i, agg) in self.aggregates.iter().enumerate() {
924                    // Binary set functions: read two column values
925                    if agg.column2.is_some() {
926                        let y_val = agg
927                            .column
928                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
929                        let x_val = agg
930                            .column2
931                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
932                        self.states[i].update_bivariate(y_val, x_val);
933                        continue;
934                    }
935
936                    let value = match (agg.function, agg.distinct) {
937                        // COUNT(*) without DISTINCT doesn't need a value
938                        (AggregateFunction::Count, false) => None,
939                        // COUNT DISTINCT needs the actual value to track unique values
940                        (AggregateFunction::Count, true) => agg
941                            .column
942                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
943                        _ => agg
944                            .column
945                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
946                    };
947
948                    match (agg.function, agg.distinct) {
949                        (AggregateFunction::Count, false) => self.states[i].update(None),
950                        (AggregateFunction::Count, true) => {
951                            // COUNT DISTINCT needs the value to track unique values
952                            if value.is_some() && !matches!(value, Some(Value::Null)) {
953                                self.states[i].update(value);
954                            }
955                        }
956                        (AggregateFunction::CountNonNull, _) => {
957                            if value.is_some() && !matches!(value, Some(Value::Null)) {
958                                self.states[i].update(value);
959                            }
960                        }
961                        _ => {
962                            if value.is_some() && !matches!(value, Some(Value::Null)) {
963                                self.states[i].update(value);
964                            }
965                        }
966                    }
967                }
968            }
969        }
970
971        // Output single result row
972        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
973
974        for (i, state) in self.states.iter().enumerate() {
975            if let Some(col) = builder.column_mut(i) {
976                col.push_value(state.finalize());
977            }
978        }
979        builder.advance_row();
980
981        self.done = true;
982        Ok(Some(builder.finish()))
983    }
984
985    fn reset(&mut self) {
986        self.child.reset();
987        self.states = self
988            .aggregates
989            .iter()
990            .map(|agg| {
991                AggregateState::new(
992                    agg.function,
993                    agg.distinct,
994                    agg.percentile,
995                    agg.separator.as_deref(),
996                )
997            })
998            .collect();
999        self.done = false;
1000    }
1001
1002    fn name(&self) -> &'static str {
1003        "SimpleAggregate"
1004    }
1005}
1006
1007#[cfg(test)]
1008mod tests {
1009    use super::*;
1010    use crate::execution::chunk::DataChunkBuilder;
1011
1012    struct MockOperator {
1013        chunks: Vec<DataChunk>,
1014        position: usize,
1015    }
1016
1017    impl MockOperator {
1018        fn new(chunks: Vec<DataChunk>) -> Self {
1019            Self {
1020                chunks,
1021                position: 0,
1022            }
1023        }
1024    }
1025
1026    impl Operator for MockOperator {
1027        fn next(&mut self) -> OperatorResult {
1028            if self.position < self.chunks.len() {
1029                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1030                self.position += 1;
1031                Ok(Some(chunk))
1032            } else {
1033                Ok(None)
1034            }
1035        }
1036
1037        fn reset(&mut self) {
1038            self.position = 0;
1039        }
1040
1041        fn name(&self) -> &'static str {
1042            "Mock"
1043        }
1044    }
1045
1046    fn create_test_chunk() -> DataChunk {
1047        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1048        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1049
1050        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1051        for (group, value) in data {
1052            builder.column_mut(0).unwrap().push_int64(group);
1053            builder.column_mut(1).unwrap().push_int64(value);
1054            builder.advance_row();
1055        }
1056
1057        builder.finish()
1058    }
1059
1060    #[test]
1061    fn test_simple_count() {
1062        let mock = MockOperator::new(vec![create_test_chunk()]);
1063
1064        let mut agg = SimpleAggregateOperator::new(
1065            Box::new(mock),
1066            vec![AggregateExpr::count_star()],
1067            vec![LogicalType::Int64],
1068        );
1069
1070        let result = agg.next().unwrap().unwrap();
1071        assert_eq!(result.row_count(), 1);
1072        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1073
1074        // Should be done
1075        assert!(agg.next().unwrap().is_none());
1076    }
1077
1078    #[test]
1079    fn test_simple_sum() {
1080        let mock = MockOperator::new(vec![create_test_chunk()]);
1081
1082        let mut agg = SimpleAggregateOperator::new(
1083            Box::new(mock),
1084            vec![AggregateExpr::sum(1)], // Sum of column 1
1085            vec![LogicalType::Int64],
1086        );
1087
1088        let result = agg.next().unwrap().unwrap();
1089        assert_eq!(result.row_count(), 1);
1090        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1091        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1092    }
1093
1094    #[test]
1095    fn test_simple_avg() {
1096        let mock = MockOperator::new(vec![create_test_chunk()]);
1097
1098        let mut agg = SimpleAggregateOperator::new(
1099            Box::new(mock),
1100            vec![AggregateExpr::avg(1)],
1101            vec![LogicalType::Float64],
1102        );
1103
1104        let result = agg.next().unwrap().unwrap();
1105        assert_eq!(result.row_count(), 1);
1106        // Avg: 150 / 5 = 30.0
1107        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1108        assert!((avg - 30.0).abs() < 0.001);
1109    }
1110
1111    #[test]
1112    fn test_simple_min_max() {
1113        let mock = MockOperator::new(vec![create_test_chunk()]);
1114
1115        let mut agg = SimpleAggregateOperator::new(
1116            Box::new(mock),
1117            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1118            vec![LogicalType::Int64, LogicalType::Int64],
1119        );
1120
1121        let result = agg.next().unwrap().unwrap();
1122        assert_eq!(result.row_count(), 1);
1123        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1124        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1125    }
1126
1127    #[test]
1128    fn test_sum_with_string_values() {
1129        // Test SUM with string values (like RDF stores numeric literals)
1130        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1131        builder.column_mut(0).unwrap().push_string("30");
1132        builder.advance_row();
1133        builder.column_mut(0).unwrap().push_string("25");
1134        builder.advance_row();
1135        builder.column_mut(0).unwrap().push_string("35");
1136        builder.advance_row();
1137        let chunk = builder.finish();
1138
1139        let mock = MockOperator::new(vec![chunk]);
1140        let mut agg = SimpleAggregateOperator::new(
1141            Box::new(mock),
1142            vec![AggregateExpr::sum(0)],
1143            vec![LogicalType::Float64],
1144        );
1145
1146        let result = agg.next().unwrap().unwrap();
1147        assert_eq!(result.row_count(), 1);
1148        // Should parse strings and sum: 30 + 25 + 35 = 90
1149        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1150        assert!(
1151            (sum_val - 90.0).abs() < 0.001,
1152            "Expected 90.0, got {}",
1153            sum_val
1154        );
1155    }
1156
1157    #[test]
1158    fn test_grouped_aggregation() {
1159        let mock = MockOperator::new(vec![create_test_chunk()]);
1160
1161        // GROUP BY column 0, SUM(column 1)
1162        let mut agg = HashAggregateOperator::new(
1163            Box::new(mock),
1164            vec![0],                     // Group by column 0
1165            vec![AggregateExpr::sum(1)], // Sum of column 1
1166            vec![LogicalType::Int64, LogicalType::Int64],
1167        );
1168
1169        let mut results: Vec<(i64, i64)> = Vec::new();
1170        while let Some(chunk) = agg.next().unwrap() {
1171            for row in chunk.selected_indices() {
1172                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1173                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1174                results.push((group, sum));
1175            }
1176        }
1177
1178        results.sort_by_key(|(g, _)| *g);
1179        assert_eq!(results.len(), 2);
1180        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1181        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1182    }
1183
1184    #[test]
1185    fn test_grouped_count() {
1186        let mock = MockOperator::new(vec![create_test_chunk()]);
1187
1188        // GROUP BY column 0, COUNT(*)
1189        let mut agg = HashAggregateOperator::new(
1190            Box::new(mock),
1191            vec![0],
1192            vec![AggregateExpr::count_star()],
1193            vec![LogicalType::Int64, LogicalType::Int64],
1194        );
1195
1196        let mut results: Vec<(i64, i64)> = Vec::new();
1197        while let Some(chunk) = agg.next().unwrap() {
1198            for row in chunk.selected_indices() {
1199                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1200                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1201                results.push((group, count));
1202            }
1203        }
1204
1205        results.sort_by_key(|(g, _)| *g);
1206        assert_eq!(results.len(), 2);
1207        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1208        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1209    }
1210
1211    #[test]
1212    fn test_multiple_aggregates() {
1213        let mock = MockOperator::new(vec![create_test_chunk()]);
1214
1215        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1216        let mut agg = HashAggregateOperator::new(
1217            Box::new(mock),
1218            vec![0],
1219            vec![
1220                AggregateExpr::count_star(),
1221                AggregateExpr::sum(1),
1222                AggregateExpr::avg(1),
1223            ],
1224            vec![
1225                LogicalType::Int64,   // Group key
1226                LogicalType::Int64,   // COUNT
1227                LogicalType::Int64,   // SUM
1228                LogicalType::Float64, // AVG
1229            ],
1230        );
1231
1232        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1233        while let Some(chunk) = agg.next().unwrap() {
1234            for row in chunk.selected_indices() {
1235                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1236                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1237                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1238                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1239                results.push((group, count, sum, avg));
1240            }
1241        }
1242
1243        results.sort_by_key(|(g, _, _, _)| *g);
1244        assert_eq!(results.len(), 2);
1245
1246        // Group 1: COUNT=2, SUM=30, AVG=15.0
1247        assert_eq!(results[0].0, 1);
1248        assert_eq!(results[0].1, 2);
1249        assert_eq!(results[0].2, 30);
1250        assert!((results[0].3 - 15.0).abs() < 0.001);
1251
1252        // Group 2: COUNT=3, SUM=120, AVG=40.0
1253        assert_eq!(results[1].0, 2);
1254        assert_eq!(results[1].1, 3);
1255        assert_eq!(results[1].2, 120);
1256        assert!((results[1].3 - 40.0).abs() < 0.001);
1257    }
1258
1259    fn create_test_chunk_with_duplicates() -> DataChunk {
1260        // Create data with duplicate values in column 1
1261        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1262        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1263        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1264        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1265
1266        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1267        for (group, value) in data {
1268            builder.column_mut(0).unwrap().push_int64(group);
1269            builder.column_mut(1).unwrap().push_int64(value);
1270            builder.advance_row();
1271        }
1272
1273        builder.finish()
1274    }
1275
1276    #[test]
1277    fn test_count_distinct() {
1278        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1279
1280        // COUNT(DISTINCT column 1)
1281        let mut agg = SimpleAggregateOperator::new(
1282            Box::new(mock),
1283            vec![AggregateExpr::count(1).with_distinct()],
1284            vec![LogicalType::Int64],
1285        );
1286
1287        let result = agg.next().unwrap().unwrap();
1288        assert_eq!(result.row_count(), 1);
1289        // Total distinct values: 10, 20, 30 = 3 distinct values
1290        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1291    }
1292
1293    #[test]
1294    fn test_grouped_count_distinct() {
1295        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1296
1297        // GROUP BY column 0, COUNT(DISTINCT column 1)
1298        let mut agg = HashAggregateOperator::new(
1299            Box::new(mock),
1300            vec![0],
1301            vec![AggregateExpr::count(1).with_distinct()],
1302            vec![LogicalType::Int64, LogicalType::Int64],
1303        );
1304
1305        let mut results: Vec<(i64, i64)> = Vec::new();
1306        while let Some(chunk) = agg.next().unwrap() {
1307            for row in chunk.selected_indices() {
1308                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1309                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1310                results.push((group, count));
1311            }
1312        }
1313
1314        results.sort_by_key(|(g, _)| *g);
1315        assert_eq!(results.len(), 2);
1316        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1317        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1318    }
1319
1320    #[test]
1321    fn test_sum_distinct() {
1322        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1323
1324        // SUM(DISTINCT column 1)
1325        let mut agg = SimpleAggregateOperator::new(
1326            Box::new(mock),
1327            vec![AggregateExpr::sum(1).with_distinct()],
1328            vec![LogicalType::Int64],
1329        );
1330
1331        let result = agg.next().unwrap().unwrap();
1332        assert_eq!(result.row_count(), 1);
1333        // Sum of distinct values: 10 + 20 + 30 = 60
1334        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1335    }
1336
1337    #[test]
1338    fn test_avg_distinct() {
1339        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1340
1341        // AVG(DISTINCT column 1)
1342        let mut agg = SimpleAggregateOperator::new(
1343            Box::new(mock),
1344            vec![AggregateExpr::avg(1).with_distinct()],
1345            vec![LogicalType::Float64],
1346        );
1347
1348        let result = agg.next().unwrap().unwrap();
1349        assert_eq!(result.row_count(), 1);
1350        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1351        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1352        assert!((avg - 20.0).abs() < 0.001);
1353    }
1354
1355    fn create_statistical_test_chunk() -> DataChunk {
1356        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1357        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1358        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1359
1360        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1361            builder.column_mut(0).unwrap().push_int64(value);
1362            builder.advance_row();
1363        }
1364
1365        builder.finish()
1366    }
1367
1368    #[test]
1369    fn test_stdev_sample() {
1370        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1371
1372        let mut agg = SimpleAggregateOperator::new(
1373            Box::new(mock),
1374            vec![AggregateExpr::stdev(0)],
1375            vec![LogicalType::Float64],
1376        );
1377
1378        let result = agg.next().unwrap().unwrap();
1379        assert_eq!(result.row_count(), 1);
1380        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1381        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1382        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1383        assert!((stdev - 2.138).abs() < 0.01);
1384    }
1385
1386    #[test]
1387    fn test_stdev_population() {
1388        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1389
1390        let mut agg = SimpleAggregateOperator::new(
1391            Box::new(mock),
1392            vec![AggregateExpr::stdev_pop(0)],
1393            vec![LogicalType::Float64],
1394        );
1395
1396        let result = agg.next().unwrap().unwrap();
1397        assert_eq!(result.row_count(), 1);
1398        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1399        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1400        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1401        assert!((stdev - 2.0).abs() < 0.01);
1402    }
1403
1404    #[test]
1405    fn test_percentile_disc() {
1406        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1407
1408        // Median (50th percentile discrete)
1409        let mut agg = SimpleAggregateOperator::new(
1410            Box::new(mock),
1411            vec![AggregateExpr::percentile_disc(0, 0.5)],
1412            vec![LogicalType::Float64],
1413        );
1414
1415        let result = agg.next().unwrap().unwrap();
1416        assert_eq!(result.row_count(), 1);
1417        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1418        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1419        assert!((percentile - 4.0).abs() < 0.01);
1420    }
1421
1422    #[test]
1423    fn test_percentile_cont() {
1424        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1425
1426        // Median (50th percentile continuous)
1427        let mut agg = SimpleAggregateOperator::new(
1428            Box::new(mock),
1429            vec![AggregateExpr::percentile_cont(0, 0.5)],
1430            vec![LogicalType::Float64],
1431        );
1432
1433        let result = agg.next().unwrap().unwrap();
1434        assert_eq!(result.row_count(), 1);
1435        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1436        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1437        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1438        assert!((percentile - 4.5).abs() < 0.01);
1439    }
1440
1441    #[test]
1442    fn test_percentile_extremes() {
1443        // Test 0th and 100th percentiles
1444        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1445
1446        let mut agg = SimpleAggregateOperator::new(
1447            Box::new(mock),
1448            vec![
1449                AggregateExpr::percentile_disc(0, 0.0),
1450                AggregateExpr::percentile_disc(0, 1.0),
1451            ],
1452            vec![LogicalType::Float64, LogicalType::Float64],
1453        );
1454
1455        let result = agg.next().unwrap().unwrap();
1456        assert_eq!(result.row_count(), 1);
1457        // 0th percentile = minimum = 2
1458        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1459        assert!((p0 - 2.0).abs() < 0.01);
1460        // 100th percentile = maximum = 9
1461        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1462        assert!((p100 - 9.0).abs() < 0.01);
1463    }
1464
1465    #[test]
1466    fn test_stdev_single_value() {
1467        // Single value should return null for sample stdev
1468        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1469        builder.column_mut(0).unwrap().push_int64(42);
1470        builder.advance_row();
1471        let chunk = builder.finish();
1472
1473        let mock = MockOperator::new(vec![chunk]);
1474
1475        let mut agg = SimpleAggregateOperator::new(
1476            Box::new(mock),
1477            vec![AggregateExpr::stdev(0)],
1478            vec![LogicalType::Float64],
1479        );
1480
1481        let result = agg.next().unwrap().unwrap();
1482        assert_eq!(result.row_count(), 1);
1483        // Sample stdev of single value is undefined (null)
1484        assert!(matches!(
1485            result.column(0).unwrap().get_value(0),
1486            Some(Value::Null)
1487        ));
1488    }
1489
1490    #[test]
1491    fn test_first_and_last() {
1492        let mock = MockOperator::new(vec![create_test_chunk()]);
1493
1494        let mut agg = SimpleAggregateOperator::new(
1495            Box::new(mock),
1496            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1497            vec![LogicalType::Int64, LogicalType::Int64],
1498        );
1499
1500        let result = agg.next().unwrap().unwrap();
1501        assert_eq!(result.row_count(), 1);
1502        // First: 10, Last: 50
1503        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1504        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1505    }
1506
1507    #[test]
1508    fn test_collect() {
1509        let mock = MockOperator::new(vec![create_test_chunk()]);
1510
1511        let mut agg = SimpleAggregateOperator::new(
1512            Box::new(mock),
1513            vec![AggregateExpr::collect(1)],
1514            vec![LogicalType::Any],
1515        );
1516
1517        let result = agg.next().unwrap().unwrap();
1518        let val = result.column(0).unwrap().get_value(0).unwrap();
1519        if let Value::List(items) = val {
1520            assert_eq!(items.len(), 5);
1521        } else {
1522            panic!("Expected List value");
1523        }
1524    }
1525
1526    #[test]
1527    fn test_collect_distinct() {
1528        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1529
1530        let mut agg = SimpleAggregateOperator::new(
1531            Box::new(mock),
1532            vec![AggregateExpr::collect(1).with_distinct()],
1533            vec![LogicalType::Any],
1534        );
1535
1536        let result = agg.next().unwrap().unwrap();
1537        let val = result.column(0).unwrap().get_value(0).unwrap();
1538        if let Value::List(items) = val {
1539            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1540            assert_eq!(items.len(), 3);
1541        } else {
1542            panic!("Expected List value");
1543        }
1544    }
1545
1546    #[test]
1547    fn test_group_concat() {
1548        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1549        for s in ["hello", "world", "foo"] {
1550            builder.column_mut(0).unwrap().push_string(s);
1551            builder.advance_row();
1552        }
1553        let chunk = builder.finish();
1554        let mock = MockOperator::new(vec![chunk]);
1555
1556        let agg_expr = AggregateExpr {
1557            function: AggregateFunction::GroupConcat,
1558            column: Some(0),
1559            column2: None,
1560            distinct: false,
1561            alias: None,
1562            percentile: None,
1563            separator: None,
1564        };
1565
1566        let mut agg =
1567            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1568
1569        let result = agg.next().unwrap().unwrap();
1570        let val = result.column(0).unwrap().get_value(0).unwrap();
1571        assert_eq!(val, Value::String("hello world foo".into()));
1572    }
1573
1574    #[test]
1575    fn test_sample() {
1576        let mock = MockOperator::new(vec![create_test_chunk()]);
1577
1578        let agg_expr = AggregateExpr {
1579            function: AggregateFunction::Sample,
1580            column: Some(1),
1581            column2: None,
1582            distinct: false,
1583            alias: None,
1584            percentile: None,
1585            separator: None,
1586        };
1587
1588        let mut agg =
1589            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1590
1591        let result = agg.next().unwrap().unwrap();
1592        // Sample should return the first non-null value (10)
1593        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1594    }
1595
1596    #[test]
1597    fn test_variance_sample() {
1598        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1599
1600        let agg_expr = AggregateExpr {
1601            function: AggregateFunction::Variance,
1602            column: Some(0),
1603            column2: None,
1604            distinct: false,
1605            alias: None,
1606            percentile: None,
1607            separator: None,
1608        };
1609
1610        let mut agg = SimpleAggregateOperator::new(
1611            Box::new(mock),
1612            vec![agg_expr],
1613            vec![LogicalType::Float64],
1614        );
1615
1616        let result = agg.next().unwrap().unwrap();
1617        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1618        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1619        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1620    }
1621
1622    #[test]
1623    fn test_variance_population() {
1624        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1625
1626        let agg_expr = AggregateExpr {
1627            function: AggregateFunction::VariancePop,
1628            column: Some(0),
1629            column2: None,
1630            distinct: false,
1631            alias: None,
1632            percentile: None,
1633            separator: None,
1634        };
1635
1636        let mut agg = SimpleAggregateOperator::new(
1637            Box::new(mock),
1638            vec![agg_expr],
1639            vec![LogicalType::Float64],
1640        );
1641
1642        let result = agg.next().unwrap().unwrap();
1643        // Population variance: M2/n = 32/8 = 4.0
1644        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1645        assert!((variance - 4.0).abs() < 0.01);
1646    }
1647
1648    #[test]
1649    fn test_variance_single_value() {
1650        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1651        builder.column_mut(0).unwrap().push_int64(42);
1652        builder.advance_row();
1653        let chunk = builder.finish();
1654        let mock = MockOperator::new(vec![chunk]);
1655
1656        let agg_expr = AggregateExpr {
1657            function: AggregateFunction::Variance,
1658            column: Some(0),
1659            column2: None,
1660            distinct: false,
1661            alias: None,
1662            percentile: None,
1663            separator: None,
1664        };
1665
1666        let mut agg = SimpleAggregateOperator::new(
1667            Box::new(mock),
1668            vec![agg_expr],
1669            vec![LogicalType::Float64],
1670        );
1671
1672        let result = agg.next().unwrap().unwrap();
1673        // Sample variance of single value is undefined (null)
1674        assert!(matches!(
1675            result.column(0).unwrap().get_value(0),
1676            Some(Value::Null)
1677        ));
1678    }
1679
1680    #[test]
1681    fn test_empty_aggregation() {
1682        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1683        // (ISO/IEC 39075 Section 20.9)
1684        let mock = MockOperator::new(vec![]);
1685
1686        let mut agg = SimpleAggregateOperator::new(
1687            Box::new(mock),
1688            vec![
1689                AggregateExpr::count_star(),
1690                AggregateExpr::sum(0),
1691                AggregateExpr::avg(0),
1692                AggregateExpr::min(0),
1693                AggregateExpr::max(0),
1694            ],
1695            vec![
1696                LogicalType::Int64,
1697                LogicalType::Int64,
1698                LogicalType::Float64,
1699                LogicalType::Int64,
1700                LogicalType::Int64,
1701            ],
1702        );
1703
1704        let result = agg.next().unwrap().unwrap();
1705        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1706        assert!(matches!(
1707            result.column(1).unwrap().get_value(0),
1708            Some(Value::Null)
1709        )); // SUM
1710        assert!(matches!(
1711            result.column(2).unwrap().get_value(0),
1712            Some(Value::Null)
1713        )); // AVG
1714        assert!(matches!(
1715            result.column(3).unwrap().get_value(0),
1716            Some(Value::Null)
1717        )); // MIN
1718        assert!(matches!(
1719            result.column(4).unwrap().get_value(0),
1720            Some(Value::Null)
1721        )); // MAX
1722    }
1723
1724    #[test]
1725    fn test_stdev_pop_single_value() {
1726        // Single value should return 0 for population stdev
1727        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1728        builder.column_mut(0).unwrap().push_int64(42);
1729        builder.advance_row();
1730        let chunk = builder.finish();
1731
1732        let mock = MockOperator::new(vec![chunk]);
1733
1734        let mut agg = SimpleAggregateOperator::new(
1735            Box::new(mock),
1736            vec![AggregateExpr::stdev_pop(0)],
1737            vec![LogicalType::Float64],
1738        );
1739
1740        let result = agg.next().unwrap().unwrap();
1741        assert_eq!(result.row_count(), 1);
1742        // Population stdev of single value is 0
1743        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1744        assert!((stdev - 0.0).abs() < 0.01);
1745    }
1746}