Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - [`HashAggregateOperator`]: Hash-based grouping with aggregation functions
5//! - [`SimpleAggregateOperator`]: Global aggregation without GROUP BY
6//!
7//! Shared types ([`AggregateFunction`], [`AggregateExpr`], [`HashableValue`]) live in
8//! the [`super::accumulator`] module.
9
10use indexmap::IndexMap;
11use std::collections::HashSet;
12use std::sync::Arc;
13
14use arcstr::ArcStr;
15use grafeo_common::types::{LogicalType, PropertyKey, Value};
16
17use super::accumulator::{AggregateExpr, AggregateFunction, HashableValue};
18use super::{Operator, OperatorError, OperatorResult};
19use crate::execution::DataChunk;
20use crate::execution::chunk::DataChunkBuilder;
21
22/// State for a single aggregation computation.
23#[derive(Debug, Clone)]
24pub(crate) enum AggregateState {
25    /// Count state.
26    Count(i64),
27    /// Count distinct state (count, seen values).
28    CountDistinct(i64, HashSet<HashableValue>),
29    /// Sum state (integer sum, count of values added).
30    SumInt(i64, i64),
31    /// Sum distinct state (integer sum, count, seen values).
32    SumIntDistinct(i64, i64, HashSet<HashableValue>),
33    /// Sum state (float sum, count of values added).
34    SumFloat(f64, f64, i64),
35    /// Sum distinct state (float sum, compensation, count, seen values).
36    SumFloatDistinct(f64, f64, i64, HashSet<HashableValue>),
37    /// Average state (sum, count).
38    Avg(f64, i64),
39    /// Average distinct state (sum, count, seen values).
40    AvgDistinct(f64, i64, HashSet<HashableValue>),
41    /// Min state.
42    Min(Option<Value>),
43    /// Max state.
44    Max(Option<Value>),
45    /// First state.
46    First(Option<Value>),
47    /// Last state.
48    Last(Option<Value>),
49    /// Collect state.
50    Collect(Vec<Value>),
51    /// Collect distinct state (values, seen).
52    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
53    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
54    StdDev { count: i64, mean: f64, m2: f64 },
55    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
56    StdDevPop { count: i64, mean: f64, m2: f64 },
57    /// Discrete percentile state (values, percentile).
58    PercentileDisc { values: Vec<f64>, percentile: f64 },
59    /// Continuous percentile state (values, percentile).
60    PercentileCont { values: Vec<f64>, percentile: f64 },
61    /// GROUP_CONCAT / LISTAGG state (collected string values, separator).
62    GroupConcat(Vec<String>, String),
63    /// GROUP_CONCAT / LISTAGG distinct state (collected string values, separator, seen).
64    GroupConcatDistinct(Vec<String>, String, HashSet<HashableValue>),
65    /// SAMPLE state (first non-null value encountered).
66    Sample(Option<Value>),
67    /// Sample variance state using Welford's algorithm (count, mean, M2).
68    Variance { count: i64, mean: f64, m2: f64 },
69    /// Population variance state using Welford's algorithm (count, mean, M2).
70    VariancePop { count: i64, mean: f64, m2: f64 },
71    /// Two-variable online statistics (Welford generalization for covariance/regression).
72    Bivariate {
73        /// Which binary set function this state will finalize to.
74        kind: AggregateFunction,
75        count: i64,
76        mean_x: f64,
77        mean_y: f64,
78        m2_x: f64,
79        m2_y: f64,
80        c_xy: f64,
81    },
82}
83
84impl AggregateState {
85    /// Creates initial state for an aggregation function.
86    pub(crate) fn new(
87        function: AggregateFunction,
88        distinct: bool,
89        percentile: Option<f64>,
90        separator: Option<&str>,
91    ) -> Self {
92        match (function, distinct) {
93            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
94                AggregateState::Count(0)
95            }
96            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
97                AggregateState::CountDistinct(0, HashSet::new())
98            }
99            (AggregateFunction::Sum, false) => AggregateState::SumInt(0, 0),
100            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, 0, HashSet::new()),
101            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
102            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
103            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
104            (AggregateFunction::Max, _) => AggregateState::Max(None),
105            (AggregateFunction::First, _) => AggregateState::First(None),
106            (AggregateFunction::Last, _) => AggregateState::Last(None),
107            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
108            (AggregateFunction::Collect, true) => {
109                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
110            }
111            // Statistical functions (Welford's algorithm for online computation)
112            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
113                count: 0,
114                mean: 0.0,
115                m2: 0.0,
116            },
117            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
118                count: 0,
119                mean: 0.0,
120                m2: 0.0,
121            },
122            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
123                values: Vec::new(),
124                percentile: percentile.unwrap_or(0.5),
125            },
126            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
127                values: Vec::new(),
128                percentile: percentile.unwrap_or(0.5),
129            },
130            (AggregateFunction::GroupConcat, false) => {
131                AggregateState::GroupConcat(Vec::new(), separator.unwrap_or(" ").to_string())
132            }
133            (AggregateFunction::GroupConcat, true) => AggregateState::GroupConcatDistinct(
134                Vec::new(),
135                separator.unwrap_or(" ").to_string(),
136                HashSet::new(),
137            ),
138            (AggregateFunction::Sample, _) => AggregateState::Sample(None),
139            // Binary set functions (all share the same Bivariate state)
140            (
141                AggregateFunction::CovarSamp
142                | AggregateFunction::CovarPop
143                | AggregateFunction::Corr
144                | AggregateFunction::RegrSlope
145                | AggregateFunction::RegrIntercept
146                | AggregateFunction::RegrR2
147                | AggregateFunction::RegrCount
148                | AggregateFunction::RegrSxx
149                | AggregateFunction::RegrSyy
150                | AggregateFunction::RegrSxy
151                | AggregateFunction::RegrAvgx
152                | AggregateFunction::RegrAvgy,
153                _,
154            ) => AggregateState::Bivariate {
155                kind: function,
156                count: 0,
157                mean_x: 0.0,
158                mean_y: 0.0,
159                m2_x: 0.0,
160                m2_y: 0.0,
161                c_xy: 0.0,
162            },
163            (AggregateFunction::Variance, _) => AggregateState::Variance {
164                count: 0,
165                mean: 0.0,
166                m2: 0.0,
167            },
168            (AggregateFunction::VariancePop, _) => AggregateState::VariancePop {
169                count: 0,
170                mean: 0.0,
171                m2: 0.0,
172            },
173        }
174    }
175
176    /// Updates the state with a new value.
177    pub(crate) fn update(&mut self, value: Option<Value>) {
178        match self {
179            AggregateState::Count(count) => {
180                *count += 1;
181            }
182            AggregateState::CountDistinct(count, seen) => {
183                if let Some(ref v) = value {
184                    let hashable = HashableValue::from(v);
185                    if seen.insert(hashable) {
186                        *count += 1;
187                    }
188                }
189            }
190            AggregateState::SumInt(sum, count) => {
191                if let Some(Value::Int64(v)) = value {
192                    *sum += v;
193                    *count += 1;
194                } else if let Some(Value::Float64(v)) = value {
195                    // Convert to float sum, carrying count forward
196                    *self = AggregateState::SumFloat(*sum as f64 + v, 0.0, *count + 1);
197                } else if let Some(ref v) = value {
198                    // RDF stores numeric literals as strings - try to parse
199                    if let Some(num) = value_to_f64(v) {
200                        *self = AggregateState::SumFloat(*sum as f64 + num, 0.0, *count + 1);
201                    }
202                }
203            }
204            AggregateState::SumIntDistinct(sum, count, seen) => {
205                if let Some(ref v) = value {
206                    let hashable = HashableValue::from(v);
207                    if seen.insert(hashable) {
208                        if let Value::Int64(i) = v {
209                            *sum += i;
210                            *count += 1;
211                        } else if let Value::Float64(f) = v {
212                            // Convert to float distinct: move the seen set instead of cloning
213                            let moved_seen = std::mem::take(seen);
214                            *self = AggregateState::SumFloatDistinct(
215                                *sum as f64 + f,
216                                0.0,
217                                *count + 1,
218                                moved_seen,
219                            );
220                        } else if let Some(num) = value_to_f64(v) {
221                            // RDF string-encoded numerics
222                            let moved_seen = std::mem::take(seen);
223                            *self = AggregateState::SumFloatDistinct(
224                                *sum as f64 + num,
225                                0.0,
226                                *count + 1,
227                                moved_seen,
228                            );
229                        }
230                    }
231                }
232            }
233            AggregateState::SumFloat(sum, comp, count) => {
234                if let Some(ref v) = value {
235                    // Use value_to_f64 which now handles strings
236                    if let Some(num) = value_to_f64(v) {
237                        // Kahan compensated summation to reduce rounding error
238                        let y = num - *comp;
239                        let t = *sum + y;
240                        *comp = (t - *sum) - y;
241                        *sum = t;
242                        *count += 1;
243                    }
244                }
245            }
246            AggregateState::SumFloatDistinct(sum, comp, count, seen) => {
247                if let Some(ref v) = value {
248                    let hashable = HashableValue::from(v);
249                    if seen.insert(hashable)
250                        && let Some(num) = value_to_f64(v)
251                    {
252                        let y = num - *comp;
253                        let t = *sum + y;
254                        *comp = (t - *sum) - y;
255                        *sum = t;
256                        *count += 1;
257                    }
258                }
259            }
260            AggregateState::Avg(sum, count) => {
261                if let Some(ref v) = value
262                    && let Some(num) = value_to_f64(v)
263                {
264                    *sum += num;
265                    *count += 1;
266                }
267            }
268            AggregateState::AvgDistinct(sum, count, seen) => {
269                if let Some(ref v) = value {
270                    let hashable = HashableValue::from(v);
271                    if seen.insert(hashable)
272                        && let Some(num) = value_to_f64(v)
273                    {
274                        *sum += num;
275                        *count += 1;
276                    }
277                }
278            }
279            AggregateState::Min(min) => {
280                if let Some(v) = value {
281                    match min {
282                        None => *min = Some(v),
283                        Some(current) => {
284                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
285                                *min = Some(v);
286                            }
287                        }
288                    }
289                }
290            }
291            AggregateState::Max(max) => {
292                if let Some(v) = value {
293                    match max {
294                        None => *max = Some(v),
295                        Some(current) => {
296                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
297                                *max = Some(v);
298                            }
299                        }
300                    }
301                }
302            }
303            AggregateState::First(first) => {
304                if first.is_none() {
305                    *first = value;
306                }
307            }
308            AggregateState::Last(last) => {
309                if value.is_some() {
310                    *last = value;
311                }
312            }
313            AggregateState::Collect(list) => {
314                if let Some(v) = value {
315                    list.push(v);
316                }
317            }
318            AggregateState::CollectDistinct(list, seen) => {
319                if let Some(v) = value {
320                    let hashable = HashableValue::from(&v);
321                    if seen.insert(hashable) {
322                        list.push(v);
323                    }
324                }
325            }
326            // Statistical functions using Welford's online algorithm
327            AggregateState::StdDev { count, mean, m2 }
328            | AggregateState::StdDevPop { count, mean, m2 }
329            | AggregateState::Variance { count, mean, m2 }
330            | AggregateState::VariancePop { count, mean, m2 } => {
331                if let Some(ref v) = value
332                    && let Some(x) = value_to_f64(v)
333                {
334                    *count += 1;
335                    let delta = x - *mean;
336                    *mean += delta / *count as f64;
337                    let delta2 = x - *mean;
338                    *m2 += delta * delta2;
339                }
340            }
341            AggregateState::PercentileDisc { values, .. }
342            | AggregateState::PercentileCont { values, .. } => {
343                if let Some(ref v) = value
344                    && let Some(x) = value_to_f64(v)
345                {
346                    values.push(x);
347                }
348            }
349            AggregateState::GroupConcat(list, _sep) => {
350                if let Some(v) = value {
351                    list.push(agg_value_to_string(&v));
352                }
353            }
354            AggregateState::GroupConcatDistinct(list, _sep, seen) => {
355                if let Some(v) = value {
356                    let hashable = HashableValue::from(&v);
357                    if seen.insert(hashable) {
358                        list.push(agg_value_to_string(&v));
359                    }
360                }
361            }
362            AggregateState::Sample(sample) => {
363                if sample.is_none() {
364                    *sample = value;
365                }
366            }
367            AggregateState::Bivariate { .. } => {
368                // Bivariate functions require two values; use update_bivariate() instead.
369                // Single-value update is a no-op for bivariate state.
370            }
371        }
372    }
373
374    /// Updates a bivariate (two-variable) aggregate state with a pair of values.
375    ///
376    /// Uses the two-variable Welford online algorithm for numerically stable computation
377    /// of covariance and related statistics. Skips the update if either value is null.
378    fn update_bivariate(&mut self, y_val: Option<Value>, x_val: Option<Value>) {
379        if let AggregateState::Bivariate {
380            count,
381            mean_x,
382            mean_y,
383            m2_x,
384            m2_y,
385            c_xy,
386            ..
387        } = self
388        {
389            // Skip if either value is null (SQL semantics: exclude non-pairs)
390            if let (Some(y), Some(x)) = (&y_val, &x_val)
391                && let (Some(y_f), Some(x_f)) = (value_to_f64(y), value_to_f64(x))
392            {
393                *count += 1;
394                let n = *count as f64;
395                let dx = x_f - *mean_x;
396                let dy = y_f - *mean_y;
397                *mean_x += dx / n;
398                *mean_y += dy / n;
399                let dx2 = x_f - *mean_x; // post-update delta
400                let dy2 = y_f - *mean_y; // post-update delta
401                *m2_x += dx * dx2;
402                *m2_y += dy * dy2;
403                *c_xy += dx * dy2;
404            }
405        }
406    }
407
408    /// Finalizes the state and returns the result value.
409    pub(crate) fn finalize(&self) -> Value {
410        match self {
411            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
412                Value::Int64(*count)
413            }
414            AggregateState::SumInt(sum, count) | AggregateState::SumIntDistinct(sum, count, _) => {
415                if *count == 0 {
416                    Value::Null
417                } else {
418                    Value::Int64(*sum)
419                }
420            }
421            AggregateState::SumFloat(sum, _, count)
422            | AggregateState::SumFloatDistinct(sum, _, count, _) => {
423                if *count == 0 {
424                    Value::Null
425                } else {
426                    Value::Float64(*sum)
427                }
428            }
429            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
430                if *count == 0 {
431                    Value::Null
432                } else {
433                    Value::Float64(*sum / *count as f64)
434                }
435            }
436            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
437            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
438            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
439            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
440            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
441                Value::List(list.clone().into())
442            }
443            // Sample standard deviation: sqrt(M2 / (n - 1))
444            AggregateState::StdDev { count, m2, .. } => {
445                if *count < 2 {
446                    Value::Null
447                } else {
448                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
449                }
450            }
451            // Population standard deviation: sqrt(M2 / n)
452            AggregateState::StdDevPop { count, m2, .. } => {
453                if *count == 0 {
454                    Value::Null
455                } else {
456                    Value::Float64((*m2 / *count as f64).sqrt())
457                }
458            }
459            // Sample variance: M2 / (n - 1)
460            AggregateState::Variance { count, m2, .. } => {
461                if *count < 2 {
462                    Value::Null
463                } else {
464                    Value::Float64(*m2 / (*count - 1) as f64)
465                }
466            }
467            // Population variance: M2 / n
468            AggregateState::VariancePop { count, m2, .. } => {
469                if *count == 0 {
470                    Value::Null
471                } else {
472                    Value::Float64(*m2 / *count as f64)
473                }
474            }
475            // Discrete percentile: return actual value at percentile position
476            AggregateState::PercentileDisc { values, percentile } => {
477                if values.is_empty() {
478                    Value::Null
479                } else {
480                    let mut sorted = values.clone();
481                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
482                    // Index calculation per SQL standard: floor(p * (n - 1))
483                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
484                    Value::Float64(sorted[index])
485                }
486            }
487            // Continuous percentile: interpolate between values
488            AggregateState::PercentileCont { values, percentile } => {
489                if values.is_empty() {
490                    Value::Null
491                } else {
492                    let mut sorted = values.clone();
493                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
494                    // Linear interpolation per SQL standard
495                    let rank = percentile * (sorted.len() - 1) as f64;
496                    let lower_idx = rank.floor() as usize;
497                    let upper_idx = rank.ceil() as usize;
498                    if lower_idx == upper_idx {
499                        Value::Float64(sorted[lower_idx])
500                    } else {
501                        let fraction = rank - lower_idx as f64;
502                        let result =
503                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
504                        Value::Float64(result)
505                    }
506                }
507            }
508            // GROUP_CONCAT: join strings with space separator (SPARQL default)
509            AggregateState::GroupConcat(list, sep)
510            | AggregateState::GroupConcatDistinct(list, sep, _) => {
511                Value::String(list.join(sep).into())
512            }
513            // SAMPLE: return the first non-null value seen
514            AggregateState::Sample(sample) => sample.clone().unwrap_or(Value::Null),
515            // Binary set functions: dispatch on kind
516            AggregateState::Bivariate {
517                kind,
518                count,
519                mean_x,
520                mean_y,
521                m2_x,
522                m2_y,
523                c_xy,
524            } => {
525                let n = *count;
526                match kind {
527                    AggregateFunction::CovarSamp => {
528                        if n < 2 {
529                            Value::Null
530                        } else {
531                            Value::Float64(*c_xy / (n - 1) as f64)
532                        }
533                    }
534                    AggregateFunction::CovarPop => {
535                        if n == 0 {
536                            Value::Null
537                        } else {
538                            Value::Float64(*c_xy / n as f64)
539                        }
540                    }
541                    AggregateFunction::Corr => {
542                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
543                            Value::Null
544                        } else {
545                            Value::Float64(*c_xy / (*m2_x * *m2_y).sqrt())
546                        }
547                    }
548                    AggregateFunction::RegrSlope => {
549                        if n == 0 || *m2_x == 0.0 {
550                            Value::Null
551                        } else {
552                            Value::Float64(*c_xy / *m2_x)
553                        }
554                    }
555                    AggregateFunction::RegrIntercept => {
556                        if n == 0 || *m2_x == 0.0 {
557                            Value::Null
558                        } else {
559                            let slope = *c_xy / *m2_x;
560                            Value::Float64(*mean_y - slope * *mean_x)
561                        }
562                    }
563                    AggregateFunction::RegrR2 => {
564                        if n == 0 || *m2_x == 0.0 || *m2_y == 0.0 {
565                            Value::Null
566                        } else {
567                            Value::Float64((*c_xy * *c_xy) / (*m2_x * *m2_y))
568                        }
569                    }
570                    AggregateFunction::RegrCount => Value::Int64(n),
571                    AggregateFunction::RegrSxx => {
572                        if n == 0 {
573                            Value::Null
574                        } else {
575                            Value::Float64(*m2_x)
576                        }
577                    }
578                    AggregateFunction::RegrSyy => {
579                        if n == 0 {
580                            Value::Null
581                        } else {
582                            Value::Float64(*m2_y)
583                        }
584                    }
585                    AggregateFunction::RegrSxy => {
586                        if n == 0 {
587                            Value::Null
588                        } else {
589                            Value::Float64(*c_xy)
590                        }
591                    }
592                    AggregateFunction::RegrAvgx => {
593                        if n == 0 {
594                            Value::Null
595                        } else {
596                            Value::Float64(*mean_x)
597                        }
598                    }
599                    AggregateFunction::RegrAvgy => {
600                        if n == 0 {
601                            Value::Null
602                        } else {
603                            Value::Float64(*mean_y)
604                        }
605                    }
606                    _ => Value::Null, // non-bivariate functions never reach here
607                }
608            }
609        }
610    }
611}
612
613use super::value_utils::{compare_values, value_to_f64};
614
615/// Converts a Value to its string representation for GROUP_CONCAT.
616fn agg_value_to_string(val: &Value) -> String {
617    match val {
618        Value::String(s) => s.to_string(),
619        Value::Int64(i) => i.to_string(),
620        Value::Float64(f) => f.to_string(),
621        Value::Bool(b) => b.to_string(),
622        Value::Null => String::new(),
623        other => format!("{other:?}"),
624    }
625}
626
627/// A group key for hash-based aggregation.
628#[derive(Debug, Clone, PartialEq, Eq, Hash)]
629pub struct GroupKey(Vec<GroupKeyPart>);
630
631#[derive(Debug, Clone, PartialEq, Eq, Hash)]
632enum GroupKeyPart {
633    Null,
634    Bool(bool),
635    Int64(i64),
636    String(ArcStr),
637    Bytes(Arc<[u8]>),
638    Date(grafeo_common::types::Date),
639    Time(grafeo_common::types::Time),
640    Timestamp(grafeo_common::types::Timestamp),
641    Duration(grafeo_common::types::Duration),
642    ZonedDatetime(grafeo_common::types::ZonedDatetime),
643    List(Vec<GroupKeyPart>),
644    Map(Vec<(ArcStr, GroupKeyPart)>),
645}
646
647impl GroupKeyPart {
648    fn from_value(v: Value) -> Self {
649        match v {
650            Value::Null => Self::Null,
651            Value::Bool(b) => Self::Bool(b),
652            Value::Int64(i) => Self::Int64(i),
653            Value::Float64(f) => Self::Int64(f.to_bits() as i64),
654            Value::String(s) => Self::String(s.clone()),
655            Value::Bytes(b) => Self::Bytes(b),
656            Value::Date(d) => Self::Date(d),
657            Value::Time(t) => Self::Time(t),
658            Value::Timestamp(ts) => Self::Timestamp(ts),
659            Value::Duration(d) => Self::Duration(d),
660            Value::ZonedDatetime(zdt) => Self::ZonedDatetime(zdt),
661            Value::List(items) => Self::List(items.iter().cloned().map(Self::from_value).collect()),
662            Value::Map(map) => {
663                // BTreeMap already iterates in key order, so this is deterministic
664                let entries: Vec<(ArcStr, GroupKeyPart)> = map
665                    .iter()
666                    .map(|(k, v)| (ArcStr::from(k.as_str()), Self::from_value(v.clone())))
667                    .collect();
668                Self::Map(entries)
669            }
670            // Path, Vector, GCounter, OnCounter: use Debug string as fallback
671            other => Self::String(ArcStr::from(format!("{other:?}"))),
672        }
673    }
674
675    fn to_value(&self) -> Value {
676        match self {
677            Self::Null => Value::Null,
678            Self::Bool(b) => Value::Bool(*b),
679            Self::Int64(i) => Value::Int64(*i),
680            Self::String(s) => Value::String(s.clone()),
681            Self::Bytes(b) => Value::Bytes(Arc::clone(b)),
682            Self::Date(d) => Value::Date(*d),
683            Self::Time(t) => Value::Time(*t),
684            Self::Timestamp(ts) => Value::Timestamp(*ts),
685            Self::Duration(d) => Value::Duration(*d),
686            Self::ZonedDatetime(zdt) => Value::ZonedDatetime(*zdt),
687            Self::List(parts) => {
688                let values: Vec<Value> = parts.iter().map(Self::to_value).collect();
689                Value::List(Arc::from(values.into_boxed_slice()))
690            }
691            Self::Map(entries) => {
692                let map: std::collections::BTreeMap<PropertyKey, Value> = entries
693                    .iter()
694                    .map(|(k, v)| (PropertyKey::new(k.as_str()), v.to_value()))
695                    .collect();
696                Value::Map(Arc::new(map))
697            }
698        }
699    }
700}
701
702impl GroupKey {
703    /// Creates a group key from column values.
704    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
705        let parts: Vec<GroupKeyPart> = group_columns
706            .iter()
707            .map(|&col_idx| {
708                chunk
709                    .column(col_idx)
710                    .and_then(|col| col.get_value(row))
711                    .map_or(GroupKeyPart::Null, GroupKeyPart::from_value)
712            })
713            .collect();
714        GroupKey(parts)
715    }
716
717    /// Converts the group key back to values.
718    fn to_values(&self) -> Vec<Value> {
719        self.0.iter().map(GroupKeyPart::to_value).collect()
720    }
721}
722
723/// Hash-based aggregate operator.
724///
725/// Groups input by key columns and computes aggregations for each group.
726pub struct HashAggregateOperator {
727    /// Child operator to read from.
728    child: Box<dyn Operator>,
729    /// Columns to group by.
730    group_columns: Vec<usize>,
731    /// Aggregation expressions.
732    aggregates: Vec<AggregateExpr>,
733    /// Output schema.
734    output_schema: Vec<LogicalType>,
735    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
736    groups: IndexMap<GroupKey, Vec<AggregateState>>,
737    /// Whether aggregation is complete.
738    aggregation_complete: bool,
739    /// Results iterator.
740    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
741}
742
743impl HashAggregateOperator {
744    /// Creates a new hash aggregate operator.
745    ///
746    /// # Arguments
747    /// * `child` - Child operator to read from.
748    /// * `group_columns` - Column indices to group by.
749    /// * `aggregates` - Aggregation expressions.
750    /// * `output_schema` - Schema of the output (group columns + aggregate results).
751    pub fn new(
752        child: Box<dyn Operator>,
753        group_columns: Vec<usize>,
754        aggregates: Vec<AggregateExpr>,
755        output_schema: Vec<LogicalType>,
756    ) -> Self {
757        Self {
758            child,
759            group_columns,
760            aggregates,
761            output_schema,
762            groups: IndexMap::new(),
763            aggregation_complete: false,
764            results: None,
765        }
766    }
767
768    /// Performs the aggregation.
769    fn aggregate(&mut self) -> Result<(), OperatorError> {
770        while let Some(chunk) = self.child.next()? {
771            for row in chunk.selected_indices() {
772                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
773
774                // Get or create aggregate states for this group
775                let states = self.groups.entry(key).or_insert_with(|| {
776                    self.aggregates
777                        .iter()
778                        .map(|agg| {
779                            AggregateState::new(
780                                agg.function,
781                                agg.distinct,
782                                agg.percentile,
783                                agg.separator.as_deref(),
784                            )
785                        })
786                        .collect()
787                });
788
789                // Update each aggregate
790                for (i, agg) in self.aggregates.iter().enumerate() {
791                    // Binary set functions: read two column values
792                    if agg.column2.is_some() {
793                        let y_val = agg
794                            .column
795                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
796                        let x_val = agg
797                            .column2
798                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
799                        states[i].update_bivariate(y_val, x_val);
800                        continue;
801                    }
802
803                    let value = match (agg.function, agg.distinct) {
804                        // COUNT(*) without DISTINCT doesn't need a value
805                        (AggregateFunction::Count, false) => None,
806                        // COUNT DISTINCT needs the actual value to track unique values
807                        (AggregateFunction::Count, true) => agg
808                            .column
809                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
810                        _ => agg
811                            .column
812                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
813                    };
814
815                    // For COUNT without DISTINCT, always update. For others, skip nulls.
816                    match (agg.function, agg.distinct) {
817                        (AggregateFunction::Count, false) => states[i].update(None),
818                        (AggregateFunction::Count, true) => {
819                            // COUNT DISTINCT needs the value to track unique values
820                            if value.is_some() && !matches!(value, Some(Value::Null)) {
821                                states[i].update(value);
822                            }
823                        }
824                        (AggregateFunction::CountNonNull, _) => {
825                            if value.is_some() && !matches!(value, Some(Value::Null)) {
826                                states[i].update(value);
827                            }
828                        }
829                        _ => {
830                            if value.is_some() && !matches!(value, Some(Value::Null)) {
831                                states[i].update(value);
832                            }
833                        }
834                    }
835                }
836            }
837        }
838
839        self.aggregation_complete = true;
840
841        // Convert to results iterator (IndexMap::drain takes a range)
842        let results: Vec<_> = self.groups.drain(..).collect();
843        self.results = Some(results.into_iter());
844
845        Ok(())
846    }
847}
848
849impl Operator for HashAggregateOperator {
850    fn next(&mut self) -> OperatorResult {
851        // Perform aggregation if not done
852        if !self.aggregation_complete {
853            self.aggregate()?;
854        }
855
856        // Special case: no groups (global aggregation with no data)
857        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
858            // For global aggregation (no GROUP BY), return one row with initial values
859            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
860
861            for agg in &self.aggregates {
862                let state = AggregateState::new(
863                    agg.function,
864                    agg.distinct,
865                    agg.percentile,
866                    agg.separator.as_deref(),
867                );
868                let value = state.finalize();
869                if let Some(col) = builder.column_mut(self.group_columns.len()) {
870                    col.push_value(value);
871                }
872            }
873            builder.advance_row();
874
875            self.results = Some(Vec::new().into_iter()); // Mark as done
876            return Ok(Some(builder.finish()));
877        }
878
879        let Some(results) = &mut self.results else {
880            return Ok(None);
881        };
882
883        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
884
885        for (key, states) in results.by_ref() {
886            // Output group key columns
887            let key_values = key.to_values();
888            for (i, value) in key_values.into_iter().enumerate() {
889                if let Some(col) = builder.column_mut(i) {
890                    col.push_value(value);
891                }
892            }
893
894            // Output aggregate results
895            for (i, state) in states.iter().enumerate() {
896                let col_idx = self.group_columns.len() + i;
897                if let Some(col) = builder.column_mut(col_idx) {
898                    col.push_value(state.finalize());
899                }
900            }
901
902            builder.advance_row();
903
904            if builder.is_full() {
905                return Ok(Some(builder.finish()));
906            }
907        }
908
909        if builder.row_count() > 0 {
910            Ok(Some(builder.finish()))
911        } else {
912            Ok(None)
913        }
914    }
915
916    fn reset(&mut self) {
917        self.child.reset();
918        self.groups.clear();
919        self.aggregation_complete = false;
920        self.results = None;
921    }
922
923    fn name(&self) -> &'static str {
924        "HashAggregate"
925    }
926}
927
928/// Simple (non-grouping) aggregate operator for global aggregations.
929///
930/// Used when there's no GROUP BY clause - aggregates all input into a single row.
931pub struct SimpleAggregateOperator {
932    /// Child operator.
933    child: Box<dyn Operator>,
934    /// Aggregation expressions.
935    aggregates: Vec<AggregateExpr>,
936    /// Output schema.
937    output_schema: Vec<LogicalType>,
938    /// Aggregate states.
939    states: Vec<AggregateState>,
940    /// Whether aggregation is complete.
941    done: bool,
942}
943
944impl SimpleAggregateOperator {
945    /// Creates a new simple aggregate operator.
946    pub fn new(
947        child: Box<dyn Operator>,
948        aggregates: Vec<AggregateExpr>,
949        output_schema: Vec<LogicalType>,
950    ) -> Self {
951        let states = aggregates
952            .iter()
953            .map(|agg| {
954                AggregateState::new(
955                    agg.function,
956                    agg.distinct,
957                    agg.percentile,
958                    agg.separator.as_deref(),
959                )
960            })
961            .collect();
962
963        Self {
964            child,
965            aggregates,
966            output_schema,
967            states,
968            done: false,
969        }
970    }
971}
972
973impl Operator for SimpleAggregateOperator {
974    fn next(&mut self) -> OperatorResult {
975        if self.done {
976            return Ok(None);
977        }
978
979        // Process all input
980        while let Some(chunk) = self.child.next()? {
981            for row in chunk.selected_indices() {
982                for (i, agg) in self.aggregates.iter().enumerate() {
983                    // Binary set functions: read two column values
984                    if agg.column2.is_some() {
985                        let y_val = agg
986                            .column
987                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
988                        let x_val = agg
989                            .column2
990                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
991                        self.states[i].update_bivariate(y_val, x_val);
992                        continue;
993                    }
994
995                    let value = match (agg.function, agg.distinct) {
996                        // COUNT(*) without DISTINCT doesn't need a value
997                        (AggregateFunction::Count, false) => None,
998                        // COUNT DISTINCT needs the actual value to track unique values
999                        (AggregateFunction::Count, true) => agg
1000                            .column
1001                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1002                        _ => agg
1003                            .column
1004                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
1005                    };
1006
1007                    match (agg.function, agg.distinct) {
1008                        (AggregateFunction::Count, false) => self.states[i].update(None),
1009                        (AggregateFunction::Count, true) => {
1010                            // COUNT DISTINCT needs the value to track unique values
1011                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1012                                self.states[i].update(value);
1013                            }
1014                        }
1015                        (AggregateFunction::CountNonNull, _) => {
1016                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1017                                self.states[i].update(value);
1018                            }
1019                        }
1020                        _ => {
1021                            if value.is_some() && !matches!(value, Some(Value::Null)) {
1022                                self.states[i].update(value);
1023                            }
1024                        }
1025                    }
1026                }
1027            }
1028        }
1029
1030        // Output single result row
1031        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
1032
1033        for (i, state) in self.states.iter().enumerate() {
1034            if let Some(col) = builder.column_mut(i) {
1035                col.push_value(state.finalize());
1036            }
1037        }
1038        builder.advance_row();
1039
1040        self.done = true;
1041        Ok(Some(builder.finish()))
1042    }
1043
1044    fn reset(&mut self) {
1045        self.child.reset();
1046        self.states = self
1047            .aggregates
1048            .iter()
1049            .map(|agg| {
1050                AggregateState::new(
1051                    agg.function,
1052                    agg.distinct,
1053                    agg.percentile,
1054                    agg.separator.as_deref(),
1055                )
1056            })
1057            .collect();
1058        self.done = false;
1059    }
1060
1061    fn name(&self) -> &'static str {
1062        "SimpleAggregate"
1063    }
1064}
1065
1066#[cfg(test)]
1067mod tests {
1068    use super::*;
1069    use crate::execution::chunk::DataChunkBuilder;
1070
1071    struct MockOperator {
1072        chunks: Vec<DataChunk>,
1073        position: usize,
1074    }
1075
1076    impl MockOperator {
1077        fn new(chunks: Vec<DataChunk>) -> Self {
1078            Self {
1079                chunks,
1080                position: 0,
1081            }
1082        }
1083    }
1084
1085    impl Operator for MockOperator {
1086        fn next(&mut self) -> OperatorResult {
1087            if self.position < self.chunks.len() {
1088                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
1089                self.position += 1;
1090                Ok(Some(chunk))
1091            } else {
1092                Ok(None)
1093            }
1094        }
1095
1096        fn reset(&mut self) {
1097            self.position = 0;
1098        }
1099
1100        fn name(&self) -> &'static str {
1101            "Mock"
1102        }
1103    }
1104
1105    fn create_test_chunk() -> DataChunk {
1106        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
1107        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1108
1109        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
1110        for (group, value) in data {
1111            builder.column_mut(0).unwrap().push_int64(group);
1112            builder.column_mut(1).unwrap().push_int64(value);
1113            builder.advance_row();
1114        }
1115
1116        builder.finish()
1117    }
1118
1119    #[test]
1120    fn test_simple_count() {
1121        let mock = MockOperator::new(vec![create_test_chunk()]);
1122
1123        let mut agg = SimpleAggregateOperator::new(
1124            Box::new(mock),
1125            vec![AggregateExpr::count_star()],
1126            vec![LogicalType::Int64],
1127        );
1128
1129        let result = agg.next().unwrap().unwrap();
1130        assert_eq!(result.row_count(), 1);
1131        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1132
1133        // Should be done
1134        assert!(agg.next().unwrap().is_none());
1135    }
1136
1137    #[test]
1138    fn test_simple_sum() {
1139        let mock = MockOperator::new(vec![create_test_chunk()]);
1140
1141        let mut agg = SimpleAggregateOperator::new(
1142            Box::new(mock),
1143            vec![AggregateExpr::sum(1)], // Sum of column 1
1144            vec![LogicalType::Int64],
1145        );
1146
1147        let result = agg.next().unwrap().unwrap();
1148        assert_eq!(result.row_count(), 1);
1149        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1150        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1151    }
1152
1153    #[test]
1154    fn test_simple_avg() {
1155        let mock = MockOperator::new(vec![create_test_chunk()]);
1156
1157        let mut agg = SimpleAggregateOperator::new(
1158            Box::new(mock),
1159            vec![AggregateExpr::avg(1)],
1160            vec![LogicalType::Float64],
1161        );
1162
1163        let result = agg.next().unwrap().unwrap();
1164        assert_eq!(result.row_count(), 1);
1165        // Avg: 150 / 5 = 30.0
1166        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1167        assert!((avg - 30.0).abs() < 0.001);
1168    }
1169
1170    #[test]
1171    fn test_simple_min_max() {
1172        let mock = MockOperator::new(vec![create_test_chunk()]);
1173
1174        let mut agg = SimpleAggregateOperator::new(
1175            Box::new(mock),
1176            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1177            vec![LogicalType::Int64, LogicalType::Int64],
1178        );
1179
1180        let result = agg.next().unwrap().unwrap();
1181        assert_eq!(result.row_count(), 1);
1182        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1183        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1184    }
1185
1186    #[test]
1187    fn test_sum_with_string_values() {
1188        // Test SUM with string values (like RDF stores numeric literals)
1189        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1190        builder.column_mut(0).unwrap().push_string("30");
1191        builder.advance_row();
1192        builder.column_mut(0).unwrap().push_string("25");
1193        builder.advance_row();
1194        builder.column_mut(0).unwrap().push_string("35");
1195        builder.advance_row();
1196        let chunk = builder.finish();
1197
1198        let mock = MockOperator::new(vec![chunk]);
1199        let mut agg = SimpleAggregateOperator::new(
1200            Box::new(mock),
1201            vec![AggregateExpr::sum(0)],
1202            vec![LogicalType::Float64],
1203        );
1204
1205        let result = agg.next().unwrap().unwrap();
1206        assert_eq!(result.row_count(), 1);
1207        // Should parse strings and sum: 30 + 25 + 35 = 90
1208        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1209        assert!(
1210            (sum_val - 90.0).abs() < 0.001,
1211            "Expected 90.0, got {}",
1212            sum_val
1213        );
1214    }
1215
1216    #[test]
1217    fn test_grouped_aggregation() {
1218        let mock = MockOperator::new(vec![create_test_chunk()]);
1219
1220        // GROUP BY column 0, SUM(column 1)
1221        let mut agg = HashAggregateOperator::new(
1222            Box::new(mock),
1223            vec![0],                     // Group by column 0
1224            vec![AggregateExpr::sum(1)], // Sum of column 1
1225            vec![LogicalType::Int64, LogicalType::Int64],
1226        );
1227
1228        let mut results: Vec<(i64, i64)> = Vec::new();
1229        while let Some(chunk) = agg.next().unwrap() {
1230            for row in chunk.selected_indices() {
1231                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1232                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1233                results.push((group, sum));
1234            }
1235        }
1236
1237        results.sort_by_key(|(g, _)| *g);
1238        assert_eq!(results.len(), 2);
1239        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1240        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1241    }
1242
1243    #[test]
1244    fn test_grouped_count() {
1245        let mock = MockOperator::new(vec![create_test_chunk()]);
1246
1247        // GROUP BY column 0, COUNT(*)
1248        let mut agg = HashAggregateOperator::new(
1249            Box::new(mock),
1250            vec![0],
1251            vec![AggregateExpr::count_star()],
1252            vec![LogicalType::Int64, LogicalType::Int64],
1253        );
1254
1255        let mut results: Vec<(i64, i64)> = Vec::new();
1256        while let Some(chunk) = agg.next().unwrap() {
1257            for row in chunk.selected_indices() {
1258                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1259                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1260                results.push((group, count));
1261            }
1262        }
1263
1264        results.sort_by_key(|(g, _)| *g);
1265        assert_eq!(results.len(), 2);
1266        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1267        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1268    }
1269
1270    #[test]
1271    fn test_multiple_aggregates() {
1272        let mock = MockOperator::new(vec![create_test_chunk()]);
1273
1274        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1275        let mut agg = HashAggregateOperator::new(
1276            Box::new(mock),
1277            vec![0],
1278            vec![
1279                AggregateExpr::count_star(),
1280                AggregateExpr::sum(1),
1281                AggregateExpr::avg(1),
1282            ],
1283            vec![
1284                LogicalType::Int64,   // Group key
1285                LogicalType::Int64,   // COUNT
1286                LogicalType::Int64,   // SUM
1287                LogicalType::Float64, // AVG
1288            ],
1289        );
1290
1291        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1292        while let Some(chunk) = agg.next().unwrap() {
1293            for row in chunk.selected_indices() {
1294                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1295                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1296                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1297                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1298                results.push((group, count, sum, avg));
1299            }
1300        }
1301
1302        results.sort_by_key(|(g, _, _, _)| *g);
1303        assert_eq!(results.len(), 2);
1304
1305        // Group 1: COUNT=2, SUM=30, AVG=15.0
1306        assert_eq!(results[0].0, 1);
1307        assert_eq!(results[0].1, 2);
1308        assert_eq!(results[0].2, 30);
1309        assert!((results[0].3 - 15.0).abs() < 0.001);
1310
1311        // Group 2: COUNT=3, SUM=120, AVG=40.0
1312        assert_eq!(results[1].0, 2);
1313        assert_eq!(results[1].1, 3);
1314        assert_eq!(results[1].2, 120);
1315        assert!((results[1].3 - 40.0).abs() < 0.001);
1316    }
1317
1318    fn create_test_chunk_with_duplicates() -> DataChunk {
1319        // Create data with duplicate values in column 1
1320        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1321        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1322        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1323        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1324
1325        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1326        for (group, value) in data {
1327            builder.column_mut(0).unwrap().push_int64(group);
1328            builder.column_mut(1).unwrap().push_int64(value);
1329            builder.advance_row();
1330        }
1331
1332        builder.finish()
1333    }
1334
1335    #[test]
1336    fn test_count_distinct() {
1337        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1338
1339        // COUNT(DISTINCT column 1)
1340        let mut agg = SimpleAggregateOperator::new(
1341            Box::new(mock),
1342            vec![AggregateExpr::count(1).with_distinct()],
1343            vec![LogicalType::Int64],
1344        );
1345
1346        let result = agg.next().unwrap().unwrap();
1347        assert_eq!(result.row_count(), 1);
1348        // Total distinct values: 10, 20, 30 = 3 distinct values
1349        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1350    }
1351
1352    #[test]
1353    fn test_grouped_count_distinct() {
1354        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1355
1356        // GROUP BY column 0, COUNT(DISTINCT column 1)
1357        let mut agg = HashAggregateOperator::new(
1358            Box::new(mock),
1359            vec![0],
1360            vec![AggregateExpr::count(1).with_distinct()],
1361            vec![LogicalType::Int64, LogicalType::Int64],
1362        );
1363
1364        let mut results: Vec<(i64, i64)> = Vec::new();
1365        while let Some(chunk) = agg.next().unwrap() {
1366            for row in chunk.selected_indices() {
1367                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1368                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1369                results.push((group, count));
1370            }
1371        }
1372
1373        results.sort_by_key(|(g, _)| *g);
1374        assert_eq!(results.len(), 2);
1375        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1376        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1377    }
1378
1379    #[test]
1380    fn test_sum_distinct() {
1381        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1382
1383        // SUM(DISTINCT column 1)
1384        let mut agg = SimpleAggregateOperator::new(
1385            Box::new(mock),
1386            vec![AggregateExpr::sum(1).with_distinct()],
1387            vec![LogicalType::Int64],
1388        );
1389
1390        let result = agg.next().unwrap().unwrap();
1391        assert_eq!(result.row_count(), 1);
1392        // Sum of distinct values: 10 + 20 + 30 = 60
1393        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1394    }
1395
1396    #[test]
1397    fn test_avg_distinct() {
1398        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1399
1400        // AVG(DISTINCT column 1)
1401        let mut agg = SimpleAggregateOperator::new(
1402            Box::new(mock),
1403            vec![AggregateExpr::avg(1).with_distinct()],
1404            vec![LogicalType::Float64],
1405        );
1406
1407        let result = agg.next().unwrap().unwrap();
1408        assert_eq!(result.row_count(), 1);
1409        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1410        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1411        assert!((avg - 20.0).abs() < 0.001);
1412    }
1413
1414    fn create_statistical_test_chunk() -> DataChunk {
1415        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1416        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1417        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1418
1419        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1420            builder.column_mut(0).unwrap().push_int64(value);
1421            builder.advance_row();
1422        }
1423
1424        builder.finish()
1425    }
1426
1427    #[test]
1428    fn test_stdev_sample() {
1429        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1430
1431        let mut agg = SimpleAggregateOperator::new(
1432            Box::new(mock),
1433            vec![AggregateExpr::stdev(0)],
1434            vec![LogicalType::Float64],
1435        );
1436
1437        let result = agg.next().unwrap().unwrap();
1438        assert_eq!(result.row_count(), 1);
1439        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1440        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1441        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1442        assert!((stdev - 2.138).abs() < 0.01);
1443    }
1444
1445    #[test]
1446    fn test_stdev_population() {
1447        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1448
1449        let mut agg = SimpleAggregateOperator::new(
1450            Box::new(mock),
1451            vec![AggregateExpr::stdev_pop(0)],
1452            vec![LogicalType::Float64],
1453        );
1454
1455        let result = agg.next().unwrap().unwrap();
1456        assert_eq!(result.row_count(), 1);
1457        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1458        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1459        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1460        assert!((stdev - 2.0).abs() < 0.01);
1461    }
1462
1463    #[test]
1464    fn test_percentile_disc() {
1465        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1466
1467        // Median (50th percentile discrete)
1468        let mut agg = SimpleAggregateOperator::new(
1469            Box::new(mock),
1470            vec![AggregateExpr::percentile_disc(0, 0.5)],
1471            vec![LogicalType::Float64],
1472        );
1473
1474        let result = agg.next().unwrap().unwrap();
1475        assert_eq!(result.row_count(), 1);
1476        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1477        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1478        assert!((percentile - 4.0).abs() < 0.01);
1479    }
1480
1481    #[test]
1482    fn test_percentile_cont() {
1483        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1484
1485        // Median (50th percentile continuous)
1486        let mut agg = SimpleAggregateOperator::new(
1487            Box::new(mock),
1488            vec![AggregateExpr::percentile_cont(0, 0.5)],
1489            vec![LogicalType::Float64],
1490        );
1491
1492        let result = agg.next().unwrap().unwrap();
1493        assert_eq!(result.row_count(), 1);
1494        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1495        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1496        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1497        assert!((percentile - 4.5).abs() < 0.01);
1498    }
1499
1500    #[test]
1501    fn test_percentile_extremes() {
1502        // Test 0th and 100th percentiles
1503        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1504
1505        let mut agg = SimpleAggregateOperator::new(
1506            Box::new(mock),
1507            vec![
1508                AggregateExpr::percentile_disc(0, 0.0),
1509                AggregateExpr::percentile_disc(0, 1.0),
1510            ],
1511            vec![LogicalType::Float64, LogicalType::Float64],
1512        );
1513
1514        let result = agg.next().unwrap().unwrap();
1515        assert_eq!(result.row_count(), 1);
1516        // 0th percentile = minimum = 2
1517        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1518        assert!((p0 - 2.0).abs() < 0.01);
1519        // 100th percentile = maximum = 9
1520        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1521        assert!((p100 - 9.0).abs() < 0.01);
1522    }
1523
1524    #[test]
1525    fn test_stdev_single_value() {
1526        // Single value should return null for sample stdev
1527        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1528        builder.column_mut(0).unwrap().push_int64(42);
1529        builder.advance_row();
1530        let chunk = builder.finish();
1531
1532        let mock = MockOperator::new(vec![chunk]);
1533
1534        let mut agg = SimpleAggregateOperator::new(
1535            Box::new(mock),
1536            vec![AggregateExpr::stdev(0)],
1537            vec![LogicalType::Float64],
1538        );
1539
1540        let result = agg.next().unwrap().unwrap();
1541        assert_eq!(result.row_count(), 1);
1542        // Sample stdev of single value is undefined (null)
1543        assert!(matches!(
1544            result.column(0).unwrap().get_value(0),
1545            Some(Value::Null)
1546        ));
1547    }
1548
1549    #[test]
1550    fn test_first_and_last() {
1551        let mock = MockOperator::new(vec![create_test_chunk()]);
1552
1553        let mut agg = SimpleAggregateOperator::new(
1554            Box::new(mock),
1555            vec![AggregateExpr::first(1), AggregateExpr::last(1)],
1556            vec![LogicalType::Int64, LogicalType::Int64],
1557        );
1558
1559        let result = agg.next().unwrap().unwrap();
1560        assert_eq!(result.row_count(), 1);
1561        // First: 10, Last: 50
1562        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1563        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50));
1564    }
1565
1566    #[test]
1567    fn test_collect() {
1568        let mock = MockOperator::new(vec![create_test_chunk()]);
1569
1570        let mut agg = SimpleAggregateOperator::new(
1571            Box::new(mock),
1572            vec![AggregateExpr::collect(1)],
1573            vec![LogicalType::Any],
1574        );
1575
1576        let result = agg.next().unwrap().unwrap();
1577        let val = result.column(0).unwrap().get_value(0).unwrap();
1578        if let Value::List(items) = val {
1579            assert_eq!(items.len(), 5);
1580        } else {
1581            panic!("Expected List value");
1582        }
1583    }
1584
1585    #[test]
1586    fn test_collect_distinct() {
1587        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1588
1589        let mut agg = SimpleAggregateOperator::new(
1590            Box::new(mock),
1591            vec![AggregateExpr::collect(1).with_distinct()],
1592            vec![LogicalType::Any],
1593        );
1594
1595        let result = agg.next().unwrap().unwrap();
1596        let val = result.column(0).unwrap().get_value(0).unwrap();
1597        if let Value::List(items) = val {
1598            // [10, 10, 20, 30, 30, 30] -> distinct: [10, 20, 30]
1599            assert_eq!(items.len(), 3);
1600        } else {
1601            panic!("Expected List value");
1602        }
1603    }
1604
1605    #[test]
1606    fn test_group_concat() {
1607        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1608        for s in ["hello", "world", "foo"] {
1609            builder.column_mut(0).unwrap().push_string(s);
1610            builder.advance_row();
1611        }
1612        let chunk = builder.finish();
1613        let mock = MockOperator::new(vec![chunk]);
1614
1615        let agg_expr = AggregateExpr {
1616            function: AggregateFunction::GroupConcat,
1617            column: Some(0),
1618            column2: None,
1619            distinct: false,
1620            alias: None,
1621            percentile: None,
1622            separator: None,
1623        };
1624
1625        let mut agg =
1626            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::String]);
1627
1628        let result = agg.next().unwrap().unwrap();
1629        let val = result.column(0).unwrap().get_value(0).unwrap();
1630        assert_eq!(val, Value::String("hello world foo".into()));
1631    }
1632
1633    #[test]
1634    fn test_sample() {
1635        let mock = MockOperator::new(vec![create_test_chunk()]);
1636
1637        let agg_expr = AggregateExpr {
1638            function: AggregateFunction::Sample,
1639            column: Some(1),
1640            column2: None,
1641            distinct: false,
1642            alias: None,
1643            percentile: None,
1644            separator: None,
1645        };
1646
1647        let mut agg =
1648            SimpleAggregateOperator::new(Box::new(mock), vec![agg_expr], vec![LogicalType::Int64]);
1649
1650        let result = agg.next().unwrap().unwrap();
1651        // Sample should return the first non-null value (10)
1652        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10));
1653    }
1654
1655    #[test]
1656    fn test_variance_sample() {
1657        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1658
1659        let agg_expr = AggregateExpr {
1660            function: AggregateFunction::Variance,
1661            column: Some(0),
1662            column2: None,
1663            distinct: false,
1664            alias: None,
1665            percentile: None,
1666            separator: None,
1667        };
1668
1669        let mut agg = SimpleAggregateOperator::new(
1670            Box::new(mock),
1671            vec![agg_expr],
1672            vec![LogicalType::Float64],
1673        );
1674
1675        let result = agg.next().unwrap().unwrap();
1676        // Sample variance of [2, 4, 4, 4, 5, 5, 7, 9]: M2/(n-1) = 32/7 = 4.571
1677        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1678        assert!((variance - 32.0 / 7.0).abs() < 0.01);
1679    }
1680
1681    #[test]
1682    fn test_variance_population() {
1683        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1684
1685        let agg_expr = AggregateExpr {
1686            function: AggregateFunction::VariancePop,
1687            column: Some(0),
1688            column2: None,
1689            distinct: false,
1690            alias: None,
1691            percentile: None,
1692            separator: None,
1693        };
1694
1695        let mut agg = SimpleAggregateOperator::new(
1696            Box::new(mock),
1697            vec![agg_expr],
1698            vec![LogicalType::Float64],
1699        );
1700
1701        let result = agg.next().unwrap().unwrap();
1702        // Population variance: M2/n = 32/8 = 4.0
1703        let variance = result.column(0).unwrap().get_float64(0).unwrap();
1704        assert!((variance - 4.0).abs() < 0.01);
1705    }
1706
1707    #[test]
1708    fn test_variance_single_value() {
1709        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1710        builder.column_mut(0).unwrap().push_int64(42);
1711        builder.advance_row();
1712        let chunk = builder.finish();
1713        let mock = MockOperator::new(vec![chunk]);
1714
1715        let agg_expr = AggregateExpr {
1716            function: AggregateFunction::Variance,
1717            column: Some(0),
1718            column2: None,
1719            distinct: false,
1720            alias: None,
1721            percentile: None,
1722            separator: None,
1723        };
1724
1725        let mut agg = SimpleAggregateOperator::new(
1726            Box::new(mock),
1727            vec![agg_expr],
1728            vec![LogicalType::Float64],
1729        );
1730
1731        let result = agg.next().unwrap().unwrap();
1732        // Sample variance of single value is undefined (null)
1733        assert!(matches!(
1734            result.column(0).unwrap().get_value(0),
1735            Some(Value::Null)
1736        ));
1737    }
1738
1739    #[test]
1740    fn test_empty_aggregation() {
1741        // No input rows: COUNT should be 0, SUM/AVG/MIN/MAX should be NULL
1742        // (ISO/IEC 39075 Section 20.9)
1743        let mock = MockOperator::new(vec![]);
1744
1745        let mut agg = SimpleAggregateOperator::new(
1746            Box::new(mock),
1747            vec![
1748                AggregateExpr::count_star(),
1749                AggregateExpr::sum(0),
1750                AggregateExpr::avg(0),
1751                AggregateExpr::min(0),
1752                AggregateExpr::max(0),
1753            ],
1754            vec![
1755                LogicalType::Int64,
1756                LogicalType::Int64,
1757                LogicalType::Float64,
1758                LogicalType::Int64,
1759                LogicalType::Int64,
1760            ],
1761        );
1762
1763        let result = agg.next().unwrap().unwrap();
1764        assert_eq!(result.column(0).unwrap().get_int64(0), Some(0)); // COUNT
1765        assert!(matches!(
1766            result.column(1).unwrap().get_value(0),
1767            Some(Value::Null)
1768        )); // SUM
1769        assert!(matches!(
1770            result.column(2).unwrap().get_value(0),
1771            Some(Value::Null)
1772        )); // AVG
1773        assert!(matches!(
1774            result.column(3).unwrap().get_value(0),
1775            Some(Value::Null)
1776        )); // MIN
1777        assert!(matches!(
1778            result.column(4).unwrap().get_value(0),
1779            Some(Value::Null)
1780        )); // MAX
1781    }
1782
1783    #[test]
1784    fn test_stdev_pop_single_value() {
1785        // Single value should return 0 for population stdev
1786        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1787        builder.column_mut(0).unwrap().push_int64(42);
1788        builder.advance_row();
1789        let chunk = builder.finish();
1790
1791        let mock = MockOperator::new(vec![chunk]);
1792
1793        let mut agg = SimpleAggregateOperator::new(
1794            Box::new(mock),
1795            vec![AggregateExpr::stdev_pop(0)],
1796            vec![LogicalType::Float64],
1797        );
1798
1799        let result = agg.next().unwrap().unwrap();
1800        assert_eq!(result.row_count(), 1);
1801        // Population stdev of single value is 0
1802        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1803        assert!((stdev - 0.0).abs() < 0.01);
1804    }
1805}