Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6#[cfg(feature = "spill")]
7use crate::execution::spill::{PartitionedState, SpillManager};
8use crate::execution::vector::ValueVector;
9use grafeo_common::types::Value;
10use std::collections::HashMap;
11#[cfg(feature = "spill")]
12use std::io::{Read, Write};
13#[cfg(feature = "spill")]
14use std::sync::Arc;
15
16/// Aggregation function type.
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum AggregateFunction {
19    /// Count rows or non-null values.
20    Count,
21    /// Sum of values.
22    Sum,
23    /// Minimum value.
24    Min,
25    /// Maximum value.
26    Max,
27    /// Average value.
28    Avg,
29    /// First value in group.
30    First,
31}
32
33/// Aggregate expression.
34#[derive(Debug, Clone)]
35pub struct AggregateExpr {
36    /// The aggregate function.
37    pub function: AggregateFunction,
38    /// Column index to aggregate (None for COUNT(*)).
39    pub column: Option<usize>,
40    /// Whether DISTINCT applies.
41    pub distinct: bool,
42}
43
44impl AggregateExpr {
45    /// Create a COUNT(*) expression.
46    pub fn count_star() -> Self {
47        Self {
48            function: AggregateFunction::Count,
49            column: None,
50            distinct: false,
51        }
52    }
53
54    /// Create a COUNT(column) expression.
55    pub fn count(column: usize) -> Self {
56        Self {
57            function: AggregateFunction::Count,
58            column: Some(column),
59            distinct: false,
60        }
61    }
62
63    /// Create a SUM(column) expression.
64    pub fn sum(column: usize) -> Self {
65        Self {
66            function: AggregateFunction::Sum,
67            column: Some(column),
68            distinct: false,
69        }
70    }
71
72    /// Create a MIN(column) expression.
73    pub fn min(column: usize) -> Self {
74        Self {
75            function: AggregateFunction::Min,
76            column: Some(column),
77            distinct: false,
78        }
79    }
80
81    /// Create a MAX(column) expression.
82    pub fn max(column: usize) -> Self {
83        Self {
84            function: AggregateFunction::Max,
85            column: Some(column),
86            distinct: false,
87        }
88    }
89
90    /// Create an AVG(column) expression.
91    pub fn avg(column: usize) -> Self {
92        Self {
93            function: AggregateFunction::Avg,
94            column: Some(column),
95            distinct: false,
96        }
97    }
98}
99
100/// Accumulator for aggregate state.
101#[derive(Debug, Clone, Default)]
102struct Accumulator {
103    count: i64,
104    sum: f64,
105    min: Option<Value>,
106    max: Option<Value>,
107    first: Option<Value>,
108}
109
110impl Accumulator {
111    fn new() -> Self {
112        Self {
113            count: 0,
114            sum: 0.0,
115            min: None,
116            max: None,
117            first: None,
118        }
119    }
120
121    fn add(&mut self, value: &Value) {
122        // Skip nulls for aggregates
123        if matches!(value, Value::Null) {
124            return;
125        }
126
127        self.count += 1;
128
129        // Sum (for numeric types)
130        if let Some(n) = value_to_f64(value) {
131            self.sum += n;
132        }
133
134        // Min
135        if self.min.is_none() || compare_for_min(&self.min, value) {
136            self.min = Some(value.clone());
137        }
138
139        // Max
140        if self.max.is_none() || compare_for_max(&self.max, value) {
141            self.max = Some(value.clone());
142        }
143
144        // First
145        if self.first.is_none() {
146            self.first = Some(value.clone());
147        }
148    }
149
150    fn finalize(&mut self, func: AggregateFunction) -> Value {
151        match func {
152            AggregateFunction::Count => Value::Int64(self.count),
153            AggregateFunction::Sum => {
154                if self.count == 0 {
155                    Value::Null
156                } else {
157                    Value::Float64(self.sum)
158                }
159            }
160            AggregateFunction::Min => self.min.take().unwrap_or(Value::Null),
161            AggregateFunction::Max => self.max.take().unwrap_or(Value::Null),
162            AggregateFunction::Avg => {
163                if self.count == 0 {
164                    Value::Null
165                } else {
166                    Value::Float64(self.sum / self.count as f64)
167                }
168            }
169            AggregateFunction::First => self.first.take().unwrap_or(Value::Null),
170        }
171    }
172}
173
174use crate::execution::operators::value_utils::{
175    is_greater_than as compare_for_max, is_less_than as compare_for_min, value_to_f64,
176};
177
178/// Hash key for grouping.
179#[derive(Debug, Clone, PartialEq, Eq, Hash)]
180struct GroupKey(Vec<u64>);
181
182impl GroupKey {
183    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
184        let hashes: Vec<u64> = group_by
185            .iter()
186            .map(|&col| {
187                chunk
188                    .column(col)
189                    .and_then(|c| c.get_value(row))
190                    .map_or(0, |v| hash_value(&v))
191            })
192            .collect();
193        Self(hashes)
194    }
195}
196
197fn hash_value(value: &Value) -> u64 {
198    use std::collections::hash_map::DefaultHasher;
199    use std::hash::{Hash, Hasher};
200
201    let mut hasher = DefaultHasher::new();
202    match value {
203        Value::Null => 0u8.hash(&mut hasher),
204        Value::Bool(b) => b.hash(&mut hasher),
205        Value::Int64(i) => i.hash(&mut hasher),
206        Value::Float64(f) => f.to_bits().hash(&mut hasher),
207        Value::String(s) => s.hash(&mut hasher),
208        _ => 0u8.hash(&mut hasher),
209    }
210    hasher.finish()
211}
212
213/// Group state with key values and accumulators.
214#[derive(Clone)]
215struct GroupState {
216    key_values: Vec<Value>,
217    accumulators: Vec<Accumulator>,
218}
219
220/// Push-based aggregate operator.
221///
222/// This is a pipeline breaker that accumulates all input, groups by key,
223/// and produces aggregated output in the finalize phase.
224pub struct AggregatePushOperator {
225    /// Columns to group by.
226    group_by: Vec<usize>,
227    /// Aggregate expressions.
228    aggregates: Vec<AggregateExpr>,
229    /// Group states by hash key.
230    groups: HashMap<GroupKey, GroupState>,
231    /// Global accumulator (for no GROUP BY).
232    global_state: Option<Vec<Accumulator>>,
233}
234
235impl AggregatePushOperator {
236    /// Create a new aggregate operator.
237    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
238        let global_state = if group_by.is_empty() {
239            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
240        } else {
241            None
242        };
243
244        Self {
245            group_by,
246            aggregates,
247            groups: HashMap::new(),
248            global_state,
249        }
250    }
251
252    /// Create a simple global aggregate (no GROUP BY).
253    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
254        Self::new(Vec::new(), aggregates)
255    }
256}
257
258impl PushOperator for AggregatePushOperator {
259    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
260        if chunk.is_empty() {
261            return Ok(true);
262        }
263
264        for row in chunk.selected_indices() {
265            if self.group_by.is_empty() {
266                // Global aggregation
267                if let Some(ref mut accumulators) = self.global_state {
268                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
269                        if let Some(col) = expr.column {
270                            if let Some(c) = chunk.column(col)
271                                && let Some(val) = c.get_value(row)
272                            {
273                                acc.add(&val);
274                            }
275                        } else {
276                            // COUNT(*)
277                            acc.count += 1;
278                        }
279                    }
280                }
281            } else {
282                // Group by aggregation
283                let key = GroupKey::from_row(&chunk, row, &self.group_by);
284
285                let state = self.groups.entry(key).or_insert_with(|| {
286                    let key_values: Vec<Value> = self
287                        .group_by
288                        .iter()
289                        .map(|&col| {
290                            chunk
291                                .column(col)
292                                .and_then(|c| c.get_value(row))
293                                .unwrap_or(Value::Null)
294                        })
295                        .collect();
296
297                    GroupState {
298                        key_values,
299                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
300                    }
301                });
302
303                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
304                    if let Some(col) = expr.column {
305                        if let Some(c) = chunk.column(col)
306                            && let Some(val) = c.get_value(row)
307                        {
308                            acc.add(&val);
309                        }
310                    } else {
311                        // COUNT(*)
312                        acc.count += 1;
313                    }
314                }
315            }
316        }
317
318        Ok(true)
319    }
320
321    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
322        let num_output_cols = self.group_by.len() + self.aggregates.len();
323        let mut columns: Vec<ValueVector> =
324            (0..num_output_cols).map(|_| ValueVector::new()).collect();
325
326        if self.group_by.is_empty() {
327            // Global aggregation - single row output
328            if let Some(ref mut accumulators) = self.global_state {
329                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
330                    columns[i].push(acc.finalize(expr.function));
331                }
332            }
333        } else {
334            // Group by - one row per group
335            for state in self.groups.values_mut() {
336                // Output group key columns
337                for (i, val) in state.key_values.iter().enumerate() {
338                    columns[i].push(val.clone());
339                }
340
341                // Output aggregate results
342                for (i, (acc, expr)) in state
343                    .accumulators
344                    .iter_mut()
345                    .zip(&self.aggregates)
346                    .enumerate()
347                {
348                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
349                }
350            }
351        }
352
353        if !columns.is_empty() && !columns[0].is_empty() {
354            let chunk = DataChunk::new(columns);
355            sink.consume(chunk)?;
356        }
357
358        Ok(())
359    }
360
361    fn preferred_chunk_size(&self) -> ChunkSizeHint {
362        ChunkSizeHint::Default
363    }
364
365    fn name(&self) -> &'static str {
366        "AggregatePush"
367    }
368}
369
370/// Default spill threshold for aggregates (number of groups).
371#[cfg(feature = "spill")]
372pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
373
374/// Serializes a GroupState to bytes.
375#[cfg(feature = "spill")]
376fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
377    use crate::execution::spill::serialize_value;
378
379    // Write key values
380    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
381    for val in &state.key_values {
382        serialize_value(val, w)?;
383    }
384
385    // Write accumulators
386    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
387    for acc in &state.accumulators {
388        w.write_all(&acc.count.to_le_bytes())?;
389        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
390
391        // Min
392        let has_min = acc.min.is_some();
393        w.write_all(&[has_min as u8])?;
394        if let Some(ref v) = acc.min {
395            serialize_value(v, w)?;
396        }
397
398        // Max
399        let has_max = acc.max.is_some();
400        w.write_all(&[has_max as u8])?;
401        if let Some(ref v) = acc.max {
402            serialize_value(v, w)?;
403        }
404
405        // First
406        let has_first = acc.first.is_some();
407        w.write_all(&[has_first as u8])?;
408        if let Some(ref v) = acc.first {
409            serialize_value(v, w)?;
410        }
411    }
412
413    Ok(())
414}
415
416/// Deserializes a GroupState from bytes.
417#[cfg(feature = "spill")]
418fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
419    use crate::execution::spill::deserialize_value;
420
421    // Read key values
422    let mut len_buf = [0u8; 8];
423    r.read_exact(&mut len_buf)?;
424    let num_keys = u64::from_le_bytes(len_buf) as usize;
425
426    let mut key_values = Vec::with_capacity(num_keys);
427    for _ in 0..num_keys {
428        key_values.push(deserialize_value(r)?);
429    }
430
431    // Read accumulators
432    r.read_exact(&mut len_buf)?;
433    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
434
435    let mut accumulators = Vec::with_capacity(num_accumulators);
436    for _ in 0..num_accumulators {
437        let mut count_buf = [0u8; 8];
438        r.read_exact(&mut count_buf)?;
439        let count = i64::from_le_bytes(count_buf);
440
441        r.read_exact(&mut count_buf)?;
442        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
443
444        // Min
445        let mut flag_buf = [0u8; 1];
446        r.read_exact(&mut flag_buf)?;
447        let min = if flag_buf[0] != 0 {
448            Some(deserialize_value(r)?)
449        } else {
450            None
451        };
452
453        // Max
454        r.read_exact(&mut flag_buf)?;
455        let max = if flag_buf[0] != 0 {
456            Some(deserialize_value(r)?)
457        } else {
458            None
459        };
460
461        // First
462        r.read_exact(&mut flag_buf)?;
463        let first = if flag_buf[0] != 0 {
464            Some(deserialize_value(r)?)
465        } else {
466            None
467        };
468
469        accumulators.push(Accumulator {
470            count,
471            sum,
472            min,
473            max,
474            first,
475        });
476    }
477
478    Ok(GroupState {
479        key_values,
480        accumulators,
481    })
482}
483
484/// Push-based aggregate operator with spilling support.
485///
486/// Uses partitioned hash table that can spill cold partitions to disk
487/// when memory pressure is high.
488#[cfg(feature = "spill")]
489pub struct SpillableAggregatePushOperator {
490    /// Columns to group by.
491    group_by: Vec<usize>,
492    /// Aggregate expressions.
493    aggregates: Vec<AggregateExpr>,
494    /// Spill manager (None = no spilling).
495    spill_manager: Option<Arc<SpillManager>>,
496    /// Partitioned groups (used when spilling is enabled).
497    partitioned_groups: Option<PartitionedState<GroupState>>,
498    /// Non-partitioned groups (used when spilling is disabled).
499    groups: HashMap<GroupKey, GroupState>,
500    /// Global accumulator (for no GROUP BY).
501    global_state: Option<Vec<Accumulator>>,
502    /// Spill threshold (number of groups).
503    spill_threshold: usize,
504    /// Whether we've switched to partitioned mode.
505    using_partitioned: bool,
506}
507
508#[cfg(feature = "spill")]
509impl SpillableAggregatePushOperator {
510    /// Create a new spillable aggregate operator.
511    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
512        let global_state = if group_by.is_empty() {
513            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
514        } else {
515            None
516        };
517
518        Self {
519            group_by,
520            aggregates,
521            spill_manager: None,
522            partitioned_groups: None,
523            groups: HashMap::new(),
524            global_state,
525            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
526            using_partitioned: false,
527        }
528    }
529
530    /// Create a spillable aggregate operator with spilling enabled.
531    pub fn with_spilling(
532        group_by: Vec<usize>,
533        aggregates: Vec<AggregateExpr>,
534        manager: Arc<SpillManager>,
535        threshold: usize,
536    ) -> Self {
537        let global_state = if group_by.is_empty() {
538            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
539        } else {
540            None
541        };
542
543        let partitioned = PartitionedState::new(
544            Arc::clone(&manager),
545            256, // Number of partitions
546            serialize_group_state,
547            deserialize_group_state,
548        );
549
550        Self {
551            group_by,
552            aggregates,
553            spill_manager: Some(manager),
554            partitioned_groups: Some(partitioned),
555            groups: HashMap::new(),
556            global_state,
557            spill_threshold: threshold,
558            using_partitioned: true,
559        }
560    }
561
562    /// Create a simple global aggregate (no GROUP BY).
563    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
564        Self::new(Vec::new(), aggregates)
565    }
566
567    /// Sets the spill threshold.
568    pub fn with_threshold(mut self, threshold: usize) -> Self {
569        self.spill_threshold = threshold;
570        self
571    }
572
573    /// Switches to partitioned mode if needed.
574    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
575        if self.global_state.is_some() {
576            // Global aggregation doesn't need spilling
577            return Ok(());
578        }
579
580        // If using partitioned state, check if we need to spill
581        if let Some(ref mut partitioned) = self.partitioned_groups {
582            if partitioned.total_size() >= self.spill_threshold {
583                partitioned
584                    .spill_largest()
585                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
586            }
587        } else if self.groups.len() >= self.spill_threshold {
588            // Not using partitioned state yet, but reached threshold
589            // If spilling is configured, switch to partitioned mode
590            if let Some(ref manager) = self.spill_manager {
591                let mut partitioned = PartitionedState::new(
592                    Arc::clone(manager),
593                    256,
594                    serialize_group_state,
595                    deserialize_group_state,
596                );
597
598                // Move existing groups to partitioned state
599                for (_key, state) in self.groups.drain() {
600                    partitioned
601                        .insert(state.key_values.clone(), state)
602                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
603                }
604
605                self.partitioned_groups = Some(partitioned);
606                self.using_partitioned = true;
607            }
608        }
609
610        Ok(())
611    }
612}
613
614#[cfg(feature = "spill")]
615impl PushOperator for SpillableAggregatePushOperator {
616    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
617        if chunk.is_empty() {
618            return Ok(true);
619        }
620
621        for row in chunk.selected_indices() {
622            if self.group_by.is_empty() {
623                // Global aggregation - same as non-spillable
624                if let Some(ref mut accumulators) = self.global_state {
625                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
626                        if let Some(col) = expr.column {
627                            if let Some(c) = chunk.column(col)
628                                && let Some(val) = c.get_value(row)
629                            {
630                                acc.add(&val);
631                            }
632                        } else {
633                            acc.count += 1;
634                        }
635                    }
636                }
637            } else if self.using_partitioned {
638                // Use partitioned state
639                if let Some(ref mut partitioned) = self.partitioned_groups {
640                    let key_values: Vec<Value> = self
641                        .group_by
642                        .iter()
643                        .map(|&col| {
644                            chunk
645                                .column(col)
646                                .and_then(|c| c.get_value(row))
647                                .unwrap_or(Value::Null)
648                        })
649                        .collect();
650
651                    let aggregates = &self.aggregates;
652                    let state = partitioned
653                        .get_or_insert_with(key_values.clone(), || GroupState {
654                            key_values: key_values.clone(),
655                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
656                        })
657                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
658
659                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
660                        if let Some(col) = expr.column {
661                            if let Some(c) = chunk.column(col)
662                                && let Some(val) = c.get_value(row)
663                            {
664                                acc.add(&val);
665                            }
666                        } else {
667                            acc.count += 1;
668                        }
669                    }
670                }
671            } else {
672                // Use regular hash map
673                let key = GroupKey::from_row(&chunk, row, &self.group_by);
674
675                let state = self.groups.entry(key).or_insert_with(|| {
676                    let key_values: Vec<Value> = self
677                        .group_by
678                        .iter()
679                        .map(|&col| {
680                            chunk
681                                .column(col)
682                                .and_then(|c| c.get_value(row))
683                                .unwrap_or(Value::Null)
684                        })
685                        .collect();
686
687                    GroupState {
688                        key_values,
689                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
690                    }
691                });
692
693                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
694                    if let Some(col) = expr.column {
695                        if let Some(c) = chunk.column(col)
696                            && let Some(val) = c.get_value(row)
697                        {
698                            acc.add(&val);
699                        }
700                    } else {
701                        acc.count += 1;
702                    }
703                }
704            }
705        }
706
707        // Check if we need to spill
708        self.maybe_spill()?;
709
710        Ok(true)
711    }
712
713    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
714        let num_output_cols = self.group_by.len() + self.aggregates.len();
715        let mut columns: Vec<ValueVector> =
716            (0..num_output_cols).map(|_| ValueVector::new()).collect();
717
718        if self.group_by.is_empty() {
719            // Global aggregation - single row output
720            if let Some(ref mut accumulators) = self.global_state {
721                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
722                    columns[i].push(acc.finalize(expr.function));
723                }
724            }
725        } else if self.using_partitioned {
726            // Drain partitioned state
727            if let Some(ref mut partitioned) = self.partitioned_groups {
728                let groups = partitioned
729                    .drain_all()
730                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
731
732                for (_key, mut state) in groups {
733                    // Output group key columns
734                    for (i, val) in state.key_values.iter().enumerate() {
735                        columns[i].push(val.clone());
736                    }
737
738                    // Output aggregate results
739                    for (i, (acc, expr)) in state
740                        .accumulators
741                        .iter_mut()
742                        .zip(&self.aggregates)
743                        .enumerate()
744                    {
745                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
746                    }
747                }
748            }
749        } else {
750            // Group by using regular hash map - one row per group
751            for state in self.groups.values_mut() {
752                // Output group key columns
753                for (i, val) in state.key_values.iter().enumerate() {
754                    columns[i].push(val.clone());
755                }
756
757                // Output aggregate results
758                for (i, (acc, expr)) in state
759                    .accumulators
760                    .iter_mut()
761                    .zip(&self.aggregates)
762                    .enumerate()
763                {
764                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
765                }
766            }
767        }
768
769        if !columns.is_empty() && !columns[0].is_empty() {
770            let chunk = DataChunk::new(columns);
771            sink.consume(chunk)?;
772        }
773
774        Ok(())
775    }
776
777    fn preferred_chunk_size(&self) -> ChunkSizeHint {
778        ChunkSizeHint::Default
779    }
780
781    fn name(&self) -> &'static str {
782        "SpillableAggregatePush"
783    }
784}
785
786#[cfg(test)]
787mod tests {
788    use super::*;
789    use crate::execution::sink::CollectorSink;
790
791    fn create_test_chunk(values: &[i64]) -> DataChunk {
792        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
793        let vector = ValueVector::from_values(&v);
794        DataChunk::new(vec![vector])
795    }
796
797    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
798        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
799        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
800        DataChunk::new(vec![
801            ValueVector::from_values(&v1),
802            ValueVector::from_values(&v2),
803        ])
804    }
805
806    #[test]
807    fn test_global_count() {
808        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
809        let mut sink = CollectorSink::new();
810
811        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
812            .unwrap();
813        agg.finalize(&mut sink).unwrap();
814
815        let chunks = sink.into_chunks();
816        assert_eq!(chunks.len(), 1);
817        assert_eq!(
818            chunks[0].column(0).unwrap().get_value(0),
819            Some(Value::Int64(5))
820        );
821    }
822
823    #[test]
824    fn test_global_sum() {
825        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
826        let mut sink = CollectorSink::new();
827
828        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
829            .unwrap();
830        agg.finalize(&mut sink).unwrap();
831
832        let chunks = sink.into_chunks();
833        assert_eq!(
834            chunks[0].column(0).unwrap().get_value(0),
835            Some(Value::Float64(15.0))
836        );
837    }
838
839    #[test]
840    fn test_global_min_max() {
841        let mut agg =
842            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
843        let mut sink = CollectorSink::new();
844
845        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
846            .unwrap();
847        agg.finalize(&mut sink).unwrap();
848
849        let chunks = sink.into_chunks();
850        assert_eq!(
851            chunks[0].column(0).unwrap().get_value(0),
852            Some(Value::Int64(1))
853        );
854        assert_eq!(
855            chunks[0].column(1).unwrap().get_value(0),
856            Some(Value::Int64(9))
857        );
858    }
859
860    #[test]
861    fn test_group_by_sum() {
862        // Group by column 0, sum column 1
863        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
864        let mut sink = CollectorSink::new();
865
866        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
867        agg.push(
868            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
869            &mut sink,
870        )
871        .unwrap();
872        agg.finalize(&mut sink).unwrap();
873
874        let chunks = sink.into_chunks();
875        assert_eq!(chunks[0].len(), 2); // 2 groups
876    }
877
878    #[test]
879    #[cfg(feature = "spill")]
880    fn test_spillable_aggregate_no_spill() {
881        // When threshold is not reached, should work like normal aggregate
882        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
883            .with_threshold(100);
884        let mut sink = CollectorSink::new();
885
886        agg.push(
887            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
888            &mut sink,
889        )
890        .unwrap();
891        agg.finalize(&mut sink).unwrap();
892
893        let chunks = sink.into_chunks();
894        assert_eq!(chunks[0].len(), 2); // 2 groups
895    }
896
897    #[test]
898    #[cfg(feature = "spill")]
899    fn test_spillable_aggregate_with_spilling() {
900        use tempfile::TempDir;
901
902        let temp_dir = TempDir::new().unwrap();
903        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
904
905        // Set very low threshold to force spilling
906        let mut agg = SpillableAggregatePushOperator::with_spilling(
907            vec![0],
908            vec![AggregateExpr::sum(1)],
909            manager,
910            3, // Spill after 3 groups
911        );
912        let mut sink = CollectorSink::new();
913
914        // Create 10 different groups
915        for i in 0..10 {
916            let chunk = create_two_column_chunk(&[i], &[i * 10]);
917            agg.push(chunk, &mut sink).unwrap();
918        }
919        agg.finalize(&mut sink).unwrap();
920
921        let chunks = sink.into_chunks();
922        assert_eq!(chunks.len(), 1);
923        assert_eq!(chunks[0].len(), 10); // 10 groups
924
925        // Verify sums are correct
926        let mut sums: Vec<f64> = Vec::new();
927        for i in 0..chunks[0].len() {
928            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
929                sums.push(sum);
930            }
931        }
932        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
933        assert_eq!(
934            sums,
935            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
936        );
937    }
938
939    #[test]
940    #[cfg(feature = "spill")]
941    fn test_spillable_aggregate_global() {
942        // Global aggregation shouldn't be affected by spilling
943        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
944        let mut sink = CollectorSink::new();
945
946        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
947            .unwrap();
948        agg.finalize(&mut sink).unwrap();
949
950        let chunks = sink.into_chunks();
951        assert_eq!(chunks.len(), 1);
952        assert_eq!(
953            chunks[0].column(0).unwrap().get_value(0),
954            Some(Value::Int64(5))
955        );
956    }
957
958    #[test]
959    #[cfg(feature = "spill")]
960    fn test_spillable_aggregate_many_groups() {
961        use tempfile::TempDir;
962
963        let temp_dir = TempDir::new().unwrap();
964        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
965
966        let mut agg = SpillableAggregatePushOperator::with_spilling(
967            vec![0],
968            vec![AggregateExpr::count_star()],
969            manager,
970            10, // Very low threshold
971        );
972        let mut sink = CollectorSink::new();
973
974        // Create 100 different groups
975        for i in 0..100 {
976            let chunk = create_test_chunk(&[i]);
977            agg.push(chunk, &mut sink).unwrap();
978        }
979        agg.finalize(&mut sink).unwrap();
980
981        let chunks = sink.into_chunks();
982        assert_eq!(chunks.len(), 1);
983        assert_eq!(chunks[0].len(), 100); // 100 groups
984
985        // Each group should have count = 1
986        for i in 0..100 {
987            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
988                assert_eq!(count, 1);
989            }
990        }
991    }
992}