Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22/// State for a single aggregation computation.
23///
24/// Used by both the pull-based [`HashAggregateOperator`] and the push-based
25/// `AggregatePushOperator`.
26/// Supports all [`AggregateFunction`] variants including Welford's algorithm
27/// for online statistics, Kahan summation, distinct tracking, and bivariate
28/// regression functions.
29#[derive(Debug, Clone)]
30#[allow(missing_docs)]
31pub enum AggregateState {
32    /// Count state.
33    Count(i64),
34    /// Count distinct state (count, seen values).
35    CountDistinct(i64, HashSet<HashableValue>),
36    /// Sum state (integer sum, count of values added).
37    SumInt(i64, i64),
38    /// Sum distinct state (integer sum, count, seen values).
39    SumIntDistinct(i64, i64, HashSet<HashableValue>),
40    /// Sum state (float sum, count of values added).
41    SumFloat(f64, f64, i64),
42    /// Sum distinct state (float sum, compensation, count, seen values).
43    SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
44    /// Average state (sum, count).
45    Avg(f64, i64),
46    /// Average distinct state (sum, count, seen values).
47    AvgDistinct(f64, i64, HashSet<HashableValue>),
48    /// Min state.
49    Min(Option<Value>),
50    /// Max state.
51    Max(Option<Value>),
52    /// First state.
53    First(Option<Value>),
54    /// Last state.
55    Last(Option<Value>),
56    /// Collect state.
57    Collect(Vec<Value>),
58    /// Collect distinct state (values, seen).
59    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
60    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
61    StdDev { count: i64, mean: f64, m2: f64 },
62    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
63    StdDevPop { count: i64, mean: f64, m2: f64 },
64    /// Discrete percentile state (values, percentile).
65    PercentileDisc { values: Vec<f64>, percentile: f64 },
66    /// Continuous percentile state (values, percentile).
67    PercentileCont { values: Vec<f64>, percentile: f64 },
68    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
69    GroupConcat(Vec<String>, String),
70    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
71    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
72    /// SAMPLE state (first non-null value encountered).
73    Sample(Option<Value>),
74    /// Sample variance state using Welford's algorithm (count, mean, M2).
75    Variance { count: i64, mean: f64, m2: f64 },
76    /// Population variance state using Welford's algorithm (count, mean, M2).
77    VariancePop { count: i64, mean: f64, m2: f64 },
78    /// Two-variable online statistics (Welford generalization for covariance/regression).
79    Bivariate {
80        /// Which binary set function this state will finalize to.
81        kind: AggregateFunction,
82        count: i64,
83        mean_x: f64,
84        mean_y: f64,
85        m2_x: f64,
86        m2_y: f64,
87        c_xy: f64,
88    },
89    /// Immutable finalized value restored from spill. Ignores further updates
90    /// so that reloaded groups that were serialized via the FINALIZED fallback
91    /// do not silently corrupt their result when more rows arrive.
92    Frozen(Value),
93}
94
95impl AggregateState {
96    /// Creates initial state for an aggregation function.
97    pub fn new(
98        function: AggregateFunction,
99        distinct: bool,
100        percentile: Option<f64>,
101        separator: Option<&str>,
102    ) -> Self {
103        match (function, distinct) {
104            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
105                AggregateState::Count(0)
106            }
107            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
108                AggregateState::CountDistinct(0, HashSet::new())
109            }
110            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
111            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
112            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
113            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
114            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
115            (AggregateFunction::Max, _) => AggregateState::Max(None),
116            (AggregateFunction::First, _) => AggregateState::First(None),
117            (AggregateFunction::Last, _) => AggregateState::Last(None),
118            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
119            (AggregateFunction::Collect, true) => {
120                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
121            }
122            // Statistical functions (Welford's algorithm for online computation)
123            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
124                count: 0,
125                mean: 0.0,
126                m2: 0.0,
127            },
128            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
129                count: 0,
130                mean: 0.0,
131                m2: 0.0,
132            },
133            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
134                values: Vec::new(),
135                percentile: percentile.unwrap_or(0.5),
136            },
137            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
138                values: Vec::new(),
139                percentile: percentile.unwrap_or(0.5),
140            },
141            (AggregateFunction::GroupConcat, false) => {
142                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
143            }
144            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
145                Vec::new(),
146                separator.unwrap_or(" ").to_string(),
147                HashSet::new(),
148            ),
149            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
150            // Binary set functions (all share the same Bivariate state)
151            (
152                AggregateFunction::CovarSamp
153                | AggregateFunction::CovarPop
154                | AggregateFunction::Corr
155                | AggregateFunction::RegrSlope
156                | AggregateFunction::RegrIntercept
157                | AggregateFunction::RegrR2
158                | AggregateFunction::RegrCount
159                | AggregateFunction::RegrSxx
160                | AggregateFunction::RegrSyy
161                | AggregateFunction::RegrSxy
162                | AggregateFunction::RegrAvgx
163                | AggregateFunction::RegrAvgy,
164                _,
165            ) => AggregateState::Bivariate {
166                kind: function,
167                count: 0,
168                mean_x: 0.0,
169                mean_y: 0.0,
170                m2_x: 0.0,
171                m2_y: 0.0,
172                c_xy: 0.0,
173            },
174            (AggregateFunction::Variance, _) => AggregateState::Variance {
175                count: 0,
176                mean: 0.0,
177                m2: 0.0,
178            },
179            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
180                count: 0,
181                mean: 0.0,
182                m2: 0.0,
183            },
184        }
185    }
186
187    /// Updates the state with a new value.
188    ///
189    /// For `COUNT(*)`, pass `None` to count all rows. For column-specific
190    /// aggregates, pass `Some(value)` (nulls are skipped by most functions).
191    pub fn update(&mut self, value: Option<Value>) {
192        match self {
193            AggregateState::Count(count) => {
194                *count += 1;
195            }
196            AggregateState::CountDistinct(count, seen) => {
197                if let Some(ref v) = value {
198                    let hashable = HashableValue::from(v);
199                    if seen.insert(hashable) {
200                        *count += 1;
201                    }
202                }
203            }
204            AggregateState::SumInt(sum, count) => {
205                if let Some(Value::Int64(v)) = value {
206                    *sum += v;
207                    *count += 1;
208                } else if let Some(Value::Float64(v)) = value {
209                    // Convert to float sum, carrying count forward
210                    *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
211                } else if let Some(ref v) = value {
212                    // RDF stores numeric literals as strings - try to parse
213                    if let Some(num) = value_to_f64(v) {
214                        *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
215                    }
216                }
217            }
218            AggregateState::SumIntDistinct(sum, count, seen) => {
219                if let Some(ref v) = value {
220                    let hashable = HashableValue::from(v);
221                    if seen.insert(hashable) {
222                        if let Value::Int64(i) = v {
223                            *sum += i;
224                            *count += 1;
225                        } else if let Value::Float64(f) = v {
226                            // Convert to float distinct: move the seen set instead of cloning
227                            let moved_seen = std::mem::take(seen);
228                            *self = AggregateState::SumFloatDistinct(
229                                *sum as f64 + f,
230                                0.0,
231                                *count + 1,
232                                moved_seen,
233                            );
234                        } else if let Some(num) = value_to_f64(v) {
235                            // RDF string-encoded numerics
236                            let moved_seen = std::mem::take(seen);
237                            *self = AggregateState::SumFloatDistinct(
238                                *sum as f64 + num,
239                                0.0,
240                                *count + 1,
241                                moved_seen,
242                            );
243                        }
244                    }
245                }
246            }
247            AggregateState::SumFloat(sum, comp, count) => {
248                if let Some(ref v) = value {
249                    // Use value_to_f64 which now handles strings
250                    if let Some(num) = value_to_f64(v) {
251                        // Kahan compensated summation to reduce rounding error
252                        let y = num - *comp;
253                        let t = *sum + y;
254                        *comp = (t - *sum) - y;
255                        *sum = t;
256                        *count += 1;
257                    }
258                }
259            }
260            AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
261                if let Some(ref v) = value {
262                    let hashable = HashableValue::from(v);
263                    if seen.insert(hashable)
264                        && let Some(num) = value_to_f64(v)
265                    {
266                        let y = num - *comp;
267                        let t = *sum + y;
268                        *comp = (t - *sum) - y;
269                        *sum = t;
270                        *count += 1;
271                    }
272                }
273            }
274            AggregateState::Avg(sum, count) => {
275                if let Some(ref v) = value
276                    && let Some(num) = value_to_f64(v)
277                {
278                    *sum += num;
279                    *count += 1;
280                }
281            }
282            AggregateState::AvgDistinct(sum, count, seen) => {
283                if let Some(ref v) = value {
284                    let hashable = HashableValue::from(v);
285                    if seen.insert(hashable)
286                        && let Some(num) = value_to_f64(v)
287                    {
288                        *sum += num;
289                        *count += 1;
290                    }
291                }
292            }
293            AggregateState::Min(min) => {
294                if let Some(v) = value {
295                    match min {
296                        None => *min = Some(v),
297                        Some(current) => {
298                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
299                                *min = Some(v);
300                            }
301                        }
302                    }
303                }
304            }
305            AggregateState::Max(max) => {
306                if let Some(v) = value {
307                    match max {
308                        None => *max = Some(v),
309                        Some(current) => {
310                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
311                                *max = Some(v);
312                            }
313                        }
314                    }
315                }
316            }
317            AggregateState::First(first) => {
318                if first.is_none() {
319                    *first = value;
320                }
321            }
322            AggregateState::Last(last) => {
323                if value.is_some() {
324                    *last = value;
325                }
326            }
327            AggregateState::Collect(list) => {
328                if let Some(v) = value {
329                    list.push(v);
330                }
331            }
332            AggregateState::CollectDistinct(list, seen) => {
333                if let Some(v) = value {
334                    let hashable = HashableValue::from(&v);
335                    if seen.insert(hashable) {
336                        list.push(v);
337                    }
338                }
339            }
340            // Statistical functions using Welford's online algorithm
341            AggregateState::StdDev { count, mean, m2 }
342            | AggregateState::StdDevPop { count, mean, m2 }
343            | AggregateState::Variance { count, mean, m2 }
344            | AggregateState::VariancePop { count, mean, m2 } => {
345                if let Some(ref v) = value
346                    && let Some(x) = value_to_f64(v)
347                {
348                    *count += 1;
349                    let delta = x - *mean;
350                    *mean += delta / *count as f64;
351                    let delta2 = x - *mean;
352                    *m2 += delta * delta2;
353                }
354            }
355            AggregateState::PercentileDisc { values, .. }
356            | AggregateState::PercentileCont { values, .. } => {
357                if let Some(ref v) = value
358                    && let Some(x) = value_to_f64(v)
359                {
360                    values.push(x);
361                }
362            }
363            AggregateState::GroupConcat(list, _sep) => {
364                if let Some(v) = value {
365                    list.push(agg_value_to_string(&v));
366                }
367            }
368            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
369                if let Some(v) = value {
370                    let hashable = HashableValue::from(&v);
371                    if seen.insert(hashable) {
372                        list.push(agg_value_to_string(&v));
373                    }
374                }
375            }
376            AggregateState::Sample(sample) => {
377                if sample.is_none() {
378                    *sample = value;
379                }
380            }
381            AggregateState::Bivariate { .. } => {
382                // Bivariate functions require two values; use update_bivariate() instead.
383                // Single-value update is a no-op for bivariate state.
384            }
385            AggregateState::Frozen(_) => {}
386        }
387    }
388
389    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
390    ///
391    /// Uses the two-variable Welford online algorithm for numerically stable computation
392    /// of covariance and related statistics. Skips the update if either value is null.
393    pub fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
394        if let AggregateState::Bivariate {
395            count,
396            mean_x,
397            mean_y,
398            m2_x,
399            m2_y,
400            c_xy,
401            ..
402        } = self
403        {
404            // Skip if either value is null (SQL semantics: exclude non-pairs)
405            if let (Some(y), Some(x)) = (&y_val, &x_val)
406                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
407            {
408                *count += 1;
409                let n = *count as f64;
410                let dx = x_f - *mean_x;
411                let dy = y_f - *mean_y;
412                *mean_x += dx / n;
413                *mean_y += dy / n;
414                let dx2 = x_f - *mean_x; // post-update delta
415                let dy2 = y_f - *mean_y; // post-update delta
416                *m2_x += dx * dx2;
417                *m2_y += dy * dy2;
418                *c_xy += dx * dy2;
419            }
420        }
421    }
422
423    /// Finalizes the state and returns the result value.
424    pub fn finalize(&self) -> Value {
425        match self {
426            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
427                Value::Int64(*count)
428            }
429            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
430                if *count == 0 {
431                    Value::Null
432                } else {
433                    Value::Int64(*sum)
434                }
435            }
436            AggregateState::SumFloat(sum, _, count)
437            | AggregateState::SumFloatDistinct(sum, _, count, _) => {
438                if *count == 0 {
439                    Value::Null
440                } else {
441                    Value::Float64(*sum)
442                }
443            }
444            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
445                if *count == 0 {
446                    Value::Null
447                } else {
448                    Value::Float64(*sum / *count as f64)
449                }
450            }
451            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
452            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
453            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
454            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
455            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
456                Value::List(list.clone().into())
457            }
458            // Sample standard deviation: sqrt(M2 / (n - 1))
459            AggregateState::StdDev { count, m2, .. } => {
460                if *count < 2 {
461                    Value::Null
462                } else {
463                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
464                }
465            }
466            // Population standard deviation: sqrt(M2 / n)
467            AggregateState::StdDevPop { count, m2, .. } => {
468                if *count == 0 {
469                    Value::Null
470                } else {
471                    Value::Float64((*m2 / *count as f64).sqrt())
472                }
473            }
474            // Sample variance: M2 / (n - 1)
475            AggregateState::Variance { count, m2, .. } => {
476                if *count < 2 {
477                    Value::Null
478                } else {
479                    Value::Float64(*m2 / (*count - 1) as f64)
480                }
481            }
482            // Population variance: M2 / n
483            AggregateState::VariancePop { count, m2, .. } => {
484                if *count == 0 {
485                    Value::Null
486                } else {
487                    Value::Float64(*m2 / *count as f64)
488                }
489            }
490            // Discrete percentile: return actual value at percentile position
491            AggregateState::PercentileDisc { values, percentile } => {
492                if values.is_empty() {
493                    Value::Null
494                } else {
495                    let mut sorted = values.clone();
496                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
497                    // Index calculation per SQL standard: floor(p * (n - 1))
498                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
499                    Value::Float64(sorted[index])
500                }
501            }
502            // Continuous percentile: interpolate between values
503            AggregateState::PercentileCont { values, percentile } => {
504                if values.is_empty() {
505                    Value::Null
506                } else {
507                    let mut sorted = values.clone();
508                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
509                    // Linear interpolation per SQL standard
510                    let rank = percentile * (sorted.len() - 1) as f64;
511                    let lower_idx = rank.floor() as usize;
512                    let upper_idx = rank.ceil() as usize;
513                    if lower_idx == upper_idx {
514                        Value::Float64(sorted[lower_idx])
515                    } else {
516                        let fraction = rank - lower_idx as f64;
517                        let result =
518                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
519                        Value::Float64(result)
520                    }
521                }
522            }
523            // GROUP_CONCAT: join strings with space separator (SPARQL default)
524            AggregateState::GroupConcat(list, sep)
525            | AggregateState::GroupConcatDistinct(list, sep, _) => {
526                Value::String(list.join(sep).into())
527            }
528            // SAMPLE: return the first non-null value seen
529            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
530            AggregateState::Frozen(val) => val.clone(),
531            // Binary set functions: dispatch on kind
532            AggregateState::Bivariate {
533                kind,
534                count,
535                mean_x,
536                mean_y,
537                m2_x,
538                m2_y,
539                c_xy,
540            } => {
541                let n = *count;
542                match kind {
543                    AggregateFunction::CovarSamp => {
544                        if n < 2 {
545                            Value::Null
546                        } else {
547                            Value::Float64(*c_xy / (n - 1) as f64)
548                        }
549                    }
550                    AggregateFunction::CovarPop => {
551                        if n == 0 {
552                            Value::Null
553                        } else {
554                            Value::Float64(*c_xy / n as f64)
555                        }
556                    }
557                    AggregateFunction::Corr => {
558                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
559                            Value::Null
560                        } else {
561                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
562                        }
563                    }
564                    AggregateFunction::RegrSlope => {
565                        if n == 0 || *m2_x == 0.0 {
566                            Value::Null
567                        } else {
568                            Value::Float64(*c_xy / *m2_x)
569                        }
570                    }
571                    AggregateFunction::RegrIntercept => {
572                        if n == 0 || *m2_x == 0.0 {
573                            Value::Null
574                        } else {
575                            let slope = *c_xy / *m2_x;
576                            Value::Float64(*mean_y - slope * *mean_x)
577                        }
578                    }
579                    AggregateFunction::RegrR2 => {
580                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
581                            Value::Null
582                        } else {
583                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
584                        }
585                    }
586                    AggregateFunction::RegrCount => Value::Int64(n),
587                    AggregateFunction::RegrSxx => {
588                        if n == 0 {
589                            Value::Null
590                        } else {
591                            Value::Float64(*m2_x)
592                        }
593                    }
594                    AggregateFunction::RegrSyy => {
595                        if n == 0 {
596                            Value::Null
597                        } else {
598                            Value::Float64(*m2_y)
599                        }
600                    }
601                    AggregateFunction::RegrSxy => {
602                        if n == 0 {
603                            Value::Null
604                        } else {
605                            Value::Float64(*c_xy)
606                        }
607                    }
608                    AggregateFunction::RegrAvgx => {
609                        if n == 0 {
610                            Value::Null
611                        } else {
612                            Value::Float64(*mean_x)
613                        }
614                    }
615                    AggregateFunction::RegrAvgy => {
616                        if n == 0 {
617                            Value::Null
618                        } else {
619                            Value::Float64(*mean_y)
620                        }
621                    }
622                    _ => Value::Null, // non-bivariate functions never reach here
623                }
624            }
625        }
626    }
627}
628
629use super::value_utils::{compare_values, value_to_f64};
630
631/// Converts a Value to its string representation for GROUP_CONCAT.
632fn agg_value_to_string(val: &Value) -> String {
633    match val {
634        Value::String(s) => s.to_string(),
635        Value::Int64(i) => i.to_string(),
636        Value::Float64(f) => f.to_string(),
637        Value::Bool(b) => b.to_string(),
638        Value::Null => String::new(),
639        other => format!("{other:?}"),
640    }
641}
642
643/// A group key for hash-based aggregation.
644#[derive(Debug, Clone, PartialEq, Eq, Hash)]
645pub struct GroupKey(Vec<GroupKeyPart>);
646
647#[derive(Debug, Clone, PartialEq, Eq, Hash)]
648enum GroupKeyPart {
649    Null,
650    Bool(bool),
651    Int64(i64),
652    String(ArcStr),
653    Bytes(Arc<[u8]>),
654    Date(grafeo_common::types::Date),
655    Time(grafeo_common::types::Time),
656    Timestamp(grafeo_common::types::Timestamp),
657    Duration(grafeo_common::types::Duration),
658    ZonedDatetime(grafeo_common::types::ZonedDatetime),
659    List(Vec<GroupKeyPart>),
660    Map(Vec<(ArcStr, GroupKeyPart)>),
661}
662
663impl GroupKeyPart {
664    fn from_value(v: Value) -> Self {
665        match v {
666            Value::Null => Self::Null,
667            Value::Bool(b) => Self::Bool(b),
668            Value::Int64(i) => Self::Int64(i),
669            Value::Float64(f) => Self::Int64(f.to_bits() as i64),
670            Value::String(s) => Self::String(s.clone()),
671            Value::Bytes(b) => Self::Bytes(b),
672            Value::Date(d) => Self::Date(d),
673            Value::Time(t) => Self::Time(t),
674            Value::Timestamp(ts) => Self::Timestamp(ts),
675            Value::Duration(d) => Self::Duration(d),
676            Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
677            Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
678            Value::Map(map) => {
679                // BTreeMap already iterates in key order, so this is deterministic
680                let entries: Vec<(ArcStr, GroupKeyPart)> = map
681                    .iter()
682                    .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
683                    .collect();
684                Self::Map(entries)
685            }
686            // Path, Vector, GCounter, OnCounter: use Debug string as fallback
687            other => Self::String(ArcStr::from(format!("{other:?}"))),
688        }
689    }
690
691    fn to_value(&self) -> Value {
692        match self {
693            Self::Null => Value::Null,
694            Self::Bool(b) => Value::Bool(*b),
695            Self::Int64(i) => Value::Int64(*i),
696            Self::String(s) => Value::String(s.clone()),
697            Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
698            Self::Date(d) => Value::Date(*d),
699            Self::Time(t) => Value::Time(*t),
700            Self::Timestamp(ts) => Value::Timestamp(*ts),
701            Self::Duration(d) => Value::Duration(*d),
702            Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
703            Self::List(parts) => {
704                let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
705                Value::List(Arc::from(values.into_boxed_slice()))
706            }
707            Self::Map(entries) => {
708                let map: std::collections::BTreeMap<PropertyKey, Value> = entries
709                    .iter()
710                    .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
711                    .collect();
712                Value::Map(Arc::new(map))
713            }
714        }
715    }
716}
717
718impl GroupKey {
719    /// Creates a group key from column values.
720    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
721        let parts: Vec<GroupKeyPart> = group_columns
722            .iter()
723            .map(|&col_idx| {
724                chunk
725                    .column(col_idx)
726                    .and_then(|col| col.get_value(row))
727                    .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
728            })
729            .collect();
730        GroupKey(parts)
731    }
732
733    /// Converts the group key back to values.
734    fn to_values(&self) -> Vec<Value> {
735        self.0.iter().map(GroupKeyPart::to_value).collect()
736    }
737}
738
739/// Hash-based aggregate operator.
740///
741/// Groups input by key columns and computes aggregations for each group.
742pub struct HashAggregateOperator {
743    /// Child operator to read from.
744    child: Box<dyn Operator>,
745    /// Columns to group by.
746    group_columns: Vec<usize>,
747    /// Aggregation expressions.
748    aggregates: Vec<AggregateExpr>,
749    /// Output schema.
750    output_schema: Vec<LogicalType>,
751    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
752    groups: IndexMap<GroupKey, Vec<AggregateState>>,
753    /// Whether aggregation is complete.
754    aggregation_complete: bool,
755    /// Results iterator.
756    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
757}
758
759impl HashAggregateOperator {
760    /// Creates a new hash aggregate operator.
761    ///
762    /// # Arguments
763    /// * `child` - Child operator to read from.
764    /// * `group_columns` - Column indices to group by.
765    /// * `aggregates` - Aggregation expressions.
766    /// * `output_schema` - Schema of the output (group columns + aggregate results).
767    pub fn new(
768        child: Box<dyn Operator>,
769        group_columns: Vec<usize>,
770        aggregates: Vec<AggregateExpr>,
771        output_schema: Vec<LogicalType>,
772    ) -> Self {
773        Self {
774            child,
775            group_columns,
776            aggregates,
777            output_schema,
778            groups: IndexMap::new(),
779            aggregation_complete: false,
780            results: None,
781        }
782    }
783
784    /// Performs the aggregation.
785    fn aggregate(&mut self) -> Result<(), OperatorError> {
786        while let Some(chunk) = self.child.next()? {
787            for row in chunk.selected_indices() {
788                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
789
790                // Get or create aggregate states for this group
791                let states = self.groups.entry(key).or_insert_with(|| {
792                    self.aggregates
793                        .iter()
794                        .map(|agg| {
795                            AggregateState::new(
796                                agg.function,
797                                agg.distinct,
798                                agg.percentile,
799                                agg.separator.as_deref(),
800                            )
801                        })
802                        .collect()
803                });
804
805                // Update each aggregate
806                for (i, agg) in self.aggregates.iter().enumerate() {
807                    // Binary set functions: read two column values
808                    if agg.column2.is_some() {
809                        let y_val = agg
810                            .column
811                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
812                        let x_val = agg
813                            .column2
814                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
815                        states[i].update_bivariate(y_val, x_val);
816                        continue;
817                    }
818
819                    let value = match (agg.function, agg.distinct) {
820                        // COUNT(*) without DISTINCT doesn't need a value
821                        (AggregateFunction::Count, false) => None,
822                        // COUNT DISTINCT needs the actual value to track unique values
823                        (AggregateFunction::Count, true) => agg
824                            .column
825                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
826                        _ => agg
827                            .column
828                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
829                    };
830
831                    // For COUNT without DISTINCT, always update. For others, skip nulls.
832                    match (agg.function, agg.distinct) {
833                        (AggregateFunction::Count, false) => states[i].update(None),
834                        (AggregateFunction::Count, true) => {
835                            // COUNT DISTINCT needs the value to track unique values
836                            if value.is_some() && !matches!(value, Some(Value::Null)) {
837                                states[i].update(value);
838                            }
839                        }
840                        (AggregateFunction::CountNonNull, _) => {
841                            if value.is_some() && !matches!(value, Some(Value::Null)) {
842                                states[i].update(value);
843                            }
844                        }
845                        _ => {
846                            if value.is_some() && !matches!(value, Some(Value::Null)) {
847                                states[i].update(value);
848                            }
849                        }
850                    }
851                }
852            }
853        }
854
855        self.aggregation_complete = true;
856
857        // Convert to results iterator (IndexMap::drain takes a range)
858        let results: Vec<_> = self.groups.drain(..).collect();
859        self.results = Some(results.into_iter());
860
861        Ok(())
862    }
863}
864
865impl Operator for HashAggregateOperator {
866    fn next(&mut self) -> OperatorResult {
867        // Perform aggregation if not done
868        if !self.aggregation_complete {
869            self.aggregate()?;
870        }
871
872        // Special case: no groups (global aggregation with no data)
873        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
874            // For global aggregation (no GROUP BY), return one row with initial values
875            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
876
877            for agg in &self.aggregates {
878                let state = AggregateState::new(
879                    agg.function,
880                    agg.distinct,
881                    agg.percentile,
882                    agg.separator.as_deref(),
883                );
884                let value = state.finalize();
885                if let Some(col) = builder.column_mut(self.group_columns.len()) {
886                    col.push_value(value);
887                }
888            }
889            builder.advance_row();
890
891            self.results = Some(Vec::new().into_iter()); // Mark as done
892            return Ok(Some(builder.finish()));
893        }
894
895        let Some(results) = &mut self.results else {
896            return Ok(None);
897        };
898
899        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
900
901        for (key, states) in results.by_ref() {
902            // Output group key columns
903            let key_values = key.to_values();
904            for (i, value) in key_values.into_iter().enumerate() {
905                if let Some(col) = builder.column_mut(i) {
906                    col.push_value(value);
907                }
908            }
909
910            // Output aggregate results
911            for (i, state) in states.iter().enumerate() {
912                let col_idx = self.group_columns.len() + i;
913                if let Some(col) = builder.column_mut(col_idx) {
914                    col.push_value(state.finalize());
915                }
916            }
917
918            builder.advance_row();
919
920            if builder.is_full() {
921                return Ok(Some(builder.finish()));
922            }
923        }
924
925        if builder.row_count() > 0 {
926            Ok(Some(builder.finish()))
927        } else {
928            Ok(None)
929        }
930    }
931
932    fn reset(&mut self) {
933        self.child.reset();
934        self.groups.clear();
935        self.aggregation_complete = false;
936        self.results = None;
937    }
938
939    fn name(&self) -> &'static str {
940        "HashAggregate"
941    }
942}
943
944/// Simple (non-grouping) aggregate operator for global aggregations.
945///
946/// Used when there's no GROUP BY clause - aggregates all input into a single row.
947pub struct SimpleAggregateOperator {
948    /// Child operator.
949    child: Box<dyn Operator>,
950    /// Aggregation expressions.
951    aggregates: Vec<AggregateExpr>,
952    /// Output schema.
953    output_schema: Vec<LogicalType>,
954    /// Aggregate states.
955    states: Vec<AggregateState>,
956    /// Whether aggregation is complete.
957    done: bool,
958}
959
960impl SimpleAggregateOperator {
961    /// Creates a new simple aggregate operator.
962    pub fn new(
963        child: Box<dyn Operator>,
964        aggregates: Vec<AggregateExpr>,
965        output_schema: Vec<LogicalType>,
966    ) -> Self {
967        let states = aggregates
968            .iter()
969            .map(|agg| {
970                AggregateState::new(
971                    agg.function,
972                    agg.distinct,
973                    agg.percentile,
974                    agg.separator.as_deref(),
975                )
976            })
977            .collect();
978
979        Self {
980            child,
981            aggregates,
982            output_schema,
983            states,
984            done: false,
985        }
986    }
987}
988
989impl Operator for SimpleAggregateOperator {
990    fn next(&mut self) -> OperatorResult {
991        if self.done {
992            return Ok(None);
993        }
994
995        // Process all input
996        while let Some(chunk) = self.child.next()? {
997            for row in chunk.selected_indices() {
998                for (i, agg) in self.aggregates.iter().enumerate() {
999                    // Binary set functions: read two column values
1000                    if agg.column2.is_some() {
1001                        let y_val = agg
1002                            .column
1003                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1004                        let x_val = agg
1005                            .column2
1006                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
1007                        self.states[i].update_bivariate(y_val, x_val);
1008                        continue;
1009                    }
1010
1011                    let value = match (agg.function, agg.distinct) {
1012                        // COUNT(*) without DISTINCT doesn't need a value
1013                        (AggregateFunction::Count, false) => None,
1014                        // COUNT DISTINCT needs the actual value to track unique values
1015                        (AggregateFunction::Count, true) => agg
1016                            .column
1017                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1018                        _ => agg
1019                            .column
1020                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1021                    };
1022
1023                    match (agg.function, agg.distinct) {
1024                        (AggregateFunction::Count, false) => self.states[i].update(None),
1025                        (AggregateFunction::Count, true) => {
1026                            // COUNT DISTINCT needs the value to track unique values
1027                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1028                                self.states[i].update(value);
1029                            }
1030                        }
1031                        (AggregateFunction::CountNonNull, _) => {
1032                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1033                                self.states[i].update(value);
1034                            }
1035                        }
1036                        _ => {
1037                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1038                                self.states[i].update(value);
1039                            }
1040                        }
1041                    }
1042                }
1043            }
1044        }
1045
1046        // Output single result row
1047        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1048
1049        for (i, state) in self.states.iter().enumerate() {
1050            if let Some(col) = builder.column_mut(i) {
1051                col.push_value(state.finalize());
1052            }
1053        }
1054        builder.advance_row();
1055
1056        self.done = true;
1057        Ok(Some(builder.finish()))
1058    }
1059
1060    fn reset(&mut self) {
1061        self.child.reset();
1062        self.states = self
1063            .aggregates
1064            .iter()
1065            .map(|agg| {
1066                AggregateState::new(
1067                    agg.function,
1068                    agg.distinct,
1069                    agg.percentile,
1070                    agg.separator.as_deref(),
1071                )
1072            })
1073            .collect();
1074        self.done = false;
1075    }
1076
1077    fn name(&self) -> &'static str {
1078        "SimpleAggregate"
1079    }
1080}
1081
1082#[cfg(test)]
1083mod tests {
1084    use super::*;
1085    use crate::execution::chunk::DataChunkBuilder;
1086
1087    struct MockOperator {
1088        chunks: Vec<DataChunk>,
1089        position: usize,
1090    }
1091
1092    impl MockOperator {
1093        fn new(chunks: Vec<DataChunk>) -> Self {
1094            Self {
1095                chunks,
1096                position: 0,
1097            }
1098        }
1099    }
1100
1101    impl Operator for MockOperator {
1102        fn next(&mut self) -> OperatorResult {
1103            if self.position < self.chunks.len() {
1104                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1105                self.position += 1;
1106                Ok(Some(chunk))
1107            } else {
1108                Ok(None)
1109            }
1110        }
1111
1112        fn reset(&mut self) {
1113            self.position = 0;
1114        }
1115
1116        fn name(&self) -> &'static str {
1117            "Mock"
1118        }
1119    }
1120
1121    fn create_test_chunk() -> DataChunk {
1122        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1123        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1124
1125        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1126        for (group, value) in data {
1127            builder.column_mut(0).unwrap().push_int64(group);
1128            builder.column_mut(1).unwrap().push_int64(value);
1129            builder.advance_row();
1130        }
1131
1132        builder.finish()
1133    }
1134
1135    #[test]
1136    fn test_simple_count() {
1137        let mock = MockOperator::new(vec![create_test_chunk()]);
1138
1139        let mut agg = SimpleAggregateOperator::new(
1140            Box::new(mock),
1141            vec![AggregateExpr::count_star()],
1142            vec![LogicalType::Int64],
1143        );
1144
1145        let result = agg.next().unwrap().unwrap();
1146        assert_eq!(result.row_count(), 1);
1147        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1148
1149        // Should be done
1150        assert!(agg.next().unwrap().is_none());
1151    }
1152
1153    #[test]
1154    fn test_simple_sum() {
1155        let mock = MockOperator::new(vec![create_test_chunk()]);
1156
1157        let mut agg = SimpleAggregateOperator::new(
1158            Box::new(mock),
1159            vec![AggregateExpr::sum(1)], // Sum of column 1
1160            vec![LogicalType::Int64],
1161        );
1162
1163        let result = agg.next().unwrap().unwrap();
1164        assert_eq!(result.row_count(), 1);
1165        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1166        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1167    }
1168
1169    #[test]
1170    fn test_simple_avg() {
1171        let mock = MockOperator::new(vec![create_test_chunk()]);
1172
1173        let mut agg = SimpleAggregateOperator::new(
1174            Box::new(mock),
1175            vec![AggregateExpr::avg(1)],
1176            vec![LogicalType::Float64],
1177        );
1178
1179        let result = agg.next().unwrap().unwrap();
1180        assert_eq!(result.row_count(), 1);
1181        // Avg: 150 / 5 = 30.0
1182        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1183        assert!((avg - 30.0).abs() < 0.001);
1184    }
1185
1186    #[test]
1187    fn test_simple_min_max() {
1188        let mock = MockOperator::new(vec![create_test_chunk()]);
1189
1190        let mut agg = SimpleAggregateOperator::new(
1191            Box::new(mock),
1192            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1193            vec![LogicalType::Int64, LogicalType::Int64],
1194        );
1195
1196        let result = agg.next().unwrap().unwrap();
1197        assert_eq!(result.row_count(), 1);
1198        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1199        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1200    }
1201
1202    #[test]
1203    fn test_sum_with_string_values() {
1204        // Test SUM with string values (like RDF stores numeric literals)
1205        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1206        builder.column_mut(0).unwrap().push_string("30");
1207        builder.advance_row();
1208        builder.column_mut(0).unwrap().push_string("25");
1209        builder.advance_row();
1210        builder.column_mut(0).unwrap().push_string("35");
1211        builder.advance_row();
1212        let chunk = builder.finish();
1213
1214        let mock = MockOperator::new(vec![chunk]);
1215        let mut agg = SimpleAggregateOperator::new(
1216            Box::new(mock),
1217            vec![AggregateExpr::sum(0)],
1218            vec![LogicalType::Float64],
1219        );
1220
1221        let result = agg.next().unwrap().unwrap();
1222        assert_eq!(result.row_count(), 1);
1223        // Should parse strings and sum: 30 + 25 + 35 = 90
1224        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1225        assert!(
1226            (sum_val - 90.0).abs() < 0.001,
1227            "Expected 90.0, got {}",
1228            sum_val
1229        );
1230    }
1231
1232    #[test]
1233    fn test_grouped_aggregation() {
1234        let mock = MockOperator::new(vec![create_test_chunk()]);
1235
1236        // GROUP BY column 0, SUM(column 1)
1237        let mut agg = HashAggregateOperator::new(
1238            Box::new(mock),
1239            vec![0],                     // Group by column 0
1240            vec![AggregateExpr::sum(1)], // Sum of column 1
1241            vec![LogicalType::Int64, LogicalType::Int64],
1242        );
1243
1244        let mut results: Vec<(i64, i64)> = Vec::new();
1245        while let Some(chunk) = agg.next().unwrap() {
1246            for row in chunk.selected_indices() {
1247                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1248                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1249                results.push((group, sum));
1250            }
1251        }
1252
1253        results.sort_by_key(|(g, _)| *g);
1254        assert_eq!(results.len(), 2);
1255        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1256        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1257    }
1258
1259    #[test]
1260    fn test_grouped_count() {
1261        let mock = MockOperator::new(vec![create_test_chunk()]);
1262
1263        // GROUP BY column 0, COUNT(*)
1264        let mut agg = HashAggregateOperator::new(
1265            Box::new(mock),
1266            vec![0],
1267            vec![AggregateExpr::count_star()],
1268            vec![LogicalType::Int64, LogicalType::Int64],
1269        );
1270
1271        let mut results: Vec<(i64, i64)> = Vec::new();
1272        while let Some(chunk) = agg.next().unwrap() {
1273            for row in chunk.selected_indices() {
1274                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1275                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1276                results.push((group, count));
1277            }
1278        }
1279
1280        results.sort_by_key(|(g, _)| *g);
1281        assert_eq!(results.len(), 2);
1282        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1283        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1284    }
1285
1286    #[test]
1287    fn test_multiple_aggregates() {
1288        let mock = MockOperator::new(vec![create_test_chunk()]);
1289
1290        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1291        let mut agg = HashAggregateOperator::new(
1292            Box::new(mock),
1293            vec![0],
1294            vec![
1295                AggregateExpr::count_star(),
1296                AggregateExpr::sum(1),
1297                AggregateExpr::avg(1),
1298            ],
1299            vec![
1300                LogicalType::Int64,   // Group key
1301                LogicalType::Int64,   // COUNT
1302                LogicalType::Int64,   // SUM
1303                LogicalType::Float64, // AVG
1304            ],
1305        );
1306
1307        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1308        while let Some(chunk) = agg.next().unwrap() {
1309            for row in chunk.selected_indices() {
1310                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1311                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1312                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1313                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1314                results.push((group, count, sum, avg));
1315            }
1316        }
1317
1318        results.sort_by_key(|(g, _, _, _)| *g);
1319        assert_eq!(results.len(), 2);
1320
1321        // Group 1: COUNT=2, SUM=30, AVG=15.0
1322        assert_eq!(results[0].0, 1);
1323        assert_eq!(results[0].1, 2);
1324        assert_eq!(results[0].2, 30);
1325        assert!((results[0].3 - 15.0).abs() < 0.001);
1326
1327        // Group 2: COUNT=3, SUM=120, AVG=40.0
1328        assert_eq!(results[1].0, 2);
1329        assert_eq!(results[1].1, 3);
1330        assert_eq!(results[1].2, 120);
1331        assert!((results[1].3 - 40.0).abs() < 0.001);
1332    }
1333
1334    fn create_test_chunk_with_duplicates() -> DataChunk {
1335        // Create data with duplicate values in column 1
1336        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1337        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1338        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1339        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1340
1341        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1342        for (group, value) in data {
1343            builder.column_mut(0).unwrap().push_int64(group);
1344            builder.column_mut(1).unwrap().push_int64(value);
1345            builder.advance_row();
1346        }
1347
1348        builder.finish()
1349    }
1350
1351    #[test]
1352    fn test_count_distinct() {
1353        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1354
1355        // COUNT(DISTINCT column 1)
1356        let mut agg = SimpleAggregateOperator::new(
1357            Box::new(mock),
1358            vec![AggregateExpr::count(1).with_distinct()],
1359            vec![LogicalType::Int64],
1360        );
1361
1362        let result = agg.next().unwrap().unwrap();
1363        assert_eq!(result.row_count(), 1);
1364        // Total distinct values: 10, 20, 30 = 3 distinct values
1365        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1366    }
1367
1368    #[test]
1369    fn test_grouped_count_distinct() {
1370        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1371
1372        // GROUP BY column 0, COUNT(DISTINCT column 1)
1373        let mut agg = HashAggregateOperator::new(
1374            Box::new(mock),
1375            vec![0],
1376            vec![AggregateExpr::count(1).with_distinct()],
1377            vec![LogicalType::Int64, LogicalType::Int64],
1378        );
1379
1380        let mut results: Vec<(i64, i64)> = Vec::new();
1381        while let Some(chunk) = agg.next().unwrap() {
1382            for row in chunk.selected_indices() {
1383                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1384                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1385                results.push((group, count));
1386            }
1387        }
1388
1389        results.sort_by_key(|(g, _)| *g);
1390        assert_eq!(results.len(), 2);
1391        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1392        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1393    }
1394
1395    #[test]
1396    fn test_sum_distinct() {
1397        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1398
1399        // SUM(DISTINCT column 1)
1400        let mut agg = SimpleAggregateOperator::new(
1401            Box::new(mock),
1402            vec![AggregateExpr::sum(1).with_distinct()],
1403            vec![LogicalType::Int64],
1404        );
1405
1406        let result = agg.next().unwrap().unwrap();
1407        assert_eq!(result.row_count(), 1);
1408        // Sum of distinct values: 10 + 20 + 30 = 60
1409        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1410    }
1411
1412    #[test]
1413    fn test_avg_distinct() {
1414        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1415
1416        // AVG(DISTINCT column 1)
1417        let mut agg = SimpleAggregateOperator::new(
1418            Box::new(mock),
1419            vec![AggregateExpr::avg(1).with_distinct()],
1420            vec![LogicalType::Float64],
1421        );
1422
1423        let result = agg.next().unwrap().unwrap();
1424        assert_eq!(result.row_count(), 1);
1425        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1426        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1427        assert!((avg - 20.0).abs() < 0.001);
1428    }
1429
1430    fn create_statistical_test_chunk() -> DataChunk {
1431        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1432        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1433        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1434
1435        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1436            builder.column_mut(0).unwrap().push_int64(value);
1437            builder.advance_row();
1438        }
1439
1440        builder.finish()
1441    }
1442
1443    #[test]
1444    fn test_stdev_sample() {
1445        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1446
1447        let mut agg = SimpleAggregateOperator::new(
1448            Box::new(mock),
1449            vec![AggregateExpr::stdev(0)],
1450            vec![LogicalType::Float64],
1451        );
1452
1453        let result = agg.next().unwrap().unwrap();
1454        assert_eq!(result.row_count(), 1);
1455        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1456        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1457        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1458        assert!((stdev - 2.138).abs() < 0.01);
1459    }
1460
1461    #[test]
1462    fn test_stdev_population() {
1463        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1464
1465        let mut agg = SimpleAggregateOperator::new(
1466            Box::new(mock),
1467            vec![AggregateExpr::stdev_pop(0)],
1468            vec![LogicalType::Float64],
1469        );
1470
1471        let result = agg.next().unwrap().unwrap();
1472        assert_eq!(result.row_count(), 1);
1473        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1474        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1475        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1476        assert!((stdev - 2.0).abs() < 0.01);
1477    }
1478
1479    #[test]
1480    fn test_percentile_disc() {
1481        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1482
1483        // Median (50th percentile discrete)
1484        let mut agg = SimpleAggregateOperator::new(
1485            Box::new(mock),
1486            vec![AggregateExpr::percentile_disc(0, 0.5)],
1487            vec![LogicalType::Float64],
1488        );
1489
1490        let result = agg.next().unwrap().unwrap();
1491        assert_eq!(result.row_count(), 1);
1492        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1493        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1494        assert!((percentile - 4.0).abs() < 0.01);
1495    }
1496
1497    #[test]
1498    fn test_percentile_cont() {
1499        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1500
1501        // Median (50th percentile continuous)
1502        let mut agg = SimpleAggregateOperator::new(
1503            Box::new(mock),
1504            vec![AggregateExpr::percentile_cont(0, 0.5)],
1505            vec![LogicalType::Float64],
1506        );
1507
1508        let result = agg.next().unwrap().unwrap();
1509        assert_eq!(result.row_count(), 1);
1510        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1511        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1512        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1513        assert!((percentile - 4.5).abs() < 0.01);
1514    }
1515
1516    #[test]
1517    fn test_percentile_extremes() {
1518        // Test 0th and 100th percentiles
1519        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1520
1521        let mut agg = SimpleAggregateOperator::new(
1522            Box::new(mock),
1523            vec![
1524                AggregateExpr::percentile_disc(0, 0.0),
1525                AggregateExpr::percentile_disc(0, 1.0),
1526            ],
1527            vec![LogicalType::Float64, LogicalType::Float64],
1528        );
1529
1530        let result = agg.next().unwrap().unwrap();
1531        assert_eq!(result.row_count(), 1);
1532        // 0th percentile = minimum = 2
1533        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1534        assert!((p0 - 2.0).abs() < 0.01);
1535        // 100th percentile = maximum = 9
1536        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1537        assert!((p100 - 9.0).abs() < 0.01);
1538    }
1539
1540    #[test]
1541    fn test_stdev_single_value() {
1542        // Single value should return null for sample stdev
1543        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1544        builder.column_mut(0).unwrap().push_int64(42);
1545        builder.advance_row();
1546        let chunk = builder.finish();
1547
1548        let mock = MockOperator::new(vec![chunk]);
1549
1550        let mut agg = SimpleAggregateOperator::new(
1551            Box::new(mock),
1552            vec![AggregateExpr::stdev(0)],
1553            vec![LogicalType::Float64],
1554        );
1555
1556        let result = agg.next().unwrap().unwrap();
1557        assert_eq!(result.row_count(), 1);
1558        // Sample stdev of single value is undefined (null)
1559        assert!(matches!(
1560            result.column(0).unwrap().get_value(0),
1561            Some(Value::Null)
1562        ));
1563    }
1564
1565    #[test]
1566    fn test_first_and_last() {
1567        let mock = MockOperator::new(vec![create_test_chunk()]);
1568
1569        let mut agg = SimpleAggregateOperator::new(
1570            Box::new(mock),
1571            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1572            vec![LogicalType::Int64, LogicalType::Int64],
1573        );
1574
1575        let result = agg.next().unwrap().unwrap();
1576        assert_eq!(result.row_count(), 1);
1577        // First: 10, Last: 50
1578        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1579        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1580    }
1581
1582    #[test]
1583    fn test_collect() {
1584        let mock = MockOperator::new(vec![create_test_chunk()]);
1585
1586        let mut agg = SimpleAggregateOperator::new(
1587            Box::new(mock),
1588            vec![AggregateExpr::collect(1)],
1589            vec![LogicalType::Any],
1590        );
1591
1592        let result = agg.next().unwrap().unwrap();
1593        let val = result.column(0).unwrap().get_value(0).unwrap();
1594        if let Value::List(items) = val {
1595            assert_eq!(items.len(), 5);
1596        } else {
1597            panic!("Expected List value");
1598        }
1599    }
1600
1601    #[test]
1602    fn test_collect_distinct() {
1603        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1604
1605        let mut agg = SimpleAggregateOperator::new(
1606            Box::new(mock),
1607            vec![AggregateExpr::collect(1).with_distinct()],
1608            vec![LogicalType::Any],
1609        );
1610
1611        let result = agg.next().unwrap().unwrap();
1612        let val = result.column(0).unwrap().get_value(0).unwrap();
1613        if let Value::List(items) = val {
1614            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1615            assert_eq!(items.len(), 3);
1616        } else {
1617            panic!("Expected List value");
1618        }
1619    }
1620
1621    #[test]
1622    fn test_group_concat() {
1623        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1624        for s in ["hello", "world", "foo"] {
1625            builder.column_mut(0).unwrap().push_string(s);
1626            builder.advance_row();
1627        }
1628        let chunk = builder.finish();
1629        let mock = MockOperator::new(vec![chunk]);
1630
1631        let agg_expr = AggregateExpr {
1632            function: AggregateFunction::GroupConcat,
1633            column: Some(0),
1634            column2: None,
1635            distinct: false,
1636            alias: None,
1637            percentile: None,
1638            separator: None,
1639        };
1640
1641        let mut agg =
1642            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1643
1644        let result = agg.next().unwrap().unwrap();
1645        let val = result.column(0).unwrap().get_value(0).unwrap();
1646        assert_eq!(val, Value::String("hello world foo".into()));
1647    }
1648
1649    #[test]
1650    fn test_sample() {
1651        let mock = MockOperator::new(vec![create_test_chunk()]);
1652
1653        let agg_expr = AggregateExpr {
1654            function: AggregateFunction::Sample,
1655            column: Some(1),
1656            column2: None,
1657            distinct: false,
1658            alias: None,
1659            percentile: None,
1660            separator: None,
1661        };
1662
1663        let mut agg =
1664            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1665
1666        let result = agg.next().unwrap().unwrap();
1667        // Sample should return the first non-null value (10)
1668        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1669    }
1670
1671    #[test]
1672    fn test_variance_sample() {
1673        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1674
1675        let agg_expr = AggregateExpr {
1676            function: AggregateFunction::Variance,
1677            column: Some(0),
1678            column2: None,
1679            distinct: false,
1680            alias: None,
1681            percentile: None,
1682            separator: None,
1683        };
1684
1685        let mut agg = SimpleAggregateOperator::new(
1686            Box::new(mock),
1687            vec![agg_expr],
1688            vec![LogicalType::Float64],
1689        );
1690
1691        let result = agg.next().unwrap().unwrap();
1692        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1693        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1694        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1695    }
1696
1697    #[test]
1698    fn test_variance_population() {
1699        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1700
1701        let agg_expr = AggregateExpr {
1702            function: AggregateFunction::VariancePop,
1703            column: Some(0),
1704            column2: None,
1705            distinct: false,
1706            alias: None,
1707            percentile: None,
1708            separator: None,
1709        };
1710
1711        let mut agg = SimpleAggregateOperator::new(
1712            Box::new(mock),
1713            vec![agg_expr],
1714            vec![LogicalType::Float64],
1715        );
1716
1717        let result = agg.next().unwrap().unwrap();
1718        // Population variance: M2/n = 32/8 = 4.0
1719        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1720        assert!((variance - 4.0).abs() < 0.01);
1721    }
1722
1723    #[test]
1724    fn test_variance_single_value() {
1725        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1726        builder.column_mut(0).unwrap().push_int64(42);
1727        builder.advance_row();
1728        let chunk = builder.finish();
1729        let mock = MockOperator::new(vec![chunk]);
1730
1731        let agg_expr = AggregateExpr {
1732            function: AggregateFunction::Variance,
1733            column: Some(0),
1734            column2: None,
1735            distinct: false,
1736            alias: None,
1737            percentile: None,
1738            separator: None,
1739        };
1740
1741        let mut agg = SimpleAggregateOperator::new(
1742            Box::new(mock),
1743            vec![agg_expr],
1744            vec![LogicalType::Float64],
1745        );
1746
1747        let result = agg.next().unwrap().unwrap();
1748        // Sample variance of single value is undefined (null)
1749        assert!(matches!(
1750            result.column(0).unwrap().get_value(0),
1751            Some(Value::Null)
1752        ));
1753    }
1754
1755    #[test]
1756    fn test_empty_aggregation() {
1757        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1758        // (ISO/IEC 39075 Section 20.9)
1759        let mock = MockOperator::new(vec![]);
1760
1761        let mut agg = SimpleAggregateOperator::new(
1762            Box::new(mock),
1763            vec![
1764                AggregateExpr::count_star(),
1765                AggregateExpr::sum(0),
1766                AggregateExpr::avg(0),
1767                AggregateExpr::min(0),
1768                AggregateExpr::max(0),
1769            ],
1770            vec![
1771                LogicalType::Int64,
1772                LogicalType::Int64,
1773                LogicalType::Float64,
1774                LogicalType::Int64,
1775                LogicalType::Int64,
1776            ],
1777        );
1778
1779        let result = agg.next().unwrap().unwrap();
1780        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1781        assert!(matches!(
1782            result.column(1).unwrap().get_value(0),
1783            Some(Value::Null)
1784        )); // SUM
1785        assert!(matches!(
1786            result.column(2).unwrap().get_value(0),
1787            Some(Value::Null)
1788        )); // AVG
1789        assert!(matches!(
1790            result.column(3).unwrap().get_value(0),
1791            Some(Value::Null)
1792        )); // MIN
1793        assert!(matches!(
1794            result.column(4).unwrap().get_value(0),
1795            Some(Value::Null)
1796        )); // MAX
1797    }
1798
1799    #[test]
1800    fn test_stdev_pop_single_value() {
1801        // Single value should return 0 for population stdev
1802        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1803        builder.column_mut(0).unwrap().push_int64(42);
1804        builder.advance_row();
1805        let chunk = builder.finish();
1806
1807        let mock = MockOperator::new(vec![chunk]);
1808
1809        let mut agg = SimpleAggregateOperator::new(
1810            Box::new(mock),
1811            vec![AggregateExpr::stdev_pop(0)],
1812            vec![LogicalType::Float64],
1813        );
1814
1815        let result = agg.next().unwrap().unwrap();
1816        assert_eq!(result.row_count(), 1);
1817        // Population stdev of single value is 0
1818        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1819        assert!((stdev - 0.0).abs() < 0.01);
1820    }
1821}