Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12
13use grafeo_common::types::{LogicalType, Value};
14
15use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
16use super::{Operator, OperatorError, OperatorResult};
17use crate::execution::DataChunk;
18use crate::execution::chunk::DataChunkBuilder;
19
20/// State for a single aggregation computation.
21#[derive(Debug, Clone)]
22pub(crate) enum AggregateState {
23    /// Count state.
24    Count(i64),
25    /// Count distinct state (count, seen values).
26    CountDistinct(i64, HashSet<HashableValue>),
27    /// Sum state (integer).
28    SumInt(i64),
29    /// Sum distinct state (integer, seen values).
30    SumIntDistinct(i64, HashSet<HashableValue>),
31    /// Sum state (float).
32    SumFloat(f64),
33    /// Sum distinct state (float, seen values).
34    SumFloatDistinct(f64, HashSet<HashableValue>),
35    /// Average state (sum, count).
36    Avg(f64, i64),
37    /// Average distinct state (sum, count, seen values).
38    AvgDistinct(f64, i64, HashSet<HashableValue>),
39    /// Min state.
40    Min(Option<Value>),
41    /// Max state.
42    Max(Option<Value>),
43    /// First state.
44    First(Option<Value>),
45    /// Last state.
46    Last(Option<Value>),
47    /// Collect state.
48    Collect(Vec<Value>),
49    /// Collect distinct state (values, seen).
50    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
51    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
52    StdDev { count: i64, mean: f64, m2: f64 },
53    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDevPop { count: i64, mean: f64, m2: f64 },
55    /// Discrete percentile state (values, percentile).
56    PercentileDisc { values: Vec<f64>, percentile: f64 },
57    /// Continuous percentile state (values, percentile).
58    PercentileCont { values: Vec<f64>, percentile: f64 },
59    /// GROUP_CONCAT state (collected string values, separator defaults to space).
60    GroupConcat(Vec<String>),
61    /// GROUP_CONCAT distinct state (collected string values, seen).
62    GroupConcatDistinct(Vec<String>, HashSet<HashableValue>),
63    /// SAMPLE state (first non-null value encountered).
64    Sample(Option<Value>),
65    /// Sample variance state using Welford's algorithm (count, mean, M2).
66    Variance { count: i64, mean: f64, m2: f64 },
67    /// Population variance state using Welford's algorithm (count, mean, M2).
68    VariancePop { count: i64, mean: f64, m2: f64 },
69    /// Two-variable online statistics (Welford generalization for covariance/regression).
70    Bivariate {
71        /// Which binary set function this state will finalize to.
72        kind: AggregateFunction,
73        count: i64,
74        mean_x: f64,
75        mean_y: f64,
76        m2_x: f64,
77        m2_y: f64,
78        c_xy: f64,
79    },
80}
81
82impl AggregateState {
83    /// Creates initial state for an aggregation function.
84    pub(crate) fn new(
85        function: AggregateFunction,
86        distinct: bool,
87        percentile: Option<f64>,
88    ) -> Self {
89        match (function, distinct) {
90            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
91                AggregateState::Count(0)
92            }
93            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
94                AggregateState::CountDistinct(0, HashSet::new())
95            }
96            (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
97            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
98            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
99            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
100            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
101            (AggregateFunction::Max, _) => AggregateState::Max(None),
102            (AggregateFunction::First, _) => AggregateState::First(None),
103            (AggregateFunction::Last, _) => AggregateState::Last(None),
104            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
105            (AggregateFunction::Collect, true) => {
106                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
107            }
108            // Statistical functions (Welford's algorithm for online computation)
109            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
110                count: 0,
111                mean: 0.0,
112                m2: 0.0,
113            },
114            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
115                count: 0,
116                mean: 0.0,
117                m2: 0.0,
118            },
119            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
120                values: Vec::new(),
121                percentile: percentile.unwrap_or(0.5),
122            },
123            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
124                values: Vec::new(),
125                percentile: percentile.unwrap_or(0.5),
126            },
127            (AggregateFunction::GroupConcat, false) => AggregateState::GroupConcat(Vec::new()),
128            (AggregateFunction::GroupConcat, true) => {
129                AggregateState::GroupConcatDistinct(Vec::new(), HashSet::new())
130            }
131            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
132            // Binary set functions (all share the same Bivariate state)
133            (
134                AggregateFunction::CovarSamp
135                | AggregateFunction::CovarPop
136                | AggregateFunction::Corr
137                | AggregateFunction::RegrSlope
138                | AggregateFunction::RegrIntercept
139                | AggregateFunction::RegrR2
140                | AggregateFunction::RegrCount
141                | AggregateFunction::RegrSxx
142                | AggregateFunction::RegrSyy
143                | AggregateFunction::RegrSxy
144                | AggregateFunction::RegrAvgx
145                | AggregateFunction::RegrAvgy,
146                _,
147            ) => AggregateState::Bivariate {
148                kind: function,
149                count: 0,
150                mean_x: 0.0,
151                mean_y: 0.0,
152                m2_x: 0.0,
153                m2_y: 0.0,
154                c_xy: 0.0,
155            },
156            (AggregateFunction::Variance, _) => AggregateState::Variance {
157                count: 0,
158                mean: 0.0,
159                m2: 0.0,
160            },
161            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
162                count: 0,
163                mean: 0.0,
164                m2: 0.0,
165            },
166        }
167    }
168
169    /// Updates the state with a new value.
170    pub(crate) fn update(&mut self, value: Option<Value>) {
171        match self {
172            AggregateState::Count(count) => {
173                *count += 1;
174            }
175            AggregateState::CountDistinct(count, seen) => {
176                if let Some(ref v) = value {
177                    let hashable = HashableValue::from(v);
178                    if seen.insert(hashable) {
179                        *count += 1;
180                    }
181                }
182            }
183            AggregateState::SumInt(sum) => {
184                if let Some(Value::Int64(v)) = value {
185                    *sum += v;
186                } else if let Some(Value::Float64(v)) = value {
187                    // Convert to float sum
188                    *self = AggregateState::SumFloat(*sum as f64 + v);
189                } else if let Some(ref v) = value {
190                    // RDF stores numeric literals as strings - try to parse
191                    if let Some(num) = value_to_f64(v) {
192                        *self = AggregateState::SumFloat(*sum as f64 + num);
193                    }
194                }
195            }
196            AggregateState::SumIntDistinct(sum, seen) => {
197                if let Some(ref v) = value {
198                    let hashable = HashableValue::from(v);
199                    if seen.insert(hashable) {
200                        if let Value::Int64(i) = v {
201                            *sum += i;
202                        } else if let Value::Float64(f) = v {
203                            // Convert to float distinct: move the seen set instead of cloning
204                            let moved_seen = std::mem::take(seen);
205                            *self = AggregateState::SumFloatDistinct(*sum as f64 + f, moved_seen);
206                        } else if let Some(num) = value_to_f64(v) {
207                            // RDF string-encoded numerics
208                            let moved_seen = std::mem::take(seen);
209                            *self = AggregateState::SumFloatDistinct(*sum as f64 + num, moved_seen);
210                        }
211                    }
212                }
213            }
214            AggregateState::SumFloat(sum) => {
215                if let Some(ref v) = value {
216                    // Use value_to_f64 which now handles strings
217                    if let Some(num) = value_to_f64(v) {
218                        *sum += num;
219                    }
220                }
221            }
222            AggregateState::SumFloatDistinct(sum, seen) => {
223                if let Some(ref v) = value {
224                    let hashable = HashableValue::from(v);
225                    if seen.insert(hashable)
226                        && let Some(num) = value_to_f64(v)
227                    {
228                        *sum += num;
229                    }
230                }
231            }
232            AggregateState::Avg(sum, count) => {
233                if let Some(ref v) = value
234                    && let Some(num) = value_to_f64(v)
235                {
236                    *sum += num;
237                    *count += 1;
238                }
239            }
240            AggregateState::AvgDistinct(sum, count, seen) => {
241                if let Some(ref v) = value {
242                    let hashable = HashableValue::from(v);
243                    if seen.insert(hashable)
244                        && let Some(num) = value_to_f64(v)
245                    {
246                        *sum += num;
247                        *count += 1;
248                    }
249                }
250            }
251            AggregateState::Min(min) => {
252                if let Some(v) = value {
253                    match min {
254                        None => *min = Some(v),
255                        Some(current) => {
256                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
257                                *min = Some(v);
258                            }
259                        }
260                    }
261                }
262            }
263            AggregateState::Max(max) => {
264                if let Some(v) = value {
265                    match max {
266                        None => *max = Some(v),
267                        Some(current) => {
268                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
269                                *max = Some(v);
270                            }
271                        }
272                    }
273                }
274            }
275            AggregateState::First(first) => {
276                if first.is_none() {
277                    *first = value;
278                }
279            }
280            AggregateState::Last(last) => {
281                if value.is_some() {
282                    *last = value;
283                }
284            }
285            AggregateState::Collect(list) => {
286                if let Some(v) = value {
287                    list.push(v);
288                }
289            }
290            AggregateState::CollectDistinct(list, seen) => {
291                if let Some(v) = value {
292                    let hashable = HashableValue::from(&v);
293                    if seen.insert(hashable) {
294                        list.push(v);
295                    }
296                }
297            }
298            // Statistical functions using Welford's online algorithm
299            AggregateState::StdDev { count, mean, m2 }
300            | AggregateState::StdDevPop { count, mean, m2 }
301            | AggregateState::Variance { count, mean, m2 }
302            | AggregateState::VariancePop { count, mean, m2 } => {
303                if let Some(ref v) = value
304                    && let Some(x) = value_to_f64(v)
305                {
306                    *count += 1;
307                    let delta = x - *mean;
308                    *mean += delta / *count as f64;
309                    let delta2 = x - *mean;
310                    *m2 += delta * delta2;
311                }
312            }
313            AggregateState::PercentileDisc { values, .. }
314            | AggregateState::PercentileCont { values, .. } => {
315                if let Some(ref v) = value
316                    && let Some(x) = value_to_f64(v)
317                {
318                    values.push(x);
319                }
320            }
321            AggregateState::GroupConcat(list) => {
322                if let Some(v) = value {
323                    list.push(agg_value_to_string(&v));
324                }
325            }
326            AggregateState::GroupConcatDistinct(list, seen) => {
327                if let Some(v) = value {
328                    let hashable = HashableValue::from(&v);
329                    if seen.insert(hashable) {
330                        list.push(agg_value_to_string(&v));
331                    }
332                }
333            }
334            AggregateState::Sample(sample) => {
335                if sample.is_none() {
336                    *sample = value;
337                }
338            }
339            AggregateState::Bivariate { .. } => {
340                // Bivariate functions require two values; use update_bivariate() instead.
341                // Single-value update is a no-op for bivariate state.
342            }
343        }
344    }
345
346    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
347    ///
348    /// Uses the two-variable Welford online algorithm for numerically stable computation
349    /// of covariance and related statistics. Skips the update if either value is null.
350    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
351        if let AggregateState::Bivariate {
352            count,
353            mean_x,
354            mean_y,
355            m2_x,
356            m2_y,
357            c_xy,
358            ..
359        } = self
360        {
361            // Skip if either value is null (SQL semantics: exclude non-pairs)
362            if let (Some(y), Some(x)) = (&y_val, &x_val)
363                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
364            {
365                *count += 1;
366                let n = *count as f64;
367                let dx = x_f - *mean_x;
368                let dy = y_f - *mean_y;
369                *mean_x += dx / n;
370                *mean_y += dy / n;
371                let dx2 = x_f - *mean_x; // post-update delta
372                let dy2 = y_f - *mean_y; // post-update delta
373                *m2_x += dx * dx2;
374                *m2_y += dy * dy2;
375                *c_xy += dx * dy2;
376            }
377        }
378    }
379
380    /// Finalizes the state and returns the result value.
381    pub(crate) fn finalize(&self) -> Value {
382        match self {
383            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
384                Value::Int64(*count)
385            }
386            AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
387                Value::Int64(*sum)
388            }
389            AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
390                Value::Float64(*sum)
391            }
392            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
393                if *count == 0 {
394                    Value::Null
395                } else {
396                    Value::Float64(*sum / *count as f64)
397                }
398            }
399            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
400            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
401            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
402            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
403            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
404                Value::List(list.clone().into())
405            }
406            // Sample standard deviation: sqrt(M2 / (n - 1))
407            AggregateState::StdDev { count, m2, .. } => {
408                if *count < 2 {
409                    Value::Null
410                } else {
411                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
412                }
413            }
414            // Population standard deviation: sqrt(M2 / n)
415            AggregateState::StdDevPop { count, m2, .. } => {
416                if *count == 0 {
417                    Value::Null
418                } else {
419                    Value::Float64((*m2 / *count as f64).sqrt())
420                }
421            }
422            // Sample variance: M2 / (n - 1)
423            AggregateState::Variance { count, m2, .. } => {
424                if *count < 2 {
425                    Value::Null
426                } else {
427                    Value::Float64(*m2 / (*count - 1) as f64)
428                }
429            }
430            // Population variance: M2 / n
431            AggregateState::VariancePop { count, m2, .. } => {
432                if *count == 0 {
433                    Value::Null
434                } else {
435                    Value::Float64(*m2 / *count as f64)
436                }
437            }
438            // Discrete percentile: return actual value at percentile position
439            AggregateState::PercentileDisc { values, percentile } => {
440                if values.is_empty() {
441                    Value::Null
442                } else {
443                    let mut sorted = values.clone();
444                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
445                    // Index calculation per SQL standard: floor(p * (n - 1))
446                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
447                    Value::Float64(sorted[index])
448                }
449            }
450            // Continuous percentile: interpolate between values
451            AggregateState::PercentileCont { values, percentile } => {
452                if values.is_empty() {
453                    Value::Null
454                } else {
455                    let mut sorted = values.clone();
456                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
457                    // Linear interpolation per SQL standard
458                    let rank = percentile * (sorted.len() - 1) as f64;
459                    let lower_idx = rank.floor() as usize;
460                    let upper_idx = rank.ceil() as usize;
461                    if lower_idx == upper_idx {
462                        Value::Float64(sorted[lower_idx])
463                    } else {
464                        let fraction = rank - lower_idx as f64;
465                        let result =
466                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
467                        Value::Float64(result)
468                    }
469                }
470            }
471            // GROUP_CONCAT: join strings with space separator (SPARQL default)
472            AggregateState::GroupConcat(list) | AggregateState::GroupConcatDistinct(list, _) => {
473                Value::String(list.join(" ").into())
474            }
475            // SAMPLE: return the first non-null value seen
476            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
477            // Binary set functions: dispatch on kind
478            AggregateState::Bivariate {
479                kind,
480                count,
481                mean_x,
482                mean_y,
483                m2_x,
484                m2_y,
485                c_xy,
486            } => {
487                let n = *count;
488                match kind {
489                    AggregateFunction::CovarSamp => {
490                        if n < 2 {
491                            Value::Null
492                        } else {
493                            Value::Float64(*c_xy / (n - 1) as f64)
494                        }
495                    }
496                    AggregateFunction::CovarPop => {
497                        if n == 0 {
498                            Value::Null
499                        } else {
500                            Value::Float64(*c_xy / n as f64)
501                        }
502                    }
503                    AggregateFunction::Corr => {
504                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
505                            Value::Null
506                        } else {
507                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
508                        }
509                    }
510                    AggregateFunction::RegrSlope => {
511                        if n == 0 || *m2_x == 0.0 {
512                            Value::Null
513                        } else {
514                            Value::Float64(*c_xy / *m2_x)
515                        }
516                    }
517                    AggregateFunction::RegrIntercept => {
518                        if n == 0 || *m2_x == 0.0 {
519                            Value::Null
520                        } else {
521                            let slope = *c_xy / *m2_x;
522                            Value::Float64(*mean_y - slope * *mean_x)
523                        }
524                    }
525                    AggregateFunction::RegrR2 => {
526                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
527                            Value::Null
528                        } else {
529                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
530                        }
531                    }
532                    AggregateFunction::RegrCount => Value::Int64(n),
533                    AggregateFunction::RegrSxx => {
534                        if n == 0 {
535                            Value::Null
536                        } else {
537                            Value::Float64(*m2_x)
538                        }
539                    }
540                    AggregateFunction::RegrSyy => {
541                        if n == 0 {
542                            Value::Null
543                        } else {
544                            Value::Float64(*m2_y)
545                        }
546                    }
547                    AggregateFunction::RegrSxy => {
548                        if n == 0 {
549                            Value::Null
550                        } else {
551                            Value::Float64(*c_xy)
552                        }
553                    }
554                    AggregateFunction::RegrAvgx => {
555                        if n == 0 {
556                            Value::Null
557                        } else {
558                            Value::Float64(*mean_x)
559                        }
560                    }
561                    AggregateFunction::RegrAvgy => {
562                        if n == 0 {
563                            Value::Null
564                        } else {
565                            Value::Float64(*mean_y)
566                        }
567                    }
568                    _ => Value::Null, // non-bivariate functions never reach here
569                }
570            }
571        }
572    }
573}
574
575use super::value_utils::{compare_values, value_to_f64};
576
577/// Converts a Value to its string representation for GROUP_CONCAT.
578fn agg_value_to_string(val: &Value) -> String {
579    match val {
580        Value::String(s) => s.to_string(),
581        Value::Int64(i) => i.to_string(),
582        Value::Float64(f) => f.to_string(),
583        Value::Bool(b) => b.to_string(),
584        Value::Null => String::new(),
585        other => format!("{other:?}"),
586    }
587}
588
589/// A group key for hash-based aggregation.
590#[derive(Debug, Clone, PartialEq, Eq, Hash)]
591pub struct GroupKey(Vec<GroupKeyPart>);
592
593#[derive(Debug, Clone, PartialEq, Eq, Hash)]
594enum GroupKeyPart {
595    Null,
596    Bool(bool),
597    Int64(i64),
598    String(String),
599}
600
601impl GroupKey {
602    /// Creates a group key from column values.
603    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
604        let parts: Vec<GroupKeyPart> = group_columns
605            .iter()
606            .map(|&col_idx| {
607                chunk
608                    .column(col_idx)
609                    .and_then(|col| col.get_value(row))
610                    .map_or(GroupKeyPart::Null, |v| match v {
611                        Value::Null => GroupKeyPart::Null,
612                        Value::Bool(b) => GroupKeyPart::Bool(b),
613                        Value::Int64(i) => GroupKeyPart::Int64(i),
614                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
615                        Value::String(s) => GroupKeyPart::String(s.to_string()),
616                        _ => GroupKeyPart::String(format!("{v:?}")),
617                    })
618            })
619            .collect();
620        GroupKey(parts)
621    }
622
623    /// Converts the group key back to values.
624    fn to_values(&self) -> Vec<Value> {
625        self.0
626            .iter()
627            .map(|part| match part {
628                GroupKeyPart::Null => Value::Null,
629                GroupKeyPart::Bool(b) => Value::Bool(*b),
630                GroupKeyPart::Int64(i) => Value::Int64(*i),
631                GroupKeyPart::String(s) => Value::String(s.clone().into()),
632            })
633            .collect()
634    }
635}
636
637/// Hash-based aggregate operator.
638///
639/// Groups input by key columns and computes aggregations for each group.
640pub struct HashAggregateOperator {
641    /// Child operator to read from.
642    child: Box<dyn Operator>,
643    /// Columns to group by.
644    group_columns: Vec<usize>,
645    /// Aggregation expressions.
646    aggregates: Vec<AggregateExpr>,
647    /// Output schema.
648    output_schema: Vec<LogicalType>,
649    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
650    groups: IndexMap<GroupKey, Vec<AggregateState>>,
651    /// Whether aggregation is complete.
652    aggregation_complete: bool,
653    /// Results iterator.
654    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
655}
656
657impl HashAggregateOperator {
658    /// Creates a new hash aggregate operator.
659    ///
660    /// # Arguments
661    /// * `child` - Child operator to read from.
662    /// * `group_columns` - Column indices to group by.
663    /// * `aggregates` - Aggregation expressions.
664    /// * `output_schema` - Schema of the output (group columns + aggregate results).
665    pub fn new(
666        child: Box<dyn Operator>,
667        group_columns: Vec<usize>,
668        aggregates: Vec<AggregateExpr>,
669        output_schema: Vec<LogicalType>,
670    ) -> Self {
671        Self {
672            child,
673            group_columns,
674            aggregates,
675            output_schema,
676            groups: IndexMap::new(),
677            aggregation_complete: false,
678            results: None,
679        }
680    }
681
682    /// Performs the aggregation.
683    fn aggregate(&mut self) -> Result<(), OperatorError> {
684        while let Some(chunk) = self.child.next()? {
685            for row in chunk.selected_indices() {
686                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
687
688                // Get or create aggregate states for this group
689                let states = self.groups.entry(key).or_insert_with(|| {
690                    self.aggregates
691                        .iter()
692                        .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
693                        .collect()
694                });
695
696                // Update each aggregate
697                for (i, agg) in self.aggregates.iter().enumerate() {
698                    // Binary set functions: read two column values
699                    if agg.column2.is_some() {
700                        let y_val = agg
701                            .column
702                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
703                        let x_val = agg
704                            .column2
705                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
706                        states[i].update_bivariate(y_val, x_val);
707                        continue;
708                    }
709
710                    let value = match (agg.function, agg.distinct) {
711                        // COUNT(*) without DISTINCT doesn't need a value
712                        (AggregateFunction::Count, false) => None,
713                        // COUNT DISTINCT needs the actual value to track unique values
714                        (AggregateFunction::Count, true) => agg
715                            .column
716                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
717                        _ => agg
718                            .column
719                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
720                    };
721
722                    // For COUNT without DISTINCT, always update. For others, skip nulls.
723                    match (agg.function, agg.distinct) {
724                        (AggregateFunction::Count, false) => states[i].update(None),
725                        (AggregateFunction::Count, true) => {
726                            // COUNT DISTINCT needs the value to track unique values
727                            if value.is_some() && !matches!(value, Some(Value::Null)) {
728                                states[i].update(value);
729                            }
730                        }
731                        (AggregateFunction::CountNonNull, _) => {
732                            if value.is_some() && !matches!(value, Some(Value::Null)) {
733                                states[i].update(value);
734                            }
735                        }
736                        _ => {
737                            if value.is_some() && !matches!(value, Some(Value::Null)) {
738                                states[i].update(value);
739                            }
740                        }
741                    }
742                }
743            }
744        }
745
746        self.aggregation_complete = true;
747
748        // Convert to results iterator (IndexMap::drain takes a range)
749        let results: Vec<_> = self.groups.drain(..).collect();
750        self.results = Some(results.into_iter());
751
752        Ok(())
753    }
754}
755
756impl Operator for HashAggregateOperator {
757    fn next(&mut self) -> OperatorResult {
758        // Perform aggregation if not done
759        if !self.aggregation_complete {
760            self.aggregate()?;
761        }
762
763        // Special case: no groups (global aggregation with no data)
764        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
765            // For global aggregation (no GROUP BY), return one row with initial values
766            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
767
768            for agg in &self.aggregates {
769                let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
770                let value = state.finalize();
771                if let Some(col) = builder.column_mut(self.group_columns.len()) {
772                    col.push_value(value);
773                }
774            }
775            builder.advance_row();
776
777            self.results = Some(Vec::new().into_iter()); // Mark as done
778            return Ok(Some(builder.finish()));
779        }
780
781        let Some(results) = &mut self.results else {
782            return Ok(None);
783        };
784
785        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
786
787        for (key, states) in results.by_ref() {
788            // Output group key columns
789            let key_values = key.to_values();
790            for (i, value) in key_values.into_iter().enumerate() {
791                if let Some(col) = builder.column_mut(i) {
792                    col.push_value(value);
793                }
794            }
795
796            // Output aggregate results
797            for (i, state) in states.iter().enumerate() {
798                let col_idx = self.group_columns.len() + i;
799                if let Some(col) = builder.column_mut(col_idx) {
800                    col.push_value(state.finalize());
801                }
802            }
803
804            builder.advance_row();
805
806            if builder.is_full() {
807                return Ok(Some(builder.finish()));
808            }
809        }
810
811        if builder.row_count() > 0 {
812            Ok(Some(builder.finish()))
813        } else {
814            Ok(None)
815        }
816    }
817
818    fn reset(&mut self) {
819        self.child.reset();
820        self.groups.clear();
821        self.aggregation_complete = false;
822        self.results = None;
823    }
824
825    fn name(&self) -> &'static str {
826        "HashAggregate"
827    }
828}
829
830/// Simple (non-grouping) aggregate operator for global aggregations.
831///
832/// Used when there's no GROUP BY clause - aggregates all input into a single row.
833pub struct SimpleAggregateOperator {
834    /// Child operator.
835    child: Box<dyn Operator>,
836    /// Aggregation expressions.
837    aggregates: Vec<AggregateExpr>,
838    /// Output schema.
839    output_schema: Vec<LogicalType>,
840    /// Aggregate states.
841    states: Vec<AggregateState>,
842    /// Whether aggregation is complete.
843    done: bool,
844}
845
846impl SimpleAggregateOperator {
847    /// Creates a new simple aggregate operator.
848    pub fn new(
849        child: Box<dyn Operator>,
850        aggregates: Vec<AggregateExpr>,
851        output_schema: Vec<LogicalType>,
852    ) -> Self {
853        let states = aggregates
854            .iter()
855            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
856            .collect();
857
858        Self {
859            child,
860            aggregates,
861            output_schema,
862            states,
863            done: false,
864        }
865    }
866}
867
868impl Operator for SimpleAggregateOperator {
869    fn next(&mut self) -> OperatorResult {
870        if self.done {
871            return Ok(None);
872        }
873
874        // Process all input
875        while let Some(chunk) = self.child.next()? {
876            for row in chunk.selected_indices() {
877                for (i, agg) in self.aggregates.iter().enumerate() {
878                    // Binary set functions: read two column values
879                    if agg.column2.is_some() {
880                        let y_val = agg
881                            .column
882                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
883                        let x_val = agg
884                            .column2
885                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
886                        self.states[i].update_bivariate(y_val, x_val);
887                        continue;
888                    }
889
890                    let value = match (agg.function, agg.distinct) {
891                        // COUNT(*) without DISTINCT doesn't need a value
892                        (AggregateFunction::Count, false) => None,
893                        // COUNT DISTINCT needs the actual value to track unique values
894                        (AggregateFunction::Count, true) => agg
895                            .column
896                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
897                        _ => agg
898                            .column
899                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
900                    };
901
902                    match (agg.function, agg.distinct) {
903                        (AggregateFunction::Count, false) => self.states[i].update(None),
904                        (AggregateFunction::Count, true) => {
905                            // COUNT DISTINCT needs the value to track unique values
906                            if value.is_some() && !matches!(value, Some(Value::Null)) {
907                                self.states[i].update(value);
908                            }
909                        }
910                        (AggregateFunction::CountNonNull, _) => {
911                            if value.is_some() && !matches!(value, Some(Value::Null)) {
912                                self.states[i].update(value);
913                            }
914                        }
915                        _ => {
916                            if value.is_some() && !matches!(value, Some(Value::Null)) {
917                                self.states[i].update(value);
918                            }
919                        }
920                    }
921                }
922            }
923        }
924
925        // Output single result row
926        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
927
928        for (i, state) in self.states.iter().enumerate() {
929            if let Some(col) = builder.column_mut(i) {
930                col.push_value(state.finalize());
931            }
932        }
933        builder.advance_row();
934
935        self.done = true;
936        Ok(Some(builder.finish()))
937    }
938
939    fn reset(&mut self) {
940        self.child.reset();
941        self.states = self
942            .aggregates
943            .iter()
944            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
945            .collect();
946        self.done = false;
947    }
948
949    fn name(&self) -> &'static str {
950        "SimpleAggregate"
951    }
952}
953
954#[cfg(test)]
955mod tests {
956    use super::*;
957    use crate::execution::chunk::DataChunkBuilder;
958
959    struct MockOperator {
960        chunks: Vec<DataChunk>,
961        position: usize,
962    }
963
964    impl MockOperator {
965        fn new(chunks: Vec<DataChunk>) -> Self {
966            Self {
967                chunks,
968                position: 0,
969            }
970        }
971    }
972
973    impl Operator for MockOperator {
974        fn next(&mut self) -> OperatorResult {
975            if self.position < self.chunks.len() {
976                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
977                self.position += 1;
978                Ok(Some(chunk))
979            } else {
980                Ok(None)
981            }
982        }
983
984        fn reset(&mut self) {
985            self.position = 0;
986        }
987
988        fn name(&self) -> &'static str {
989            "Mock"
990        }
991    }
992
993    fn create_test_chunk() -> DataChunk {
994        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
995        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
996
997        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
998        for (group, value) in data {
999            builder.column_mut(0).unwrap().push_int64(group);
1000            builder.column_mut(1).unwrap().push_int64(value);
1001            builder.advance_row();
1002        }
1003
1004        builder.finish()
1005    }
1006
1007    #[test]
1008    fn test_simple_count() {
1009        let mock = MockOperator::new(vec![create_test_chunk()]);
1010
1011        let mut agg = SimpleAggregateOperator::new(
1012            Box::new(mock),
1013            vec![AggregateExpr::count_star()],
1014            vec![LogicalType::Int64],
1015        );
1016
1017        let result = agg.next().unwrap().unwrap();
1018        assert_eq!(result.row_count(), 1);
1019        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1020
1021        // Should be done
1022        assert!(agg.next().unwrap().is_none());
1023    }
1024
1025    #[test]
1026    fn test_simple_sum() {
1027        let mock = MockOperator::new(vec![create_test_chunk()]);
1028
1029        let mut agg = SimpleAggregateOperator::new(
1030            Box::new(mock),
1031            vec![AggregateExpr::sum(1)], // Sum of column 1
1032            vec![LogicalType::Int64],
1033        );
1034
1035        let result = agg.next().unwrap().unwrap();
1036        assert_eq!(result.row_count(), 1);
1037        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1038        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1039    }
1040
1041    #[test]
1042    fn test_simple_avg() {
1043        let mock = MockOperator::new(vec![create_test_chunk()]);
1044
1045        let mut agg = SimpleAggregateOperator::new(
1046            Box::new(mock),
1047            vec![AggregateExpr::avg(1)],
1048            vec![LogicalType::Float64],
1049        );
1050
1051        let result = agg.next().unwrap().unwrap();
1052        assert_eq!(result.row_count(), 1);
1053        // Avg: 150 / 5 = 30.0
1054        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1055        assert!((avg - 30.0).abs() < 0.001);
1056    }
1057
1058    #[test]
1059    fn test_simple_min_max() {
1060        let mock = MockOperator::new(vec![create_test_chunk()]);
1061
1062        let mut agg = SimpleAggregateOperator::new(
1063            Box::new(mock),
1064            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1065            vec![LogicalType::Int64, LogicalType::Int64],
1066        );
1067
1068        let result = agg.next().unwrap().unwrap();
1069        assert_eq!(result.row_count(), 1);
1070        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1071        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1072    }
1073
1074    #[test]
1075    fn test_sum_with_string_values() {
1076        // Test SUM with string values (like RDF stores numeric literals)
1077        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1078        builder.column_mut(0).unwrap().push_string("30");
1079        builder.advance_row();
1080        builder.column_mut(0).unwrap().push_string("25");
1081        builder.advance_row();
1082        builder.column_mut(0).unwrap().push_string("35");
1083        builder.advance_row();
1084        let chunk = builder.finish();
1085
1086        let mock = MockOperator::new(vec![chunk]);
1087        let mut agg = SimpleAggregateOperator::new(
1088            Box::new(mock),
1089            vec![AggregateExpr::sum(0)],
1090            vec![LogicalType::Float64],
1091        );
1092
1093        let result = agg.next().unwrap().unwrap();
1094        assert_eq!(result.row_count(), 1);
1095        // Should parse strings and sum: 30 + 25 + 35 = 90
1096        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1097        assert!(
1098            (sum_val - 90.0).abs() < 0.001,
1099            "Expected 90.0, got {}",
1100            sum_val
1101        );
1102    }
1103
1104    #[test]
1105    fn test_grouped_aggregation() {
1106        let mock = MockOperator::new(vec![create_test_chunk()]);
1107
1108        // GROUP BY column 0, SUM(column 1)
1109        let mut agg = HashAggregateOperator::new(
1110            Box::new(mock),
1111            vec![0],                     // Group by column 0
1112            vec![AggregateExpr::sum(1)], // Sum of column 1
1113            vec![LogicalType::Int64, LogicalType::Int64],
1114        );
1115
1116        let mut results: Vec<(i64, i64)> = Vec::new();
1117        while let Some(chunk) = agg.next().unwrap() {
1118            for row in chunk.selected_indices() {
1119                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1120                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1121                results.push((group, sum));
1122            }
1123        }
1124
1125        results.sort_by_key(|(g, _)| *g);
1126        assert_eq!(results.len(), 2);
1127        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1128        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1129    }
1130
1131    #[test]
1132    fn test_grouped_count() {
1133        let mock = MockOperator::new(vec![create_test_chunk()]);
1134
1135        // GROUP BY column 0, COUNT(*)
1136        let mut agg = HashAggregateOperator::new(
1137            Box::new(mock),
1138            vec![0],
1139            vec![AggregateExpr::count_star()],
1140            vec![LogicalType::Int64, LogicalType::Int64],
1141        );
1142
1143        let mut results: Vec<(i64, i64)> = Vec::new();
1144        while let Some(chunk) = agg.next().unwrap() {
1145            for row in chunk.selected_indices() {
1146                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1147                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1148                results.push((group, count));
1149            }
1150        }
1151
1152        results.sort_by_key(|(g, _)| *g);
1153        assert_eq!(results.len(), 2);
1154        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1155        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1156    }
1157
1158    #[test]
1159    fn test_multiple_aggregates() {
1160        let mock = MockOperator::new(vec![create_test_chunk()]);
1161
1162        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1163        let mut agg = HashAggregateOperator::new(
1164            Box::new(mock),
1165            vec![0],
1166            vec![
1167                AggregateExpr::count_star(),
1168                AggregateExpr::sum(1),
1169                AggregateExpr::avg(1),
1170            ],
1171            vec![
1172                LogicalType::Int64,   // Group key
1173                LogicalType::Int64,   // COUNT
1174                LogicalType::Int64,   // SUM
1175                LogicalType::Float64, // AVG
1176            ],
1177        );
1178
1179        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1180        while let Some(chunk) = agg.next().unwrap() {
1181            for row in chunk.selected_indices() {
1182                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1183                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1184                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1185                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1186                results.push((group, count, sum, avg));
1187            }
1188        }
1189
1190        results.sort_by_key(|(g, _, _, _)| *g);
1191        assert_eq!(results.len(), 2);
1192
1193        // Group 1: COUNT=2, SUM=30, AVG=15.0
1194        assert_eq!(results[0].0, 1);
1195        assert_eq!(results[0].1, 2);
1196        assert_eq!(results[0].2, 30);
1197        assert!((results[0].3 - 15.0).abs() < 0.001);
1198
1199        // Group 2: COUNT=3, SUM=120, AVG=40.0
1200        assert_eq!(results[1].0, 2);
1201        assert_eq!(results[1].1, 3);
1202        assert_eq!(results[1].2, 120);
1203        assert!((results[1].3 - 40.0).abs() < 0.001);
1204    }
1205
1206    fn create_test_chunk_with_duplicates() -> DataChunk {
1207        // Create data with duplicate values in column 1
1208        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1209        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1210        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1211        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1212
1213        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1214        for (group, value) in data {
1215            builder.column_mut(0).unwrap().push_int64(group);
1216            builder.column_mut(1).unwrap().push_int64(value);
1217            builder.advance_row();
1218        }
1219
1220        builder.finish()
1221    }
1222
1223    #[test]
1224    fn test_count_distinct() {
1225        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1226
1227        // COUNT(DISTINCT column 1)
1228        let mut agg = SimpleAggregateOperator::new(
1229            Box::new(mock),
1230            vec![AggregateExpr::count(1).with_distinct()],
1231            vec![LogicalType::Int64],
1232        );
1233
1234        let result = agg.next().unwrap().unwrap();
1235        assert_eq!(result.row_count(), 1);
1236        // Total distinct values: 10, 20, 30 = 3 distinct values
1237        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1238    }
1239
1240    #[test]
1241    fn test_grouped_count_distinct() {
1242        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1243
1244        // GROUP BY column 0, COUNT(DISTINCT column 1)
1245        let mut agg = HashAggregateOperator::new(
1246            Box::new(mock),
1247            vec![0],
1248            vec![AggregateExpr::count(1).with_distinct()],
1249            vec![LogicalType::Int64, LogicalType::Int64],
1250        );
1251
1252        let mut results: Vec<(i64, i64)> = Vec::new();
1253        while let Some(chunk) = agg.next().unwrap() {
1254            for row in chunk.selected_indices() {
1255                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1256                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1257                results.push((group, count));
1258            }
1259        }
1260
1261        results.sort_by_key(|(g, _)| *g);
1262        assert_eq!(results.len(), 2);
1263        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1264        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1265    }
1266
1267    #[test]
1268    fn test_sum_distinct() {
1269        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1270
1271        // SUM(DISTINCT column 1)
1272        let mut agg = SimpleAggregateOperator::new(
1273            Box::new(mock),
1274            vec![AggregateExpr::sum(1).with_distinct()],
1275            vec![LogicalType::Int64],
1276        );
1277
1278        let result = agg.next().unwrap().unwrap();
1279        assert_eq!(result.row_count(), 1);
1280        // Sum of distinct values: 10 + 20 + 30 = 60
1281        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1282    }
1283
1284    #[test]
1285    fn test_avg_distinct() {
1286        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1287
1288        // AVG(DISTINCT column 1)
1289        let mut agg = SimpleAggregateOperator::new(
1290            Box::new(mock),
1291            vec![AggregateExpr::avg(1).with_distinct()],
1292            vec![LogicalType::Float64],
1293        );
1294
1295        let result = agg.next().unwrap().unwrap();
1296        assert_eq!(result.row_count(), 1);
1297        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1298        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1299        assert!((avg - 20.0).abs() < 0.001);
1300    }
1301
1302    fn create_statistical_test_chunk() -> DataChunk {
1303        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1304        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1305        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1306
1307        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1308            builder.column_mut(0).unwrap().push_int64(value);
1309            builder.advance_row();
1310        }
1311
1312        builder.finish()
1313    }
1314
1315    #[test]
1316    fn test_stdev_sample() {
1317        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1318
1319        let mut agg = SimpleAggregateOperator::new(
1320            Box::new(mock),
1321            vec![AggregateExpr::stdev(0)],
1322            vec![LogicalType::Float64],
1323        );
1324
1325        let result = agg.next().unwrap().unwrap();
1326        assert_eq!(result.row_count(), 1);
1327        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1328        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1329        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1330        assert!((stdev - 2.138).abs() < 0.01);
1331    }
1332
1333    #[test]
1334    fn test_stdev_population() {
1335        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1336
1337        let mut agg = SimpleAggregateOperator::new(
1338            Box::new(mock),
1339            vec![AggregateExpr::stdev_pop(0)],
1340            vec![LogicalType::Float64],
1341        );
1342
1343        let result = agg.next().unwrap().unwrap();
1344        assert_eq!(result.row_count(), 1);
1345        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1346        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1347        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1348        assert!((stdev - 2.0).abs() < 0.01);
1349    }
1350
1351    #[test]
1352    fn test_percentile_disc() {
1353        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1354
1355        // Median (50th percentile discrete)
1356        let mut agg = SimpleAggregateOperator::new(
1357            Box::new(mock),
1358            vec![AggregateExpr::percentile_disc(0, 0.5)],
1359            vec![LogicalType::Float64],
1360        );
1361
1362        let result = agg.next().unwrap().unwrap();
1363        assert_eq!(result.row_count(), 1);
1364        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1365        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1366        assert!((percentile - 4.0).abs() < 0.01);
1367    }
1368
1369    #[test]
1370    fn test_percentile_cont() {
1371        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1372
1373        // Median (50th percentile continuous)
1374        let mut agg = SimpleAggregateOperator::new(
1375            Box::new(mock),
1376            vec![AggregateExpr::percentile_cont(0, 0.5)],
1377            vec![LogicalType::Float64],
1378        );
1379
1380        let result = agg.next().unwrap().unwrap();
1381        assert_eq!(result.row_count(), 1);
1382        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1383        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1384        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1385        assert!((percentile - 4.5).abs() < 0.01);
1386    }
1387
1388    #[test]
1389    fn test_percentile_extremes() {
1390        // Test 0th and 100th percentiles
1391        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1392
1393        let mut agg = SimpleAggregateOperator::new(
1394            Box::new(mock),
1395            vec![
1396                AggregateExpr::percentile_disc(0, 0.0),
1397                AggregateExpr::percentile_disc(0, 1.0),
1398            ],
1399            vec![LogicalType::Float64, LogicalType::Float64],
1400        );
1401
1402        let result = agg.next().unwrap().unwrap();
1403        assert_eq!(result.row_count(), 1);
1404        // 0th percentile = minimum = 2
1405        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1406        assert!((p0 - 2.0).abs() < 0.01);
1407        // 100th percentile = maximum = 9
1408        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1409        assert!((p100 - 9.0).abs() < 0.01);
1410    }
1411
1412    #[test]
1413    fn test_stdev_single_value() {
1414        // Single value should return null for sample stdev
1415        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1416        builder.column_mut(0).unwrap().push_int64(42);
1417        builder.advance_row();
1418        let chunk = builder.finish();
1419
1420        let mock = MockOperator::new(vec![chunk]);
1421
1422        let mut agg = SimpleAggregateOperator::new(
1423            Box::new(mock),
1424            vec![AggregateExpr::stdev(0)],
1425            vec![LogicalType::Float64],
1426        );
1427
1428        let result = agg.next().unwrap().unwrap();
1429        assert_eq!(result.row_count(), 1);
1430        // Sample stdev of single value is undefined (null)
1431        assert!(matches!(
1432            result.column(0).unwrap().get_value(0),
1433            Some(Value::Null)
1434        ));
1435    }
1436
1437    #[test]
1438    fn test_first_and_last() {
1439        let mock = MockOperator::new(vec![create_test_chunk()]);
1440
1441        let mut agg = SimpleAggregateOperator::new(
1442            Box::new(mock),
1443            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1444            vec![LogicalType::Int64, LogicalType::Int64],
1445        );
1446
1447        let result = agg.next().unwrap().unwrap();
1448        assert_eq!(result.row_count(), 1);
1449        // First: 10, Last: 50
1450        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1451        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1452    }
1453
1454    #[test]
1455    fn test_collect() {
1456        let mock = MockOperator::new(vec![create_test_chunk()]);
1457
1458        let mut agg = SimpleAggregateOperator::new(
1459            Box::new(mock),
1460            vec![AggregateExpr::collect(1)],
1461            vec![LogicalType::Any],
1462        );
1463
1464        let result = agg.next().unwrap().unwrap();
1465        let val = result.column(0).unwrap().get_value(0).unwrap();
1466        if let Value::List(items) = val {
1467            assert_eq!(items.len(), 5);
1468        } else {
1469            panic!("Expected List value");
1470        }
1471    }
1472
1473    #[test]
1474    fn test_collect_distinct() {
1475        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1476
1477        let mut agg = SimpleAggregateOperator::new(
1478            Box::new(mock),
1479            vec![AggregateExpr::collect(1).with_distinct()],
1480            vec![LogicalType::Any],
1481        );
1482
1483        let result = agg.next().unwrap().unwrap();
1484        let val = result.column(0).unwrap().get_value(0).unwrap();
1485        if let Value::List(items) = val {
1486            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1487            assert_eq!(items.len(), 3);
1488        } else {
1489            panic!("Expected List value");
1490        }
1491    }
1492
1493    #[test]
1494    fn test_group_concat() {
1495        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1496        for s in ["hello", "world", "foo"] {
1497            builder.column_mut(0).unwrap().push_string(s);
1498            builder.advance_row();
1499        }
1500        let chunk = builder.finish();
1501        let mock = MockOperator::new(vec![chunk]);
1502
1503        let agg_expr = AggregateExpr {
1504            function: AggregateFunction::GroupConcat,
1505            column: Some(0),
1506            column2: None,
1507            distinct: false,
1508            alias: None,
1509            percentile: None,
1510        };
1511
1512        let mut agg =
1513            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1514
1515        let result = agg.next().unwrap().unwrap();
1516        let val = result.column(0).unwrap().get_value(0).unwrap();
1517        assert_eq!(val, Value::String("hello world foo".into()));
1518    }
1519
1520    #[test]
1521    fn test_sample() {
1522        let mock = MockOperator::new(vec![create_test_chunk()]);
1523
1524        let agg_expr = AggregateExpr {
1525            function: AggregateFunction::Sample,
1526            column: Some(1),
1527            column2: None,
1528            distinct: false,
1529            alias: None,
1530            percentile: None,
1531        };
1532
1533        let mut agg =
1534            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1535
1536        let result = agg.next().unwrap().unwrap();
1537        // Sample should return the first non-null value (10)
1538        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1539    }
1540
1541    #[test]
1542    fn test_variance_sample() {
1543        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1544
1545        let agg_expr = AggregateExpr {
1546            function: AggregateFunction::Variance,
1547            column: Some(0),
1548            column2: None,
1549            distinct: false,
1550            alias: None,
1551            percentile: None,
1552        };
1553
1554        let mut agg = SimpleAggregateOperator::new(
1555            Box::new(mock),
1556            vec![agg_expr],
1557            vec![LogicalType::Float64],
1558        );
1559
1560        let result = agg.next().unwrap().unwrap();
1561        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1562        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1563        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1564    }
1565
1566    #[test]
1567    fn test_variance_population() {
1568        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1569
1570        let agg_expr = AggregateExpr {
1571            function: AggregateFunction::VariancePop,
1572            column: Some(0),
1573            column2: None,
1574            distinct: false,
1575            alias: None,
1576            percentile: None,
1577        };
1578
1579        let mut agg = SimpleAggregateOperator::new(
1580            Box::new(mock),
1581            vec![agg_expr],
1582            vec![LogicalType::Float64],
1583        );
1584
1585        let result = agg.next().unwrap().unwrap();
1586        // Population variance: M2/n = 32/8 = 4.0
1587        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1588        assert!((variance - 4.0).abs() < 0.01);
1589    }
1590
1591    #[test]
1592    fn test_variance_single_value() {
1593        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1594        builder.column_mut(0).unwrap().push_int64(42);
1595        builder.advance_row();
1596        let chunk = builder.finish();
1597        let mock = MockOperator::new(vec![chunk]);
1598
1599        let agg_expr = AggregateExpr {
1600            function: AggregateFunction::Variance,
1601            column: Some(0),
1602            column2: None,
1603            distinct: false,
1604            alias: None,
1605            percentile: None,
1606        };
1607
1608        let mut agg = SimpleAggregateOperator::new(
1609            Box::new(mock),
1610            vec![agg_expr],
1611            vec![LogicalType::Float64],
1612        );
1613
1614        let result = agg.next().unwrap().unwrap();
1615        // Sample variance of single value is undefined (null)
1616        assert!(matches!(
1617            result.column(0).unwrap().get_value(0),
1618            Some(Value::Null)
1619        ));
1620    }
1621
1622    #[test]
1623    fn test_empty_aggregation() {
1624        // No input rows: COUNT should be 0, SUM 0, AVG null, MIN/MAX null
1625        let mock = MockOperator::new(vec![]);
1626
1627        let mut agg = SimpleAggregateOperator::new(
1628            Box::new(mock),
1629            vec![
1630                AggregateExpr::count_star(),
1631                AggregateExpr::sum(0),
1632                AggregateExpr::avg(0),
1633                AggregateExpr::min(0),
1634                AggregateExpr::max(0),
1635            ],
1636            vec![
1637                LogicalType::Int64,
1638                LogicalType::Int64,
1639                LogicalType::Float64,
1640                LogicalType::Int64,
1641                LogicalType::Int64,
1642            ],
1643        );
1644
1645        let result = agg.next().unwrap().unwrap();
1646        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1647        assert_eq!(result.column(1).unwrap().get_int64(0), Some(0)); // SUM
1648        assert!(matches!(
1649            result.column(2).unwrap().get_value(0),
1650            Some(Value::Null)
1651        )); // AVG
1652        assert!(matches!(
1653            result.column(3).unwrap().get_value(0),
1654            Some(Value::Null)
1655        )); // MIN
1656        assert!(matches!(
1657            result.column(4).unwrap().get_value(0),
1658            Some(Value::Null)
1659        )); // MAX
1660    }
1661
1662    #[test]
1663    fn test_stdev_pop_single_value() {
1664        // Single value should return 0 for population stdev
1665        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1666        builder.column_mut(0).unwrap().push_int64(42);
1667        builder.advance_row();
1668        let chunk = builder.finish();
1669
1670        let mock = MockOperator::new(vec![chunk]);
1671
1672        let mut agg = SimpleAggregateOperator::new(
1673            Box::new(mock),
1674            vec![AggregateExpr::stdev_pop(0)],
1675            vec![LogicalType::Float64],
1676        );
1677
1678        let result = agg.next().unwrap().unwrap();
1679        assert_eq!(result.row_count(), 1);
1680        // Population stdev of single value is 0
1681        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1682        assert!((stdev - 0.0).abs() < 0.01);
1683    }
1684}