Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6#[cfg(feature = "spill")]
7use crate::execution::spill::{PartitionedState, SpillManager};
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::collections::HashMap;
11#[cfg(feature = "spill")]
12use std::io::{Read, Write};
13#[cfg(feature = "spill")]
14use std::sync::Arc;
15
16/// Aggregation function type.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AggregateFunction {
19    /// Count rows or non-null values.
20    Count,
21    /// Sum of values.
22    Sum,
23    /// Minimum value.
24    Min,
25    /// Maximum value.
26    Max,
27    /// Average value.
28    Avg,
29    /// First value in group.
30    First,
31}
32
33/// Aggregate expression.
34#[derive(Debug, Clone)]
35pub struct AggregateExpr {
36    /// The aggregate function.
37    pub function: AggregateFunction,
38    /// Column index to aggregate (None for COUNT(*)).
39    pub column: Option<usize>,
40    /// Whether DISTINCT applies.
41    pub distinct: bool,
42}
43
44impl AggregateExpr {
45    /// Create a COUNT(*) expression.
46    pub fn count_star() -> Self {
47        Self {
48            function: AggregateFunction::Count,
49            column: None,
50            distinct: false,
51        }
52    }
53
54    /// Create a COUNT(column) expression.
55    pub fn count(column: usize) -> Self {
56        Self {
57            function: AggregateFunction::Count,
58            column: Some(column),
59            distinct: false,
60        }
61    }
62
63    /// Create a SUM(column) expression.
64    pub fn sum(column: usize) -> Self {
65        Self {
66            function: AggregateFunction::Sum,
67            column: Some(column),
68            distinct: false,
69        }
70    }
71
72    /// Create a MIN(column) expression.
73    pub fn min(column: usize) -> Self {
74        Self {
75            function: AggregateFunction::Min,
76            column: Some(column),
77            distinct: false,
78        }
79    }
80
81    /// Create a MAX(column) expression.
82    pub fn max(column: usize) -> Self {
83        Self {
84            function: AggregateFunction::Max,
85            column: Some(column),
86            distinct: false,
87        }
88    }
89
90    /// Create an AVG(column) expression.
91    pub fn avg(column: usize) -> Self {
92        Self {
93            function: AggregateFunction::Avg,
94            column: Some(column),
95            distinct: false,
96        }
97    }
98}
99
100/// Accumulator for aggregate state.
101#[derive(Debug, Clone, Default)]
102struct Accumulator {
103    count: i64,
104    sum: f64,
105    min: Option<Value>,
106    max: Option<Value>,
107    first: Option<Value>,
108}
109
110impl Accumulator {
111    fn new() -> Self {
112        Self {
113            count: 0,
114            sum: 0.0,
115            min: None,
116            max: None,
117            first: None,
118        }
119    }
120
121    fn add(&mut self, value: &Value) {
122        // Skip nulls for aggregates
123        if matches!(value, Value::Null) {
124            return;
125        }
126
127        self.count += 1;
128
129        // Sum (for numeric types)
130        if let Some(n) = value_to_f64(value) {
131            self.sum += n;
132        }
133
134        // Min
135        if self.min.is_none() || compare_for_min(&self.min, value) {
136            self.min = Some(value.clone());
137        }
138
139        // Max
140        if self.max.is_none() || compare_for_max(&self.max, value) {
141            self.max = Some(value.clone());
142        }
143
144        // First
145        if self.first.is_none() {
146            self.first = Some(value.clone());
147        }
148    }
149
150    fn finalize(&self, func: AggregateFunction) -> Value {
151        match func {
152            AggregateFunction::Count => Value::Int64(self.count),
153            AggregateFunction::Sum => {
154                if self.count == 0 {
155                    Value::Null
156                } else {
157                    Value::Float64(self.sum)
158                }
159            }
160            AggregateFunction::Min => self.min.clone().unwrap_or(Value::Null),
161            AggregateFunction::Max => self.max.clone().unwrap_or(Value::Null),
162            AggregateFunction::Avg => {
163                if self.count == 0 {
164                    Value::Null
165                } else {
166                    Value::Float64(self.sum / self.count as f64)
167                }
168            }
169            AggregateFunction::First => self.first.clone().unwrap_or(Value::Null),
170        }
171    }
172}
173
174fn value_to_f64(value: &Value) -> Option<f64> {
175    match value {
176        Value::Int64(i) => Some(*i as f64),
177        Value::Float64(f) => Some(*f),
178        // RDF stores numeric literals as strings - try to parse
179        Value::String(s) => s.parse::<f64>().ok(),
180        _ => None,
181    }
182}
183
184fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
185    match (current, new) {
186        (None, _) => true,
187        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
188        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
189        (Some(Value::String(a)), Value::String(b)) => {
190            // Try numeric comparison for RDF values
191            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
192                b_num < a_num
193            } else {
194                b < a
195            }
196        }
197        // Cross-type comparisons for RDF
198        (Some(Value::String(a)), Value::Int64(b)) => {
199            if let Ok(a_num) = a.parse::<f64>() {
200                (*b as f64) < a_num
201            } else {
202                false
203            }
204        }
205        (Some(Value::Int64(a)), Value::String(b)) => {
206            if let Ok(b_num) = b.parse::<f64>() {
207                b_num < *a as f64
208            } else {
209                false
210            }
211        }
212        _ => false,
213    }
214}
215
216fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
217    match (current, new) {
218        (None, _) => true,
219        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
220        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
221        (Some(Value::String(a)), Value::String(b)) => {
222            // Try numeric comparison for RDF values
223            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
224                b_num > a_num
225            } else {
226                b > a
227            }
228        }
229        // Cross-type comparisons for RDF
230        (Some(Value::String(a)), Value::Int64(b)) => {
231            if let Ok(a_num) = a.parse::<f64>() {
232                (*b as f64) > a_num
233            } else {
234                false
235            }
236        }
237        (Some(Value::Int64(a)), Value::String(b)) => {
238            if let Ok(b_num) = b.parse::<f64>() {
239                b_num > *a as f64
240            } else {
241                false
242            }
243        }
244        _ => false,
245    }
246}
247
248/// Hash key for grouping.
249#[derive(Debug, Clone, PartialEq, Eq, Hash)]
250struct GroupKey(Vec<u64>);
251
252impl GroupKey {
253    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
254        let hashes: Vec<u64> = group_by
255            .iter()
256            .map(|&col| {
257                chunk
258                    .column(col)
259                    .and_then(|c| c.get_value(row))
260                    .map_or(0, |v| hash_value(&v))
261            })
262            .collect();
263        Self(hashes)
264    }
265}
266
267fn hash_value(value: &Value) -> u64 {
268    use std::collections::hash_map::DefaultHasher;
269    use std::hash::{Hash, Hasher};
270
271    let mut hasher = DefaultHasher::new();
272    match value {
273        Value::Null => 0u8.hash(&mut hasher),
274        Value::Bool(b) => b.hash(&mut hasher),
275        Value::Int64(i) => i.hash(&mut hasher),
276        Value::Float64(f) => f.to_bits().hash(&mut hasher),
277        Value::String(s) => s.hash(&mut hasher),
278        _ => 0u8.hash(&mut hasher),
279    }
280    hasher.finish()
281}
282
283/// Group state with key values and accumulators.
284#[derive(Clone)]
285struct GroupState {
286    key_values: Vec<Value>,
287    accumulators: Vec<Accumulator>,
288}
289
290/// Push-based aggregate operator.
291///
292/// This is a pipeline breaker that accumulates all input, groups by key,
293/// and produces aggregated output in the finalize phase.
294pub struct AggregatePushOperator {
295    /// Columns to group by.
296    group_by: Vec<usize>,
297    /// Aggregate expressions.
298    aggregates: Vec<AggregateExpr>,
299    /// Group states by hash key.
300    groups: HashMap<GroupKey, GroupState>,
301    /// Global accumulator (for no GROUP BY).
302    global_state: Option<Vec<Accumulator>>,
303}
304
305impl AggregatePushOperator {
306    /// Create a new aggregate operator.
307    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
308        let global_state = if group_by.is_empty() {
309            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
310        } else {
311            None
312        };
313
314        Self {
315            group_by,
316            aggregates,
317            groups: HashMap::new(),
318            global_state,
319        }
320    }
321
322    /// Create a simple global aggregate (no GROUP BY).
323    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
324        Self::new(Vec::new(), aggregates)
325    }
326}
327
328impl PushOperator for AggregatePushOperator {
329    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
330        if chunk.is_empty() {
331            return Ok(true);
332        }
333
334        for row in chunk.selected_indices() {
335            if self.group_by.is_empty() {
336                // Global aggregation
337                if let Some(ref mut accumulators) = self.global_state {
338                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
339                        if let Some(col) = expr.column {
340                            if let Some(c) = chunk.column(col)
341                                && let Some(val) = c.get_value(row)
342                            {
343                                acc.add(&val);
344                            }
345                        } else {
346                            // COUNT(*)
347                            acc.count += 1;
348                        }
349                    }
350                }
351            } else {
352                // Group by aggregation
353                let key = GroupKey::from_row(&chunk, row, &self.group_by);
354
355                let state = self.groups.entry(key).or_insert_with(|| {
356                    let key_values: Vec<Value> = self
357                        .group_by
358                        .iter()
359                        .map(|&col| {
360                            chunk
361                                .column(col)
362                                .and_then(|c| c.get_value(row))
363                                .unwrap_or(Value::Null)
364                        })
365                        .collect();
366
367                    GroupState {
368                        key_values,
369                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
370                    }
371                });
372
373                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
374                    if let Some(col) = expr.column {
375                        if let Some(c) = chunk.column(col)
376                            && let Some(val) = c.get_value(row)
377                        {
378                            acc.add(&val);
379                        }
380                    } else {
381                        // COUNT(*)
382                        acc.count += 1;
383                    }
384                }
385            }
386        }
387
388        Ok(true)
389    }
390
391    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
392        let num_output_cols = self.group_by.len() + self.aggregates.len();
393        let mut columns: Vec<ValueVector> =
394            (0..num_output_cols).map(|_| ValueVector::new()).collect();
395
396        if self.group_by.is_empty() {
397            // Global aggregation - single row output
398            if let Some(ref accumulators) = self.global_state {
399                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
400                    columns[i].push(acc.finalize(expr.function));
401                }
402            }
403        } else {
404            // Group by - one row per group
405            for state in self.groups.values() {
406                // Output group key columns
407                for (i, val) in state.key_values.iter().enumerate() {
408                    columns[i].push(val.clone());
409                }
410
411                // Output aggregate results
412                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
413                {
414                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
415                }
416            }
417        }
418
419        if !columns.is_empty() && !columns[0].is_empty() {
420            let chunk = DataChunk::new(columns);
421            sink.consume(chunk)?;
422        }
423
424        Ok(())
425    }
426
427    fn preferred_chunk_size(&self) -> ChunkSizeHint {
428        ChunkSizeHint::Default
429    }
430
431    fn name(&self) -> &'static str {
432        "AggregatePush"
433    }
434}
435
436/// Default spill threshold for aggregates (number of groups).
437#[cfg(feature = "spill")]
438pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
439
440/// Serializes a GroupState to bytes.
441#[cfg(feature = "spill")]
442fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
443    use crate::execution::spill::serialize_value;
444
445    // Write key values
446    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
447    for val in &state.key_values {
448        serialize_value(val, w)?;
449    }
450
451    // Write accumulators
452    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
453    for acc in &state.accumulators {
454        w.write_all(&acc.count.to_le_bytes())?;
455        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
456
457        // Min
458        let has_min = acc.min.is_some();
459        w.write_all(&[has_min as u8])?;
460        if let Some(ref v) = acc.min {
461            serialize_value(v, w)?;
462        }
463
464        // Max
465        let has_max = acc.max.is_some();
466        w.write_all(&[has_max as u8])?;
467        if let Some(ref v) = acc.max {
468            serialize_value(v, w)?;
469        }
470
471        // First
472        let has_first = acc.first.is_some();
473        w.write_all(&[has_first as u8])?;
474        if let Some(ref v) = acc.first {
475            serialize_value(v, w)?;
476        }
477    }
478
479    Ok(())
480}
481
482/// Deserializes a GroupState from bytes.
483#[cfg(feature = "spill")]
484fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
485    use crate::execution::spill::deserialize_value;
486
487    // Read key values
488    let mut len_buf = [0u8; 8];
489    r.read_exact(&mut len_buf)?;
490    let num_keys = u64::from_le_bytes(len_buf) as usize;
491
492    let mut key_values = Vec::with_capacity(num_keys);
493    for _ in 0..num_keys {
494        key_values.push(deserialize_value(r)?);
495    }
496
497    // Read accumulators
498    r.read_exact(&mut len_buf)?;
499    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
500
501    let mut accumulators = Vec::with_capacity(num_accumulators);
502    for _ in 0..num_accumulators {
503        let mut count_buf = [0u8; 8];
504        r.read_exact(&mut count_buf)?;
505        let count = i64::from_le_bytes(count_buf);
506
507        r.read_exact(&mut count_buf)?;
508        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
509
510        // Min
511        let mut flag_buf = [0u8; 1];
512        r.read_exact(&mut flag_buf)?;
513        let min = if flag_buf[0] != 0 {
514            Some(deserialize_value(r)?)
515        } else {
516            None
517        };
518
519        // Max
520        r.read_exact(&mut flag_buf)?;
521        let max = if flag_buf[0] != 0 {
522            Some(deserialize_value(r)?)
523        } else {
524            None
525        };
526
527        // First
528        r.read_exact(&mut flag_buf)?;
529        let first = if flag_buf[0] != 0 {
530            Some(deserialize_value(r)?)
531        } else {
532            None
533        };
534
535        accumulators.push(Accumulator {
536            count,
537            sum,
538            min,
539            max,
540            first,
541        });
542    }
543
544    Ok(GroupState {
545        key_values,
546        accumulators,
547    })
548}
549
550/// Push-based aggregate operator with spilling support.
551///
552/// Uses partitioned hash table that can spill cold partitions to disk
553/// when memory pressure is high.
554#[cfg(feature = "spill")]
555pub struct SpillableAggregatePushOperator {
556    /// Columns to group by.
557    group_by: Vec<usize>,
558    /// Aggregate expressions.
559    aggregates: Vec<AggregateExpr>,
560    /// Spill manager (None = no spilling).
561    spill_manager: Option<Arc<SpillManager>>,
562    /// Partitioned groups (used when spilling is enabled).
563    partitioned_groups: Option<PartitionedState<GroupState>>,
564    /// Non-partitioned groups (used when spilling is disabled).
565    groups: HashMap<GroupKey, GroupState>,
566    /// Global accumulator (for no GROUP BY).
567    global_state: Option<Vec<Accumulator>>,
568    /// Spill threshold (number of groups).
569    spill_threshold: usize,
570    /// Whether we've switched to partitioned mode.
571    using_partitioned: bool,
572}
573
574#[cfg(feature = "spill")]
575impl SpillableAggregatePushOperator {
576    /// Create a new spillable aggregate operator.
577    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
578        let global_state = if group_by.is_empty() {
579            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
580        } else {
581            None
582        };
583
584        Self {
585            group_by,
586            aggregates,
587            spill_manager: None,
588            partitioned_groups: None,
589            groups: HashMap::new(),
590            global_state,
591            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
592            using_partitioned: false,
593        }
594    }
595
596    /// Create a spillable aggregate operator with spilling enabled.
597    pub fn with_spilling(
598        group_by: Vec<usize>,
599        aggregates: Vec<AggregateExpr>,
600        manager: Arc<SpillManager>,
601        threshold: usize,
602    ) -> Self {
603        let global_state = if group_by.is_empty() {
604            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
605        } else {
606            None
607        };
608
609        let partitioned = PartitionedState::new(
610            Arc::clone(&manager),
611            256, // Number of partitions
612            serialize_group_state,
613            deserialize_group_state,
614        );
615
616        Self {
617            group_by,
618            aggregates,
619            spill_manager: Some(manager),
620            partitioned_groups: Some(partitioned),
621            groups: HashMap::new(),
622            global_state,
623            spill_threshold: threshold,
624            using_partitioned: true,
625        }
626    }
627
628    /// Create a simple global aggregate (no GROUP BY).
629    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
630        Self::new(Vec::new(), aggregates)
631    }
632
633    /// Sets the spill threshold.
634    pub fn with_threshold(mut self, threshold: usize) -> Self {
635        self.spill_threshold = threshold;
636        self
637    }
638
639    /// Switches to partitioned mode if needed.
640    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
641        if self.global_state.is_some() {
642            // Global aggregation doesn't need spilling
643            return Ok(());
644        }
645
646        // If using partitioned state, check if we need to spill
647        if let Some(ref mut partitioned) = self.partitioned_groups {
648            if partitioned.total_size() >= self.spill_threshold {
649                partitioned
650                    .spill_largest()
651                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
652            }
653        } else if self.groups.len() >= self.spill_threshold {
654            // Not using partitioned state yet, but reached threshold
655            // If spilling is configured, switch to partitioned mode
656            if let Some(ref manager) = self.spill_manager {
657                let mut partitioned = PartitionedState::new(
658                    Arc::clone(manager),
659                    256,
660                    serialize_group_state,
661                    deserialize_group_state,
662                );
663
664                // Move existing groups to partitioned state
665                for (_key, state) in self.groups.drain() {
666                    partitioned
667                        .insert(state.key_values.clone(), state)
668                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
669                }
670
671                self.partitioned_groups = Some(partitioned);
672                self.using_partitioned = true;
673            }
674        }
675
676        Ok(())
677    }
678}
679
680#[cfg(feature = "spill")]
681impl PushOperator for SpillableAggregatePushOperator {
682    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
683        if chunk.is_empty() {
684            return Ok(true);
685        }
686
687        for row in chunk.selected_indices() {
688            if self.group_by.is_empty() {
689                // Global aggregation - same as non-spillable
690                if let Some(ref mut accumulators) = self.global_state {
691                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
692                        if let Some(col) = expr.column {
693                            if let Some(c) = chunk.column(col)
694                                && let Some(val) = c.get_value(row)
695                            {
696                                acc.add(&val);
697                            }
698                        } else {
699                            acc.count += 1;
700                        }
701                    }
702                }
703            } else if self.using_partitioned {
704                // Use partitioned state
705                if let Some(ref mut partitioned) = self.partitioned_groups {
706                    let key_values: Vec<Value> = self
707                        .group_by
708                        .iter()
709                        .map(|&col| {
710                            chunk
711                                .column(col)
712                                .and_then(|c| c.get_value(row))
713                                .unwrap_or(Value::Null)
714                        })
715                        .collect();
716
717                    let aggregates = &self.aggregates;
718                    let state = partitioned
719                        .get_or_insert_with(key_values.clone(), || GroupState {
720                            key_values: key_values.clone(),
721                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
722                        })
723                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
724
725                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
726                        if let Some(col) = expr.column {
727                            if let Some(c) = chunk.column(col)
728                                && let Some(val) = c.get_value(row)
729                            {
730                                acc.add(&val);
731                            }
732                        } else {
733                            acc.count += 1;
734                        }
735                    }
736                }
737            } else {
738                // Use regular hash map
739                let key = GroupKey::from_row(&chunk, row, &self.group_by);
740
741                let state = self.groups.entry(key).or_insert_with(|| {
742                    let key_values: Vec<Value> = self
743                        .group_by
744                        .iter()
745                        .map(|&col| {
746                            chunk
747                                .column(col)
748                                .and_then(|c| c.get_value(row))
749                                .unwrap_or(Value::Null)
750                        })
751                        .collect();
752
753                    GroupState {
754                        key_values,
755                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
756                    }
757                });
758
759                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
760                    if let Some(col) = expr.column {
761                        if let Some(c) = chunk.column(col)
762                            && let Some(val) = c.get_value(row)
763                        {
764                            acc.add(&val);
765                        }
766                    } else {
767                        acc.count += 1;
768                    }
769                }
770            }
771        }
772
773        // Check if we need to spill
774        self.maybe_spill()?;
775
776        Ok(true)
777    }
778
779    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
780        let num_output_cols = self.group_by.len() + self.aggregates.len();
781        let mut columns: Vec<ValueVector> =
782            (0..num_output_cols).map(|_| ValueVector::new()).collect();
783
784        if self.group_by.is_empty() {
785            // Global aggregation - single row output
786            if let Some(ref accumulators) = self.global_state {
787                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
788                    columns[i].push(acc.finalize(expr.function));
789                }
790            }
791        } else if self.using_partitioned {
792            // Drain partitioned state
793            if let Some(ref mut partitioned) = self.partitioned_groups {
794                let groups = partitioned
795                    .drain_all()
796                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
797
798                for (_key, state) in groups {
799                    // Output group key columns
800                    for (i, val) in state.key_values.iter().enumerate() {
801                        columns[i].push(val.clone());
802                    }
803
804                    // Output aggregate results
805                    for (i, (acc, expr)) in
806                        state.accumulators.iter().zip(&self.aggregates).enumerate()
807                    {
808                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
809                    }
810                }
811            }
812        } else {
813            // Group by using regular hash map - one row per group
814            for state in self.groups.values() {
815                // Output group key columns
816                for (i, val) in state.key_values.iter().enumerate() {
817                    columns[i].push(val.clone());
818                }
819
820                // Output aggregate results
821                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
822                {
823                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
824                }
825            }
826        }
827
828        if !columns.is_empty() && !columns[0].is_empty() {
829            let chunk = DataChunk::new(columns);
830            sink.consume(chunk)?;
831        }
832
833        Ok(())
834    }
835
836    fn preferred_chunk_size(&self) -> ChunkSizeHint {
837        ChunkSizeHint::Default
838    }
839
840    fn name(&self) -> &'static str {
841        "SpillableAggregatePush"
842    }
843}
844
845#[cfg(test)]
846mod tests {
847    use super::*;
848    use crate::execution::sink::CollectorSink;
849
850    fn create_test_chunk(values: &[i64]) -> DataChunk {
851        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
852        let vector = ValueVector::from_values(&v);
853        DataChunk::new(vec![vector])
854    }
855
856    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
857        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
858        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
859        DataChunk::new(vec![
860            ValueVector::from_values(&v1),
861            ValueVector::from_values(&v2),
862        ])
863    }
864
865    #[test]
866    fn test_global_count() {
867        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
868        let mut sink = CollectorSink::new();
869
870        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
871            .unwrap();
872        agg.finalize(&mut sink).unwrap();
873
874        let chunks = sink.into_chunks();
875        assert_eq!(chunks.len(), 1);
876        assert_eq!(
877            chunks[0].column(0).unwrap().get_value(0),
878            Some(Value::Int64(5))
879        );
880    }
881
882    #[test]
883    fn test_global_sum() {
884        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
885        let mut sink = CollectorSink::new();
886
887        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
888            .unwrap();
889        agg.finalize(&mut sink).unwrap();
890
891        let chunks = sink.into_chunks();
892        assert_eq!(
893            chunks[0].column(0).unwrap().get_value(0),
894            Some(Value::Float64(15.0))
895        );
896    }
897
898    #[test]
899    fn test_global_min_max() {
900        let mut agg =
901            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
902        let mut sink = CollectorSink::new();
903
904        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
905            .unwrap();
906        agg.finalize(&mut sink).unwrap();
907
908        let chunks = sink.into_chunks();
909        assert_eq!(
910            chunks[0].column(0).unwrap().get_value(0),
911            Some(Value::Int64(1))
912        );
913        assert_eq!(
914            chunks[0].column(1).unwrap().get_value(0),
915            Some(Value::Int64(9))
916        );
917    }
918
919    #[test]
920    fn test_group_by_sum() {
921        // Group by column 0, sum column 1
922        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
923        let mut sink = CollectorSink::new();
924
925        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
926        agg.push(
927            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
928            &mut sink,
929        )
930        .unwrap();
931        agg.finalize(&mut sink).unwrap();
932
933        let chunks = sink.into_chunks();
934        assert_eq!(chunks[0].len(), 2); // 2 groups
935    }
936
937    #[test]
938    #[cfg(feature = "spill")]
939    fn test_spillable_aggregate_no_spill() {
940        // When threshold is not reached, should work like normal aggregate
941        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
942            .with_threshold(100);
943        let mut sink = CollectorSink::new();
944
945        agg.push(
946            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
947            &mut sink,
948        )
949        .unwrap();
950        agg.finalize(&mut sink).unwrap();
951
952        let chunks = sink.into_chunks();
953        assert_eq!(chunks[0].len(), 2); // 2 groups
954    }
955
956    #[test]
957    #[cfg(feature = "spill")]
958    fn test_spillable_aggregate_with_spilling() {
959        use tempfile::TempDir;
960
961        let temp_dir = TempDir::new().unwrap();
962        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
963
964        // Set very low threshold to force spilling
965        let mut agg = SpillableAggregatePushOperator::with_spilling(
966            vec![0],
967            vec![AggregateExpr::sum(1)],
968            manager,
969            3, // Spill after 3 groups
970        );
971        let mut sink = CollectorSink::new();
972
973        // Create 10 different groups
974        for i in 0..10 {
975            let chunk = create_two_column_chunk(&[i], &[i * 10]);
976            agg.push(chunk, &mut sink).unwrap();
977        }
978        agg.finalize(&mut sink).unwrap();
979
980        let chunks = sink.into_chunks();
981        assert_eq!(chunks.len(), 1);
982        assert_eq!(chunks[0].len(), 10); // 10 groups
983
984        // Verify sums are correct
985        let mut sums: Vec<f64> = Vec::new();
986        for i in 0..chunks[0].len() {
987            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
988                sums.push(sum);
989            }
990        }
991        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
992        assert_eq!(
993            sums,
994            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
995        );
996    }
997
998    #[test]
999    #[cfg(feature = "spill")]
1000    fn test_spillable_aggregate_global() {
1001        // Global aggregation shouldn't be affected by spilling
1002        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1003        let mut sink = CollectorSink::new();
1004
1005        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
1006            .unwrap();
1007        agg.finalize(&mut sink).unwrap();
1008
1009        let chunks = sink.into_chunks();
1010        assert_eq!(chunks.len(), 1);
1011        assert_eq!(
1012            chunks[0].column(0).unwrap().get_value(0),
1013            Some(Value::Int64(5))
1014        );
1015    }
1016
1017    #[test]
1018    #[cfg(feature = "spill")]
1019    fn test_spillable_aggregate_many_groups() {
1020        use tempfile::TempDir;
1021
1022        let temp_dir = TempDir::new().unwrap();
1023        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1024
1025        let mut agg = SpillableAggregatePushOperator::with_spilling(
1026            vec![0],
1027            vec![AggregateExpr::count_star()],
1028            manager,
1029            10, // Very low threshold
1030        );
1031        let mut sink = CollectorSink::new();
1032
1033        // Create 100 different groups
1034        for i in 0..100 {
1035            let chunk = create_test_chunk(&[i]);
1036            agg.push(chunk, &mut sink).unwrap();
1037        }
1038        agg.finalize(&mut sink).unwrap();
1039
1040        let chunks = sink.into_chunks();
1041        assert_eq!(chunks.len(), 1);
1042        assert_eq!(chunks[0].len(), 100); // 100 groups
1043
1044        // Each group should have count = 1
1045        for i in 0..100 {
1046            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1047                assert_eq!(count, 1);
1048            }
1049        }
1050    }
1051}