Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22/// State for a single aggregation computation.
23///
24/// Used by both the pull-based [`HashAggregateOperator`] and the push-based
25/// `AggregatePushOperator`.
26/// Supports all [`AggregateFunction`] variants including Welford's algorithm
27/// for online statistics, Kahan summation, distinct tracking, and bivariate
28/// regression functions.
29#[derive(Debug, Clone)]
30#[allow(missing_docs)]
31pub enum AggregateState {
32    /// Count state.
33    Count(i64),
34    /// Count distinct state (count, seen values).
35    CountDistinct(i64, HashSet<HashableValue>),
36    /// Sum state (integer sum, count of values added).
37    SumInt(i64, i64),
38    /// Sum distinct state (integer sum, count, seen values).
39    SumIntDistinct(i64, i64, HashSet<HashableValue>),
40    /// Sum state (float sum, count of values added).
41    SumFloat(f64, f64, i64),
42    /// Sum distinct state (float sum, compensation, count, seen values).
43    SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
44    /// Average state (sum, count).
45    Avg(f64, i64),
46    /// Average distinct state (sum, count, seen values).
47    AvgDistinct(f64, i64, HashSet<HashableValue>),
48    /// Min state.
49    Min(Option<Value>),
50    /// Max state.
51    Max(Option<Value>),
52    /// First state.
53    First(Option<Value>),
54    /// Last state.
55    Last(Option<Value>),
56    /// Collect state.
57    Collect(Vec<Value>),
58    /// Collect distinct state (values, seen).
59    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
60    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
61    StdDev { count: i64, mean: f64, m2: f64 },
62    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
63    StdDevPop { count: i64, mean: f64, m2: f64 },
64    /// Discrete percentile state (values, percentile).
65    PercentileDisc { values: Vec<f64>, percentile: f64 },
66    /// Continuous percentile state (values, percentile).
67    PercentileCont { values: Vec<f64>, percentile: f64 },
68    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
69    GroupConcat(Vec<String>, String),
70    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
71    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
72    /// SAMPLE state (first non-null value encountered).
73    Sample(Option<Value>),
74    /// Sample variance state using Welford's algorithm (count, mean, M2).
75    Variance { count: i64, mean: f64, m2: f64 },
76    /// Population variance state using Welford's algorithm (count, mean, M2).
77    VariancePop { count: i64, mean: f64, m2: f64 },
78    /// Two-variable online statistics (Welford generalization for covariance/regression).
79    Bivariate {
80        /// Which binary set function this state will finalize to.
81        kind: AggregateFunction,
82        count: i64,
83        mean_x: f64,
84        mean_y: f64,
85        m2_x: f64,
86        m2_y: f64,
87        c_xy: f64,
88    },
89    /// Immutable finalized value restored from spill. Ignores further updates
90    /// so that reloaded groups that were serialized via the FINALIZED fallback
91    /// do not silently corrupt their result when more rows arrive.
92    Frozen(Value),
93}
94
95impl AggregateState {
96    /// Creates initial state for an aggregation function.
97    pub fn new(
98        function: AggregateFunction,
99        distinct: bool,
100        percentile: Option<f64>,
101        separator: Option<&str>,
102    ) -> Self {
103        match (function, distinct) {
104            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
105                AggregateState::Count(0)
106            }
107            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
108                AggregateState::CountDistinct(0, HashSet::new())
109            }
110            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
111            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
112            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
113            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
114            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
115            (AggregateFunction::Max, _) => AggregateState::Max(None),
116            (AggregateFunction::First, _) => AggregateState::First(None),
117            (AggregateFunction::Last, _) => AggregateState::Last(None),
118            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
119            (AggregateFunction::Collect, true) => {
120                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
121            }
122            // Statistical functions (Welford's algorithm for online computation)
123            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
124                count: 0,
125                mean: 0.0,
126                m2: 0.0,
127            },
128            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
129                count: 0,
130                mean: 0.0,
131                m2: 0.0,
132            },
133            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
134                values: Vec::new(),
135                percentile: percentile.unwrap_or(0.5),
136            },
137            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
138                values: Vec::new(),
139                percentile: percentile.unwrap_or(0.5),
140            },
141            (AggregateFunction::GroupConcat, false) => {
142                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
143            }
144            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
145                Vec::new(),
146                separator.unwrap_or(" ").to_string(),
147                HashSet::new(),
148            ),
149            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
150            // Binary set functions (all share the same Bivariate state)
151            (
152                AggregateFunction::CovarSamp
153                | AggregateFunction::CovarPop
154                | AggregateFunction::Corr
155                | AggregateFunction::RegrSlope
156                | AggregateFunction::RegrIntercept
157                | AggregateFunction::RegrR2
158                | AggregateFunction::RegrCount
159                | AggregateFunction::RegrSxx
160                | AggregateFunction::RegrSyy
161                | AggregateFunction::RegrSxy
162                | AggregateFunction::RegrAvgx
163                | AggregateFunction::RegrAvgy,
164                _,
165            ) => AggregateState::Bivariate {
166                kind: function,
167                count: 0,
168                mean_x: 0.0,
169                mean_y: 0.0,
170                m2_x: 0.0,
171                m2_y: 0.0,
172                c_xy: 0.0,
173            },
174            (AggregateFunction::Variance, _) => AggregateState::Variance {
175                count: 0,
176                mean: 0.0,
177                m2: 0.0,
178            },
179            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
180                count: 0,
181                mean: 0.0,
182                m2: 0.0,
183            },
184        }
185    }
186
187    /// Updates the state with a new value.
188    ///
189    /// For `COUNT(*)`, pass `None` to count all rows. For column-specific
190    /// aggregates, pass `Some(value)` (nulls are skipped by most functions).
191    pub fn update(&mut self, value: Option<Value>) {
192        match self {
193            AggregateState::Count(count) => {
194                *count += 1;
195            }
196            AggregateState::CountDistinct(count, seen) => {
197                if let Some(ref v) = value {
198                    let hashable = HashableValue::from(v);
199                    if seen.insert(hashable) {
200                        *count += 1;
201                    }
202                }
203            }
204            AggregateState::SumInt(sum, count) => {
205                if let Some(Value::Int64(v)) = value {
206                    *sum += v;
207                    *count += 1;
208                } else if let Some(Value::Float64(v)) = value {
209                    // Convert to float sum, carrying count forward
210                    *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
211                } else if let Some(ref v) = value {
212                    // RDF stores numeric literals as strings - try to parse
213                    if let Some(num) = value_to_f64(v) {
214                        *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
215                    }
216                }
217            }
218            AggregateState::SumIntDistinct(sum, count, seen) => {
219                if let Some(ref v) = value {
220                    let hashable = HashableValue::from(v);
221                    if seen.insert(hashable) {
222                        if let Value::Int64(i) = v {
223                            *sum += i;
224                            *count += 1;
225                        } else if let Value::Float64(f) = v {
226                            // Convert to float distinct: move the seen set instead of cloning
227                            let moved_seen = std::mem::take(seen);
228                            *self = AggregateState::SumFloatDistinct(
229                                *sum as f64 + f,
230                                0.0,
231                                *count + 1,
232                                moved_seen,
233                            );
234                        } else if let Some(num) = value_to_f64(v) {
235                            // RDF string-encoded numerics
236                            let moved_seen = std::mem::take(seen);
237                            *self = AggregateState::SumFloatDistinct(
238                                *sum as f64 + num,
239                                0.0,
240                                *count + 1,
241                                moved_seen,
242                            );
243                        }
244                    }
245                }
246            }
247            AggregateState::SumFloat(sum, comp, count) => {
248                if let Some(ref v) = value {
249                    // Use value_to_f64 which now handles strings
250                    if let Some(num) = value_to_f64(v) {
251                        // Kahan compensated summation to reduce rounding error
252                        let y = num - *comp;
253                        let t = *sum + y;
254                        *comp = (t - *sum) - y;
255                        *sum = t;
256                        *count += 1;
257                    }
258                }
259            }
260            AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
261                if let Some(ref v) = value {
262                    let hashable = HashableValue::from(v);
263                    if seen.insert(hashable)
264                        && let Some(num) = value_to_f64(v)
265                    {
266                        let y = num - *comp;
267                        let t = *sum + y;
268                        *comp = (t - *sum) - y;
269                        *sum = t;
270                        *count += 1;
271                    }
272                }
273            }
274            AggregateState::Avg(sum, count) => {
275                if let Some(ref v) = value
276                    && let Some(num) = value_to_f64(v)
277                {
278                    *sum += num;
279                    *count += 1;
280                }
281            }
282            AggregateState::AvgDistinct(sum, count, seen) => {
283                if let Some(ref v) = value {
284                    let hashable = HashableValue::from(v);
285                    if seen.insert(hashable)
286                        && let Some(num) = value_to_f64(v)
287                    {
288                        *sum += num;
289                        *count += 1;
290                    }
291                }
292            }
293            AggregateState::Min(min) => {
294                if let Some(v) = value {
295                    match min {
296                        None => *min = Some(v),
297                        Some(current) => {
298                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
299                                *min = Some(v);
300                            }
301                        }
302                    }
303                }
304            }
305            AggregateState::Max(max) => {
306                if let Some(v) = value {
307                    match max {
308                        None => *max = Some(v),
309                        Some(current) => {
310                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
311                                *max = Some(v);
312                            }
313                        }
314                    }
315                }
316            }
317            AggregateState::First(first) => {
318                if first.is_none() {
319                    *first = value;
320                }
321            }
322            AggregateState::Last(last) => {
323                if value.is_some() {
324                    *last = value;
325                }
326            }
327            AggregateState::Collect(list) => {
328                if let Some(v) = value {
329                    list.push(v);
330                }
331            }
332            AggregateState::CollectDistinct(list, seen) => {
333                if let Some(v) = value {
334                    let hashable = HashableValue::from(&v);
335                    if seen.insert(hashable) {
336                        list.push(v);
337                    }
338                }
339            }
340            // Statistical functions using Welford's online algorithm
341            AggregateState::StdDev { count, mean, m2 }
342            | AggregateState::StdDevPop { count, mean, m2 }
343            | AggregateState::Variance { count, mean, m2 }
344            | AggregateState::VariancePop { count, mean, m2 } => {
345                if let Some(ref v) = value
346                    && let Some(x) = value_to_f64(v)
347                {
348                    *count += 1;
349                    let delta = x - *mean;
350                    *mean += delta / *count as f64;
351                    let delta2 = x - *mean;
352                    *m2 += delta * delta2;
353                }
354            }
355            AggregateState::PercentileDisc { values, .. }
356            | AggregateState::PercentileCont { values, .. } => {
357                if let Some(ref v) = value
358                    && let Some(x) = value_to_f64(v)
359                {
360                    values.push(x);
361                }
362            }
363            AggregateState::GroupConcat(list, _sep) => {
364                if let Some(v) = value {
365                    list.push(agg_value_to_string(&v));
366                }
367            }
368            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
369                if let Some(v) = value {
370                    let hashable = HashableValue::from(&v);
371                    if seen.insert(hashable) {
372                        list.push(agg_value_to_string(&v));
373                    }
374                }
375            }
376            AggregateState::Sample(sample) => {
377                if sample.is_none() {
378                    *sample = value;
379                }
380            }
381            AggregateState::Bivariate { .. } => {
382                // Bivariate functions require two values; use update_bivariate() instead.
383                // Single-value update is a no-op for bivariate state.
384            }
385            AggregateState::Frozen(_) => {}
386        }
387    }
388
389    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
390    ///
391    /// Uses the two-variable Welford online algorithm for numerically stable computation
392    /// of covariance and related statistics. Skips the update if either value is null.
393    pub fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
394        if let AggregateState::Bivariate {
395            count,
396            mean_x,
397            mean_y,
398            m2_x,
399            m2_y,
400            c_xy,
401            ..
402        } = self
403        {
404            // Skip if either value is null (SQL semantics: exclude non-pairs)
405            if let (Some(y), Some(x)) = (&y_val, &x_val)
406                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
407            {
408                *count += 1;
409                let n = *count as f64;
410                let dx = x_f - *mean_x;
411                let dy = y_f - *mean_y;
412                *mean_x += dx / n;
413                *mean_y += dy / n;
414                let dx2 = x_f - *mean_x; // post-update delta
415                let dy2 = y_f - *mean_y; // post-update delta
416                *m2_x += dx * dx2;
417                *m2_y += dy * dy2;
418                *c_xy += dx * dy2;
419            }
420        }
421    }
422
423    /// Finalizes the state and returns the result value.
424    pub fn finalize(&self) -> Value {
425        match self {
426            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
427                Value::Int64(*count)
428            }
429            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
430                if *count == 0 {
431                    Value::Null
432                } else {
433                    Value::Int64(*sum)
434                }
435            }
436            AggregateState::SumFloat(sum, _, count)
437            | AggregateState::SumFloatDistinct(sum, _, count, _) => {
438                if *count == 0 {
439                    Value::Null
440                } else {
441                    Value::Float64(*sum)
442                }
443            }
444            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
445                if *count == 0 {
446                    Value::Null
447                } else {
448                    Value::Float64(*sum / *count as f64)
449                }
450            }
451            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
452            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
453            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
454            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
455            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
456                Value::List(list.clone().into())
457            }
458            // Sample standard deviation: sqrt(M2 / (n - 1))
459            AggregateState::StdDev { count, m2, .. } => {
460                if *count < 2 {
461                    Value::Null
462                } else {
463                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
464                }
465            }
466            // Population standard deviation: sqrt(M2 / n)
467            AggregateState::StdDevPop { count, m2, .. } => {
468                if *count == 0 {
469                    Value::Null
470                } else {
471                    Value::Float64((*m2 / *count as f64).sqrt())
472                }
473            }
474            // Sample variance: M2 / (n - 1)
475            AggregateState::Variance { count, m2, .. } => {
476                if *count < 2 {
477                    Value::Null
478                } else {
479                    Value::Float64(*m2 / (*count - 1) as f64)
480                }
481            }
482            // Population variance: M2 / n
483            AggregateState::VariancePop { count, m2, .. } => {
484                if *count == 0 {
485                    Value::Null
486                } else {
487                    Value::Float64(*m2 / *count as f64)
488                }
489            }
490            // Discrete percentile: return actual value at percentile position
491            AggregateState::PercentileDisc { values, percentile } => {
492                if values.is_empty() {
493                    Value::Null
494                } else {
495                    let mut sorted = values.clone();
496                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497                    // Index calculation per SQL standard: floor(p * (n - 1))
498                    // reason: percentile index is bounded by sorted.len(), fits usize
499                    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
500                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
501                    Value::Float64(sorted[index])
502                }
503            }
504            // Continuous percentile: interpolate between values
505            AggregateState::PercentileCont { values, percentile } => {
506                if values.is_empty() {
507                    Value::Null
508                } else {
509                    let mut sorted = values.clone();
510                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
511                    // Linear interpolation per SQL standard
512                    let rank = percentile * (sorted.len() - 1) as f64;
513                    // reason: rank is bounded by sorted.len() - 1, fits usize
514                    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
515                    let lower_idx = rank.floor() as usize;
516                    // reason: rank is a non-negative f64 from percentile calculation, fits usize
517                    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
518                    let upper_idx = rank.ceil() as usize;
519                    if lower_idx == upper_idx {
520                        Value::Float64(sorted[lower_idx])
521                    } else {
522                        let fraction = rank - lower_idx as f64;
523                        let result =
524                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
525                        Value::Float64(result)
526                    }
527                }
528            }
529            // GROUP_CONCAT: join strings with space separator (SPARQL default)
530            AggregateState::GroupConcat(list, sep)
531            | AggregateState::GroupConcatDistinct(list, sep, _) => {
532                Value::String(list.join(sep).into())
533            }
534            // SAMPLE: return the first non-null value seen
535            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
536            AggregateState::Frozen(val) => val.clone(),
537            // Binary set functions: dispatch on kind
538            AggregateState::Bivariate {
539                kind,
540                count,
541                mean_x,
542                mean_y,
543                m2_x,
544                m2_y,
545                c_xy,
546            } => {
547                let n = *count;
548                match kind {
549                    AggregateFunction::CovarSamp => {
550                        if n < 2 {
551                            Value::Null
552                        } else {
553                            Value::Float64(*c_xy / (n - 1) as f64)
554                        }
555                    }
556                    AggregateFunction::CovarPop => {
557                        if n == 0 {
558                            Value::Null
559                        } else {
560                            Value::Float64(*c_xy / n as f64)
561                        }
562                    }
563                    AggregateFunction::Corr => {
564                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
565                            Value::Null
566                        } else {
567                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
568                        }
569                    }
570                    AggregateFunction::RegrSlope => {
571                        if n == 0 || *m2_x == 0.0 {
572                            Value::Null
573                        } else {
574                            Value::Float64(*c_xy / *m2_x)
575                        }
576                    }
577                    AggregateFunction::RegrIntercept => {
578                        if n == 0 || *m2_x == 0.0 {
579                            Value::Null
580                        } else {
581                            let slope = *c_xy / *m2_x;
582                            Value::Float64(*mean_y - slope * *mean_x)
583                        }
584                    }
585                    AggregateFunction::RegrR2 => {
586                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
587                            Value::Null
588                        } else {
589                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
590                        }
591                    }
592                    AggregateFunction::RegrCount => Value::Int64(n),
593                    AggregateFunction::RegrSxx => {
594                        if n == 0 {
595                            Value::Null
596                        } else {
597                            Value::Float64(*m2_x)
598                        }
599                    }
600                    AggregateFunction::RegrSyy => {
601                        if n == 0 {
602                            Value::Null
603                        } else {
604                            Value::Float64(*m2_y)
605                        }
606                    }
607                    AggregateFunction::RegrSxy => {
608                        if n == 0 {
609                            Value::Null
610                        } else {
611                            Value::Float64(*c_xy)
612                        }
613                    }
614                    AggregateFunction::RegrAvgx => {
615                        if n == 0 {
616                            Value::Null
617                        } else {
618                            Value::Float64(*mean_x)
619                        }
620                    }
621                    AggregateFunction::RegrAvgy => {
622                        if n == 0 {
623                            Value::Null
624                        } else {
625                            Value::Float64(*mean_y)
626                        }
627                    }
628                    _ => Value::Null, // non-bivariate functions never reach here
629                }
630            }
631        }
632    }
633}
634
635use super::value_utils::{compare_values, value_to_f64};
636
637/// Converts a Value to its string representation for GROUP_CONCAT.
638fn agg_value_to_string(val: &Value) -> String {
639    match val {
640        Value::String(s) => s.to_string(),
641        Value::Int64(i) => i.to_string(),
642        Value::Float64(f) => f.to_string(),
643        Value::Bool(b) => b.to_string(),
644        Value::Null => String::new(),
645        other => format!("{other:?}"),
646    }
647}
648
649/// A group key for hash-based aggregation.
650#[derive(Debug, Clone, PartialEq, Eq, Hash)]
651pub struct GroupKey(Vec<GroupKeyPart>);
652
653#[derive(Debug, Clone, PartialEq, Eq, Hash)]
654enum GroupKeyPart {
655    Null,
656    Bool(bool),
657    Int64(i64),
658    String(ArcStr),
659    Bytes(Arc<[u8]>),
660    Date(grafeo_common::types::Date),
661    Time(grafeo_common::types::Time),
662    Timestamp(grafeo_common::types::Timestamp),
663    Duration(grafeo_common::types::Duration),
664    ZonedDatetime(grafeo_common::types::ZonedDatetime),
665    List(Vec<GroupKeyPart>),
666    Map(Vec<(ArcStr, GroupKeyPart)>),
667}
668
669impl GroupKeyPart {
670    fn from_value(v: Value) -> Self {
671        match v {
672            Value::Null => Self::Null,
673            Value::Bool(b) => Self::Bool(b),
674            Value::Int64(i) => Self::Int64(i),
675            // reason: intentional bit-level reinterpretation for grouping equality
676            #[allow(clippy::cast_possible_wrap)]
677            Value::Float64(f) => Self::Int64(f.to_bits() as i64),
678            Value::String(s) => Self::String(s.clone()),
679            Value::Bytes(b) => Self::Bytes(b),
680            Value::Date(d) => Self::Date(d),
681            Value::Time(t) => Self::Time(t),
682            Value::Timestamp(ts) => Self::Timestamp(ts),
683            Value::Duration(d) => Self::Duration(d),
684            Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
685            Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
686            Value::Map(map) => {
687                // BTreeMap already iterates in key order, so this is deterministic
688                let entries: Vec<(ArcStr, GroupKeyPart)> = map
689                    .iter()
690                    .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
691                    .collect();
692                Self::Map(entries)
693            }
694            // Path, Vector, GCounter, OnCounter: use Debug string as fallback
695            other => Self::String(ArcStr::from(format!("{other:?}"))),
696        }
697    }
698
699    fn to_value(&self) -> Value {
700        match self {
701            Self::Null => Value::Null,
702            Self::Bool(b) => Value::Bool(*b),
703            Self::Int64(i) => Value::Int64(*i),
704            Self::String(s) => Value::String(s.clone()),
705            Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
706            Self::Date(d) => Value::Date(*d),
707            Self::Time(t) => Value::Time(*t),
708            Self::Timestamp(ts) => Value::Timestamp(*ts),
709            Self::Duration(d) => Value::Duration(*d),
710            Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
711            Self::List(parts) => {
712                let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
713                Value::List(Arc::from(values.into_boxed_slice()))
714            }
715            Self::Map(entries) => {
716                let map: std::collections::BTreeMap<PropertyKey, Value> = entries
717                    .iter()
718                    .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
719                    .collect();
720                Value::Map(Arc::new(map))
721            }
722        }
723    }
724}
725
726impl GroupKey {
727    /// Creates a group key from column values.
728    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
729        let parts: Vec<GroupKeyPart> = group_columns
730            .iter()
731            .map(|&col_idx| {
732                chunk
733                    .column(col_idx)
734                    .and_then(|col| col.get_value(row))
735                    .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
736            })
737            .collect();
738        GroupKey(parts)
739    }
740
741    /// Converts the group key back to values.
742    fn to_values(&self) -> Vec<Value> {
743        self.0.iter().map(GroupKeyPart::to_value).collect()
744    }
745}
746
747/// Hash-based aggregate operator.
748///
749/// Groups input by key columns and computes aggregations for each group.
750pub struct HashAggregateOperator {
751    /// Child operator to read from.
752    child: Box<dyn Operator>,
753    /// Columns to group by.
754    group_columns: Vec<usize>,
755    /// Aggregation expressions.
756    aggregates: Vec<AggregateExpr>,
757    /// Output schema.
758    output_schema: Vec<LogicalType>,
759    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
760    groups: IndexMap<GroupKey, Vec<AggregateState>>,
761    /// Whether aggregation is complete.
762    aggregation_complete: bool,
763    /// Results iterator.
764    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
765}
766
767impl HashAggregateOperator {
768    /// Creates a new hash aggregate operator.
769    ///
770    /// # Arguments
771    /// * `child` - Child operator to read from.
772    /// * `group_columns` - Column indices to group by.
773    /// * `aggregates` - Aggregation expressions.
774    /// * `output_schema` - Schema of the output (group columns + aggregate results).
775    pub fn new(
776        child: Box<dyn Operator>,
777        group_columns: Vec<usize>,
778        aggregates: Vec<AggregateExpr>,
779        output_schema: Vec<LogicalType>,
780    ) -> Self {
781        Self {
782            child,
783            group_columns,
784            aggregates,
785            output_schema,
786            groups: IndexMap::new(),
787            aggregation_complete: false,
788            results: None,
789        }
790    }
791
792    /// Decomposes this operator for push-based conversion.
793    pub fn into_parts(self) -> (Box<dyn Operator>, Vec<usize>, Vec<AggregateExpr>) {
794        (self.child, self.group_columns, self.aggregates)
795    }
796
797    /// Performs the aggregation.
798    fn aggregate(&mut self) -> Result<(), OperatorError> {
799        while let Some(chunk) = self.child.next()? {
800            for row in chunk.selected_indices() {
801                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
802
803                // Get or create aggregate states for this group
804                let states = self.groups.entry(key).or_insert_with(|| {
805                    self.aggregates
806                        .iter()
807                        .map(|agg| {
808                            AggregateState::new(
809                                agg.function,
810                                agg.distinct,
811                                agg.percentile,
812                                agg.separator.as_deref(),
813                            )
814                        })
815                        .collect()
816                });
817
818                // Update each aggregate
819                for (i, agg) in self.aggregates.iter().enumerate() {
820                    // Binary set functions: read two column values
821                    if agg.column2.is_some() {
822                        let y_val = agg
823                            .column
824                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
825                        let x_val = agg
826                            .column2
827                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
828                        states[i].update_bivariate(y_val, x_val);
829                        continue;
830                    }
831
832                    let value = match (agg.function, agg.distinct) {
833                        // COUNT(*) without DISTINCT doesn't need a value
834                        (AggregateFunction::Count, false) => None,
835                        // COUNT DISTINCT needs the actual value to track unique values
836                        (AggregateFunction::Count, true) => agg
837                            .column
838                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
839                        _ => agg
840                            .column
841                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
842                    };
843
844                    // For COUNT without DISTINCT, always update. For others, skip nulls.
845                    match (agg.function, agg.distinct) {
846                        (AggregateFunction::Count, false) => states[i].update(None),
847                        (AggregateFunction::Count, true) => {
848                            // COUNT DISTINCT needs the value to track unique values
849                            if value.is_some() && !matches!(value, Some(Value::Null)) {
850                                states[i].update(value);
851                            }
852                        }
853                        (AggregateFunction::CountNonNull, _) => {
854                            if value.is_some() && !matches!(value, Some(Value::Null)) {
855                                states[i].update(value);
856                            }
857                        }
858                        _ => {
859                            if value.is_some() && !matches!(value, Some(Value::Null)) {
860                                states[i].update(value);
861                            }
862                        }
863                    }
864                }
865            }
866        }
867
868        self.aggregation_complete = true;
869
870        // Convert to results iterator (IndexMap::drain takes a range)
871        let results: Vec<_> = self.groups.drain(..).collect();
872        self.results = Some(results.into_iter());
873
874        Ok(())
875    }
876}
877
878impl Operator for HashAggregateOperator {
879    fn next(&mut self) -> OperatorResult {
880        // Perform aggregation if not done
881        if !self.aggregation_complete {
882            self.aggregate()?;
883        }
884
885        // Special case: no groups (global aggregation with no data)
886        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
887            // For global aggregation (no GROUP BY), return one row with initial values
888            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
889
890            for agg in &self.aggregates {
891                let state = AggregateState::new(
892                    agg.function,
893                    agg.distinct,
894                    agg.percentile,
895                    agg.separator.as_deref(),
896                );
897                let value = state.finalize();
898                if let Some(col) = builder.column_mut(self.group_columns.len()) {
899                    col.push_value(value);
900                }
901            }
902            builder.advance_row();
903
904            self.results = Some(Vec::new().into_iter()); // Mark as done
905            return Ok(Some(builder.finish()));
906        }
907
908        let Some(results) = &mut self.results else {
909            return Ok(None);
910        };
911
912        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
913
914        for (key, states) in results.by_ref() {
915            // Output group key columns
916            let key_values = key.to_values();
917            for (i, value) in key_values.into_iter().enumerate() {
918                if let Some(col) = builder.column_mut(i) {
919                    col.push_value(value);
920                }
921            }
922
923            // Output aggregate results
924            for (i, state) in states.iter().enumerate() {
925                let col_idx = self.group_columns.len() + i;
926                if let Some(col) = builder.column_mut(col_idx) {
927                    col.push_value(state.finalize());
928                }
929            }
930
931            builder.advance_row();
932
933            if builder.is_full() {
934                return Ok(Some(builder.finish()));
935            }
936        }
937
938        if builder.row_count() > 0 {
939            Ok(Some(builder.finish()))
940        } else {
941            Ok(None)
942        }
943    }
944
945    fn reset(&mut self) {
946        self.child.reset();
947        self.groups.clear();
948        self.aggregation_complete = false;
949        self.results = None;
950    }
951
952    fn name(&self) -> &'static str {
953        "HashAggregate"
954    }
955
956    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
957        self
958    }
959}
960
961/// Simple (non-grouping) aggregate operator for global aggregations.
962///
963/// Used when there's no GROUP BY clause - aggregates all input into a single row.
964pub struct SimpleAggregateOperator {
965    /// Child operator.
966    child: Box<dyn Operator>,
967    /// Aggregation expressions.
968    aggregates: Vec<AggregateExpr>,
969    /// Output schema.
970    output_schema: Vec<LogicalType>,
971    /// Aggregate states.
972    states: Vec<AggregateState>,
973    /// Whether aggregation is complete.
974    done: bool,
975}
976
977impl SimpleAggregateOperator {
978    /// Creates a new simple aggregate operator.
979    pub fn new(
980        child: Box<dyn Operator>,
981        aggregates: Vec<AggregateExpr>,
982        output_schema: Vec<LogicalType>,
983    ) -> Self {
984        let states = aggregates
985            .iter()
986            .map(|agg| {
987                AggregateState::new(
988                    agg.function,
989                    agg.distinct,
990                    agg.percentile,
991                    agg.separator.as_deref(),
992                )
993            })
994            .collect();
995
996        Self {
997            child,
998            aggregates,
999            output_schema,
1000            states,
1001            done: false,
1002        }
1003    }
1004}
1005
1006impl Operator for SimpleAggregateOperator {
1007    fn next(&mut self) -> OperatorResult {
1008        if self.done {
1009            return Ok(None);
1010        }
1011
1012        // Process all input
1013        while let Some(chunk) = self.child.next()? {
1014            for row in chunk.selected_indices() {
1015                for (i, agg) in self.aggregates.iter().enumerate() {
1016                    // Binary set functions: read two column values
1017                    if agg.column2.is_some() {
1018                        let y_val = agg
1019                            .column
1020                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1021                        let x_val = agg
1022                            .column2
1023                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1024                        self.states[i].update_bivariate(y_val, x_val);
1025                        continue;
1026                    }
1027
1028                    let value = match (agg.function, agg.distinct) {
1029                        // COUNT(*) without DISTINCT doesn't need a value
1030                        (AggregateFunction::Count, false) => None,
1031                        // COUNT DISTINCT needs the actual value to track unique values
1032                        (AggregateFunction::Count, true) => agg
1033                            .column
1034                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1035                        _ => agg
1036                            .column
1037                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1038                    };
1039
1040                    match (agg.function, agg.distinct) {
1041                        (AggregateFunction::Count, false) => self.states[i].update(None),
1042                        (AggregateFunction::Count, true) => {
1043                            // COUNT DISTINCT needs the value to track unique values
1044                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1045                                self.states[i].update(value);
1046                            }
1047                        }
1048                        (AggregateFunction::CountNonNull, _) => {
1049                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1050                                self.states[i].update(value);
1051                            }
1052                        }
1053                        _ => {
1054                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1055                                self.states[i].update(value);
1056                            }
1057                        }
1058                    }
1059                }
1060            }
1061        }
1062
1063        // Output single result row
1064        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1065
1066        for (i, state) in self.states.iter().enumerate() {
1067            if let Some(col) = builder.column_mut(i) {
1068                col.push_value(state.finalize());
1069            }
1070        }
1071        builder.advance_row();
1072
1073        self.done = true;
1074        Ok(Some(builder.finish()))
1075    }
1076
1077    fn reset(&mut self) {
1078        self.child.reset();
1079        self.states = self
1080            .aggregates
1081            .iter()
1082            .map(|agg| {
1083                AggregateState::new(
1084                    agg.function,
1085                    agg.distinct,
1086                    agg.percentile,
1087                    agg.separator.as_deref(),
1088                )
1089            })
1090            .collect();
1091        self.done = false;
1092    }
1093
1094    fn name(&self) -> &'static str {
1095        "SimpleAggregate"
1096    }
1097
1098    fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
1099        self
1100    }
1101}
1102
1103#[cfg(test)]
1104mod tests {
1105    use super::*;
1106    use crate::execution::chunk::DataChunkBuilder;
1107
1108    struct MockOperator {
1109        chunks: Vec<DataChunk>,
1110        position: usize,
1111    }
1112
1113    impl MockOperator {
1114        fn new(chunks: Vec<DataChunk>) -> Self {
1115            Self {
1116                chunks,
1117                position: 0,
1118            }
1119        }
1120    }
1121
1122    impl Operator for MockOperator {
1123        fn next(&mut self) -> OperatorResult {
1124            if self.position < self.chunks.len() {
1125                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1126                self.position += 1;
1127                Ok(Some(chunk))
1128            } else {
1129                Ok(None)
1130            }
1131        }
1132
1133        fn reset(&mut self) {
1134            self.position = 0;
1135        }
1136
1137        fn name(&self) -> &'static str {
1138            "Mock"
1139        }
1140
1141        fn into_any(self: Box<Self>) -> Box<dyn std::any::Any + Send> {
1142            self
1143        }
1144    }
1145
1146    fn create_test_chunk() -> DataChunk {
1147        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1148        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1149
1150        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1151        for (group, value) in data {
1152            builder.column_mut(0).unwrap().push_int64(group);
1153            builder.column_mut(1).unwrap().push_int64(value);
1154            builder.advance_row();
1155        }
1156
1157        builder.finish()
1158    }
1159
1160    #[test]
1161    fn test_simple_count() {
1162        let mock = MockOperator::new(vec![create_test_chunk()]);
1163
1164        let mut agg = SimpleAggregateOperator::new(
1165            Box::new(mock),
1166            vec![AggregateExpr::count_star()],
1167            vec![LogicalType::Int64],
1168        );
1169
1170        let result = agg.next().unwrap().unwrap();
1171        assert_eq!(result.row_count(), 1);
1172        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1173
1174        // Should be done
1175        assert!(agg.next().unwrap().is_none());
1176    }
1177
1178    #[test]
1179    fn test_simple_sum() {
1180        let mock = MockOperator::new(vec![create_test_chunk()]);
1181
1182        let mut agg = SimpleAggregateOperator::new(
1183            Box::new(mock),
1184            vec![AggregateExpr::sum(1)], // Sum of column 1
1185            vec![LogicalType::Int64],
1186        );
1187
1188        let result = agg.next().unwrap().unwrap();
1189        assert_eq!(result.row_count(), 1);
1190        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1191        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1192    }
1193
1194    #[test]
1195    fn test_simple_avg() {
1196        let mock = MockOperator::new(vec![create_test_chunk()]);
1197
1198        let mut agg = SimpleAggregateOperator::new(
1199            Box::new(mock),
1200            vec![AggregateExpr::avg(1)],
1201            vec![LogicalType::Float64],
1202        );
1203
1204        let result = agg.next().unwrap().unwrap();
1205        assert_eq!(result.row_count(), 1);
1206        // Avg: 150 / 5 = 30.0
1207        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1208        assert!((avg - 30.0).abs() < 0.001);
1209    }
1210
1211    #[test]
1212    fn test_simple_min_max() {
1213        let mock = MockOperator::new(vec![create_test_chunk()]);
1214
1215        let mut agg = SimpleAggregateOperator::new(
1216            Box::new(mock),
1217            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1218            vec![LogicalType::Int64, LogicalType::Int64],
1219        );
1220
1221        let result = agg.next().unwrap().unwrap();
1222        assert_eq!(result.row_count(), 1);
1223        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1224        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1225    }
1226
1227    #[test]
1228    fn test_sum_with_string_values() {
1229        // Test SUM with string values (like RDF stores numeric literals)
1230        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1231        builder.column_mut(0).unwrap().push_string("30");
1232        builder.advance_row();
1233        builder.column_mut(0).unwrap().push_string("25");
1234        builder.advance_row();
1235        builder.column_mut(0).unwrap().push_string("35");
1236        builder.advance_row();
1237        let chunk = builder.finish();
1238
1239        let mock = MockOperator::new(vec![chunk]);
1240        let mut agg = SimpleAggregateOperator::new(
1241            Box::new(mock),
1242            vec![AggregateExpr::sum(0)],
1243            vec![LogicalType::Float64],
1244        );
1245
1246        let result = agg.next().unwrap().unwrap();
1247        assert_eq!(result.row_count(), 1);
1248        // Should parse strings and sum: 30 + 25 + 35 = 90
1249        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1250        assert!(
1251            (sum_val - 90.0).abs() < 0.001,
1252            "Expected 90.0, got {}",
1253            sum_val
1254        );
1255    }
1256
1257    #[test]
1258    fn test_grouped_aggregation() {
1259        let mock = MockOperator::new(vec![create_test_chunk()]);
1260
1261        // GROUP BY column 0, SUM(column 1)
1262        let mut agg = HashAggregateOperator::new(
1263            Box::new(mock),
1264            vec![0],                     // Group by column 0
1265            vec![AggregateExpr::sum(1)], // Sum of column 1
1266            vec![LogicalType::Int64, LogicalType::Int64],
1267        );
1268
1269        let mut results: Vec<(i64, i64)> = Vec::new();
1270        while let Some(chunk) = agg.next().unwrap() {
1271            for row in chunk.selected_indices() {
1272                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1273                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1274                results.push((group, sum));
1275            }
1276        }
1277
1278        results.sort_by_key(|(g, _)| *g);
1279        assert_eq!(results.len(), 2);
1280        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1281        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1282    }
1283
1284    #[test]
1285    fn test_grouped_count() {
1286        let mock = MockOperator::new(vec![create_test_chunk()]);
1287
1288        // GROUP BY column 0, COUNT(*)
1289        let mut agg = HashAggregateOperator::new(
1290            Box::new(mock),
1291            vec![0],
1292            vec![AggregateExpr::count_star()],
1293            vec![LogicalType::Int64, LogicalType::Int64],
1294        );
1295
1296        let mut results: Vec<(i64, i64)> = Vec::new();
1297        while let Some(chunk) = agg.next().unwrap() {
1298            for row in chunk.selected_indices() {
1299                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1300                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1301                results.push((group, count));
1302            }
1303        }
1304
1305        results.sort_by_key(|(g, _)| *g);
1306        assert_eq!(results.len(), 2);
1307        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1308        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1309    }
1310
1311    #[test]
1312    fn test_multiple_aggregates() {
1313        let mock = MockOperator::new(vec![create_test_chunk()]);
1314
1315        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1316        let mut agg = HashAggregateOperator::new(
1317            Box::new(mock),
1318            vec![0],
1319            vec![
1320                AggregateExpr::count_star(),
1321                AggregateExpr::sum(1),
1322                AggregateExpr::avg(1),
1323            ],
1324            vec![
1325                LogicalType::Int64,   // Group key
1326                LogicalType::Int64,   // COUNT
1327                LogicalType::Int64,   // SUM
1328                LogicalType::Float64, // AVG
1329            ],
1330        );
1331
1332        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1333        while let Some(chunk) = agg.next().unwrap() {
1334            for row in chunk.selected_indices() {
1335                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1336                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1337                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1338                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1339                results.push((group, count, sum, avg));
1340            }
1341        }
1342
1343        results.sort_by_key(|(g, _, _, _)| *g);
1344        assert_eq!(results.len(), 2);
1345
1346        // Group 1: COUNT=2, SUM=30, AVG=15.0
1347        assert_eq!(results[0].0, 1);
1348        assert_eq!(results[0].1, 2);
1349        assert_eq!(results[0].2, 30);
1350        assert!((results[0].3 - 15.0).abs() < 0.001);
1351
1352        // Group 2: COUNT=3, SUM=120, AVG=40.0
1353        assert_eq!(results[1].0, 2);
1354        assert_eq!(results[1].1, 3);
1355        assert_eq!(results[1].2, 120);
1356        assert!((results[1].3 - 40.0).abs() < 0.001);
1357    }
1358
1359    fn create_test_chunk_with_duplicates() -> DataChunk {
1360        // Create data with duplicate values in column 1
1361        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1362        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1363        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1364        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1365
1366        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1367        for (group, value) in data {
1368            builder.column_mut(0).unwrap().push_int64(group);
1369            builder.column_mut(1).unwrap().push_int64(value);
1370            builder.advance_row();
1371        }
1372
1373        builder.finish()
1374    }
1375
1376    #[test]
1377    fn test_count_distinct() {
1378        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1379
1380        // COUNT(DISTINCT column 1)
1381        let mut agg = SimpleAggregateOperator::new(
1382            Box::new(mock),
1383            vec![AggregateExpr::count(1).with_distinct()],
1384            vec![LogicalType::Int64],
1385        );
1386
1387        let result = agg.next().unwrap().unwrap();
1388        assert_eq!(result.row_count(), 1);
1389        // Total distinct values: 10, 20, 30 = 3 distinct values
1390        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1391    }
1392
1393    #[test]
1394    fn test_grouped_count_distinct() {
1395        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1396
1397        // GROUP BY column 0, COUNT(DISTINCT column 1)
1398        let mut agg = HashAggregateOperator::new(
1399            Box::new(mock),
1400            vec![0],
1401            vec![AggregateExpr::count(1).with_distinct()],
1402            vec![LogicalType::Int64, LogicalType::Int64],
1403        );
1404
1405        let mut results: Vec<(i64, i64)> = Vec::new();
1406        while let Some(chunk) = agg.next().unwrap() {
1407            for row in chunk.selected_indices() {
1408                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1409                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1410                results.push((group, count));
1411            }
1412        }
1413
1414        results.sort_by_key(|(g, _)| *g);
1415        assert_eq!(results.len(), 2);
1416        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1417        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1418    }
1419
1420    #[test]
1421    fn test_sum_distinct() {
1422        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1423
1424        // SUM(DISTINCT column 1)
1425        let mut agg = SimpleAggregateOperator::new(
1426            Box::new(mock),
1427            vec![AggregateExpr::sum(1).with_distinct()],
1428            vec![LogicalType::Int64],
1429        );
1430
1431        let result = agg.next().unwrap().unwrap();
1432        assert_eq!(result.row_count(), 1);
1433        // Sum of distinct values: 10 + 20 + 30 = 60
1434        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1435    }
1436
1437    #[test]
1438    fn test_avg_distinct() {
1439        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1440
1441        // AVG(DISTINCT column 1)
1442        let mut agg = SimpleAggregateOperator::new(
1443            Box::new(mock),
1444            vec![AggregateExpr::avg(1).with_distinct()],
1445            vec![LogicalType::Float64],
1446        );
1447
1448        let result = agg.next().unwrap().unwrap();
1449        assert_eq!(result.row_count(), 1);
1450        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1451        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1452        assert!((avg - 20.0).abs() < 0.001);
1453    }
1454
1455    fn create_statistical_test_chunk() -> DataChunk {
1456        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1457        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1458        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1459
1460        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1461            builder.column_mut(0).unwrap().push_int64(value);
1462            builder.advance_row();
1463        }
1464
1465        builder.finish()
1466    }
1467
1468    #[test]
1469    fn test_stdev_sample() {
1470        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1471
1472        let mut agg = SimpleAggregateOperator::new(
1473            Box::new(mock),
1474            vec![AggregateExpr::stdev(0)],
1475            vec![LogicalType::Float64],
1476        );
1477
1478        let result = agg.next().unwrap().unwrap();
1479        assert_eq!(result.row_count(), 1);
1480        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1481        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1482        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1483        assert!((stdev - 2.138).abs() < 0.01);
1484    }
1485
1486    #[test]
1487    fn test_stdev_population() {
1488        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1489
1490        let mut agg = SimpleAggregateOperator::new(
1491            Box::new(mock),
1492            vec![AggregateExpr::stdev_pop(0)],
1493            vec![LogicalType::Float64],
1494        );
1495
1496        let result = agg.next().unwrap().unwrap();
1497        assert_eq!(result.row_count(), 1);
1498        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1499        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1500        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1501        assert!((stdev - 2.0).abs() < 0.01);
1502    }
1503
1504    #[test]
1505    fn test_percentile_disc() {
1506        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1507
1508        // Median (50th percentile discrete)
1509        let mut agg = SimpleAggregateOperator::new(
1510            Box::new(mock),
1511            vec![AggregateExpr::percentile_disc(0, 0.5)],
1512            vec![LogicalType::Float64],
1513        );
1514
1515        let result = agg.next().unwrap().unwrap();
1516        assert_eq!(result.row_count(), 1);
1517        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1518        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1519        assert!((percentile - 4.0).abs() < 0.01);
1520    }
1521
1522    #[test]
1523    fn test_percentile_cont() {
1524        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1525
1526        // Median (50th percentile continuous)
1527        let mut agg = SimpleAggregateOperator::new(
1528            Box::new(mock),
1529            vec![AggregateExpr::percentile_cont(0, 0.5)],
1530            vec![LogicalType::Float64],
1531        );
1532
1533        let result = agg.next().unwrap().unwrap();
1534        assert_eq!(result.row_count(), 1);
1535        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1536        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1537        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1538        assert!((percentile - 4.5).abs() < 0.01);
1539    }
1540
1541    #[test]
1542    fn test_percentile_extremes() {
1543        // Test 0th and 100th percentiles
1544        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1545
1546        let mut agg = SimpleAggregateOperator::new(
1547            Box::new(mock),
1548            vec![
1549                AggregateExpr::percentile_disc(0, 0.0),
1550                AggregateExpr::percentile_disc(0, 1.0),
1551            ],
1552            vec![LogicalType::Float64, LogicalType::Float64],
1553        );
1554
1555        let result = agg.next().unwrap().unwrap();
1556        assert_eq!(result.row_count(), 1);
1557        // 0th percentile = minimum = 2
1558        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1559        assert!((p0 - 2.0).abs() < 0.01);
1560        // 100th percentile = maximum = 9
1561        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1562        assert!((p100 - 9.0).abs() < 0.01);
1563    }
1564
1565    #[test]
1566    fn test_stdev_single_value() {
1567        // Single value should return null for sample stdev
1568        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1569        builder.column_mut(0).unwrap().push_int64(42);
1570        builder.advance_row();
1571        let chunk = builder.finish();
1572
1573        let mock = MockOperator::new(vec![chunk]);
1574
1575        let mut agg = SimpleAggregateOperator::new(
1576            Box::new(mock),
1577            vec![AggregateExpr::stdev(0)],
1578            vec![LogicalType::Float64],
1579        );
1580
1581        let result = agg.next().unwrap().unwrap();
1582        assert_eq!(result.row_count(), 1);
1583        // Sample stdev of single value is undefined (null)
1584        assert!(matches!(
1585            result.column(0).unwrap().get_value(0),
1586            Some(Value::Null)
1587        ));
1588    }
1589
1590    #[test]
1591    fn test_first_and_last() {
1592        let mock = MockOperator::new(vec![create_test_chunk()]);
1593
1594        let mut agg = SimpleAggregateOperator::new(
1595            Box::new(mock),
1596            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1597            vec![LogicalType::Int64, LogicalType::Int64],
1598        );
1599
1600        let result = agg.next().unwrap().unwrap();
1601        assert_eq!(result.row_count(), 1);
1602        // First: 10, Last: 50
1603        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1604        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1605    }
1606
1607    #[test]
1608    fn test_collect() {
1609        let mock = MockOperator::new(vec![create_test_chunk()]);
1610
1611        let mut agg = SimpleAggregateOperator::new(
1612            Box::new(mock),
1613            vec![AggregateExpr::collect(1)],
1614            vec![LogicalType::Any],
1615        );
1616
1617        let result = agg.next().unwrap().unwrap();
1618        let val = result.column(0).unwrap().get_value(0).unwrap();
1619        if let Value::List(items) = val {
1620            assert_eq!(items.len(), 5);
1621        } else {
1622            panic!("Expected List value");
1623        }
1624    }
1625
1626    #[test]
1627    fn test_collect_distinct() {
1628        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1629
1630        let mut agg = SimpleAggregateOperator::new(
1631            Box::new(mock),
1632            vec![AggregateExpr::collect(1).with_distinct()],
1633            vec![LogicalType::Any],
1634        );
1635
1636        let result = agg.next().unwrap().unwrap();
1637        let val = result.column(0).unwrap().get_value(0).unwrap();
1638        if let Value::List(items) = val {
1639            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1640            assert_eq!(items.len(), 3);
1641        } else {
1642            panic!("Expected List value");
1643        }
1644    }
1645
1646    #[test]
1647    fn test_group_concat() {
1648        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1649        for s in ["hello", "world", "foo"] {
1650            builder.column_mut(0).unwrap().push_string(s);
1651            builder.advance_row();
1652        }
1653        let chunk = builder.finish();
1654        let mock = MockOperator::new(vec![chunk]);
1655
1656        let agg_expr = AggregateExpr {
1657            function: AggregateFunction::GroupConcat,
1658            column: Some(0),
1659            column2: None,
1660            distinct: false,
1661            alias: None,
1662            percentile: None,
1663            separator: None,
1664        };
1665
1666        let mut agg =
1667            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1668
1669        let result = agg.next().unwrap().unwrap();
1670        let val = result.column(0).unwrap().get_value(0).unwrap();
1671        assert_eq!(val, Value::String("hello world foo".into()));
1672    }
1673
1674    #[test]
1675    fn test_sample() {
1676        let mock = MockOperator::new(vec![create_test_chunk()]);
1677
1678        let agg_expr = AggregateExpr {
1679            function: AggregateFunction::Sample,
1680            column: Some(1),
1681            column2: None,
1682            distinct: false,
1683            alias: None,
1684            percentile: None,
1685            separator: None,
1686        };
1687
1688        let mut agg =
1689            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1690
1691        let result = agg.next().unwrap().unwrap();
1692        // Sample should return the first non-null value (10)
1693        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1694    }
1695
1696    #[test]
1697    fn test_variance_sample() {
1698        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1699
1700        let agg_expr = AggregateExpr {
1701            function: AggregateFunction::Variance,
1702            column: Some(0),
1703            column2: None,
1704            distinct: false,
1705            alias: None,
1706            percentile: None,
1707            separator: None,
1708        };
1709
1710        let mut agg = SimpleAggregateOperator::new(
1711            Box::new(mock),
1712            vec![agg_expr],
1713            vec![LogicalType::Float64],
1714        );
1715
1716        let result = agg.next().unwrap().unwrap();
1717        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1718        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1719        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1720    }
1721
1722    #[test]
1723    fn test_variance_population() {
1724        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1725
1726        let agg_expr = AggregateExpr {
1727            function: AggregateFunction::VariancePop,
1728            column: Some(0),
1729            column2: None,
1730            distinct: false,
1731            alias: None,
1732            percentile: None,
1733            separator: None,
1734        };
1735
1736        let mut agg = SimpleAggregateOperator::new(
1737            Box::new(mock),
1738            vec![agg_expr],
1739            vec![LogicalType::Float64],
1740        );
1741
1742        let result = agg.next().unwrap().unwrap();
1743        // Population variance: M2/n = 32/8 = 4.0
1744        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1745        assert!((variance - 4.0).abs() < 0.01);
1746    }
1747
1748    #[test]
1749    fn test_variance_single_value() {
1750        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1751        builder.column_mut(0).unwrap().push_int64(42);
1752        builder.advance_row();
1753        let chunk = builder.finish();
1754        let mock = MockOperator::new(vec![chunk]);
1755
1756        let agg_expr = AggregateExpr {
1757            function: AggregateFunction::Variance,
1758            column: Some(0),
1759            column2: None,
1760            distinct: false,
1761            alias: None,
1762            percentile: None,
1763            separator: None,
1764        };
1765
1766        let mut agg = SimpleAggregateOperator::new(
1767            Box::new(mock),
1768            vec![agg_expr],
1769            vec![LogicalType::Float64],
1770        );
1771
1772        let result = agg.next().unwrap().unwrap();
1773        // Sample variance of single value is undefined (null)
1774        assert!(matches!(
1775            result.column(0).unwrap().get_value(0),
1776            Some(Value::Null)
1777        ));
1778    }
1779
1780    #[test]
1781    fn test_empty_aggregation() {
1782        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1783        // (ISO/IEC 39075 Section 20.9)
1784        let mock = MockOperator::new(vec![]);
1785
1786        let mut agg = SimpleAggregateOperator::new(
1787            Box::new(mock),
1788            vec![
1789                AggregateExpr::count_star(),
1790                AggregateExpr::sum(0),
1791                AggregateExpr::avg(0),
1792                AggregateExpr::min(0),
1793                AggregateExpr::max(0),
1794            ],
1795            vec![
1796                LogicalType::Int64,
1797                LogicalType::Int64,
1798                LogicalType::Float64,
1799                LogicalType::Int64,
1800                LogicalType::Int64,
1801            ],
1802        );
1803
1804        let result = agg.next().unwrap().unwrap();
1805        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1806        assert!(matches!(
1807            result.column(1).unwrap().get_value(0),
1808            Some(Value::Null)
1809        )); // SUM
1810        assert!(matches!(
1811            result.column(2).unwrap().get_value(0),
1812            Some(Value::Null)
1813        )); // AVG
1814        assert!(matches!(
1815            result.column(3).unwrap().get_value(0),
1816            Some(Value::Null)
1817        )); // MIN
1818        assert!(matches!(
1819            result.column(4).unwrap().get_value(0),
1820            Some(Value::Null)
1821        )); // MAX
1822    }
1823
1824    #[test]
1825    fn test_stdev_pop_single_value() {
1826        // Single value should return 0 for population stdev
1827        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1828        builder.column_mut(0).unwrap().push_int64(42);
1829        builder.advance_row();
1830        let chunk = builder.finish();
1831
1832        let mock = MockOperator::new(vec![chunk]);
1833
1834        let mut agg = SimpleAggregateOperator::new(
1835            Box::new(mock),
1836            vec![AggregateExpr::stdev_pop(0)],
1837            vec![LogicalType::Float64],
1838        );
1839
1840        let result = agg.next().unwrap().unwrap();
1841        assert_eq!(result.row_count(), 1);
1842        // Population stdev of single value is 0
1843        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1844        assert!((stdev - 0.0).abs() < 0.01);
1845    }
1846
1847    #[test]
1848    fn test_hash_aggregate_into_any() {
1849        let mock = MockOperator::new(vec![]);
1850        let op = HashAggregateOperator::new(
1851            Box::new(mock),
1852            vec![0],
1853            vec![AggregateExpr::count_star()],
1854            vec![LogicalType::Int64, LogicalType::Int64],
1855        );
1856        let any = Box::new(op).into_any();
1857        assert!(any.downcast::<HashAggregateOperator>().is_ok());
1858    }
1859
1860    #[test]
1861    fn test_simple_aggregate_into_any() {
1862        let mock = MockOperator::new(vec![]);
1863        let op = SimpleAggregateOperator::new(
1864            Box::new(mock),
1865            vec![AggregateExpr::count_star()],
1866            vec![LogicalType::Int64],
1867        );
1868        let any = Box::new(op).into_any();
1869        assert!(any.downcast::<SimpleAggregateOperator>().is_ok());
1870    }
1871
1872    #[test]
1873    fn test_hash_aggregate_into_parts() {
1874        let mock = MockOperator::new(vec![]);
1875        let op = HashAggregateOperator::new(
1876            Box::new(mock),
1877            vec![0, 2],
1878            vec![AggregateExpr::sum(1), AggregateExpr::count_star()],
1879            vec![LogicalType::Int64, LogicalType::Int64, LogicalType::Int64],
1880        );
1881        let (mut child, group_columns, aggregates) = op.into_parts();
1882        assert_eq!(group_columns, vec![0, 2]);
1883        assert_eq!(aggregates.len(), 2);
1884        assert!(child.next().unwrap().is_none());
1885    }
1886}