Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6use crate::execution::spill::{PartitionedState, SpillManager};
7use crate::execution::vector::ValueVector;
8use grafeo_common::types::Value;
9use std::collections::HashMap;
10use std::io::{Read, Write};
11use std::sync::Arc;
12
13/// Aggregation function type.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum AggregateFunction {
16    /// Count rows or non-null values.
17    Count,
18    /// Sum of values.
19    Sum,
20    /// Minimum value.
21    Min,
22    /// Maximum value.
23    Max,
24    /// Average value.
25    Avg,
26    /// First value in group.
27    First,
28}
29
30/// Aggregate expression.
31#[derive(Debug, Clone)]
32pub struct AggregateExpr {
33    /// The aggregate function.
34    pub function: AggregateFunction,
35    /// Column index to aggregate (None for COUNT(*)).
36    pub column: Option<usize>,
37    /// Whether DISTINCT applies.
38    pub distinct: bool,
39}
40
41impl AggregateExpr {
42    /// Create a COUNT(*) expression.
43    pub fn count_star() -> Self {
44        Self {
45            function: AggregateFunction::Count,
46            column: None,
47            distinct: false,
48        }
49    }
50
51    /// Create a COUNT(column) expression.
52    pub fn count(column: usize) -> Self {
53        Self {
54            function: AggregateFunction::Count,
55            column: Some(column),
56            distinct: false,
57        }
58    }
59
60    /// Create a SUM(column) expression.
61    pub fn sum(column: usize) -> Self {
62        Self {
63            function: AggregateFunction::Sum,
64            column: Some(column),
65            distinct: false,
66        }
67    }
68
69    /// Create a MIN(column) expression.
70    pub fn min(column: usize) -> Self {
71        Self {
72            function: AggregateFunction::Min,
73            column: Some(column),
74            distinct: false,
75        }
76    }
77
78    /// Create a MAX(column) expression.
79    pub fn max(column: usize) -> Self {
80        Self {
81            function: AggregateFunction::Max,
82            column: Some(column),
83            distinct: false,
84        }
85    }
86
87    /// Create an AVG(column) expression.
88    pub fn avg(column: usize) -> Self {
89        Self {
90            function: AggregateFunction::Avg,
91            column: Some(column),
92            distinct: false,
93        }
94    }
95}
96
97/// Accumulator for aggregate state.
98#[derive(Debug, Clone, Default)]
99struct Accumulator {
100    count: i64,
101    sum: f64,
102    min: Option<Value>,
103    max: Option<Value>,
104    first: Option<Value>,
105}
106
107impl Accumulator {
108    fn new() -> Self {
109        Self {
110            count: 0,
111            sum: 0.0,
112            min: None,
113            max: None,
114            first: None,
115        }
116    }
117
118    fn add(&mut self, value: &Value) {
119        // Skip nulls for aggregates
120        if matches!(value, Value::Null) {
121            return;
122        }
123
124        self.count += 1;
125
126        // Sum (for numeric types)
127        if let Some(n) = value_to_f64(value) {
128            self.sum += n;
129        }
130
131        // Min
132        if self.min.is_none() || compare_for_min(&self.min, value) {
133            self.min = Some(value.clone());
134        }
135
136        // Max
137        if self.max.is_none() || compare_for_max(&self.max, value) {
138            self.max = Some(value.clone());
139        }
140
141        // First
142        if self.first.is_none() {
143            self.first = Some(value.clone());
144        }
145    }
146
147    fn finalize(&self, func: AggregateFunction) -> Value {
148        match func {
149            AggregateFunction::Count => Value::Int64(self.count),
150            AggregateFunction::Sum => {
151                if self.count == 0 {
152                    Value::Null
153                } else {
154                    Value::Float64(self.sum)
155                }
156            }
157            AggregateFunction::Min => self.min.clone().unwrap_or(Value::Null),
158            AggregateFunction::Max => self.max.clone().unwrap_or(Value::Null),
159            AggregateFunction::Avg => {
160                if self.count == 0 {
161                    Value::Null
162                } else {
163                    Value::Float64(self.sum / self.count as f64)
164                }
165            }
166            AggregateFunction::First => self.first.clone().unwrap_or(Value::Null),
167        }
168    }
169}
170
171fn value_to_f64(value: &Value) -> Option<f64> {
172    match value {
173        Value::Int64(i) => Some(*i as f64),
174        Value::Float64(f) => Some(*f),
175        // RDF stores numeric literals as strings - try to parse
176        Value::String(s) => s.parse::<f64>().ok(),
177        _ => None,
178    }
179}
180
181fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
182    match (current, new) {
183        (None, _) => true,
184        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
185        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
186        (Some(Value::String(a)), Value::String(b)) => {
187            // Try numeric comparison for RDF values
188            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
189                b_num < a_num
190            } else {
191                b < a
192            }
193        }
194        // Cross-type comparisons for RDF
195        (Some(Value::String(a)), Value::Int64(b)) => {
196            if let Ok(a_num) = a.parse::<f64>() {
197                (*b as f64) < a_num
198            } else {
199                false
200            }
201        }
202        (Some(Value::Int64(a)), Value::String(b)) => {
203            if let Ok(b_num) = b.parse::<f64>() {
204                b_num < *a as f64
205            } else {
206                false
207            }
208        }
209        _ => false,
210    }
211}
212
213fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
214    match (current, new) {
215        (None, _) => true,
216        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
217        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
218        (Some(Value::String(a)), Value::String(b)) => {
219            // Try numeric comparison for RDF values
220            if let (Ok(a_num), Ok(b_num)) = (a.parse::<f64>(), b.parse::<f64>()) {
221                b_num > a_num
222            } else {
223                b > a
224            }
225        }
226        // Cross-type comparisons for RDF
227        (Some(Value::String(a)), Value::Int64(b)) => {
228            if let Ok(a_num) = a.parse::<f64>() {
229                (*b as f64) > a_num
230            } else {
231                false
232            }
233        }
234        (Some(Value::Int64(a)), Value::String(b)) => {
235            if let Ok(b_num) = b.parse::<f64>() {
236                b_num > *a as f64
237            } else {
238                false
239            }
240        }
241        _ => false,
242    }
243}
244
245/// Hash key for grouping.
246#[derive(Debug, Clone, PartialEq, Eq, Hash)]
247struct GroupKey(Vec<u64>);
248
249impl GroupKey {
250    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
251        let hashes: Vec<u64> = group_by
252            .iter()
253            .map(|&col| {
254                chunk
255                    .column(col)
256                    .and_then(|c| c.get_value(row))
257                    .map(|v| hash_value(&v))
258                    .unwrap_or(0)
259            })
260            .collect();
261        Self(hashes)
262    }
263}
264
265fn hash_value(value: &Value) -> u64 {
266    use std::collections::hash_map::DefaultHasher;
267    use std::hash::{Hash, Hasher};
268
269    let mut hasher = DefaultHasher::new();
270    match value {
271        Value::Null => 0u8.hash(&mut hasher),
272        Value::Bool(b) => b.hash(&mut hasher),
273        Value::Int64(i) => i.hash(&mut hasher),
274        Value::Float64(f) => f.to_bits().hash(&mut hasher),
275        Value::String(s) => s.hash(&mut hasher),
276        _ => 0u8.hash(&mut hasher),
277    }
278    hasher.finish()
279}
280
281/// Group state with key values and accumulators.
282#[derive(Clone)]
283struct GroupState {
284    key_values: Vec<Value>,
285    accumulators: Vec<Accumulator>,
286}
287
288/// Push-based aggregate operator.
289///
290/// This is a pipeline breaker that accumulates all input, groups by key,
291/// and produces aggregated output in the finalize phase.
292pub struct AggregatePushOperator {
293    /// Columns to group by.
294    group_by: Vec<usize>,
295    /// Aggregate expressions.
296    aggregates: Vec<AggregateExpr>,
297    /// Group states by hash key.
298    groups: HashMap<GroupKey, GroupState>,
299    /// Global accumulator (for no GROUP BY).
300    global_state: Option<Vec<Accumulator>>,
301}
302
303impl AggregatePushOperator {
304    /// Create a new aggregate operator.
305    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
306        let global_state = if group_by.is_empty() {
307            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
308        } else {
309            None
310        };
311
312        Self {
313            group_by,
314            aggregates,
315            groups: HashMap::new(),
316            global_state,
317        }
318    }
319
320    /// Create a simple global aggregate (no GROUP BY).
321    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
322        Self::new(Vec::new(), aggregates)
323    }
324}
325
326impl PushOperator for AggregatePushOperator {
327    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
328        if chunk.is_empty() {
329            return Ok(true);
330        }
331
332        for row in chunk.selected_indices() {
333            if self.group_by.is_empty() {
334                // Global aggregation
335                if let Some(ref mut accumulators) = self.global_state {
336                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
337                        if let Some(col) = expr.column {
338                            if let Some(c) = chunk.column(col) {
339                                if let Some(val) = c.get_value(row) {
340                                    acc.add(&val);
341                                }
342                            }
343                        } else {
344                            // COUNT(*)
345                            acc.count += 1;
346                        }
347                    }
348                }
349            } else {
350                // Group by aggregation
351                let key = GroupKey::from_row(&chunk, row, &self.group_by);
352
353                let state = self.groups.entry(key).or_insert_with(|| {
354                    let key_values: Vec<Value> = self
355                        .group_by
356                        .iter()
357                        .map(|&col| {
358                            chunk
359                                .column(col)
360                                .and_then(|c| c.get_value(row))
361                                .unwrap_or(Value::Null)
362                        })
363                        .collect();
364
365                    GroupState {
366                        key_values,
367                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
368                    }
369                });
370
371                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
372                    if let Some(col) = expr.column {
373                        if let Some(c) = chunk.column(col) {
374                            if let Some(val) = c.get_value(row) {
375                                acc.add(&val);
376                            }
377                        }
378                    } else {
379                        // COUNT(*)
380                        acc.count += 1;
381                    }
382                }
383            }
384        }
385
386        Ok(true)
387    }
388
389    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
390        let num_output_cols = self.group_by.len() + self.aggregates.len();
391        let mut columns: Vec<ValueVector> =
392            (0..num_output_cols).map(|_| ValueVector::new()).collect();
393
394        if self.group_by.is_empty() {
395            // Global aggregation - single row output
396            if let Some(ref accumulators) = self.global_state {
397                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
398                    columns[i].push(acc.finalize(expr.function));
399                }
400            }
401        } else {
402            // Group by - one row per group
403            for state in self.groups.values() {
404                // Output group key columns
405                for (i, val) in state.key_values.iter().enumerate() {
406                    columns[i].push(val.clone());
407                }
408
409                // Output aggregate results
410                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
411                {
412                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
413                }
414            }
415        }
416
417        if !columns.is_empty() && !columns[0].is_empty() {
418            let chunk = DataChunk::new(columns);
419            sink.consume(chunk)?;
420        }
421
422        Ok(())
423    }
424
425    fn preferred_chunk_size(&self) -> ChunkSizeHint {
426        ChunkSizeHint::Default
427    }
428
429    fn name(&self) -> &'static str {
430        "AggregatePush"
431    }
432}
433
434/// Default spill threshold for aggregates (number of groups).
435pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
436
437/// Serializes a GroupState to bytes.
438fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
439    use crate::execution::spill::serialize_value;
440
441    // Write key values
442    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
443    for val in &state.key_values {
444        serialize_value(val, w)?;
445    }
446
447    // Write accumulators
448    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
449    for acc in &state.accumulators {
450        w.write_all(&acc.count.to_le_bytes())?;
451        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
452
453        // Min
454        let has_min = acc.min.is_some();
455        w.write_all(&[has_min as u8])?;
456        if let Some(ref v) = acc.min {
457            serialize_value(v, w)?;
458        }
459
460        // Max
461        let has_max = acc.max.is_some();
462        w.write_all(&[has_max as u8])?;
463        if let Some(ref v) = acc.max {
464            serialize_value(v, w)?;
465        }
466
467        // First
468        let has_first = acc.first.is_some();
469        w.write_all(&[has_first as u8])?;
470        if let Some(ref v) = acc.first {
471            serialize_value(v, w)?;
472        }
473    }
474
475    Ok(())
476}
477
478/// Deserializes a GroupState from bytes.
479fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
480    use crate::execution::spill::deserialize_value;
481
482    // Read key values
483    let mut len_buf = [0u8; 8];
484    r.read_exact(&mut len_buf)?;
485    let num_keys = u64::from_le_bytes(len_buf) as usize;
486
487    let mut key_values = Vec::with_capacity(num_keys);
488    for _ in 0..num_keys {
489        key_values.push(deserialize_value(r)?);
490    }
491
492    // Read accumulators
493    r.read_exact(&mut len_buf)?;
494    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
495
496    let mut accumulators = Vec::with_capacity(num_accumulators);
497    for _ in 0..num_accumulators {
498        let mut count_buf = [0u8; 8];
499        r.read_exact(&mut count_buf)?;
500        let count = i64::from_le_bytes(count_buf);
501
502        r.read_exact(&mut count_buf)?;
503        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
504
505        // Min
506        let mut flag_buf = [0u8; 1];
507        r.read_exact(&mut flag_buf)?;
508        let min = if flag_buf[0] != 0 {
509            Some(deserialize_value(r)?)
510        } else {
511            None
512        };
513
514        // Max
515        r.read_exact(&mut flag_buf)?;
516        let max = if flag_buf[0] != 0 {
517            Some(deserialize_value(r)?)
518        } else {
519            None
520        };
521
522        // First
523        r.read_exact(&mut flag_buf)?;
524        let first = if flag_buf[0] != 0 {
525            Some(deserialize_value(r)?)
526        } else {
527            None
528        };
529
530        accumulators.push(Accumulator {
531            count,
532            sum,
533            min,
534            max,
535            first,
536        });
537    }
538
539    Ok(GroupState {
540        key_values,
541        accumulators,
542    })
543}
544
545/// Push-based aggregate operator with spilling support.
546///
547/// Uses partitioned hash table that can spill cold partitions to disk
548/// when memory pressure is high.
549pub struct SpillableAggregatePushOperator {
550    /// Columns to group by.
551    group_by: Vec<usize>,
552    /// Aggregate expressions.
553    aggregates: Vec<AggregateExpr>,
554    /// Spill manager (None = no spilling).
555    spill_manager: Option<Arc<SpillManager>>,
556    /// Partitioned groups (used when spilling is enabled).
557    partitioned_groups: Option<PartitionedState<GroupState>>,
558    /// Non-partitioned groups (used when spilling is disabled).
559    groups: HashMap<GroupKey, GroupState>,
560    /// Global accumulator (for no GROUP BY).
561    global_state: Option<Vec<Accumulator>>,
562    /// Spill threshold (number of groups).
563    spill_threshold: usize,
564    /// Whether we've switched to partitioned mode.
565    using_partitioned: bool,
566}
567
568impl SpillableAggregatePushOperator {
569    /// Create a new spillable aggregate operator.
570    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
571        let global_state = if group_by.is_empty() {
572            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
573        } else {
574            None
575        };
576
577        Self {
578            group_by,
579            aggregates,
580            spill_manager: None,
581            partitioned_groups: None,
582            groups: HashMap::new(),
583            global_state,
584            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
585            using_partitioned: false,
586        }
587    }
588
589    /// Create a spillable aggregate operator with spilling enabled.
590    pub fn with_spilling(
591        group_by: Vec<usize>,
592        aggregates: Vec<AggregateExpr>,
593        manager: Arc<SpillManager>,
594        threshold: usize,
595    ) -> Self {
596        let global_state = if group_by.is_empty() {
597            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
598        } else {
599            None
600        };
601
602        let partitioned = PartitionedState::new(
603            Arc::clone(&manager),
604            256, // Number of partitions
605            serialize_group_state,
606            deserialize_group_state,
607        );
608
609        Self {
610            group_by,
611            aggregates,
612            spill_manager: Some(manager),
613            partitioned_groups: Some(partitioned),
614            groups: HashMap::new(),
615            global_state,
616            spill_threshold: threshold,
617            using_partitioned: true,
618        }
619    }
620
621    /// Create a simple global aggregate (no GROUP BY).
622    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
623        Self::new(Vec::new(), aggregates)
624    }
625
626    /// Sets the spill threshold.
627    pub fn with_threshold(mut self, threshold: usize) -> Self {
628        self.spill_threshold = threshold;
629        self
630    }
631
632    /// Switches to partitioned mode if needed.
633    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
634        if self.global_state.is_some() {
635            // Global aggregation doesn't need spilling
636            return Ok(());
637        }
638
639        // If using partitioned state, check if we need to spill
640        if let Some(ref mut partitioned) = self.partitioned_groups {
641            if partitioned.total_size() >= self.spill_threshold {
642                partitioned
643                    .spill_largest()
644                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
645            }
646        } else if self.groups.len() >= self.spill_threshold {
647            // Not using partitioned state yet, but reached threshold
648            // If spilling is configured, switch to partitioned mode
649            if let Some(ref manager) = self.spill_manager {
650                let mut partitioned = PartitionedState::new(
651                    Arc::clone(manager),
652                    256,
653                    serialize_group_state,
654                    deserialize_group_state,
655                );
656
657                // Move existing groups to partitioned state
658                for (_key, state) in self.groups.drain() {
659                    partitioned
660                        .insert(state.key_values.clone(), state)
661                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
662                }
663
664                self.partitioned_groups = Some(partitioned);
665                self.using_partitioned = true;
666            }
667        }
668
669        Ok(())
670    }
671}
672
673impl PushOperator for SpillableAggregatePushOperator {
674    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
675        if chunk.is_empty() {
676            return Ok(true);
677        }
678
679        for row in chunk.selected_indices() {
680            if self.group_by.is_empty() {
681                // Global aggregation - same as non-spillable
682                if let Some(ref mut accumulators) = self.global_state {
683                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
684                        if let Some(col) = expr.column {
685                            if let Some(c) = chunk.column(col) {
686                                if let Some(val) = c.get_value(row) {
687                                    acc.add(&val);
688                                }
689                            }
690                        } else {
691                            acc.count += 1;
692                        }
693                    }
694                }
695            } else if self.using_partitioned {
696                // Use partitioned state
697                if let Some(ref mut partitioned) = self.partitioned_groups {
698                    let key_values: Vec<Value> = self
699                        .group_by
700                        .iter()
701                        .map(|&col| {
702                            chunk
703                                .column(col)
704                                .and_then(|c| c.get_value(row))
705                                .unwrap_or(Value::Null)
706                        })
707                        .collect();
708
709                    let aggregates = &self.aggregates;
710                    let state = partitioned
711                        .get_or_insert_with(key_values.clone(), || GroupState {
712                            key_values: key_values.clone(),
713                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
714                        })
715                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
716
717                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
718                        if let Some(col) = expr.column {
719                            if let Some(c) = chunk.column(col) {
720                                if let Some(val) = c.get_value(row) {
721                                    acc.add(&val);
722                                }
723                            }
724                        } else {
725                            acc.count += 1;
726                        }
727                    }
728                }
729            } else {
730                // Use regular hash map
731                let key = GroupKey::from_row(&chunk, row, &self.group_by);
732
733                let state = self.groups.entry(key).or_insert_with(|| {
734                    let key_values: Vec<Value> = self
735                        .group_by
736                        .iter()
737                        .map(|&col| {
738                            chunk
739                                .column(col)
740                                .and_then(|c| c.get_value(row))
741                                .unwrap_or(Value::Null)
742                        })
743                        .collect();
744
745                    GroupState {
746                        key_values,
747                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
748                    }
749                });
750
751                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
752                    if let Some(col) = expr.column {
753                        if let Some(c) = chunk.column(col) {
754                            if let Some(val) = c.get_value(row) {
755                                acc.add(&val);
756                            }
757                        }
758                    } else {
759                        acc.count += 1;
760                    }
761                }
762            }
763        }
764
765        // Check if we need to spill
766        self.maybe_spill()?;
767
768        Ok(true)
769    }
770
771    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
772        let num_output_cols = self.group_by.len() + self.aggregates.len();
773        let mut columns: Vec<ValueVector> =
774            (0..num_output_cols).map(|_| ValueVector::new()).collect();
775
776        if self.group_by.is_empty() {
777            // Global aggregation - single row output
778            if let Some(ref accumulators) = self.global_state {
779                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
780                    columns[i].push(acc.finalize(expr.function));
781                }
782            }
783        } else if self.using_partitioned {
784            // Drain partitioned state
785            if let Some(ref mut partitioned) = self.partitioned_groups {
786                let groups = partitioned
787                    .drain_all()
788                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
789
790                for (_key, state) in groups {
791                    // Output group key columns
792                    for (i, val) in state.key_values.iter().enumerate() {
793                        columns[i].push(val.clone());
794                    }
795
796                    // Output aggregate results
797                    for (i, (acc, expr)) in
798                        state.accumulators.iter().zip(&self.aggregates).enumerate()
799                    {
800                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
801                    }
802                }
803            }
804        } else {
805            // Group by using regular hash map - one row per group
806            for state in self.groups.values() {
807                // Output group key columns
808                for (i, val) in state.key_values.iter().enumerate() {
809                    columns[i].push(val.clone());
810                }
811
812                // Output aggregate results
813                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
814                {
815                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
816                }
817            }
818        }
819
820        if !columns.is_empty() && !columns[0].is_empty() {
821            let chunk = DataChunk::new(columns);
822            sink.consume(chunk)?;
823        }
824
825        Ok(())
826    }
827
828    fn preferred_chunk_size(&self) -> ChunkSizeHint {
829        ChunkSizeHint::Default
830    }
831
832    fn name(&self) -> &'static str {
833        "SpillableAggregatePush"
834    }
835}
836
837#[cfg(test)]
838mod tests {
839    use super::*;
840    use crate::execution::sink::CollectorSink;
841
842    fn create_test_chunk(values: &[i64]) -> DataChunk {
843        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
844        let vector = ValueVector::from_values(&v);
845        DataChunk::new(vec![vector])
846    }
847
848    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
849        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
850        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
851        DataChunk::new(vec![
852            ValueVector::from_values(&v1),
853            ValueVector::from_values(&v2),
854        ])
855    }
856
857    #[test]
858    fn test_global_count() {
859        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
860        let mut sink = CollectorSink::new();
861
862        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
863            .unwrap();
864        agg.finalize(&mut sink).unwrap();
865
866        let chunks = sink.into_chunks();
867        assert_eq!(chunks.len(), 1);
868        assert_eq!(
869            chunks[0].column(0).unwrap().get_value(0),
870            Some(Value::Int64(5))
871        );
872    }
873
874    #[test]
875    fn test_global_sum() {
876        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
877        let mut sink = CollectorSink::new();
878
879        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
880            .unwrap();
881        agg.finalize(&mut sink).unwrap();
882
883        let chunks = sink.into_chunks();
884        assert_eq!(
885            chunks[0].column(0).unwrap().get_value(0),
886            Some(Value::Float64(15.0))
887        );
888    }
889
890    #[test]
891    fn test_global_min_max() {
892        let mut agg =
893            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
894        let mut sink = CollectorSink::new();
895
896        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
897            .unwrap();
898        agg.finalize(&mut sink).unwrap();
899
900        let chunks = sink.into_chunks();
901        assert_eq!(
902            chunks[0].column(0).unwrap().get_value(0),
903            Some(Value::Int64(1))
904        );
905        assert_eq!(
906            chunks[0].column(1).unwrap().get_value(0),
907            Some(Value::Int64(9))
908        );
909    }
910
911    #[test]
912    fn test_group_by_sum() {
913        // Group by column 0, sum column 1
914        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
915        let mut sink = CollectorSink::new();
916
917        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
918        agg.push(
919            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
920            &mut sink,
921        )
922        .unwrap();
923        agg.finalize(&mut sink).unwrap();
924
925        let chunks = sink.into_chunks();
926        assert_eq!(chunks[0].len(), 2); // 2 groups
927    }
928
929    #[test]
930    fn test_spillable_aggregate_no_spill() {
931        // When threshold is not reached, should work like normal aggregate
932        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
933            .with_threshold(100);
934        let mut sink = CollectorSink::new();
935
936        agg.push(
937            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
938            &mut sink,
939        )
940        .unwrap();
941        agg.finalize(&mut sink).unwrap();
942
943        let chunks = sink.into_chunks();
944        assert_eq!(chunks[0].len(), 2); // 2 groups
945    }
946
947    #[test]
948    fn test_spillable_aggregate_with_spilling() {
949        use tempfile::TempDir;
950
951        let temp_dir = TempDir::new().unwrap();
952        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
953
954        // Set very low threshold to force spilling
955        let mut agg = SpillableAggregatePushOperator::with_spilling(
956            vec![0],
957            vec![AggregateExpr::sum(1)],
958            manager,
959            3, // Spill after 3 groups
960        );
961        let mut sink = CollectorSink::new();
962
963        // Create 10 different groups
964        for i in 0..10 {
965            let chunk = create_two_column_chunk(&[i], &[i * 10]);
966            agg.push(chunk, &mut sink).unwrap();
967        }
968        agg.finalize(&mut sink).unwrap();
969
970        let chunks = sink.into_chunks();
971        assert_eq!(chunks.len(), 1);
972        assert_eq!(chunks[0].len(), 10); // 10 groups
973
974        // Verify sums are correct
975        let mut sums: Vec<f64> = Vec::new();
976        for i in 0..chunks[0].len() {
977            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
978                sums.push(sum);
979            }
980        }
981        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
982        assert_eq!(
983            sums,
984            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
985        );
986    }
987
988    #[test]
989    fn test_spillable_aggregate_global() {
990        // Global aggregation shouldn't be affected by spilling
991        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
992        let mut sink = CollectorSink::new();
993
994        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
995            .unwrap();
996        agg.finalize(&mut sink).unwrap();
997
998        let chunks = sink.into_chunks();
999        assert_eq!(chunks.len(), 1);
1000        assert_eq!(
1001            chunks[0].column(0).unwrap().get_value(0),
1002            Some(Value::Int64(5))
1003        );
1004    }
1005
1006    #[test]
1007    fn test_spillable_aggregate_many_groups() {
1008        use tempfile::TempDir;
1009
1010        let temp_dir = TempDir::new().unwrap();
1011        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1012
1013        let mut agg = SpillableAggregatePushOperator::with_spilling(
1014            vec![0],
1015            vec![AggregateExpr::count_star()],
1016            manager,
1017            10, // Very low threshold
1018        );
1019        let mut sink = CollectorSink::new();
1020
1021        // Create 100 different groups
1022        for i in 0..100 {
1023            let chunk = create_test_chunk(&[i]);
1024            agg.push(chunk, &mut sink).unwrap();
1025        }
1026        agg.finalize(&mut sink).unwrap();
1027
1028        let chunks = sink.into_chunks();
1029        assert_eq!(chunks.len(), 1);
1030        assert_eq!(chunks[0].len(), 100); // 100 groups
1031
1032        // Each group should have count = 1
1033        for i in 0..100 {
1034            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1035                assert_eq!(count, 1);
1036            }
1037        }
1038    }
1039}