Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6#[cfg(feature = "spill")]
7use crate::execution::spill::{PartitionedState, SpillManager};
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::collections::HashMap;
11#[cfg(feature = "spill")]
12use std::io::{Read, Write};
13#[cfg(feature = "spill")]
14use std::sync::Arc;
15
16/// Aggregation function type.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AggregateFunction {
19    /// Count rows or non-null values.
20    Count,
21    /// Sum of values.
22    Sum,
23    /// Minimum value.
24    Min,
25    /// Maximum value.
26    Max,
27    /// Average value.
28    Avg,
29    /// First value in group.
30    First,
31}
32
33/// Aggregate expression.
34#[derive(Debug, Clone)]
35pub struct AggregateExpr {
36    /// The aggregate function.
37    pub function: AggregateFunction,
38    /// Column index to aggregate (None for COUNT(*)).
39    pub column: Option<usize>,
40    /// Whether DISTINCT applies.
41    pub distinct: bool,
42}
43
44impl AggregateExpr {
45    /// Create a COUNT(*) expression.
46    pub fn count_star() -> Self {
47        Self {
48            function: AggregateFunction::Count,
49            column: None,
50            distinct: false,
51        }
52    }
53
54    /// Create a COUNT(column) expression.
55    pub fn count(column: usize) -> Self {
56        Self {
57            function: AggregateFunction::Count,
58            column: Some(column),
59            distinct: false,
60        }
61    }
62
63    /// Create a SUM(column) expression.
64    pub fn sum(column: usize) -> Self {
65        Self {
66            function: AggregateFunction::Sum,
67            column: Some(column),
68            distinct: false,
69        }
70    }
71
72    /// Create a MIN(column) expression.
73    pub fn min(column: usize) -> Self {
74        Self {
75            function: AggregateFunction::Min,
76            column: Some(column),
77            distinct: false,
78        }
79    }
80
81    /// Create a MAX(column) expression.
82    pub fn max(column: usize) -> Self {
83        Self {
84            function: AggregateFunction::Max,
85            column: Some(column),
86            distinct: false,
87        }
88    }
89
90    /// Create an AVG(column) expression.
91    pub fn avg(column: usize) -> Self {
92        Self {
93            function: AggregateFunction::Avg,
94            column: Some(column),
95            distinct: false,
96        }
97    }
98}
99
100/// Accumulator for aggregate state.
101#[derive(Debug, Clone, Default)]
102struct Accumulator {
103    count: i64,
104    sum: f64,
105    min: Option<Value>,
106    max: Option<Value>,
107    first: Option<Value>,
108}
109
110impl Accumulator {
111    fn new() -> Self {
112        Self {
113            count: 0,
114            sum: 0.0,
115            min: None,
116            max: None,
117            first: None,
118        }
119    }
120
121    fn add(&mut self, value: &Value) {
122        // Skip nulls for aggregates
123        if matches!(value, Value::Null) {
124            return;
125        }
126
127        self.count += 1;
128
129        // Sum (for numeric types)
130        if let Some(n) = value_to_f64(value) {
131            self.sum += n;
132        }
133
134        // Min
135        if self.min.is_none() || compare_for_min(&self.min, value) {
136            self.min = Some(value.clone());
137        }
138
139        // Max
140        if self.max.is_none() || compare_for_max(&self.max, value) {
141            self.max = Some(value.clone());
142        }
143
144        // First
145        if self.first.is_none() {
146            self.first = Some(value.clone());
147        }
148    }
149
150    fn finalize(&mut self, func: AggregateFunction) -> Value {
151        match func {
152            AggregateFunction::Count => Value::Int64(self.count),
153            AggregateFunction::Sum => {
154                if self.count == 0 {
155                    Value::Null
156                } else {
157                    Value::Float64(self.sum)
158                }
159            }
160            AggregateFunction::Min => self.min.take().unwrap_or(Value::Null),
161            AggregateFunction::Max => self.max.take().unwrap_or(Value::Null),
162            AggregateFunction::Avg => {
163                if self.count == 0 {
164                    Value::Null
165                } else {
166                    Value::Float64(self.sum / self.count as f64)
167                }
168            }
169            AggregateFunction::First => self.first.take().unwrap_or(Value::Null),
170        }
171    }
172}
173
174fn value_to_f64(value: &Value) -> Option<f64> {
175    match value {
176        Value::Int64(i) => Some(*i as f64),
177        Value::Float64(f) => Some(*f),
178        // RDF stores numeric literals as strings - try to parse
179        Value::String(s) => s.parse::<f64>().ok(),
180        _ => None,
181    }
182}
183
184fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
185    match (current, new) {
186        (None, _) => true,
187        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
188        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
189        (Some(Value::String(a)), Value::String(b)) => {
190            // Try numeric comparison for RDF values
191            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
192                b_num < a_num
193            } else {
194                b < a
195            }
196        }
197        // Cross-type comparisons for RDF
198        (Some(Value::String(a)), Value::Int64(b)) => {
199            if let Ok(a_num) = a.parse::<f64>() {
200                (*b as f64) < a_num
201            } else {
202                false
203            }
204        }
205        (Some(Value::Int64(a)), Value::String(b)) => {
206            if let Ok(b_num) = b.parse::<f64>() {
207                b_num < *a as f64
208            } else {
209                false
210            }
211        }
212        _ => false,
213    }
214}
215
216fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
217    match (current, new) {
218        (None, _) => true,
219        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
220        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
221        (Some(Value::String(a)), Value::String(b)) => {
222            // Try numeric comparison for RDF values
223            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
224                b_num > a_num
225            } else {
226                b > a
227            }
228        }
229        // Cross-type comparisons for RDF
230        (Some(Value::String(a)), Value::Int64(b)) => {
231            if let Ok(a_num) = a.parse::<f64>() {
232                (*b as f64) > a_num
233            } else {
234                false
235            }
236        }
237        (Some(Value::Int64(a)), Value::String(b)) => {
238            if let Ok(b_num) = b.parse::<f64>() {
239                b_num > *a as f64
240            } else {
241                false
242            }
243        }
244        _ => false,
245    }
246}
247
248/// Hash key for grouping.
249#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250struct GroupKey(Vec<u64>);
251
252impl GroupKey {
253    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
254        let hashes: Vec<u64> = group_by
255            .iter()
256            .map(|&col| {
257                chunk
258                    .column(col)
259                    .and_then(|c| c.get_value(row))
260                    .map_or(0, |v| hash_value(&v))
261            })
262            .collect();
263        Self(hashes)
264    }
265}
266
267fn hash_value(value: &Value) -> u64 {
268    use std::collections::hash_map::DefaultHasher;
269    use std::hash::{Hash, Hasher};
270
271    let mut hasher = DefaultHasher::new();
272    match value {
273        Value::Null => 0u8.hash(&mut hasher),
274        Value::Bool(b) => b.hash(&mut hasher),
275        Value::Int64(i) => i.hash(&mut hasher),
276        Value::Float64(f) => f.to_bits().hash(&mut hasher),
277        Value::String(s) => s.hash(&mut hasher),
278        _ => 0u8.hash(&mut hasher),
279    }
280    hasher.finish()
281}
282
283/// Group state with key values and accumulators.
284#[derive(Clone)]
285struct GroupState {
286    key_values: Vec<Value>,
287    accumulators: Vec<Accumulator>,
288}
289
290/// Push-based aggregate operator.
291///
292/// This is a pipeline breaker that accumulates all input, groups by key,
293/// and produces aggregated output in the finalize phase.
294pub struct AggregatePushOperator {
295    /// Columns to group by.
296    group_by: Vec<usize>,
297    /// Aggregate expressions.
298    aggregates: Vec<AggregateExpr>,
299    /// Group states by hash key.
300    groups: HashMap<GroupKey, GroupState>,
301    /// Global accumulator (for no GROUP BY).
302    global_state: Option<Vec<Accumulator>>,
303}
304
305impl AggregatePushOperator {
306    /// Create a new aggregate operator.
307    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
308        let global_state = if group_by.is_empty() {
309            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
310        } else {
311            None
312        };
313
314        Self {
315            group_by,
316            aggregates,
317            groups: HashMap::new(),
318            global_state,
319        }
320    }
321
322    /// Create a simple global aggregate (no GROUP BY).
323    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
324        Self::new(Vec::new(), aggregates)
325    }
326}
327
328impl PushOperator for AggregatePushOperator {
329    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
330        if chunk.is_empty() {
331            return Ok(true);
332        }
333
334        for row in chunk.selected_indices() {
335            if self.group_by.is_empty() {
336                // Global aggregation
337                if let Some(ref mut accumulators) = self.global_state {
338                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
339                        if let Some(col) = expr.column {
340                            if let Some(c) = chunk.column(col)
341                                && let Some(val) = c.get_value(row)
342                            {
343                                acc.add(&val);
344                            }
345                        } else {
346                            // COUNT(*)
347                            acc.count += 1;
348                        }
349                    }
350                }
351            } else {
352                // Group by aggregation
353                let key = GroupKey::from_row(&chunk, row, &self.group_by);
354
355                let state = self.groups.entry(key).or_insert_with(|| {
356                    let key_values: Vec<Value> = self
357                        .group_by
358                        .iter()
359                        .map(|&col| {
360                            chunk
361                                .column(col)
362                                .and_then(|c| c.get_value(row))
363                                .unwrap_or(Value::Null)
364                        })
365                        .collect();
366
367                    GroupState {
368                        key_values,
369                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
370                    }
371                });
372
373                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
374                    if let Some(col) = expr.column {
375                        if let Some(c) = chunk.column(col)
376                            && let Some(val) = c.get_value(row)
377                        {
378                            acc.add(&val);
379                        }
380                    } else {
381                        // COUNT(*)
382                        acc.count += 1;
383                    }
384                }
385            }
386        }
387
388        Ok(true)
389    }
390
391    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
392        let num_output_cols = self.group_by.len() + self.aggregates.len();
393        let mut columns: Vec<ValueVector> =
394            (0..num_output_cols).map(|_| ValueVector::new()).collect();
395
396        if self.group_by.is_empty() {
397            // Global aggregation - single row output
398            if let Some(ref mut accumulators) = self.global_state {
399                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
400                    columns[i].push(acc.finalize(expr.function));
401                }
402            }
403        } else {
404            // Group by - one row per group
405            for state in self.groups.values_mut() {
406                // Output group key columns
407                for (i, val) in state.key_values.iter().enumerate() {
408                    columns[i].push(val.clone());
409                }
410
411                // Output aggregate results
412                for (i, (acc, expr)) in state
413                    .accumulators
414                    .iter_mut()
415                    .zip(&self.aggregates)
416                    .enumerate()
417                {
418                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
419                }
420            }
421        }
422
423        if !columns.is_empty() && !columns[0].is_empty() {
424            let chunk = DataChunk::new(columns);
425            sink.consume(chunk)?;
426        }
427
428        Ok(())
429    }
430
431    fn preferred_chunk_size(&self) -> ChunkSizeHint {
432        ChunkSizeHint::Default
433    }
434
435    fn name(&self) -> &'static str {
436        "AggregatePush"
437    }
438}
439
440/// Default spill threshold for aggregates (number of groups).
441#[cfg(feature = "spill")]
442pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
443
444/// Serializes a GroupState to bytes.
445#[cfg(feature = "spill")]
446fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
447    use crate::execution::spill::serialize_value;
448
449    // Write key values
450    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
451    for val in &state.key_values {
452        serialize_value(val, w)?;
453    }
454
455    // Write accumulators
456    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
457    for acc in &state.accumulators {
458        w.write_all(&acc.count.to_le_bytes())?;
459        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
460
461        // Min
462        let has_min = acc.min.is_some();
463        w.write_all(&[has_min as u8])?;
464        if let Some(ref v) = acc.min {
465            serialize_value(v, w)?;
466        }
467
468        // Max
469        let has_max = acc.max.is_some();
470        w.write_all(&[has_max as u8])?;
471        if let Some(ref v) = acc.max {
472            serialize_value(v, w)?;
473        }
474
475        // First
476        let has_first = acc.first.is_some();
477        w.write_all(&[has_first as u8])?;
478        if let Some(ref v) = acc.first {
479            serialize_value(v, w)?;
480        }
481    }
482
483    Ok(())
484}
485
486/// Deserializes a GroupState from bytes.
487#[cfg(feature = "spill")]
488fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
489    use crate::execution::spill::deserialize_value;
490
491    // Read key values
492    let mut len_buf = [0u8; 8];
493    r.read_exact(&mut len_buf)?;
494    let num_keys = u64::from_le_bytes(len_buf) as usize;
495
496    let mut key_values = Vec::with_capacity(num_keys);
497    for _ in 0..num_keys {
498        key_values.push(deserialize_value(r)?);
499    }
500
501    // Read accumulators
502    r.read_exact(&mut len_buf)?;
503    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
504
505    let mut accumulators = Vec::with_capacity(num_accumulators);
506    for _ in 0..num_accumulators {
507        let mut count_buf = [0u8; 8];
508        r.read_exact(&mut count_buf)?;
509        let count = i64::from_le_bytes(count_buf);
510
511        r.read_exact(&mut count_buf)?;
512        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
513
514        // Min
515        let mut flag_buf = [0u8; 1];
516        r.read_exact(&mut flag_buf)?;
517        let min = if flag_buf[0] != 0 {
518            Some(deserialize_value(r)?)
519        } else {
520            None
521        };
522
523        // Max
524        r.read_exact(&mut flag_buf)?;
525        let max = if flag_buf[0] != 0 {
526            Some(deserialize_value(r)?)
527        } else {
528            None
529        };
530
531        // First
532        r.read_exact(&mut flag_buf)?;
533        let first = if flag_buf[0] != 0 {
534            Some(deserialize_value(r)?)
535        } else {
536            None
537        };
538
539        accumulators.push(Accumulator {
540            count,
541            sum,
542            min,
543            max,
544            first,
545        });
546    }
547
548    Ok(GroupState {
549        key_values,
550        accumulators,
551    })
552}
553
554/// Push-based aggregate operator with spilling support.
555///
556/// Uses partitioned hash table that can spill cold partitions to disk
557/// when memory pressure is high.
558#[cfg(feature = "spill")]
559pub struct SpillableAggregatePushOperator {
560    /// Columns to group by.
561    group_by: Vec<usize>,
562    /// Aggregate expressions.
563    aggregates: Vec<AggregateExpr>,
564    /// Spill manager (None = no spilling).
565    spill_manager: Option<Arc<SpillManager>>,
566    /// Partitioned groups (used when spilling is enabled).
567    partitioned_groups: Option<PartitionedState<GroupState>>,
568    /// Non-partitioned groups (used when spilling is disabled).
569    groups: HashMap<GroupKey, GroupState>,
570    /// Global accumulator (for no GROUP BY).
571    global_state: Option<Vec<Accumulator>>,
572    /// Spill threshold (number of groups).
573    spill_threshold: usize,
574    /// Whether we've switched to partitioned mode.
575    using_partitioned: bool,
576}
577
578#[cfg(feature = "spill")]
579impl SpillableAggregatePushOperator {
580    /// Create a new spillable aggregate operator.
581    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
582        let global_state = if group_by.is_empty() {
583            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
584        } else {
585            None
586        };
587
588        Self {
589            group_by,
590            aggregates,
591            spill_manager: None,
592            partitioned_groups: None,
593            groups: HashMap::new(),
594            global_state,
595            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
596            using_partitioned: false,
597        }
598    }
599
600    /// Create a spillable aggregate operator with spilling enabled.
601    pub fn with_spilling(
602        group_by: Vec<usize>,
603        aggregates: Vec<AggregateExpr>,
604        manager: Arc<SpillManager>,
605        threshold: usize,
606    ) -> Self {
607        let global_state = if group_by.is_empty() {
608            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
609        } else {
610            None
611        };
612
613        let partitioned = PartitionedState::new(
614            Arc::clone(&manager),
615            256, // Number of partitions
616            serialize_group_state,
617            deserialize_group_state,
618        );
619
620        Self {
621            group_by,
622            aggregates,
623            spill_manager: Some(manager),
624            partitioned_groups: Some(partitioned),
625            groups: HashMap::new(),
626            global_state,
627            spill_threshold: threshold,
628            using_partitioned: true,
629        }
630    }
631
632    /// Create a simple global aggregate (no GROUP BY).
633    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
634        Self::new(Vec::new(), aggregates)
635    }
636
637    /// Sets the spill threshold.
638    pub fn with_threshold(mut self, threshold: usize) -> Self {
639        self.spill_threshold = threshold;
640        self
641    }
642
643    /// Switches to partitioned mode if needed.
644    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
645        if self.global_state.is_some() {
646            // Global aggregation doesn't need spilling
647            return Ok(());
648        }
649
650        // If using partitioned state, check if we need to spill
651        if let Some(ref mut partitioned) = self.partitioned_groups {
652            if partitioned.total_size() >= self.spill_threshold {
653                partitioned
654                    .spill_largest()
655                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
656            }
657        } else if self.groups.len() >= self.spill_threshold {
658            // Not using partitioned state yet, but reached threshold
659            // If spilling is configured, switch to partitioned mode
660            if let Some(ref manager) = self.spill_manager {
661                let mut partitioned = PartitionedState::new(
662                    Arc::clone(manager),
663                    256,
664                    serialize_group_state,
665                    deserialize_group_state,
666                );
667
668                // Move existing groups to partitioned state
669                for (_key, state) in self.groups.drain() {
670                    partitioned
671                        .insert(state.key_values.clone(), state)
672                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
673                }
674
675                self.partitioned_groups = Some(partitioned);
676                self.using_partitioned = true;
677            }
678        }
679
680        Ok(())
681    }
682}
683
684#[cfg(feature = "spill")]
685impl PushOperator for SpillableAggregatePushOperator {
686    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
687        if chunk.is_empty() {
688            return Ok(true);
689        }
690
691        for row in chunk.selected_indices() {
692            if self.group_by.is_empty() {
693                // Global aggregation - same as non-spillable
694                if let Some(ref mut accumulators) = self.global_state {
695                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
696                        if let Some(col) = expr.column {
697                            if let Some(c) = chunk.column(col)
698                                && let Some(val) = c.get_value(row)
699                            {
700                                acc.add(&val);
701                            }
702                        } else {
703                            acc.count += 1;
704                        }
705                    }
706                }
707            } else if self.using_partitioned {
708                // Use partitioned state
709                if let Some(ref mut partitioned) = self.partitioned_groups {
710                    let key_values: Vec<Value> = self
711                        .group_by
712                        .iter()
713                        .map(|&col| {
714                            chunk
715                                .column(col)
716                                .and_then(|c| c.get_value(row))
717                                .unwrap_or(Value::Null)
718                        })
719                        .collect();
720
721                    let aggregates = &self.aggregates;
722                    let state = partitioned
723                        .get_or_insert_with(key_values.clone(), || GroupState {
724                            key_values: key_values.clone(),
725                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
726                        })
727                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
728
729                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
730                        if let Some(col) = expr.column {
731                            if let Some(c) = chunk.column(col)
732                                && let Some(val) = c.get_value(row)
733                            {
734                                acc.add(&val);
735                            }
736                        } else {
737                            acc.count += 1;
738                        }
739                    }
740                }
741            } else {
742                // Use regular hash map
743                let key = GroupKey::from_row(&chunk, row, &self.group_by);
744
745                let state = self.groups.entry(key).or_insert_with(|| {
746                    let key_values: Vec<Value> = self
747                        .group_by
748                        .iter()
749                        .map(|&col| {
750                            chunk
751                                .column(col)
752                                .and_then(|c| c.get_value(row))
753                                .unwrap_or(Value::Null)
754                        })
755                        .collect();
756
757                    GroupState {
758                        key_values,
759                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
760                    }
761                });
762
763                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
764                    if let Some(col) = expr.column {
765                        if let Some(c) = chunk.column(col)
766                            && let Some(val) = c.get_value(row)
767                        {
768                            acc.add(&val);
769                        }
770                    } else {
771                        acc.count += 1;
772                    }
773                }
774            }
775        }
776
777        // Check if we need to spill
778        self.maybe_spill()?;
779
780        Ok(true)
781    }
782
783    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
784        let num_output_cols = self.group_by.len() + self.aggregates.len();
785        let mut columns: Vec<ValueVector> =
786            (0..num_output_cols).map(|_| ValueVector::new()).collect();
787
788        if self.group_by.is_empty() {
789            // Global aggregation - single row output
790            if let Some(ref mut accumulators) = self.global_state {
791                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
792                    columns[i].push(acc.finalize(expr.function));
793                }
794            }
795        } else if self.using_partitioned {
796            // Drain partitioned state
797            if let Some(ref mut partitioned) = self.partitioned_groups {
798                let groups = partitioned
799                    .drain_all()
800                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
801
802                for (_key, mut state) in groups {
803                    // Output group key columns
804                    for (i, val) in state.key_values.iter().enumerate() {
805                        columns[i].push(val.clone());
806                    }
807
808                    // Output aggregate results
809                    for (i, (acc, expr)) in state
810                        .accumulators
811                        .iter_mut()
812                        .zip(&self.aggregates)
813                        .enumerate()
814                    {
815                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
816                    }
817                }
818            }
819        } else {
820            // Group by using regular hash map - one row per group
821            for state in self.groups.values_mut() {
822                // Output group key columns
823                for (i, val) in state.key_values.iter().enumerate() {
824                    columns[i].push(val.clone());
825                }
826
827                // Output aggregate results
828                for (i, (acc, expr)) in state
829                    .accumulators
830                    .iter_mut()
831                    .zip(&self.aggregates)
832                    .enumerate()
833                {
834                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
835                }
836            }
837        }
838
839        if !columns.is_empty() && !columns[0].is_empty() {
840            let chunk = DataChunk::new(columns);
841            sink.consume(chunk)?;
842        }
843
844        Ok(())
845    }
846
847    fn preferred_chunk_size(&self) -> ChunkSizeHint {
848        ChunkSizeHint::Default
849    }
850
851    fn name(&self) -> &'static str {
852        "SpillableAggregatePush"
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use super::*;
859    use crate::execution::sink::CollectorSink;
860
861    fn create_test_chunk(values: &[i64]) -> DataChunk {
862        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
863        let vector = ValueVector::from_values(&v);
864        DataChunk::new(vec![vector])
865    }
866
867    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
868        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
869        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
870        DataChunk::new(vec![
871            ValueVector::from_values(&v1),
872            ValueVector::from_values(&v2),
873        ])
874    }
875
876    #[test]
877    fn test_global_count() {
878        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
879        let mut sink = CollectorSink::new();
880
881        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
882            .unwrap();
883        agg.finalize(&mut sink).unwrap();
884
885        let chunks = sink.into_chunks();
886        assert_eq!(chunks.len(), 1);
887        assert_eq!(
888            chunks[0].column(0).unwrap().get_value(0),
889            Some(Value::Int64(5))
890        );
891    }
892
893    #[test]
894    fn test_global_sum() {
895        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
896        let mut sink = CollectorSink::new();
897
898        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
899            .unwrap();
900        agg.finalize(&mut sink).unwrap();
901
902        let chunks = sink.into_chunks();
903        assert_eq!(
904            chunks[0].column(0).unwrap().get_value(0),
905            Some(Value::Float64(15.0))
906        );
907    }
908
909    #[test]
910    fn test_global_min_max() {
911        let mut agg =
912            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
913        let mut sink = CollectorSink::new();
914
915        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
916            .unwrap();
917        agg.finalize(&mut sink).unwrap();
918
919        let chunks = sink.into_chunks();
920        assert_eq!(
921            chunks[0].column(0).unwrap().get_value(0),
922            Some(Value::Int64(1))
923        );
924        assert_eq!(
925            chunks[0].column(1).unwrap().get_value(0),
926            Some(Value::Int64(9))
927        );
928    }
929
930    #[test]
931    fn test_group_by_sum() {
932        // Group by column 0, sum column 1
933        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
934        let mut sink = CollectorSink::new();
935
936        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
937        agg.push(
938            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
939            &mut sink,
940        )
941        .unwrap();
942        agg.finalize(&mut sink).unwrap();
943
944        let chunks = sink.into_chunks();
945        assert_eq!(chunks[0].len(), 2); // 2 groups
946    }
947
948    #[test]
949    #[cfg(feature = "spill")]
950    fn test_spillable_aggregate_no_spill() {
951        // When threshold is not reached, should work like normal aggregate
952        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
953            .with_threshold(100);
954        let mut sink = CollectorSink::new();
955
956        agg.push(
957            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
958            &mut sink,
959        )
960        .unwrap();
961        agg.finalize(&mut sink).unwrap();
962
963        let chunks = sink.into_chunks();
964        assert_eq!(chunks[0].len(), 2); // 2 groups
965    }
966
967    #[test]
968    #[cfg(feature = "spill")]
969    fn test_spillable_aggregate_with_spilling() {
970        use tempfile::TempDir;
971
972        let temp_dir = TempDir::new().unwrap();
973        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
974
975        // Set very low threshold to force spilling
976        let mut agg = SpillableAggregatePushOperator::with_spilling(
977            vec![0],
978            vec![AggregateExpr::sum(1)],
979            manager,
980            3, // Spill after 3 groups
981        );
982        let mut sink = CollectorSink::new();
983
984        // Create 10 different groups
985        for i in 0..10 {
986            let chunk = create_two_column_chunk(&[i], &[i * 10]);
987            agg.push(chunk, &mut sink).unwrap();
988        }
989        agg.finalize(&mut sink).unwrap();
990
991        let chunks = sink.into_chunks();
992        assert_eq!(chunks.len(), 1);
993        assert_eq!(chunks[0].len(), 10); // 10 groups
994
995        // Verify sums are correct
996        let mut sums: Vec<f64> = Vec::new();
997        for i in 0..chunks[0].len() {
998            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
999                sums.push(sum);
1000            }
1001        }
1002        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
1003        assert_eq!(
1004            sums,
1005            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
1006        );
1007    }
1008
1009    #[test]
1010    #[cfg(feature = "spill")]
1011    fn test_spillable_aggregate_global() {
1012        // Global aggregation shouldn't be affected by spilling
1013        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1014        let mut sink = CollectorSink::new();
1015
1016        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
1017            .unwrap();
1018        agg.finalize(&mut sink).unwrap();
1019
1020        let chunks = sink.into_chunks();
1021        assert_eq!(chunks.len(), 1);
1022        assert_eq!(
1023            chunks[0].column(0).unwrap().get_value(0),
1024            Some(Value::Int64(5))
1025        );
1026    }
1027
1028    #[test]
1029    #[cfg(feature = "spill")]
1030    fn test_spillable_aggregate_many_groups() {
1031        use tempfile::TempDir;
1032
1033        let temp_dir = TempDir::new().unwrap();
1034        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1035
1036        let mut agg = SpillableAggregatePushOperator::with_spilling(
1037            vec![0],
1038            vec![AggregateExpr::count_star()],
1039            manager,
1040            10, // Very low threshold
1041        );
1042        let mut sink = CollectorSink::new();
1043
1044        // Create 100 different groups
1045        for i in 0..100 {
1046            let chunk = create_test_chunk(&[i]);
1047            agg.push(chunk, &mut sink).unwrap();
1048        }
1049        agg.finalize(&mut sink).unwrap();
1050
1051        let chunks = sink.into_chunks();
1052        assert_eq!(chunks.len(), 1);
1053        assert_eq!(chunks[0].len(), 100); // 100 groups
1054
1055        // Each group should have count = 1
1056        for i in 0..100 {
1057            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1058                assert_eq!(count, 1);
1059            }
1060        }
1061    }
1062}