Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22/// State for a single aggregation computation.
23#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25    /// Count state.
26    Count(i64),
27    /// Count distinct state (count, seen values).
28    CountDistinct(i64, HashSet<HashableValue>),
29    /// Sum state (integer sum, count of values added).
30    SumInt(i64, i64),
31    /// Sum distinct state (integer sum, count, seen values).
32    SumIntDistinct(i64, i64, HashSet<HashableValue>),
33    /// Sum state (float sum, count of values added).
34    SumFloat(f64, i64),
35    /// Sum distinct state (float sum, count, seen values).
36    SumFloatDistinct(f64, i64, HashSet<HashableValue>),
37    /// Average state (sum, count).
38    Avg(f64, i64),
39    /// Average distinct state (sum, count, seen values).
40    AvgDistinct(f64, i64, HashSet<HashableValue>),
41    /// Min state.
42    Min(Option<Value>),
43    /// Max state.
44    Max(Option<Value>),
45    /// First state.
46    First(Option<Value>),
47    /// Last state.
48    Last(Option<Value>),
49    /// Collect state.
50    Collect(Vec<Value>),
51    /// Collect distinct state (values, seen).
52    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDev { count: i64, mean: f64, m2: f64 },
55    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
56    StdDevPop { count: i64, mean: f64, m2: f64 },
57    /// Discrete percentile state (values, percentile).
58    PercentileDisc { values: Vec<f64>, percentile: f64 },
59    /// Continuous percentile state (values, percentile).
60    PercentileCont { values: Vec<f64>, percentile: f64 },
61    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
62    GroupConcat(Vec<String>, String),
63    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
64    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65    /// SAMPLE state (first non-null value encountered).
66    Sample(Option<Value>),
67    /// Sample variance state using Welford's algorithm (count, mean, M2).
68    Variance { count: i64, mean: f64, m2: f64 },
69    /// Population variance state using Welford's algorithm (count, mean, M2).
70    VariancePop { count: i64, mean: f64, m2: f64 },
71    /// Two-variable online statistics (Welford generalization for covariance/regression).
72    Bivariate {
73        /// Which binary set function this state will finalize to.
74        kind: AggregateFunction,
75        count: i64,
76        mean_x: f64,
77        mean_y: f64,
78        m2_x: f64,
79        m2_y: f64,
80        c_xy: f64,
81    },
82}
83
84impl AggregateState {
85    /// Creates initial state for an aggregation function.
86    pub(crate) fn new(
87        function: AggregateFunction,
88        distinct: bool,
89        percentile: Option<f64>,
90        separator: Option<&str>,
91    ) -> Self {
92        match (function, distinct) {
93            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94                AggregateState::Count(0)
95            }
96            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97                AggregateState::CountDistinct(0, HashSet::new())
98            }
99            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
104            (AggregateFunction::Max, _) => AggregateState::Max(None),
105            (AggregateFunction::First, _) => AggregateState::First(None),
106            (AggregateFunction::Last, _) => AggregateState::Last(None),
107            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108            (AggregateFunction::Collect, true) => {
109                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110            }
111            // Statistical functions (Welford's algorithm for online computation)
112            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113                count: 0,
114                mean: 0.0,
115                m2: 0.0,
116            },
117            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118                count: 0,
119                mean: 0.0,
120                m2: 0.0,
121            },
122            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123                values: Vec::new(),
124                percentile: percentile.unwrap_or(0.5),
125            },
126            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127                values: Vec::new(),
128                percentile: percentile.unwrap_or(0.5),
129            },
130            (AggregateFunction::GroupConcat, false) => {
131                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132            }
133            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134                Vec::new(),
135                separator.unwrap_or(" ").to_string(),
136                HashSet::new(),
137            ),
138            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139            // Binary set functions (all share the same Bivariate state)
140            (
141                AggregateFunction::CovarSamp
142                | AggregateFunction::CovarPop
143                | AggregateFunction::Corr
144                | AggregateFunction::RegrSlope
145                | AggregateFunction::RegrIntercept
146                | AggregateFunction::RegrR2
147                | AggregateFunction::RegrCount
148                | AggregateFunction::RegrSxx
149                | AggregateFunction::RegrSyy
150                | AggregateFunction::RegrSxy
151                | AggregateFunction::RegrAvgx
152                | AggregateFunction::RegrAvgy,
153                _,
154            ) => AggregateState::Bivariate {
155                kind: function,
156                count: 0,
157                mean_x: 0.0,
158                mean_y: 0.0,
159                m2_x: 0.0,
160                m2_y: 0.0,
161                c_xy: 0.0,
162            },
163            (AggregateFunction::Variance, _) => AggregateState::Variance {
164                count: 0,
165                mean: 0.0,
166                m2: 0.0,
167            },
168            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169                count: 0,
170                mean: 0.0,
171                m2: 0.0,
172            },
173        }
174    }
175
176    /// Updates the state with a new value.
177    pub(crate) fn update(&mut self, value: Option<Value>) {
178        match self {
179            AggregateState::Count(count) => {
180                *count += 1;
181            }
182            AggregateState::CountDistinct(count, seen) => {
183                if let Some(ref v) = value {
184                    let hashable = HashableValue::from(v);
185                    if seen.insert(hashable) {
186                        *count += 1;
187                    }
188                }
189            }
190            AggregateState::SumInt(sum, count) => {
191                if let Some(Value::Int64(v)) = value {
192                    *sum += v;
193                    *count += 1;
194                } else if let Some(Value::Float64(v)) = value {
195                    // Convert to float sum, carrying count forward
196                    *self = AggregateState::SumFloat(*sum as f64 + v, *count + 1);
197                } else if let Some(ref v) = value {
198                    // RDF stores numeric literals as strings - try to parse
199                    if let Some(num) = value_to_f64(v) {
200                        *self = AggregateState::SumFloat(*sum as f64 + num, *count + 1);
201                    }
202                }
203            }
204            AggregateState::SumIntDistinct(sum, count, seen) => {
205                if let Some(ref v) = value {
206                    let hashable = HashableValue::from(v);
207                    if seen.insert(hashable) {
208                        if let Value::Int64(i) = v {
209                            *sum += i;
210                            *count += 1;
211                        } else if let Value::Float64(f) = v {
212                            // Convert to float distinct: move the seen set instead of cloning
213                            let moved_seen = std::mem::take(seen);
214                            *self = AggregateState::SumFloatDistinct(
215                                *sum as f64 + f,
216                                *count + 1,
217                                moved_seen,
218                            );
219                        } else if let Some(num) = value_to_f64(v) {
220                            // RDF string-encoded numerics
221                            let moved_seen = std::mem::take(seen);
222                            *self = AggregateState::SumFloatDistinct(
223                                *sum as f64 + num,
224                                *count + 1,
225                                moved_seen,
226                            );
227                        }
228                    }
229                }
230            }
231            AggregateState::SumFloat(sum, count) => {
232                if let Some(ref v) = value {
233                    // Use value_to_f64 which now handles strings
234                    if let Some(num) = value_to_f64(v) {
235                        *sum += num;
236                        *count += 1;
237                    }
238                }
239            }
240            AggregateState::SumFloatDistinct(sum, count, seen) => {
241                if let Some(ref v) = value {
242                    let hashable = HashableValue::from(v);
243                    if seen.insert(hashable)
244                        && let Some(num) = value_to_f64(v)
245                    {
246                        *sum += num;
247                        *count += 1;
248                    }
249                }
250            }
251            AggregateState::Avg(sum, count) => {
252                if let Some(ref v) = value
253                    && let Some(num) = value_to_f64(v)
254                {
255                    *sum += num;
256                    *count += 1;
257                }
258            }
259            AggregateState::AvgDistinct(sum, count, seen) => {
260                if let Some(ref v) = value {
261                    let hashable = HashableValue::from(v);
262                    if seen.insert(hashable)
263                        && let Some(num) = value_to_f64(v)
264                    {
265                        *sum += num;
266                        *count += 1;
267                    }
268                }
269            }
270            AggregateState::Min(min) => {
271                if let Some(v) = value {
272                    match min {
273                        None => *min = Some(v),
274                        Some(current) => {
275                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
276                                *min = Some(v);
277                            }
278                        }
279                    }
280                }
281            }
282            AggregateState::Max(max) => {
283                if let Some(v) = value {
284                    match max {
285                        None => *max = Some(v),
286                        Some(current) => {
287                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
288                                *max = Some(v);
289                            }
290                        }
291                    }
292                }
293            }
294            AggregateState::First(first) => {
295                if first.is_none() {
296                    *first = value;
297                }
298            }
299            AggregateState::Last(last) => {
300                if value.is_some() {
301                    *last = value;
302                }
303            }
304            AggregateState::Collect(list) => {
305                if let Some(v) = value {
306                    list.push(v);
307                }
308            }
309            AggregateState::CollectDistinct(list, seen) => {
310                if let Some(v) = value {
311                    let hashable = HashableValue::from(&v);
312                    if seen.insert(hashable) {
313                        list.push(v);
314                    }
315                }
316            }
317            // Statistical functions using Welford's online algorithm
318            AggregateState::StdDev { count, mean, m2 }
319            | AggregateState::StdDevPop { count, mean, m2 }
320            | AggregateState::Variance { count, mean, m2 }
321            | AggregateState::VariancePop { count, mean, m2 } => {
322                if let Some(ref v) = value
323                    && let Some(x) = value_to_f64(v)
324                {
325                    *count += 1;
326                    let delta = x - *mean;
327                    *mean += delta / *count as f64;
328                    let delta2 = x - *mean;
329                    *m2 += delta * delta2;
330                }
331            }
332            AggregateState::PercentileDisc { values, .. }
333            | AggregateState::PercentileCont { values, .. } => {
334                if let Some(ref v) = value
335                    && let Some(x) = value_to_f64(v)
336                {
337                    values.push(x);
338                }
339            }
340            AggregateState::GroupConcat(list, _sep) => {
341                if let Some(v) = value {
342                    list.push(agg_value_to_string(&v));
343                }
344            }
345            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
346                if let Some(v) = value {
347                    let hashable = HashableValue::from(&v);
348                    if seen.insert(hashable) {
349                        list.push(agg_value_to_string(&v));
350                    }
351                }
352            }
353            AggregateState::Sample(sample) => {
354                if sample.is_none() {
355                    *sample = value;
356                }
357            }
358            AggregateState::Bivariate { .. } => {
359                // Bivariate functions require two values; use update_bivariate() instead.
360                // Single-value update is a no-op for bivariate state.
361            }
362        }
363    }
364
365    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
366    ///
367    /// Uses the two-variable Welford online algorithm for numerically stable computation
368    /// of covariance and related statistics. Skips the update if either value is null.
369    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
370        if let AggregateState::Bivariate {
371            count,
372            mean_x,
373            mean_y,
374            m2_x,
375            m2_y,
376            c_xy,
377            ..
378        } = self
379        {
380            // Skip if either value is null (SQL semantics: exclude non-pairs)
381            if let (Some(y), Some(x)) = (&y_val, &x_val)
382                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
383            {
384                *count += 1;
385                let n = *count as f64;
386                let dx = x_f - *mean_x;
387                let dy = y_f - *mean_y;
388                *mean_x += dx / n;
389                *mean_y += dy / n;
390                let dx2 = x_f - *mean_x; // post-update delta
391                let dy2 = y_f - *mean_y; // post-update delta
392                *m2_x += dx * dx2;
393                *m2_y += dy * dy2;
394                *c_xy += dx * dy2;
395            }
396        }
397    }
398
399    /// Finalizes the state and returns the result value.
400    pub(crate) fn finalize(&self) -> Value {
401        match self {
402            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
403                Value::Int64(*count)
404            }
405            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
406                if *count == 0 {
407                    Value::Null
408                } else {
409                    Value::Int64(*sum)
410                }
411            }
412            AggregateState::SumFloat(sum, count)
413            | AggregateState::SumFloatDistinct(sum, count, _) => {
414                if *count == 0 {
415                    Value::Null
416                } else {
417                    Value::Float64(*sum)
418                }
419            }
420            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
421                if *count == 0 {
422                    Value::Null
423                } else {
424                    Value::Float64(*sum / *count as f64)
425                }
426            }
427            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
428            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
429            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
430            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
431            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
432                Value::List(list.clone().into())
433            }
434            // Sample standard deviation: sqrt(M2 / (n - 1))
435            AggregateState::StdDev { count, m2, .. } => {
436                if *count < 2 {
437                    Value::Null
438                } else {
439                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
440                }
441            }
442            // Population standard deviation: sqrt(M2 / n)
443            AggregateState::StdDevPop { count, m2, .. } => {
444                if *count == 0 {
445                    Value::Null
446                } else {
447                    Value::Float64((*m2 / *count as f64).sqrt())
448                }
449            }
450            // Sample variance: M2 / (n - 1)
451            AggregateState::Variance { count, m2, .. } => {
452                if *count < 2 {
453                    Value::Null
454                } else {
455                    Value::Float64(*m2 / (*count - 1) as f64)
456                }
457            }
458            // Population variance: M2 / n
459            AggregateState::VariancePop { count, m2, .. } => {
460                if *count == 0 {
461                    Value::Null
462                } else {
463                    Value::Float64(*m2 / *count as f64)
464                }
465            }
466            // Discrete percentile: return actual value at percentile position
467            AggregateState::PercentileDisc { values, percentile } => {
468                if values.is_empty() {
469                    Value::Null
470                } else {
471                    let mut sorted = values.clone();
472                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
473                    // Index calculation per SQL standard: floor(p * (n - 1))
474                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
475                    Value::Float64(sorted[index])
476                }
477            }
478            // Continuous percentile: interpolate between values
479            AggregateState::PercentileCont { values, percentile } => {
480                if values.is_empty() {
481                    Value::Null
482                } else {
483                    let mut sorted = values.clone();
484                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
485                    // Linear interpolation per SQL standard
486                    let rank = percentile * (sorted.len() - 1) as f64;
487                    let lower_idx = rank.floor() as usize;
488                    let upper_idx = rank.ceil() as usize;
489                    if lower_idx == upper_idx {
490                        Value::Float64(sorted[lower_idx])
491                    } else {
492                        let fraction = rank - lower_idx as f64;
493                        let result =
494                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
495                        Value::Float64(result)
496                    }
497                }
498            }
499            // GROUP_CONCAT: join strings with space separator (SPARQL default)
500            AggregateState::GroupConcat(list, sep)
501            | AggregateState::GroupConcatDistinct(list, sep, _) => {
502                Value::String(list.join(sep).into())
503            }
504            // SAMPLE: return the first non-null value seen
505            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
506            // Binary set functions: dispatch on kind
507            AggregateState::Bivariate {
508                kind,
509                count,
510                mean_x,
511                mean_y,
512                m2_x,
513                m2_y,
514                c_xy,
515            } => {
516                let n = *count;
517                match kind {
518                    AggregateFunction::CovarSamp => {
519                        if n < 2 {
520                            Value::Null
521                        } else {
522                            Value::Float64(*c_xy / (n - 1) as f64)
523                        }
524                    }
525                    AggregateFunction::CovarPop => {
526                        if n == 0 {
527                            Value::Null
528                        } else {
529                            Value::Float64(*c_xy / n as f64)
530                        }
531                    }
532                    AggregateFunction::Corr => {
533                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
534                            Value::Null
535                        } else {
536                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
537                        }
538                    }
539                    AggregateFunction::RegrSlope => {
540                        if n == 0 || *m2_x == 0.0 {
541                            Value::Null
542                        } else {
543                            Value::Float64(*c_xy / *m2_x)
544                        }
545                    }
546                    AggregateFunction::RegrIntercept => {
547                        if n == 0 || *m2_x == 0.0 {
548                            Value::Null
549                        } else {
550                            let slope = *c_xy / *m2_x;
551                            Value::Float64(*mean_y - slope * *mean_x)
552                        }
553                    }
554                    AggregateFunction::RegrR2 => {
555                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
556                            Value::Null
557                        } else {
558                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
559                        }
560                    }
561                    AggregateFunction::RegrCount => Value::Int64(n),
562                    AggregateFunction::RegrSxx => {
563                        if n == 0 {
564                            Value::Null
565                        } else {
566                            Value::Float64(*m2_x)
567                        }
568                    }
569                    AggregateFunction::RegrSyy => {
570                        if n == 0 {
571                            Value::Null
572                        } else {
573                            Value::Float64(*m2_y)
574                        }
575                    }
576                    AggregateFunction::RegrSxy => {
577                        if n == 0 {
578                            Value::Null
579                        } else {
580                            Value::Float64(*c_xy)
581                        }
582                    }
583                    AggregateFunction::RegrAvgx => {
584                        if n == 0 {
585                            Value::Null
586                        } else {
587                            Value::Float64(*mean_x)
588                        }
589                    }
590                    AggregateFunction::RegrAvgy => {
591                        if n == 0 {
592                            Value::Null
593                        } else {
594                            Value::Float64(*mean_y)
595                        }
596                    }
597                    _ => Value::Null, // non-bivariate functions never reach here
598                }
599            }
600        }
601    }
602}
603
604use super::value_utils::{compare_values, value_to_f64};
605
606/// Converts a Value to its string representation for GROUP_CONCAT.
607fn agg_value_to_string(val: &Value) -> String {
608    match val {
609        Value::String(s) => s.to_string(),
610        Value::Int64(i) => i.to_string(),
611        Value::Float64(f) => f.to_string(),
612        Value::Bool(b) => b.to_string(),
613        Value::Null => String::new(),
614        other => format!("{other:?}"),
615    }
616}
617
618/// A group key for hash-based aggregation.
619#[derive(Debug, Clone, PartialEq, Eq, Hash)]
620pub struct GroupKey(Vec<GroupKeyPart>);
621
622#[derive(Debug, Clone, PartialEq, Eq, Hash)]
623enum GroupKeyPart {
624    Null,
625    Bool(bool),
626    Int64(i64),
627    String(ArcStr),
628    Bytes(Arc<[u8]>),
629    Date(grafeo_common::types::Date),
630    Time(grafeo_common::types::Time),
631    Timestamp(grafeo_common::types::Timestamp),
632    Duration(grafeo_common::types::Duration),
633    ZonedDatetime(grafeo_common::types::ZonedDatetime),
634    List(Vec<GroupKeyPart>),
635    Map(Vec<(ArcStr, GroupKeyPart)>),
636}
637
638impl GroupKeyPart {
639    fn from_value(v: Value) -> Self {
640        match v {
641            Value::Null => Self::Null,
642            Value::Bool(b) => Self::Bool(b),
643            Value::Int64(i) => Self::Int64(i),
644            Value::Float64(f) => Self::Int64(f.to_bits() as i64),
645            Value::String(s) => Self::String(s.clone()),
646            Value::Bytes(b) => Self::Bytes(b),
647            Value::Date(d) => Self::Date(d),
648            Value::Time(t) => Self::Time(t),
649            Value::Timestamp(ts) => Self::Timestamp(ts),
650            Value::Duration(d) => Self::Duration(d),
651            Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
652            Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
653            Value::Map(map) => {
654                // BTreeMap already iterates in key order, so this is deterministic
655                let entries: Vec<(ArcStr, GroupKeyPart)> = map
656                    .iter()
657                    .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
658                    .collect();
659                Self::Map(entries)
660            }
661            // Path, Vector, GCounter, OnCounter: use Debug string as fallback
662            other => Self::String(ArcStr::from(format!("{other:?}"))),
663        }
664    }
665
666    fn to_value(&self) -> Value {
667        match self {
668            Self::Null => Value::Null,
669            Self::Bool(b) => Value::Bool(*b),
670            Self::Int64(i) => Value::Int64(*i),
671            Self::String(s) => Value::String(s.clone()),
672            Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
673            Self::Date(d) => Value::Date(*d),
674            Self::Time(t) => Value::Time(*t),
675            Self::Timestamp(ts) => Value::Timestamp(*ts),
676            Self::Duration(d) => Value::Duration(*d),
677            Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
678            Self::List(parts) => {
679                let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
680                Value::List(Arc::from(values.into_boxed_slice()))
681            }
682            Self::Map(entries) => {
683                let map: std::collections::BTreeMap<PropertyKey, Value> = entries
684                    .iter()
685                    .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
686                    .collect();
687                Value::Map(Arc::new(map))
688            }
689        }
690    }
691}
692
693impl GroupKey {
694    /// Creates a group key from column values.
695    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
696        let parts: Vec<GroupKeyPart> = group_columns
697            .iter()
698            .map(|&col_idx| {
699                chunk
700                    .column(col_idx)
701                    .and_then(|col| col.get_value(row))
702                    .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
703            })
704            .collect();
705        GroupKey(parts)
706    }
707
708    /// Converts the group key back to values.
709    fn to_values(&self) -> Vec<Value> {
710        self.0.iter().map(GroupKeyPart::to_value).collect()
711    }
712}
713
714/// Hash-based aggregate operator.
715///
716/// Groups input by key columns and computes aggregations for each group.
717pub struct HashAggregateOperator {
718    /// Child operator to read from.
719    child: Box<dyn Operator>,
720    /// Columns to group by.
721    group_columns: Vec<usize>,
722    /// Aggregation expressions.
723    aggregates: Vec<AggregateExpr>,
724    /// Output schema.
725    output_schema: Vec<LogicalType>,
726    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
727    groups: IndexMap<GroupKey, Vec<AggregateState>>,
728    /// Whether aggregation is complete.
729    aggregation_complete: bool,
730    /// Results iterator.
731    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
732}
733
734impl HashAggregateOperator {
735    /// Creates a new hash aggregate operator.
736    ///
737    /// # Arguments
738    /// * `child` - Child operator to read from.
739    /// * `group_columns` - Column indices to group by.
740    /// * `aggregates` - Aggregation expressions.
741    /// * `output_schema` - Schema of the output (group columns + aggregate results).
742    pub fn new(
743        child: Box<dyn Operator>,
744        group_columns: Vec<usize>,
745        aggregates: Vec<AggregateExpr>,
746        output_schema: Vec<LogicalType>,
747    ) -> Self {
748        Self {
749            child,
750            group_columns,
751            aggregates,
752            output_schema,
753            groups: IndexMap::new(),
754            aggregation_complete: false,
755            results: None,
756        }
757    }
758
759    /// Performs the aggregation.
760    fn aggregate(&mut self) -> Result<(), OperatorError> {
761        while let Some(chunk) = self.child.next()? {
762            for row in chunk.selected_indices() {
763                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
764
765                // Get or create aggregate states for this group
766                let states = self.groups.entry(key).or_insert_with(|| {
767                    self.aggregates
768                        .iter()
769                        .map(|agg| {
770                            AggregateState::new(
771                                agg.function,
772                                agg.distinct,
773                                agg.percentile,
774                                agg.separator.as_deref(),
775                            )
776                        })
777                        .collect()
778                });
779
780                // Update each aggregate
781                for (i, agg) in self.aggregates.iter().enumerate() {
782                    // Binary set functions: read two column values
783                    if agg.column2.is_some() {
784                        let y_val = agg
785                            .column
786                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
787                        let x_val = agg
788                            .column2
789                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
790                        states[i].update_bivariate(y_val, x_val);
791                        continue;
792                    }
793
794                    let value = match (agg.function, agg.distinct) {
795                        // COUNT(*) without DISTINCT doesn't need a value
796                        (AggregateFunction::Count, false) => None,
797                        // COUNT DISTINCT needs the actual value to track unique values
798                        (AggregateFunction::Count, true) => agg
799                            .column
800                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
801                        _ => agg
802                            .column
803                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
804                    };
805
806                    // For COUNT without DISTINCT, always update. For others, skip nulls.
807                    match (agg.function, agg.distinct) {
808                        (AggregateFunction::Count, false) => states[i].update(None),
809                        (AggregateFunction::Count, true) => {
810                            // COUNT DISTINCT needs the value to track unique values
811                            if value.is_some() && !matches!(value, Some(Value::Null)) {
812                                states[i].update(value);
813                            }
814                        }
815                        (AggregateFunction::CountNonNull, _) => {
816                            if value.is_some() && !matches!(value, Some(Value::Null)) {
817                                states[i].update(value);
818                            }
819                        }
820                        _ => {
821                            if value.is_some() && !matches!(value, Some(Value::Null)) {
822                                states[i].update(value);
823                            }
824                        }
825                    }
826                }
827            }
828        }
829
830        self.aggregation_complete = true;
831
832        // Convert to results iterator (IndexMap::drain takes a range)
833        let results: Vec<_> = self.groups.drain(..).collect();
834        self.results = Some(results.into_iter());
835
836        Ok(())
837    }
838}
839
840impl Operator for HashAggregateOperator {
841    fn next(&mut self) -> OperatorResult {
842        // Perform aggregation if not done
843        if !self.aggregation_complete {
844            self.aggregate()?;
845        }
846
847        // Special case: no groups (global aggregation with no data)
848        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
849            // For global aggregation (no GROUP BY), return one row with initial values
850            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
851
852            for agg in &self.aggregates {
853                let state = AggregateState::new(
854                    agg.function,
855                    agg.distinct,
856                    agg.percentile,
857                    agg.separator.as_deref(),
858                );
859                let value = state.finalize();
860                if let Some(col) = builder.column_mut(self.group_columns.len()) {
861                    col.push_value(value);
862                }
863            }
864            builder.advance_row();
865
866            self.results = Some(Vec::new().into_iter()); // Mark as done
867            return Ok(Some(builder.finish()));
868        }
869
870        let Some(results) = &mut self.results else {
871            return Ok(None);
872        };
873
874        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
875
876        for (key, states) in results.by_ref() {
877            // Output group key columns
878            let key_values = key.to_values();
879            for (i, value) in key_values.into_iter().enumerate() {
880                if let Some(col) = builder.column_mut(i) {
881                    col.push_value(value);
882                }
883            }
884
885            // Output aggregate results
886            for (i, state) in states.iter().enumerate() {
887                let col_idx = self.group_columns.len() + i;
888                if let Some(col) = builder.column_mut(col_idx) {
889                    col.push_value(state.finalize());
890                }
891            }
892
893            builder.advance_row();
894
895            if builder.is_full() {
896                return Ok(Some(builder.finish()));
897            }
898        }
899
900        if builder.row_count() > 0 {
901            Ok(Some(builder.finish()))
902        } else {
903            Ok(None)
904        }
905    }
906
907    fn reset(&mut self) {
908        self.child.reset();
909        self.groups.clear();
910        self.aggregation_complete = false;
911        self.results = None;
912    }
913
914    fn name(&self) -> &'static str {
915        "HashAggregate"
916    }
917}
918
919/// Simple (non-grouping) aggregate operator for global aggregations.
920///
921/// Used when there's no GROUP BY clause - aggregates all input into a single row.
922pub struct SimpleAggregateOperator {
923    /// Child operator.
924    child: Box<dyn Operator>,
925    /// Aggregation expressions.
926    aggregates: Vec<AggregateExpr>,
927    /// Output schema.
928    output_schema: Vec<LogicalType>,
929    /// Aggregate states.
930    states: Vec<AggregateState>,
931    /// Whether aggregation is complete.
932    done: bool,
933}
934
935impl SimpleAggregateOperator {
936    /// Creates a new simple aggregate operator.
937    pub fn new(
938        child: Box<dyn Operator>,
939        aggregates: Vec<AggregateExpr>,
940        output_schema: Vec<LogicalType>,
941    ) -> Self {
942        let states = aggregates
943            .iter()
944            .map(|agg| {
945                AggregateState::new(
946                    agg.function,
947                    agg.distinct,
948                    agg.percentile,
949                    agg.separator.as_deref(),
950                )
951            })
952            .collect();
953
954        Self {
955            child,
956            aggregates,
957            output_schema,
958            states,
959            done: false,
960        }
961    }
962}
963
964impl Operator for SimpleAggregateOperator {
965    fn next(&mut self) -> OperatorResult {
966        if self.done {
967            return Ok(None);
968        }
969
970        // Process all input
971        while let Some(chunk) = self.child.next()? {
972            for row in chunk.selected_indices() {
973                for (i, agg) in self.aggregates.iter().enumerate() {
974                    // Binary set functions: read two column values
975                    if agg.column2.is_some() {
976                        let y_val = agg
977                            .column
978                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
979                        let x_val = agg
980                            .column2
981                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
982                        self.states[i].update_bivariate(y_val, x_val);
983                        continue;
984                    }
985
986                    let value = match (agg.function, agg.distinct) {
987                        // COUNT(*) without DISTINCT doesn't need a value
988                        (AggregateFunction::Count, false) => None,
989                        // COUNT DISTINCT needs the actual value to track unique values
990                        (AggregateFunction::Count, true) => agg
991                            .column
992                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
993                        _ => agg
994                            .column
995                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
996                    };
997
998                    match (agg.function, agg.distinct) {
999                        (AggregateFunction::Count, false) => self.states[i].update(None),
1000                        (AggregateFunction::Count, true) => {
1001                            // COUNT DISTINCT needs the value to track unique values
1002                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1003                                self.states[i].update(value);
1004                            }
1005                        }
1006                        (AggregateFunction::CountNonNull, _) => {
1007                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1008                                self.states[i].update(value);
1009                            }
1010                        }
1011                        _ => {
1012                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1013                                self.states[i].update(value);
1014                            }
1015                        }
1016                    }
1017                }
1018            }
1019        }
1020
1021        // Output single result row
1022        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1023
1024        for (i, state) in self.states.iter().enumerate() {
1025            if let Some(col) = builder.column_mut(i) {
1026                col.push_value(state.finalize());
1027            }
1028        }
1029        builder.advance_row();
1030
1031        self.done = true;
1032        Ok(Some(builder.finish()))
1033    }
1034
1035    fn reset(&mut self) {
1036        self.child.reset();
1037        self.states = self
1038            .aggregates
1039            .iter()
1040            .map(|agg| {
1041                AggregateState::new(
1042                    agg.function,
1043                    agg.distinct,
1044                    agg.percentile,
1045                    agg.separator.as_deref(),
1046                )
1047            })
1048            .collect();
1049        self.done = false;
1050    }
1051
1052    fn name(&self) -> &'static str {
1053        "SimpleAggregate"
1054    }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    use crate::execution::chunk::DataChunkBuilder;
1061
1062    struct MockOperator {
1063        chunks: Vec<DataChunk>,
1064        position: usize,
1065    }
1066
1067    impl MockOperator {
1068        fn new(chunks: Vec<DataChunk>) -> Self {
1069            Self {
1070                chunks,
1071                position: 0,
1072            }
1073        }
1074    }
1075
1076    impl Operator for MockOperator {
1077        fn next(&mut self) -> OperatorResult {
1078            if self.position < self.chunks.len() {
1079                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1080                self.position += 1;
1081                Ok(Some(chunk))
1082            } else {
1083                Ok(None)
1084            }
1085        }
1086
1087        fn reset(&mut self) {
1088            self.position = 0;
1089        }
1090
1091        fn name(&self) -> &'static str {
1092            "Mock"
1093        }
1094    }
1095
1096    fn create_test_chunk() -> DataChunk {
1097        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1098        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1099
1100        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1101        for (group, value) in data {
1102            builder.column_mut(0).unwrap().push_int64(group);
1103            builder.column_mut(1).unwrap().push_int64(value);
1104            builder.advance_row();
1105        }
1106
1107        builder.finish()
1108    }
1109
1110    #[test]
1111    fn test_simple_count() {
1112        let mock = MockOperator::new(vec![create_test_chunk()]);
1113
1114        let mut agg = SimpleAggregateOperator::new(
1115            Box::new(mock),
1116            vec![AggregateExpr::count_star()],
1117            vec![LogicalType::Int64],
1118        );
1119
1120        let result = agg.next().unwrap().unwrap();
1121        assert_eq!(result.row_count(), 1);
1122        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1123
1124        // Should be done
1125        assert!(agg.next().unwrap().is_none());
1126    }
1127
1128    #[test]
1129    fn test_simple_sum() {
1130        let mock = MockOperator::new(vec![create_test_chunk()]);
1131
1132        let mut agg = SimpleAggregateOperator::new(
1133            Box::new(mock),
1134            vec![AggregateExpr::sum(1)], // Sum of column 1
1135            vec![LogicalType::Int64],
1136        );
1137
1138        let result = agg.next().unwrap().unwrap();
1139        assert_eq!(result.row_count(), 1);
1140        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1141        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1142    }
1143
1144    #[test]
1145    fn test_simple_avg() {
1146        let mock = MockOperator::new(vec![create_test_chunk()]);
1147
1148        let mut agg = SimpleAggregateOperator::new(
1149            Box::new(mock),
1150            vec![AggregateExpr::avg(1)],
1151            vec![LogicalType::Float64],
1152        );
1153
1154        let result = agg.next().unwrap().unwrap();
1155        assert_eq!(result.row_count(), 1);
1156        // Avg: 150 / 5 = 30.0
1157        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1158        assert!((avg - 30.0).abs() < 0.001);
1159    }
1160
1161    #[test]
1162    fn test_simple_min_max() {
1163        let mock = MockOperator::new(vec![create_test_chunk()]);
1164
1165        let mut agg = SimpleAggregateOperator::new(
1166            Box::new(mock),
1167            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1168            vec![LogicalType::Int64, LogicalType::Int64],
1169        );
1170
1171        let result = agg.next().unwrap().unwrap();
1172        assert_eq!(result.row_count(), 1);
1173        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1174        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1175    }
1176
1177    #[test]
1178    fn test_sum_with_string_values() {
1179        // Test SUM with string values (like RDF stores numeric literals)
1180        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1181        builder.column_mut(0).unwrap().push_string("30");
1182        builder.advance_row();
1183        builder.column_mut(0).unwrap().push_string("25");
1184        builder.advance_row();
1185        builder.column_mut(0).unwrap().push_string("35");
1186        builder.advance_row();
1187        let chunk = builder.finish();
1188
1189        let mock = MockOperator::new(vec![chunk]);
1190        let mut agg = SimpleAggregateOperator::new(
1191            Box::new(mock),
1192            vec![AggregateExpr::sum(0)],
1193            vec![LogicalType::Float64],
1194        );
1195
1196        let result = agg.next().unwrap().unwrap();
1197        assert_eq!(result.row_count(), 1);
1198        // Should parse strings and sum: 30 + 25 + 35 = 90
1199        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1200        assert!(
1201            (sum_val - 90.0).abs() < 0.001,
1202            "Expected 90.0, got {}",
1203            sum_val
1204        );
1205    }
1206
1207    #[test]
1208    fn test_grouped_aggregation() {
1209        let mock = MockOperator::new(vec![create_test_chunk()]);
1210
1211        // GROUP BY column 0, SUM(column 1)
1212        let mut agg = HashAggregateOperator::new(
1213            Box::new(mock),
1214            vec![0],                     // Group by column 0
1215            vec![AggregateExpr::sum(1)], // Sum of column 1
1216            vec![LogicalType::Int64, LogicalType::Int64],
1217        );
1218
1219        let mut results: Vec<(i64, i64)> = Vec::new();
1220        while let Some(chunk) = agg.next().unwrap() {
1221            for row in chunk.selected_indices() {
1222                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1223                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1224                results.push((group, sum));
1225            }
1226        }
1227
1228        results.sort_by_key(|(g, _)| *g);
1229        assert_eq!(results.len(), 2);
1230        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1231        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1232    }
1233
1234    #[test]
1235    fn test_grouped_count() {
1236        let mock = MockOperator::new(vec![create_test_chunk()]);
1237
1238        // GROUP BY column 0, COUNT(*)
1239        let mut agg = HashAggregateOperator::new(
1240            Box::new(mock),
1241            vec![0],
1242            vec![AggregateExpr::count_star()],
1243            vec![LogicalType::Int64, LogicalType::Int64],
1244        );
1245
1246        let mut results: Vec<(i64, i64)> = Vec::new();
1247        while let Some(chunk) = agg.next().unwrap() {
1248            for row in chunk.selected_indices() {
1249                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1250                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1251                results.push((group, count));
1252            }
1253        }
1254
1255        results.sort_by_key(|(g, _)| *g);
1256        assert_eq!(results.len(), 2);
1257        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1258        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1259    }
1260
1261    #[test]
1262    fn test_multiple_aggregates() {
1263        let mock = MockOperator::new(vec![create_test_chunk()]);
1264
1265        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1266        let mut agg = HashAggregateOperator::new(
1267            Box::new(mock),
1268            vec![0],
1269            vec![
1270                AggregateExpr::count_star(),
1271                AggregateExpr::sum(1),
1272                AggregateExpr::avg(1),
1273            ],
1274            vec![
1275                LogicalType::Int64,   // Group key
1276                LogicalType::Int64,   // COUNT
1277                LogicalType::Int64,   // SUM
1278                LogicalType::Float64, // AVG
1279            ],
1280        );
1281
1282        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1283        while let Some(chunk) = agg.next().unwrap() {
1284            for row in chunk.selected_indices() {
1285                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1286                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1287                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1288                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1289                results.push((group, count, sum, avg));
1290            }
1291        }
1292
1293        results.sort_by_key(|(g, _, _, _)| *g);
1294        assert_eq!(results.len(), 2);
1295
1296        // Group 1: COUNT=2, SUM=30, AVG=15.0
1297        assert_eq!(results[0].0, 1);
1298        assert_eq!(results[0].1, 2);
1299        assert_eq!(results[0].2, 30);
1300        assert!((results[0].3 - 15.0).abs() < 0.001);
1301
1302        // Group 2: COUNT=3, SUM=120, AVG=40.0
1303        assert_eq!(results[1].0, 2);
1304        assert_eq!(results[1].1, 3);
1305        assert_eq!(results[1].2, 120);
1306        assert!((results[1].3 - 40.0).abs() < 0.001);
1307    }
1308
1309    fn create_test_chunk_with_duplicates() -> DataChunk {
1310        // Create data with duplicate values in column 1
1311        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1312        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1313        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1314        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1315
1316        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1317        for (group, value) in data {
1318            builder.column_mut(0).unwrap().push_int64(group);
1319            builder.column_mut(1).unwrap().push_int64(value);
1320            builder.advance_row();
1321        }
1322
1323        builder.finish()
1324    }
1325
1326    #[test]
1327    fn test_count_distinct() {
1328        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1329
1330        // COUNT(DISTINCT column 1)
1331        let mut agg = SimpleAggregateOperator::new(
1332            Box::new(mock),
1333            vec![AggregateExpr::count(1).with_distinct()],
1334            vec![LogicalType::Int64],
1335        );
1336
1337        let result = agg.next().unwrap().unwrap();
1338        assert_eq!(result.row_count(), 1);
1339        // Total distinct values: 10, 20, 30 = 3 distinct values
1340        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1341    }
1342
1343    #[test]
1344    fn test_grouped_count_distinct() {
1345        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1346
1347        // GROUP BY column 0, COUNT(DISTINCT column 1)
1348        let mut agg = HashAggregateOperator::new(
1349            Box::new(mock),
1350            vec![0],
1351            vec![AggregateExpr::count(1).with_distinct()],
1352            vec![LogicalType::Int64, LogicalType::Int64],
1353        );
1354
1355        let mut results: Vec<(i64, i64)> = Vec::new();
1356        while let Some(chunk) = agg.next().unwrap() {
1357            for row in chunk.selected_indices() {
1358                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1359                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1360                results.push((group, count));
1361            }
1362        }
1363
1364        results.sort_by_key(|(g, _)| *g);
1365        assert_eq!(results.len(), 2);
1366        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1367        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1368    }
1369
1370    #[test]
1371    fn test_sum_distinct() {
1372        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1373
1374        // SUM(DISTINCT column 1)
1375        let mut agg = SimpleAggregateOperator::new(
1376            Box::new(mock),
1377            vec![AggregateExpr::sum(1).with_distinct()],
1378            vec![LogicalType::Int64],
1379        );
1380
1381        let result = agg.next().unwrap().unwrap();
1382        assert_eq!(result.row_count(), 1);
1383        // Sum of distinct values: 10 + 20 + 30 = 60
1384        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1385    }
1386
1387    #[test]
1388    fn test_avg_distinct() {
1389        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1390
1391        // AVG(DISTINCT column 1)
1392        let mut agg = SimpleAggregateOperator::new(
1393            Box::new(mock),
1394            vec![AggregateExpr::avg(1).with_distinct()],
1395            vec![LogicalType::Float64],
1396        );
1397
1398        let result = agg.next().unwrap().unwrap();
1399        assert_eq!(result.row_count(), 1);
1400        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1401        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1402        assert!((avg - 20.0).abs() < 0.001);
1403    }
1404
1405    fn create_statistical_test_chunk() -> DataChunk {
1406        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1407        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1408        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1409
1410        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1411            builder.column_mut(0).unwrap().push_int64(value);
1412            builder.advance_row();
1413        }
1414
1415        builder.finish()
1416    }
1417
1418    #[test]
1419    fn test_stdev_sample() {
1420        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1421
1422        let mut agg = SimpleAggregateOperator::new(
1423            Box::new(mock),
1424            vec![AggregateExpr::stdev(0)],
1425            vec![LogicalType::Float64],
1426        );
1427
1428        let result = agg.next().unwrap().unwrap();
1429        assert_eq!(result.row_count(), 1);
1430        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1431        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1432        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1433        assert!((stdev - 2.138).abs() < 0.01);
1434    }
1435
1436    #[test]
1437    fn test_stdev_population() {
1438        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1439
1440        let mut agg = SimpleAggregateOperator::new(
1441            Box::new(mock),
1442            vec![AggregateExpr::stdev_pop(0)],
1443            vec![LogicalType::Float64],
1444        );
1445
1446        let result = agg.next().unwrap().unwrap();
1447        assert_eq!(result.row_count(), 1);
1448        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1449        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1450        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1451        assert!((stdev - 2.0).abs() < 0.01);
1452    }
1453
1454    #[test]
1455    fn test_percentile_disc() {
1456        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1457
1458        // Median (50th percentile discrete)
1459        let mut agg = SimpleAggregateOperator::new(
1460            Box::new(mock),
1461            vec![AggregateExpr::percentile_disc(0, 0.5)],
1462            vec![LogicalType::Float64],
1463        );
1464
1465        let result = agg.next().unwrap().unwrap();
1466        assert_eq!(result.row_count(), 1);
1467        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1468        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1469        assert!((percentile - 4.0).abs() < 0.01);
1470    }
1471
1472    #[test]
1473    fn test_percentile_cont() {
1474        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1475
1476        // Median (50th percentile continuous)
1477        let mut agg = SimpleAggregateOperator::new(
1478            Box::new(mock),
1479            vec![AggregateExpr::percentile_cont(0, 0.5)],
1480            vec![LogicalType::Float64],
1481        );
1482
1483        let result = agg.next().unwrap().unwrap();
1484        assert_eq!(result.row_count(), 1);
1485        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1486        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1487        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1488        assert!((percentile - 4.5).abs() < 0.01);
1489    }
1490
1491    #[test]
1492    fn test_percentile_extremes() {
1493        // Test 0th and 100th percentiles
1494        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1495
1496        let mut agg = SimpleAggregateOperator::new(
1497            Box::new(mock),
1498            vec![
1499                AggregateExpr::percentile_disc(0, 0.0),
1500                AggregateExpr::percentile_disc(0, 1.0),
1501            ],
1502            vec![LogicalType::Float64, LogicalType::Float64],
1503        );
1504
1505        let result = agg.next().unwrap().unwrap();
1506        assert_eq!(result.row_count(), 1);
1507        // 0th percentile = minimum = 2
1508        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1509        assert!((p0 - 2.0).abs() < 0.01);
1510        // 100th percentile = maximum = 9
1511        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1512        assert!((p100 - 9.0).abs() < 0.01);
1513    }
1514
1515    #[test]
1516    fn test_stdev_single_value() {
1517        // Single value should return null for sample stdev
1518        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1519        builder.column_mut(0).unwrap().push_int64(42);
1520        builder.advance_row();
1521        let chunk = builder.finish();
1522
1523        let mock = MockOperator::new(vec![chunk]);
1524
1525        let mut agg = SimpleAggregateOperator::new(
1526            Box::new(mock),
1527            vec![AggregateExpr::stdev(0)],
1528            vec![LogicalType::Float64],
1529        );
1530
1531        let result = agg.next().unwrap().unwrap();
1532        assert_eq!(result.row_count(), 1);
1533        // Sample stdev of single value is undefined (null)
1534        assert!(matches!(
1535            result.column(0).unwrap().get_value(0),
1536            Some(Value::Null)
1537        ));
1538    }
1539
1540    #[test]
1541    fn test_first_and_last() {
1542        let mock = MockOperator::new(vec![create_test_chunk()]);
1543
1544        let mut agg = SimpleAggregateOperator::new(
1545            Box::new(mock),
1546            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1547            vec![LogicalType::Int64, LogicalType::Int64],
1548        );
1549
1550        let result = agg.next().unwrap().unwrap();
1551        assert_eq!(result.row_count(), 1);
1552        // First: 10, Last: 50
1553        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1554        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1555    }
1556
1557    #[test]
1558    fn test_collect() {
1559        let mock = MockOperator::new(vec![create_test_chunk()]);
1560
1561        let mut agg = SimpleAggregateOperator::new(
1562            Box::new(mock),
1563            vec![AggregateExpr::collect(1)],
1564            vec![LogicalType::Any],
1565        );
1566
1567        let result = agg.next().unwrap().unwrap();
1568        let val = result.column(0).unwrap().get_value(0).unwrap();
1569        if let Value::List(items) = val {
1570            assert_eq!(items.len(), 5);
1571        } else {
1572            panic!("Expected List value");
1573        }
1574    }
1575
1576    #[test]
1577    fn test_collect_distinct() {
1578        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1579
1580        let mut agg = SimpleAggregateOperator::new(
1581            Box::new(mock),
1582            vec![AggregateExpr::collect(1).with_distinct()],
1583            vec![LogicalType::Any],
1584        );
1585
1586        let result = agg.next().unwrap().unwrap();
1587        let val = result.column(0).unwrap().get_value(0).unwrap();
1588        if let Value::List(items) = val {
1589            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1590            assert_eq!(items.len(), 3);
1591        } else {
1592            panic!("Expected List value");
1593        }
1594    }
1595
1596    #[test]
1597    fn test_group_concat() {
1598        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1599        for s in ["hello", "world", "foo"] {
1600            builder.column_mut(0).unwrap().push_string(s);
1601            builder.advance_row();
1602        }
1603        let chunk = builder.finish();
1604        let mock = MockOperator::new(vec![chunk]);
1605
1606        let agg_expr = AggregateExpr {
1607            function: AggregateFunction::GroupConcat,
1608            column: Some(0),
1609            column2: None,
1610            distinct: false,
1611            alias: None,
1612            percentile: None,
1613            separator: None,
1614        };
1615
1616        let mut agg =
1617            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1618
1619        let result = agg.next().unwrap().unwrap();
1620        let val = result.column(0).unwrap().get_value(0).unwrap();
1621        assert_eq!(val, Value::String("hello world foo".into()));
1622    }
1623
1624    #[test]
1625    fn test_sample() {
1626        let mock = MockOperator::new(vec![create_test_chunk()]);
1627
1628        let agg_expr = AggregateExpr {
1629            function: AggregateFunction::Sample,
1630            column: Some(1),
1631            column2: None,
1632            distinct: false,
1633            alias: None,
1634            percentile: None,
1635            separator: None,
1636        };
1637
1638        let mut agg =
1639            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1640
1641        let result = agg.next().unwrap().unwrap();
1642        // Sample should return the first non-null value (10)
1643        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1644    }
1645
1646    #[test]
1647    fn test_variance_sample() {
1648        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1649
1650        let agg_expr = AggregateExpr {
1651            function: AggregateFunction::Variance,
1652            column: Some(0),
1653            column2: None,
1654            distinct: false,
1655            alias: None,
1656            percentile: None,
1657            separator: None,
1658        };
1659
1660        let mut agg = SimpleAggregateOperator::new(
1661            Box::new(mock),
1662            vec![agg_expr],
1663            vec![LogicalType::Float64],
1664        );
1665
1666        let result = agg.next().unwrap().unwrap();
1667        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1668        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1669        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1670    }
1671
1672    #[test]
1673    fn test_variance_population() {
1674        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1675
1676        let agg_expr = AggregateExpr {
1677            function: AggregateFunction::VariancePop,
1678            column: Some(0),
1679            column2: None,
1680            distinct: false,
1681            alias: None,
1682            percentile: None,
1683            separator: None,
1684        };
1685
1686        let mut agg = SimpleAggregateOperator::new(
1687            Box::new(mock),
1688            vec![agg_expr],
1689            vec![LogicalType::Float64],
1690        );
1691
1692        let result = agg.next().unwrap().unwrap();
1693        // Population variance: M2/n = 32/8 = 4.0
1694        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1695        assert!((variance - 4.0).abs() < 0.01);
1696    }
1697
1698    #[test]
1699    fn test_variance_single_value() {
1700        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1701        builder.column_mut(0).unwrap().push_int64(42);
1702        builder.advance_row();
1703        let chunk = builder.finish();
1704        let mock = MockOperator::new(vec![chunk]);
1705
1706        let agg_expr = AggregateExpr {
1707            function: AggregateFunction::Variance,
1708            column: Some(0),
1709            column2: None,
1710            distinct: false,
1711            alias: None,
1712            percentile: None,
1713            separator: None,
1714        };
1715
1716        let mut agg = SimpleAggregateOperator::new(
1717            Box::new(mock),
1718            vec![agg_expr],
1719            vec![LogicalType::Float64],
1720        );
1721
1722        let result = agg.next().unwrap().unwrap();
1723        // Sample variance of single value is undefined (null)
1724        assert!(matches!(
1725            result.column(0).unwrap().get_value(0),
1726            Some(Value::Null)
1727        ));
1728    }
1729
1730    #[test]
1731    fn test_empty_aggregation() {
1732        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1733        // (ISO/IEC 39075 Section 20.9)
1734        let mock = MockOperator::new(vec![]);
1735
1736        let mut agg = SimpleAggregateOperator::new(
1737            Box::new(mock),
1738            vec![
1739                AggregateExpr::count_star(),
1740                AggregateExpr::sum(0),
1741                AggregateExpr::avg(0),
1742                AggregateExpr::min(0),
1743                AggregateExpr::max(0),
1744            ],
1745            vec![
1746                LogicalType::Int64,
1747                LogicalType::Int64,
1748                LogicalType::Float64,
1749                LogicalType::Int64,
1750                LogicalType::Int64,
1751            ],
1752        );
1753
1754        let result = agg.next().unwrap().unwrap();
1755        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1756        assert!(matches!(
1757            result.column(1).unwrap().get_value(0),
1758            Some(Value::Null)
1759        )); // SUM
1760        assert!(matches!(
1761            result.column(2).unwrap().get_value(0),
1762            Some(Value::Null)
1763        )); // AVG
1764        assert!(matches!(
1765            result.column(3).unwrap().get_value(0),
1766            Some(Value::Null)
1767        )); // MIN
1768        assert!(matches!(
1769            result.column(4).unwrap().get_value(0),
1770            Some(Value::Null)
1771        )); // MAX
1772    }
1773
1774    #[test]
1775    fn test_stdev_pop_single_value() {
1776        // Single value should return 0 for population stdev
1777        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1778        builder.column_mut(0).unwrap().push_int64(42);
1779        builder.advance_row();
1780        let chunk = builder.finish();
1781
1782        let mock = MockOperator::new(vec![chunk]);
1783
1784        let mut agg = SimpleAggregateOperator::new(
1785            Box::new(mock),
1786            vec![AggregateExpr::stdev_pop(0)],
1787            vec![LogicalType::Float64],
1788        );
1789
1790        let result = agg.next().unwrap().unwrap();
1791        assert_eq!(result.row_count(), 1);
1792        // Population stdev of single value is 0
1793        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1794        assert!((stdev - 0.0).abs() < 0.01);
1795    }
1796}