Skip to main content

grafeo_core/execution/operators/
aggregate.rs

1//! Aggregation operators for GROUP BY and aggregation functions.
2//!
3//! This module provides:
4//! - `HashAggregateOperator`: Hash-based grouping with aggregation functions
5//! - Various aggregation functions: COUNT, SUM, AVG, MIN, MAX, etc.
6
7use indexmap::IndexMap;
8use std::collections::HashSet;
9
10use grafeo_common::types::{LogicalType, Value};
11
12/// A wrapper for Value that can be hashed (for DISTINCT tracking).
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
14enum HashableValue {
15    Null,
16    Bool(bool),
17    Int64(i64),
18    Float64Bits(u64),
19    String(String),
20    Other(String),
21}
22
23impl From<&Value> for HashableValue {
24    fn from(v: &Value) -> Self {
25        match v {
26            Value::Null => HashableValue::Null,
27            Value::Bool(b) => HashableValue::Bool(*b),
28            Value::Int64(i) => HashableValue::Int64(*i),
29            Value::Float64(f) => HashableValue::Float64Bits(f.to_bits()),
30            Value::String(s) => HashableValue::String(s.to_string()),
31            other => HashableValue::Other(format!("{other:?}")),
32        }
33    }
34}
35
36impl From<Value> for HashableValue {
37    fn from(v: Value) -> Self {
38        Self::from(&v)
39    }
40}
41
42use super::{Operator, OperatorError, OperatorResult};
43use crate::execution::DataChunk;
44use crate::execution::chunk::DataChunkBuilder;
45
46/// Aggregation function types.
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum AggregateFunction {
49    /// Count of rows (COUNT(*)).
50    Count,
51    /// Count of non-null values (COUNT(column)).
52    CountNonNull,
53    /// Sum of values.
54    Sum,
55    /// Average of values.
56    Avg,
57    /// Minimum value.
58    Min,
59    /// Maximum value.
60    Max,
61    /// First value in the group.
62    First,
63    /// Last value in the group.
64    Last,
65    /// Collect values into a list.
66    Collect,
67    /// Sample standard deviation (STDEV).
68    StdDev,
69    /// Population standard deviation (STDEVP).
70    StdDevPop,
71    /// Discrete percentile (PERCENTILE_DISC).
72    PercentileDisc,
73    /// Continuous percentile (PERCENTILE_CONT).
74    PercentileCont,
75}
76
77/// An aggregation expression.
78#[derive(Debug, Clone)]
79pub struct AggregateExpr {
80    /// The aggregation function.
81    pub function: AggregateFunction,
82    /// Column index to aggregate (None for COUNT(*)).
83    pub column: Option<usize>,
84    /// Whether to aggregate distinct values only.
85    pub distinct: bool,
86    /// Output alias (for naming the result column).
87    pub alias: Option<String>,
88    /// Percentile parameter for PERCENTILE_DISC/PERCENTILE_CONT (0.0 to 1.0).
89    pub percentile: Option<f64>,
90}
91
92impl AggregateExpr {
93    /// Creates a COUNT(*) expression.
94    pub fn count_star() -> Self {
95        Self {
96            function: AggregateFunction::Count,
97            column: None,
98            distinct: false,
99            alias: None,
100            percentile: None,
101        }
102    }
103
104    /// Creates a COUNT(column) expression.
105    pub fn count(column: usize) -> Self {
106        Self {
107            function: AggregateFunction::CountNonNull,
108            column: Some(column),
109            distinct: false,
110            alias: None,
111            percentile: None,
112        }
113    }
114
115    /// Creates a SUM(column) expression.
116    pub fn sum(column: usize) -> Self {
117        Self {
118            function: AggregateFunction::Sum,
119            column: Some(column),
120            distinct: false,
121            alias: None,
122            percentile: None,
123        }
124    }
125
126    /// Creates an AVG(column) expression.
127    pub fn avg(column: usize) -> Self {
128        Self {
129            function: AggregateFunction::Avg,
130            column: Some(column),
131            distinct: false,
132            alias: None,
133            percentile: None,
134        }
135    }
136
137    /// Creates a MIN(column) expression.
138    pub fn min(column: usize) -> Self {
139        Self {
140            function: AggregateFunction::Min,
141            column: Some(column),
142            distinct: false,
143            alias: None,
144            percentile: None,
145        }
146    }
147
148    /// Creates a MAX(column) expression.
149    pub fn max(column: usize) -> Self {
150        Self {
151            function: AggregateFunction::Max,
152            column: Some(column),
153            distinct: false,
154            alias: None,
155            percentile: None,
156        }
157    }
158
159    /// Creates a FIRST(column) expression.
160    pub fn first(column: usize) -> Self {
161        Self {
162            function: AggregateFunction::First,
163            column: Some(column),
164            distinct: false,
165            alias: None,
166            percentile: None,
167        }
168    }
169
170    /// Creates a LAST(column) expression.
171    pub fn last(column: usize) -> Self {
172        Self {
173            function: AggregateFunction::Last,
174            column: Some(column),
175            distinct: false,
176            alias: None,
177            percentile: None,
178        }
179    }
180
181    /// Creates a COLLECT(column) expression.
182    pub fn collect(column: usize) -> Self {
183        Self {
184            function: AggregateFunction::Collect,
185            column: Some(column),
186            distinct: false,
187            alias: None,
188            percentile: None,
189        }
190    }
191
192    /// Creates a STDEV(column) expression (sample standard deviation).
193    pub fn stdev(column: usize) -> Self {
194        Self {
195            function: AggregateFunction::StdDev,
196            column: Some(column),
197            distinct: false,
198            alias: None,
199            percentile: None,
200        }
201    }
202
203    /// Creates a STDEVP(column) expression (population standard deviation).
204    pub fn stdev_pop(column: usize) -> Self {
205        Self {
206            function: AggregateFunction::StdDevPop,
207            column: Some(column),
208            distinct: false,
209            alias: None,
210            percentile: None,
211        }
212    }
213
214    /// Creates a PERCENTILE_DISC(column, percentile) expression.
215    ///
216    /// # Arguments
217    /// * `column` - Column index to aggregate
218    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
219    pub fn percentile_disc(column: usize, percentile: f64) -> Self {
220        Self {
221            function: AggregateFunction::PercentileDisc,
222            column: Some(column),
223            distinct: false,
224            alias: None,
225            percentile: Some(percentile.clamp(0.0, 1.0)),
226        }
227    }
228
229    /// Creates a PERCENTILE_CONT(column, percentile) expression.
230    ///
231    /// # Arguments
232    /// * `column` - Column index to aggregate
233    /// * `percentile` - Percentile value between 0.0 and 1.0 (e.g., 0.5 for median)
234    pub fn percentile_cont(column: usize, percentile: f64) -> Self {
235        Self {
236            function: AggregateFunction::PercentileCont,
237            column: Some(column),
238            distinct: false,
239            alias: None,
240            percentile: Some(percentile.clamp(0.0, 1.0)),
241        }
242    }
243
244    /// Sets the distinct flag.
245    pub fn with_distinct(mut self) -> Self {
246        self.distinct = true;
247        self
248    }
249
250    /// Sets the output alias.
251    pub fn with_alias(mut self, alias: impl Into<String>) -> Self {
252        self.alias = Some(alias.into());
253        self
254    }
255}
256
257/// State for a single aggregation computation.
258#[derive(Debug, Clone)]
259enum AggregateState {
260    /// Count state.
261    Count(i64),
262    /// Count distinct state (count, seen values).
263    CountDistinct(i64, HashSet<HashableValue>),
264    /// Sum state (integer).
265    SumInt(i64),
266    /// Sum distinct state (integer, seen values).
267    SumIntDistinct(i64, HashSet<HashableValue>),
268    /// Sum state (float).
269    SumFloat(f64),
270    /// Sum distinct state (float, seen values).
271    SumFloatDistinct(f64, HashSet<HashableValue>),
272    /// Average state (sum, count).
273    Avg(f64, i64),
274    /// Average distinct state (sum, count, seen values).
275    AvgDistinct(f64, i64, HashSet<HashableValue>),
276    /// Min state.
277    Min(Option<Value>),
278    /// Max state.
279    Max(Option<Value>),
280    /// First state.
281    First(Option<Value>),
282    /// Last state.
283    Last(Option<Value>),
284    /// Collect state.
285    Collect(Vec<Value>),
286    /// Collect distinct state (values, seen).
287    CollectDistinct(Vec<Value>, HashSet<HashableValue>),
288    /// Sample standard deviation state using Welford's algorithm (count, mean, M2).
289    StdDev { count: i64, mean: f64, m2: f64 },
290    /// Population standard deviation state using Welford's algorithm (count, mean, M2).
291    StdDevPop { count: i64, mean: f64, m2: f64 },
292    /// Discrete percentile state (values, percentile).
293    PercentileDisc { values: Vec<f64>, percentile: f64 },
294    /// Continuous percentile state (values, percentile).
295    PercentileCont { values: Vec<f64>, percentile: f64 },
296}
297
298impl AggregateState {
299    /// Creates initial state for an aggregation function.
300    fn new(function: AggregateFunction, distinct: bool, percentile: Option<f64>) -> Self {
301        match (function, distinct) {
302            (AggregateFunction::Count | AggregateFunction::CountNonNull, false) => {
303                AggregateState::Count(0)
304            }
305            (AggregateFunction::Count | AggregateFunction::CountNonNull, true) => {
306                AggregateState::CountDistinct(0, HashSet::new())
307            }
308            (AggregateFunction::Sum, false) => AggregateState::SumInt(0),
309            (AggregateFunction::Sum, true) => AggregateState::SumIntDistinct(0, HashSet::new()),
310            (AggregateFunction::Avg, false) => AggregateState::Avg(0.0, 0),
311            (AggregateFunction::Avg, true) => AggregateState::AvgDistinct(0.0, 0, HashSet::new()),
312            (AggregateFunction::Min, _) => AggregateState::Min(None), // MIN/MAX don't need distinct
313            (AggregateFunction::Max, _) => AggregateState::Max(None),
314            (AggregateFunction::First, _) => AggregateState::First(None),
315            (AggregateFunction::Last, _) => AggregateState::Last(None),
316            (AggregateFunction::Collect, false) => AggregateState::Collect(Vec::new()),
317            (AggregateFunction::Collect, true) => {
318                AggregateState::CollectDistinct(Vec::new(), HashSet::new())
319            }
320            // Statistical functions (Welford's algorithm for online computation)
321            (AggregateFunction::StdDev, _) => AggregateState::StdDev {
322                count: 0,
323                mean: 0.0,
324                m2: 0.0,
325            },
326            (AggregateFunction::StdDevPop, _) => AggregateState::StdDevPop {
327                count: 0,
328                mean: 0.0,
329                m2: 0.0,
330            },
331            (AggregateFunction::PercentileDisc, _) => AggregateState::PercentileDisc {
332                values: Vec::new(),
333                percentile: percentile.unwrap_or(0.5),
334            },
335            (AggregateFunction::PercentileCont, _) => AggregateState::PercentileCont {
336                values: Vec::new(),
337                percentile: percentile.unwrap_or(0.5),
338            },
339        }
340    }
341
342    /// Updates the state with a new value.
343    fn update(&mut self, value: Option<Value>) {
344        match self {
345            AggregateState::Count(count) => {
346                *count += 1;
347            }
348            AggregateState::CountDistinct(count, seen) => {
349                if let Some(ref v) = value {
350                    let hashable = HashableValue::from(v);
351                    if seen.insert(hashable) {
352                        *count += 1;
353                    }
354                }
355            }
356            AggregateState::SumInt(sum) => {
357                if let Some(Value::Int64(v)) = value {
358                    *sum += v;
359                } else if let Some(Value::Float64(v)) = value {
360                    // Convert to float sum
361                    *self = AggregateState::SumFloat(*sum as f64 + v);
362                } else if let Some(ref v) = value {
363                    // RDF stores numeric literals as strings - try to parse
364                    if let Some(num) = value_to_f64(v) {
365                        *self = AggregateState::SumFloat(*sum as f64 + num);
366                    }
367                }
368            }
369            AggregateState::SumIntDistinct(sum, seen) => {
370                if let Some(ref v) = value {
371                    let hashable = HashableValue::from(v);
372                    if seen.insert(hashable) {
373                        if let Value::Int64(i) = v {
374                            *sum += i;
375                        } else if let Value::Float64(f) = v {
376                            // Convert to float distinct
377                            let seen_clone = seen.clone();
378                            *self = AggregateState::SumFloatDistinct(*sum as f64 + f, seen_clone);
379                        } else if let Some(num) = value_to_f64(v) {
380                            // RDF string-encoded numerics
381                            let seen_clone = seen.clone();
382                            *self = AggregateState::SumFloatDistinct(*sum as f64 + num, seen_clone);
383                        }
384                    }
385                }
386            }
387            AggregateState::SumFloat(sum) => {
388                if let Some(ref v) = value {
389                    // Use value_to_f64 which now handles strings
390                    if let Some(num) = value_to_f64(v) {
391                        *sum += num;
392                    }
393                }
394            }
395            AggregateState::SumFloatDistinct(sum, seen) => {
396                if let Some(ref v) = value {
397                    let hashable = HashableValue::from(v);
398                    if seen.insert(hashable) {
399                        if let Some(num) = value_to_f64(v) {
400                            *sum += num;
401                        }
402                    }
403                }
404            }
405            AggregateState::Avg(sum, count) => {
406                if let Some(ref v) = value {
407                    if let Some(num) = value_to_f64(v) {
408                        *sum += num;
409                        *count += 1;
410                    }
411                }
412            }
413            AggregateState::AvgDistinct(sum, count, seen) => {
414                if let Some(ref v) = value {
415                    let hashable = HashableValue::from(v);
416                    if seen.insert(hashable) {
417                        if let Some(num) = value_to_f64(v) {
418                            *sum += num;
419                            *count += 1;
420                        }
421                    }
422                }
423            }
424            AggregateState::Min(min) => {
425                if let Some(v) = value {
426                    match min {
427                        None => *min = Some(v),
428                        Some(current) => {
429                            if compare_values(&v, current) == Some(std::cmp::Ordering::Less) {
430                                *min = Some(v);
431                            }
432                        }
433                    }
434                }
435            }
436            AggregateState::Max(max) => {
437                if let Some(v) = value {
438                    match max {
439                        None => *max = Some(v),
440                        Some(current) => {
441                            if compare_values(&v, current) == Some(std::cmp::Ordering::Greater) {
442                                *max = Some(v);
443                            }
444                        }
445                    }
446                }
447            }
448            AggregateState::First(first) => {
449                if first.is_none() {
450                    *first = value;
451                }
452            }
453            AggregateState::Last(last) => {
454                if value.is_some() {
455                    *last = value;
456                }
457            }
458            AggregateState::Collect(list) => {
459                if let Some(v) = value {
460                    list.push(v);
461                }
462            }
463            AggregateState::CollectDistinct(list, seen) => {
464                if let Some(v) = value {
465                    let hashable = HashableValue::from(&v);
466                    if seen.insert(hashable) {
467                        list.push(v);
468                    }
469                }
470            }
471            // Statistical functions using Welford's online algorithm
472            AggregateState::StdDev { count, mean, m2 }
473            | AggregateState::StdDevPop { count, mean, m2 } => {
474                if let Some(ref v) = value {
475                    if let Some(x) = value_to_f64(v) {
476                        *count += 1;
477                        let delta = x - *mean;
478                        *mean += delta / *count as f64;
479                        let delta2 = x - *mean;
480                        *m2 += delta * delta2;
481                    }
482                }
483            }
484            AggregateState::PercentileDisc { values, .. }
485            | AggregateState::PercentileCont { values, .. } => {
486                if let Some(ref v) = value {
487                    if let Some(x) = value_to_f64(v) {
488                        values.push(x);
489                    }
490                }
491            }
492        }
493    }
494
495    /// Finalizes the state and returns the result value.
496    fn finalize(&self) -> Value {
497        match self {
498            AggregateState::Count(count) | AggregateState::CountDistinct(count, _) => {
499                Value::Int64(*count)
500            }
501            AggregateState::SumInt(sum) | AggregateState::SumIntDistinct(sum, _) => {
502                Value::Int64(*sum)
503            }
504            AggregateState::SumFloat(sum) | AggregateState::SumFloatDistinct(sum, _) => {
505                Value::Float64(*sum)
506            }
507            AggregateState::Avg(sum, count) | AggregateState::AvgDistinct(sum, count, _) => {
508                if *count == 0 {
509                    Value::Null
510                } else {
511                    Value::Float64(*sum / *count as f64)
512                }
513            }
514            AggregateState::Min(min) => min.clone().unwrap_or(Value::Null),
515            AggregateState::Max(max) => max.clone().unwrap_or(Value::Null),
516            AggregateState::First(first) => first.clone().unwrap_or(Value::Null),
517            AggregateState::Last(last) => last.clone().unwrap_or(Value::Null),
518            AggregateState::Collect(list) | AggregateState::CollectDistinct(list, _) => {
519                Value::List(list.clone().into())
520            }
521            // Sample standard deviation: sqrt(M2 / (n - 1))
522            AggregateState::StdDev { count, m2, .. } => {
523                if *count < 2 {
524                    Value::Null
525                } else {
526                    Value::Float64((*m2 / (*count - 1) as f64).sqrt())
527                }
528            }
529            // Population standard deviation: sqrt(M2 / n)
530            AggregateState::StdDevPop { count, m2, .. } => {
531                if *count == 0 {
532                    Value::Null
533                } else {
534                    Value::Float64((*m2 / *count as f64).sqrt())
535                }
536            }
537            // Discrete percentile: return actual value at percentile position
538            AggregateState::PercentileDisc { values, percentile } => {
539                if values.is_empty() {
540                    Value::Null
541                } else {
542                    let mut sorted = values.clone();
543                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
544                    // Index calculation per SQL standard: floor(p * (n - 1))
545                    let index = (percentile * (sorted.len() - 1) as f64).floor() as usize;
546                    Value::Float64(sorted[index])
547                }
548            }
549            // Continuous percentile: interpolate between values
550            AggregateState::PercentileCont { values, percentile } => {
551                if values.is_empty() {
552                    Value::Null
553                } else {
554                    let mut sorted = values.clone();
555                    sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
556                    // Linear interpolation per SQL standard
557                    let rank = percentile * (sorted.len() - 1) as f64;
558                    let lower_idx = rank.floor() as usize;
559                    let upper_idx = rank.ceil() as usize;
560                    if lower_idx == upper_idx {
561                        Value::Float64(sorted[lower_idx])
562                    } else {
563                        let fraction = rank - lower_idx as f64;
564                        let result =
565                            sorted[lower_idx] + fraction * (sorted[upper_idx] - sorted[lower_idx]);
566                        Value::Float64(result)
567                    }
568                }
569            }
570        }
571    }
572}
573
574/// Convert a value to f64 for numeric aggregations.
575/// Supports RDF values stored as strings by attempting numeric parsing.
576fn value_to_f64(value: &Value) -> Option<f64> {
577    match value {
578        Value::Int64(i) => Some(*i as f64),
579        Value::Float64(f) => Some(*f),
580        // RDF stores numeric literals as strings - try to parse them
581        Value::String(s) => s.parse::<f64>().ok(),
582        _ => None,
583    }
584}
585
586/// Compare two values.
587/// Supports RDF values stored as strings by attempting numeric parsing.
588fn compare_values(a: &Value, b: &Value) -> Option<std::cmp::Ordering> {
589    match (a, b) {
590        (Value::Int64(a), Value::Int64(b)) => Some(a.cmp(b)),
591        (Value::Float64(a), Value::Float64(b)) => a.partial_cmp(b),
592        (Value::String(a), Value::String(b)) => {
593            // Try numeric comparison first if both look like numbers
594            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
595                a_num.partial_cmp(&b_num)
596            } else {
597                Some(a.cmp(b))
598            }
599        }
600        (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
601        (Value::Int64(a), Value::Float64(b)) => (*a as f64).partial_cmp(b),
602        (Value::Float64(a), Value::Int64(b)) => a.partial_cmp(&(*b as f64)),
603        // String-to-numeric comparisons for RDF
604        (Value::String(s), Value::Int64(i)) => s.parse::<f64>().ok()?.partial_cmp(&(*i as f64)),
605        (Value::String(s), Value::Float64(f)) => s.parse::<f64>().ok()?.partial_cmp(f),
606        (Value::Int64(i), Value::String(s)) => (*i as f64).partial_cmp(&s.parse::<f64>().ok()?),
607        (Value::Float64(f), Value::String(s)) => f.partial_cmp(&s.parse::<f64>().ok()?),
608        _ => None,
609    }
610}
611
612/// A group key for hash-based aggregation.
613#[derive(Debug, Clone, PartialEq, Eq, Hash)]
614pub struct GroupKey(Vec<GroupKeyPart>);
615
616#[derive(Debug, Clone, PartialEq, Eq, Hash)]
617enum GroupKeyPart {
618    Null,
619    Bool(bool),
620    Int64(i64),
621    String(String),
622}
623
624impl GroupKey {
625    /// Creates a group key from column values.
626    fn from_row(chunk: &DataChunk, row: usize, group_columns: &[usize]) -> Self {
627        let parts: Vec<GroupKeyPart> = group_columns
628            .iter()
629            .map(|&col_idx| {
630                chunk
631                    .column(col_idx)
632                    .and_then(|col| col.get_value(row))
633                    .map(|v| match v {
634                        Value::Null => GroupKeyPart::Null,
635                        Value::Bool(b) => GroupKeyPart::Bool(b),
636                        Value::Int64(i) => GroupKeyPart::Int64(i),
637                        Value::Float64(f) => GroupKeyPart::Int64(f.to_bits() as i64),
638                        Value::String(s) => GroupKeyPart::String(s.to_string()),
639                        _ => GroupKeyPart::String(format!("{v:?}")),
640                    })
641                    .unwrap_or(GroupKeyPart::Null)
642            })
643            .collect();
644        GroupKey(parts)
645    }
646
647    /// Converts the group key back to values.
648    fn to_values(&self) -> Vec<Value> {
649        self.0
650            .iter()
651            .map(|part| match part {
652                GroupKeyPart::Null => Value::Null,
653                GroupKeyPart::Bool(b) => Value::Bool(*b),
654                GroupKeyPart::Int64(i) => Value::Int64(*i),
655                GroupKeyPart::String(s) => Value::String(s.clone().into()),
656            })
657            .collect()
658    }
659}
660
661/// Hash-based aggregate operator.
662///
663/// Groups input by key columns and computes aggregations for each group.
664pub struct HashAggregateOperator {
665    /// Child operator to read from.
666    child: Box<dyn Operator>,
667    /// Columns to group by.
668    group_columns: Vec<usize>,
669    /// Aggregation expressions.
670    aggregates: Vec<AggregateExpr>,
671    /// Output schema.
672    output_schema: Vec<LogicalType>,
673    /// Ordered map: group key -> aggregate states (IndexMap for deterministic iteration order).
674    groups: IndexMap<GroupKey, Vec<AggregateState>>,
675    /// Whether aggregation is complete.
676    aggregation_complete: bool,
677    /// Results iterator.
678    results: Option<std::vec::IntoIter<(GroupKey, Vec<AggregateState>)>>,
679}
680
681impl HashAggregateOperator {
682    /// Creates a new hash aggregate operator.
683    ///
684    /// # Arguments
685    /// * `child` - Child operator to read from.
686    /// * `group_columns` - Column indices to group by.
687    /// * `aggregates` - Aggregation expressions.
688    /// * `output_schema` - Schema of the output (group columns + aggregate results).
689    pub fn new(
690        child: Box<dyn Operator>,
691        group_columns: Vec<usize>,
692        aggregates: Vec<AggregateExpr>,
693        output_schema: Vec<LogicalType>,
694    ) -> Self {
695        Self {
696            child,
697            group_columns,
698            aggregates,
699            output_schema,
700            groups: IndexMap::new(),
701            aggregation_complete: false,
702            results: None,
703        }
704    }
705
706    /// Performs the aggregation.
707    fn aggregate(&mut self) -> Result<(), OperatorError> {
708        while let Some(chunk) = self.child.next()? {
709            for row in chunk.selected_indices() {
710                let key = GroupKey::from_row(&chunk, row, &self.group_columns);
711
712                // Get or create aggregate states for this group
713                let states = self.groups.entry(key).or_insert_with(|| {
714                    self.aggregates
715                        .iter()
716                        .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
717                        .collect()
718                });
719
720                // Update each aggregate
721                for (i, agg) in self.aggregates.iter().enumerate() {
722                    let value = match (agg.function, agg.distinct) {
723                        // COUNT(*) without DISTINCT doesn't need a value
724                        (AggregateFunction::Count, false) => None,
725                        // COUNT DISTINCT needs the actual value to track unique values
726                        (AggregateFunction::Count, true) => agg
727                            .column
728                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
729                        _ => agg
730                            .column
731                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
732                    };
733
734                    // For COUNT without DISTINCT, always update. For others, skip nulls.
735                    match (agg.function, agg.distinct) {
736                        (AggregateFunction::Count, false) => states[i].update(None),
737                        (AggregateFunction::Count, true) => {
738                            // COUNT DISTINCT needs the value to track unique values
739                            if value.is_some() && !matches!(value, Some(Value::Null)) {
740                                states[i].update(value);
741                            }
742                        }
743                        (AggregateFunction::CountNonNull, _) => {
744                            if value.is_some() && !matches!(value, Some(Value::Null)) {
745                                states[i].update(value);
746                            }
747                        }
748                        _ => {
749                            if value.is_some() && !matches!(value, Some(Value::Null)) {
750                                states[i].update(value);
751                            }
752                        }
753                    }
754                }
755            }
756        }
757
758        self.aggregation_complete = true;
759
760        // Convert to results iterator (IndexMap::drain takes a range)
761        let results: Vec<_> = self.groups.drain(..).collect();
762        self.results = Some(results.into_iter());
763
764        Ok(())
765    }
766}
767
768impl Operator for HashAggregateOperator {
769    fn next(&mut self) -> OperatorResult {
770        // Perform aggregation if not done
771        if !self.aggregation_complete {
772            self.aggregate()?;
773        }
774
775        // Special case: no groups (global aggregation with no data)
776        if self.groups.is_empty() && self.results.is_none() && self.group_columns.is_empty() {
777            // For global aggregation (no GROUP BY), return one row with initial values
778            let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
779
780            for agg in &self.aggregates {
781                let state = AggregateState::new(agg.function, agg.distinct, agg.percentile);
782                let value = state.finalize();
783                if let Some(col) = builder.column_mut(self.group_columns.len()) {
784                    col.push_value(value);
785                }
786            }
787            builder.advance_row();
788
789            self.results = Some(Vec::new().into_iter()); // Mark as done
790            return Ok(Some(builder.finish()));
791        }
792
793        let results = match &mut self.results {
794            Some(r) => r,
795            None => return Ok(None),
796        };
797
798        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 2048);
799
800        for (key, states) in results.by_ref() {
801            // Output group key columns
802            let key_values = key.to_values();
803            for (i, value) in key_values.into_iter().enumerate() {
804                if let Some(col) = builder.column_mut(i) {
805                    col.push_value(value);
806                }
807            }
808
809            // Output aggregate results
810            for (i, state) in states.iter().enumerate() {
811                let col_idx = self.group_columns.len() + i;
812                if let Some(col) = builder.column_mut(col_idx) {
813                    col.push_value(state.finalize());
814                }
815            }
816
817            builder.advance_row();
818
819            if builder.is_full() {
820                return Ok(Some(builder.finish()));
821            }
822        }
823
824        if builder.row_count() > 0 {
825            Ok(Some(builder.finish()))
826        } else {
827            Ok(None)
828        }
829    }
830
831    fn reset(&mut self) {
832        self.child.reset();
833        self.groups.clear();
834        self.aggregation_complete = false;
835        self.results = None;
836    }
837
838    fn name(&self) -> &'static str {
839        "HashAggregate"
840    }
841}
842
843/// Simple (non-grouping) aggregate operator for global aggregations.
844///
845/// Used when there's no GROUP BY clause - aggregates all input into a single row.
846pub struct SimpleAggregateOperator {
847    /// Child operator.
848    child: Box<dyn Operator>,
849    /// Aggregation expressions.
850    aggregates: Vec<AggregateExpr>,
851    /// Output schema.
852    output_schema: Vec<LogicalType>,
853    /// Aggregate states.
854    states: Vec<AggregateState>,
855    /// Whether aggregation is complete.
856    done: bool,
857}
858
859impl SimpleAggregateOperator {
860    /// Creates a new simple aggregate operator.
861    pub fn new(
862        child: Box<dyn Operator>,
863        aggregates: Vec<AggregateExpr>,
864        output_schema: Vec<LogicalType>,
865    ) -> Self {
866        let states = aggregates
867            .iter()
868            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
869            .collect();
870
871        Self {
872            child,
873            aggregates,
874            output_schema,
875            states,
876            done: false,
877        }
878    }
879}
880
881impl Operator for SimpleAggregateOperator {
882    fn next(&mut self) -> OperatorResult {
883        if self.done {
884            return Ok(None);
885        }
886
887        // Process all input
888        while let Some(chunk) = self.child.next()? {
889            for row in chunk.selected_indices() {
890                for (i, agg) in self.aggregates.iter().enumerate() {
891                    let value = match (agg.function, agg.distinct) {
892                        // COUNT(*) without DISTINCT doesn't need a value
893                        (AggregateFunction::Count, false) => None,
894                        // COUNT DISTINCT needs the actual value to track unique values
895                        (AggregateFunction::Count, true) => agg
896                            .column
897                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
898                        _ => agg
899                            .column
900                            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row))),
901                    };
902
903                    match (agg.function, agg.distinct) {
904                        (AggregateFunction::Count, false) => self.states[i].update(None),
905                        (AggregateFunction::Count, true) => {
906                            // COUNT DISTINCT needs the value to track unique values
907                            if value.is_some() && !matches!(value, Some(Value::Null)) {
908                                self.states[i].update(value);
909                            }
910                        }
911                        (AggregateFunction::CountNonNull, _) => {
912                            if value.is_some() && !matches!(value, Some(Value::Null)) {
913                                self.states[i].update(value);
914                            }
915                        }
916                        _ => {
917                            if value.is_some() && !matches!(value, Some(Value::Null)) {
918                                self.states[i].update(value);
919                            }
920                        }
921                    }
922                }
923            }
924        }
925
926        // Output single result row
927        let mut builder = DataChunkBuilder::with_capacity(&self.output_schema, 1);
928
929        for (i, state) in self.states.iter().enumerate() {
930            if let Some(col) = builder.column_mut(i) {
931                col.push_value(state.finalize());
932            }
933        }
934        builder.advance_row();
935
936        self.done = true;
937        Ok(Some(builder.finish()))
938    }
939
940    fn reset(&mut self) {
941        self.child.reset();
942        self.states = self
943            .aggregates
944            .iter()
945            .map(|agg| AggregateState::new(agg.function, agg.distinct, agg.percentile))
946            .collect();
947        self.done = false;
948    }
949
950    fn name(&self) -> &'static str {
951        "SimpleAggregate"
952    }
953}
954
955#[cfg(test)]
956mod tests {
957    use super::*;
958    use crate::execution::chunk::DataChunkBuilder;
959
960    struct MockOperator {
961        chunks: Vec<DataChunk>,
962        position: usize,
963    }
964
965    impl MockOperator {
966        fn new(chunks: Vec<DataChunk>) -> Self {
967            Self {
968                chunks,
969                position: 0,
970            }
971        }
972    }
973
974    impl Operator for MockOperator {
975        fn next(&mut self) -> OperatorResult {
976            if self.position < self.chunks.len() {
977                let chunk = std::mem::replace(&mut self.chunks[self.position], DataChunk::empty());
978                self.position += 1;
979                Ok(Some(chunk))
980            } else {
981                Ok(None)
982            }
983        }
984
985        fn reset(&mut self) {
986            self.position = 0;
987        }
988
989        fn name(&self) -> &'static str {
990            "Mock"
991        }
992    }
993
994    fn create_test_chunk() -> DataChunk {
995        // Create: [(group, value)] = [(1, 10), (1, 20), (2, 30), (2, 40), (2, 50)]
996        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
997
998        let data = [(1i64, 10i64), (1, 20), (2, 30), (2, 40), (2, 50)];
999        for (group, value) in data {
1000            builder.column_mut(0).unwrap().push_int64(group);
1001            builder.column_mut(1).unwrap().push_int64(value);
1002            builder.advance_row();
1003        }
1004
1005        builder.finish()
1006    }
1007
1008    #[test]
1009    fn test_simple_count() {
1010        let mock = MockOperator::new(vec![create_test_chunk()]);
1011
1012        let mut agg = SimpleAggregateOperator::new(
1013            Box::new(mock),
1014            vec![AggregateExpr::count_star()],
1015            vec![LogicalType::Int64],
1016        );
1017
1018        let result = agg.next().unwrap().unwrap();
1019        assert_eq!(result.row_count(), 1);
1020        assert_eq!(result.column(0).unwrap().get_int64(0), Some(5));
1021
1022        // Should be done
1023        assert!(agg.next().unwrap().is_none());
1024    }
1025
1026    #[test]
1027    fn test_simple_sum() {
1028        let mock = MockOperator::new(vec![create_test_chunk()]);
1029
1030        let mut agg = SimpleAggregateOperator::new(
1031            Box::new(mock),
1032            vec![AggregateExpr::sum(1)], // Sum of column 1
1033            vec![LogicalType::Int64],
1034        );
1035
1036        let result = agg.next().unwrap().unwrap();
1037        assert_eq!(result.row_count(), 1);
1038        // Sum: 10 + 20 + 30 + 40 + 50 = 150
1039        assert_eq!(result.column(0).unwrap().get_int64(0), Some(150));
1040    }
1041
1042    #[test]
1043    fn test_simple_avg() {
1044        let mock = MockOperator::new(vec![create_test_chunk()]);
1045
1046        let mut agg = SimpleAggregateOperator::new(
1047            Box::new(mock),
1048            vec![AggregateExpr::avg(1)],
1049            vec![LogicalType::Float64],
1050        );
1051
1052        let result = agg.next().unwrap().unwrap();
1053        assert_eq!(result.row_count(), 1);
1054        // Avg: 150 / 5 = 30.0
1055        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1056        assert!((avg - 30.0).abs() < 0.001);
1057    }
1058
1059    #[test]
1060    fn test_simple_min_max() {
1061        let mock = MockOperator::new(vec![create_test_chunk()]);
1062
1063        let mut agg = SimpleAggregateOperator::new(
1064            Box::new(mock),
1065            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1066            vec![LogicalType::Int64, LogicalType::Int64],
1067        );
1068
1069        let result = agg.next().unwrap().unwrap();
1070        assert_eq!(result.row_count(), 1);
1071        assert_eq!(result.column(0).unwrap().get_int64(0), Some(10)); // Min
1072        assert_eq!(result.column(1).unwrap().get_int64(0), Some(50)); // Max
1073    }
1074
1075    #[test]
1076    fn test_sum_with_string_values() {
1077        // Test SUM with string values (like RDF stores numeric literals)
1078        let mut builder = DataChunkBuilder::new(&[LogicalType::String]);
1079        builder.column_mut(0).unwrap().push_string("30");
1080        builder.advance_row();
1081        builder.column_mut(0).unwrap().push_string("25");
1082        builder.advance_row();
1083        builder.column_mut(0).unwrap().push_string("35");
1084        builder.advance_row();
1085        let chunk = builder.finish();
1086
1087        let mock = MockOperator::new(vec![chunk]);
1088        let mut agg = SimpleAggregateOperator::new(
1089            Box::new(mock),
1090            vec![AggregateExpr::sum(0)],
1091            vec![LogicalType::Float64],
1092        );
1093
1094        let result = agg.next().unwrap().unwrap();
1095        assert_eq!(result.row_count(), 1);
1096        // Should parse strings and sum: 30 + 25 + 35 = 90
1097        let sum_val = result.column(0).unwrap().get_float64(0).unwrap();
1098        assert!(
1099            (sum_val - 90.0).abs() < 0.001,
1100            "Expected 90.0, got {}",
1101            sum_val
1102        );
1103    }
1104
1105    #[test]
1106    fn test_grouped_aggregation() {
1107        let mock = MockOperator::new(vec![create_test_chunk()]);
1108
1109        // GROUP BY column 0, SUM(column 1)
1110        let mut agg = HashAggregateOperator::new(
1111            Box::new(mock),
1112            vec![0],                     // Group by column 0
1113            vec![AggregateExpr::sum(1)], // Sum of column 1
1114            vec![LogicalType::Int64, LogicalType::Int64],
1115        );
1116
1117        let mut results: Vec<(i64, i64)> = Vec::new();
1118        while let Some(chunk) = agg.next().unwrap() {
1119            for row in chunk.selected_indices() {
1120                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1121                let sum = chunk.column(1).unwrap().get_int64(row).unwrap();
1122                results.push((group, sum));
1123            }
1124        }
1125
1126        results.sort_by_key(|(g, _)| *g);
1127        assert_eq!(results.len(), 2);
1128        assert_eq!(results[0], (1, 30)); // Group 1: 10 + 20 = 30
1129        assert_eq!(results[1], (2, 120)); // Group 2: 30 + 40 + 50 = 120
1130    }
1131
1132    #[test]
1133    fn test_grouped_count() {
1134        let mock = MockOperator::new(vec![create_test_chunk()]);
1135
1136        // GROUP BY column 0, COUNT(*)
1137        let mut agg = HashAggregateOperator::new(
1138            Box::new(mock),
1139            vec![0],
1140            vec![AggregateExpr::count_star()],
1141            vec![LogicalType::Int64, LogicalType::Int64],
1142        );
1143
1144        let mut results: Vec<(i64, i64)> = Vec::new();
1145        while let Some(chunk) = agg.next().unwrap() {
1146            for row in chunk.selected_indices() {
1147                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1148                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1149                results.push((group, count));
1150            }
1151        }
1152
1153        results.sort_by_key(|(g, _)| *g);
1154        assert_eq!(results.len(), 2);
1155        assert_eq!(results[0], (1, 2)); // Group 1: 2 rows
1156        assert_eq!(results[1], (2, 3)); // Group 2: 3 rows
1157    }
1158
1159    #[test]
1160    fn test_multiple_aggregates() {
1161        let mock = MockOperator::new(vec![create_test_chunk()]);
1162
1163        // GROUP BY column 0, COUNT(*), SUM(column 1), AVG(column 1)
1164        let mut agg = HashAggregateOperator::new(
1165            Box::new(mock),
1166            vec![0],
1167            vec![
1168                AggregateExpr::count_star(),
1169                AggregateExpr::sum(1),
1170                AggregateExpr::avg(1),
1171            ],
1172            vec![
1173                LogicalType::Int64,   // Group key
1174                LogicalType::Int64,   // COUNT
1175                LogicalType::Int64,   // SUM
1176                LogicalType::Float64, // AVG
1177            ],
1178        );
1179
1180        let mut results: Vec<(i64, i64, i64, f64)> = Vec::new();
1181        while let Some(chunk) = agg.next().unwrap() {
1182            for row in chunk.selected_indices() {
1183                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1184                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1185                let sum = chunk.column(2).unwrap().get_int64(row).unwrap();
1186                let avg = chunk.column(3).unwrap().get_float64(row).unwrap();
1187                results.push((group, count, sum, avg));
1188            }
1189        }
1190
1191        results.sort_by_key(|(g, _, _, _)| *g);
1192        assert_eq!(results.len(), 2);
1193
1194        // Group 1: COUNT=2, SUM=30, AVG=15.0
1195        assert_eq!(results[0].0, 1);
1196        assert_eq!(results[0].1, 2);
1197        assert_eq!(results[0].2, 30);
1198        assert!((results[0].3 - 15.0).abs() < 0.001);
1199
1200        // Group 2: COUNT=3, SUM=120, AVG=40.0
1201        assert_eq!(results[1].0, 2);
1202        assert_eq!(results[1].1, 3);
1203        assert_eq!(results[1].2, 120);
1204        assert!((results[1].3 - 40.0).abs() < 0.001);
1205    }
1206
1207    fn create_test_chunk_with_duplicates() -> DataChunk {
1208        // Create data with duplicate values in column 1
1209        // [(group, value)] = [(1, 10), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)]
1210        // GROUP 1: values [10, 10, 20] -> distinct count = 2
1211        // GROUP 2: values [30, 30, 30] -> distinct count = 1
1212        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64, LogicalType::Int64]);
1213
1214        let data = [(1i64, 10i64), (1, 10), (1, 20), (2, 30), (2, 30), (2, 30)];
1215        for (group, value) in data {
1216            builder.column_mut(0).unwrap().push_int64(group);
1217            builder.column_mut(1).unwrap().push_int64(value);
1218            builder.advance_row();
1219        }
1220
1221        builder.finish()
1222    }
1223
1224    #[test]
1225    fn test_count_distinct() {
1226        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1227
1228        // COUNT(DISTINCT column 1)
1229        let mut agg = SimpleAggregateOperator::new(
1230            Box::new(mock),
1231            vec![AggregateExpr::count(1).with_distinct()],
1232            vec![LogicalType::Int64],
1233        );
1234
1235        let result = agg.next().unwrap().unwrap();
1236        assert_eq!(result.row_count(), 1);
1237        // Total distinct values: 10, 20, 30 = 3 distinct values
1238        assert_eq!(result.column(0).unwrap().get_int64(0), Some(3));
1239    }
1240
1241    #[test]
1242    fn test_grouped_count_distinct() {
1243        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1244
1245        // GROUP BY column 0, COUNT(DISTINCT column 1)
1246        let mut agg = HashAggregateOperator::new(
1247            Box::new(mock),
1248            vec![0],
1249            vec![AggregateExpr::count(1).with_distinct()],
1250            vec![LogicalType::Int64, LogicalType::Int64],
1251        );
1252
1253        let mut results: Vec<(i64, i64)> = Vec::new();
1254        while let Some(chunk) = agg.next().unwrap() {
1255            for row in chunk.selected_indices() {
1256                let group = chunk.column(0).unwrap().get_int64(row).unwrap();
1257                let count = chunk.column(1).unwrap().get_int64(row).unwrap();
1258                results.push((group, count));
1259            }
1260        }
1261
1262        results.sort_by_key(|(g, _)| *g);
1263        assert_eq!(results.len(), 2);
1264        assert_eq!(results[0], (1, 2)); // Group 1: [10, 10, 20] -> 2 distinct values
1265        assert_eq!(results[1], (2, 1)); // Group 2: [30, 30, 30] -> 1 distinct value
1266    }
1267
1268    #[test]
1269    fn test_sum_distinct() {
1270        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1271
1272        // SUM(DISTINCT column 1)
1273        let mut agg = SimpleAggregateOperator::new(
1274            Box::new(mock),
1275            vec![AggregateExpr::sum(1).with_distinct()],
1276            vec![LogicalType::Int64],
1277        );
1278
1279        let result = agg.next().unwrap().unwrap();
1280        assert_eq!(result.row_count(), 1);
1281        // Sum of distinct values: 10 + 20 + 30 = 60
1282        assert_eq!(result.column(0).unwrap().get_int64(0), Some(60));
1283    }
1284
1285    #[test]
1286    fn test_avg_distinct() {
1287        let mock = MockOperator::new(vec![create_test_chunk_with_duplicates()]);
1288
1289        // AVG(DISTINCT column 1)
1290        let mut agg = SimpleAggregateOperator::new(
1291            Box::new(mock),
1292            vec![AggregateExpr::avg(1).with_distinct()],
1293            vec![LogicalType::Float64],
1294        );
1295
1296        let result = agg.next().unwrap().unwrap();
1297        assert_eq!(result.row_count(), 1);
1298        // Avg of distinct values: (10 + 20 + 30) / 3 = 20.0
1299        let avg = result.column(0).unwrap().get_float64(0).unwrap();
1300        assert!((avg - 20.0).abs() < 0.001);
1301    }
1302
1303    fn create_statistical_test_chunk() -> DataChunk {
1304        // Create data: [2, 4, 4, 4, 5, 5, 7, 9]
1305        // Mean = 5.0, Sample StdDev = 2.138, Population StdDev = 2.0
1306        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1307
1308        for value in [2i64, 4, 4, 4, 5, 5, 7, 9] {
1309            builder.column_mut(0).unwrap().push_int64(value);
1310            builder.advance_row();
1311        }
1312
1313        builder.finish()
1314    }
1315
1316    #[test]
1317    fn test_stdev_sample() {
1318        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1319
1320        let mut agg = SimpleAggregateOperator::new(
1321            Box::new(mock),
1322            vec![AggregateExpr::stdev(0)],
1323            vec![LogicalType::Float64],
1324        );
1325
1326        let result = agg.next().unwrap().unwrap();
1327        assert_eq!(result.row_count(), 1);
1328        // Sample standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1329        // Mean = 5.0, Variance = 32/7 = 4.571, StdDev = 2.138
1330        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1331        assert!((stdev - 2.138).abs() < 0.01);
1332    }
1333
1334    #[test]
1335    fn test_stdev_population() {
1336        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1337
1338        let mut agg = SimpleAggregateOperator::new(
1339            Box::new(mock),
1340            vec![AggregateExpr::stdev_pop(0)],
1341            vec![LogicalType::Float64],
1342        );
1343
1344        let result = agg.next().unwrap().unwrap();
1345        assert_eq!(result.row_count(), 1);
1346        // Population standard deviation of [2, 4, 4, 4, 5, 5, 7, 9]
1347        // Mean = 5.0, Variance = 32/8 = 4.0, StdDev = 2.0
1348        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1349        assert!((stdev - 2.0).abs() < 0.01);
1350    }
1351
1352    #[test]
1353    fn test_percentile_disc() {
1354        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1355
1356        // Median (50th percentile discrete)
1357        let mut agg = SimpleAggregateOperator::new(
1358            Box::new(mock),
1359            vec![AggregateExpr::percentile_disc(0, 0.5)],
1360            vec![LogicalType::Float64],
1361        );
1362
1363        let result = agg.next().unwrap().unwrap();
1364        assert_eq!(result.row_count(), 1);
1365        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], index = floor(0.5 * 7) = 3, value = 4
1366        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1367        assert!((percentile - 4.0).abs() < 0.01);
1368    }
1369
1370    #[test]
1371    fn test_percentile_cont() {
1372        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1373
1374        // Median (50th percentile continuous)
1375        let mut agg = SimpleAggregateOperator::new(
1376            Box::new(mock),
1377            vec![AggregateExpr::percentile_cont(0, 0.5)],
1378            vec![LogicalType::Float64],
1379        );
1380
1381        let result = agg.next().unwrap().unwrap();
1382        assert_eq!(result.row_count(), 1);
1383        // Sorted: [2, 4, 4, 4, 5, 5, 7, 9], rank = 0.5 * 7 = 3.5
1384        // Interpolate between index 3 (4) and index 4 (5): 4 + 0.5 * (5 - 4) = 4.5
1385        let percentile = result.column(0).unwrap().get_float64(0).unwrap();
1386        assert!((percentile - 4.5).abs() < 0.01);
1387    }
1388
1389    #[test]
1390    fn test_percentile_extremes() {
1391        // Test 0th and 100th percentiles
1392        let mock = MockOperator::new(vec![create_statistical_test_chunk()]);
1393
1394        let mut agg = SimpleAggregateOperator::new(
1395            Box::new(mock),
1396            vec![
1397                AggregateExpr::percentile_disc(0, 0.0),
1398                AggregateExpr::percentile_disc(0, 1.0),
1399            ],
1400            vec![LogicalType::Float64, LogicalType::Float64],
1401        );
1402
1403        let result = agg.next().unwrap().unwrap();
1404        assert_eq!(result.row_count(), 1);
1405        // 0th percentile = minimum = 2
1406        let p0 = result.column(0).unwrap().get_float64(0).unwrap();
1407        assert!((p0 - 2.0).abs() < 0.01);
1408        // 100th percentile = maximum = 9
1409        let p100 = result.column(1).unwrap().get_float64(0).unwrap();
1410        assert!((p100 - 9.0).abs() < 0.01);
1411    }
1412
1413    #[test]
1414    fn test_stdev_single_value() {
1415        // Single value should return null for sample stdev
1416        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1417        builder.column_mut(0).unwrap().push_int64(42);
1418        builder.advance_row();
1419        let chunk = builder.finish();
1420
1421        let mock = MockOperator::new(vec![chunk]);
1422
1423        let mut agg = SimpleAggregateOperator::new(
1424            Box::new(mock),
1425            vec![AggregateExpr::stdev(0)],
1426            vec![LogicalType::Float64],
1427        );
1428
1429        let result = agg.next().unwrap().unwrap();
1430        assert_eq!(result.row_count(), 1);
1431        // Sample stdev of single value is undefined (null)
1432        assert!(matches!(
1433            result.column(0).unwrap().get_value(0),
1434            Some(Value::Null)
1435        ));
1436    }
1437
1438    #[test]
1439    fn test_stdev_pop_single_value() {
1440        // Single value should return 0 for population stdev
1441        let mut builder = DataChunkBuilder::new(&[LogicalType::Int64]);
1442        builder.column_mut(0).unwrap().push_int64(42);
1443        builder.advance_row();
1444        let chunk = builder.finish();
1445
1446        let mock = MockOperator::new(vec![chunk]);
1447
1448        let mut agg = SimpleAggregateOperator::new(
1449            Box::new(mock),
1450            vec![AggregateExpr::stdev_pop(0)],
1451            vec![LogicalType::Float64],
1452        );
1453
1454        let result = agg.next().unwrap().unwrap();
1455        assert_eq!(result.row_count(), 1);
1456        // Population stdev of single value is 0
1457        let stdev = result.column(0).unwrap().get_float64(0).unwrap();
1458        assert!((stdev - 0.0).abs() < 0.01);
1459    }
1460}