Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::operators::accumulator::{AggregateExpr, AggregateFunction};
6use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
7#[cfg(feature = "spill")]
8use crate::execution::spill::{PartitionedState, SpillManager};
9use crate::execution::vector::ValueVector;
10use grafeo_common::types::Value;
11use std::collections::HashMap;
12#[cfg(feature = "spill")]
13use std::io::{Read, Write};
14#[cfg(feature = "spill")]
15use std::sync::Arc;
16
17/// Accumulator for aggregate state.
18#[derive(Debug, Clone, Default)]
19struct Accumulator {
20    count: i64,
21    sum: f64,
22    min: Option<Value>,
23    max: Option<Value>,
24    first: Option<Value>,
25}
26
27impl Accumulator {
28    fn new() -> Self {
29        Self {
30            count: 0,
31            sum: 0.0,
32            min: None,
33            max: None,
34            first: None,
35        }
36    }
37
38    fn add(&mut self, value: &Value) {
39        // Skip nulls for aggregates
40        if matches!(value, Value::Null) {
41            return;
42        }
43
44        self.count += 1;
45
46        // Sum (for numeric types)
47        if let Some(n) = value_to_f64(value) {
48            self.sum += n;
49        }
50
51        // Min
52        if self.min.is_none() || compare_for_min(&self.min, value) {
53            self.min = Some(value.clone());
54        }
55
56        // Max
57        if self.max.is_none() || compare_for_max(&self.max, value) {
58            self.max = Some(value.clone());
59        }
60
61        // First
62        if self.first.is_none() {
63            self.first = Some(value.clone());
64        }
65    }
66
67    fn finalize(&mut self, func: AggregateFunction) -> Value {
68        match func {
69            AggregateFunction::Count | AggregateFunction::CountNonNull => Value::Int64(self.count),
70            AggregateFunction::Sum => {
71                if self.count == 0 {
72                    Value::Null
73                } else {
74                    Value::Float64(self.sum)
75                }
76            }
77            AggregateFunction::Min => self.min.take().unwrap_or(Value::Null),
78            AggregateFunction::Max => self.max.take().unwrap_or(Value::Null),
79            AggregateFunction::Avg => {
80                if self.count == 0 {
81                    Value::Null
82                } else {
83                    Value::Float64(self.sum / self.count as f64)
84                }
85            }
86            AggregateFunction::First => self.first.take().unwrap_or(Value::Null),
87            // Advanced functions not supported in push-based accumulator;
88            // these are handled by the pull-based AggregateState instead.
89            AggregateFunction::Last
90            | AggregateFunction::Collect
91            | AggregateFunction::StdDev
92            | AggregateFunction::StdDevPop
93            | AggregateFunction::PercentileDisc
94            | AggregateFunction::PercentileCont => Value::Null,
95        }
96    }
97}
98
99use crate::execution::operators::value_utils::{
100    is_greater_than as compare_for_max, is_less_than as compare_for_min, value_to_f64,
101};
102
103/// Hash key for grouping.
104#[derive(Debug, Clone, PartialEq, Eq, Hash)]
105struct GroupKey(Vec<u64>);
106
107impl GroupKey {
108    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
109        let hashes: Vec<u64> = group_by
110            .iter()
111            .map(|&col| {
112                chunk
113                    .column(col)
114                    .and_then(|c| c.get_value(row))
115                    .map_or(0, |v| hash_value(&v))
116            })
117            .collect();
118        Self(hashes)
119    }
120}
121
122fn hash_value(value: &Value) -> u64 {
123    use std::collections::hash_map::DefaultHasher;
124    use std::hash::{Hash, Hasher};
125
126    let mut hasher = DefaultHasher::new();
127    match value {
128        Value::Null => 0u8.hash(&mut hasher),
129        Value::Bool(b) => b.hash(&mut hasher),
130        Value::Int64(i) => i.hash(&mut hasher),
131        Value::Float64(f) => f.to_bits().hash(&mut hasher),
132        Value::String(s) => s.hash(&mut hasher),
133        _ => 0u8.hash(&mut hasher),
134    }
135    hasher.finish()
136}
137
138/// Group state with key values and accumulators.
139#[derive(Clone)]
140struct GroupState {
141    key_values: Vec<Value>,
142    accumulators: Vec<Accumulator>,
143}
144
145/// Push-based aggregate operator.
146///
147/// This is a pipeline breaker that accumulates all input, groups by key,
148/// and produces aggregated output in the finalize phase.
149pub struct AggregatePushOperator {
150    /// Columns to group by.
151    group_by: Vec<usize>,
152    /// Aggregate expressions.
153    aggregates: Vec<AggregateExpr>,
154    /// Group states by hash key.
155    groups: HashMap<GroupKey, GroupState>,
156    /// Global accumulator (for no GROUP BY).
157    global_state: Option<Vec<Accumulator>>,
158}
159
160impl AggregatePushOperator {
161    /// Create a new aggregate operator.
162    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
163        let global_state = if group_by.is_empty() {
164            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
165        } else {
166            None
167        };
168
169        Self {
170            group_by,
171            aggregates,
172            groups: HashMap::new(),
173            global_state,
174        }
175    }
176
177    /// Create a simple global aggregate (no GROUP BY).
178    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
179        Self::new(Vec::new(), aggregates)
180    }
181}
182
183impl PushOperator for AggregatePushOperator {
184    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
185        if chunk.is_empty() {
186            return Ok(true);
187        }
188
189        for row in chunk.selected_indices() {
190            if self.group_by.is_empty() {
191                // Global aggregation
192                if let Some(ref mut accumulators) = self.global_state {
193                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
194                        if let Some(col) = expr.column {
195                            if let Some(c) = chunk.column(col)
196                                && let Some(val) = c.get_value(row)
197                            {
198                                acc.add(&val);
199                            }
200                        } else {
201                            // COUNT(*)
202                            acc.count += 1;
203                        }
204                    }
205                }
206            } else {
207                // Group by aggregation
208                let key = GroupKey::from_row(&chunk, row, &self.group_by);
209
210                let state = self.groups.entry(key).or_insert_with(|| {
211                    let key_values: Vec<Value> = self
212                        .group_by
213                        .iter()
214                        .map(|&col| {
215                            chunk
216                                .column(col)
217                                .and_then(|c| c.get_value(row))
218                                .unwrap_or(Value::Null)
219                        })
220                        .collect();
221
222                    GroupState {
223                        key_values,
224                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
225                    }
226                });
227
228                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
229                    if let Some(col) = expr.column {
230                        if let Some(c) = chunk.column(col)
231                            && let Some(val) = c.get_value(row)
232                        {
233                            acc.add(&val);
234                        }
235                    } else {
236                        // COUNT(*)
237                        acc.count += 1;
238                    }
239                }
240            }
241        }
242
243        Ok(true)
244    }
245
246    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
247        let num_output_cols = self.group_by.len() + self.aggregates.len();
248        let mut columns: Vec<ValueVector> =
249            (0..num_output_cols).map(|_| ValueVector::new()).collect();
250
251        if self.group_by.is_empty() {
252            // Global aggregation - single row output
253            if let Some(ref mut accumulators) = self.global_state {
254                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
255                    columns[i].push(acc.finalize(expr.function));
256                }
257            }
258        } else {
259            // Group by - one row per group
260            for state in self.groups.values_mut() {
261                // Output group key columns
262                for (i, val) in state.key_values.iter().enumerate() {
263                    columns[i].push(val.clone());
264                }
265
266                // Output aggregate results
267                for (i, (acc, expr)) in state
268                    .accumulators
269                    .iter_mut()
270                    .zip(&self.aggregates)
271                    .enumerate()
272                {
273                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
274                }
275            }
276        }
277
278        if !columns.is_empty() && !columns[0].is_empty() {
279            let chunk = DataChunk::new(columns);
280            sink.consume(chunk)?;
281        }
282
283        Ok(())
284    }
285
286    fn preferred_chunk_size(&self) -> ChunkSizeHint {
287        ChunkSizeHint::Default
288    }
289
290    fn name(&self) -> &'static str {
291        "AggregatePush"
292    }
293}
294
295/// Default spill threshold for aggregates (number of groups).
296#[cfg(feature = "spill")]
297pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
298
299/// Serializes a GroupState to bytes.
300#[cfg(feature = "spill")]
301fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
302    use crate::execution::spill::serialize_value;
303
304    // Write key values
305    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
306    for val in &state.key_values {
307        serialize_value(val, w)?;
308    }
309
310    // Write accumulators
311    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
312    for acc in &state.accumulators {
313        w.write_all(&acc.count.to_le_bytes())?;
314        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
315
316        // Min
317        let has_min = acc.min.is_some();
318        w.write_all(&[has_min as u8])?;
319        if let Some(ref v) = acc.min {
320            serialize_value(v, w)?;
321        }
322
323        // Max
324        let has_max = acc.max.is_some();
325        w.write_all(&[has_max as u8])?;
326        if let Some(ref v) = acc.max {
327            serialize_value(v, w)?;
328        }
329
330        // First
331        let has_first = acc.first.is_some();
332        w.write_all(&[has_first as u8])?;
333        if let Some(ref v) = acc.first {
334            serialize_value(v, w)?;
335        }
336    }
337
338    Ok(())
339}
340
341/// Deserializes a GroupState from bytes.
342#[cfg(feature = "spill")]
343fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
344    use crate::execution::spill::deserialize_value;
345
346    // Read key values
347    let mut len_buf = [0u8; 8];
348    r.read_exact(&mut len_buf)?;
349    let num_keys = u64::from_le_bytes(len_buf) as usize;
350
351    let mut key_values = Vec::with_capacity(num_keys);
352    for _ in 0..num_keys {
353        key_values.push(deserialize_value(r)?);
354    }
355
356    // Read accumulators
357    r.read_exact(&mut len_buf)?;
358    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
359
360    let mut accumulators = Vec::with_capacity(num_accumulators);
361    for _ in 0..num_accumulators {
362        let mut count_buf = [0u8; 8];
363        r.read_exact(&mut count_buf)?;
364        let count = i64::from_le_bytes(count_buf);
365
366        r.read_exact(&mut count_buf)?;
367        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
368
369        // Min
370        let mut flag_buf = [0u8; 1];
371        r.read_exact(&mut flag_buf)?;
372        let min = if flag_buf[0] != 0 {
373            Some(deserialize_value(r)?)
374        } else {
375            None
376        };
377
378        // Max
379        r.read_exact(&mut flag_buf)?;
380        let max = if flag_buf[0] != 0 {
381            Some(deserialize_value(r)?)
382        } else {
383            None
384        };
385
386        // First
387        r.read_exact(&mut flag_buf)?;
388        let first = if flag_buf[0] != 0 {
389            Some(deserialize_value(r)?)
390        } else {
391            None
392        };
393
394        accumulators.push(Accumulator {
395            count,
396            sum,
397            min,
398            max,
399            first,
400        });
401    }
402
403    Ok(GroupState {
404        key_values,
405        accumulators,
406    })
407}
408
409/// Push-based aggregate operator with spilling support.
410///
411/// Uses partitioned hash table that can spill cold partitions to disk
412/// when memory pressure is high.
413#[cfg(feature = "spill")]
414pub struct SpillableAggregatePushOperator {
415    /// Columns to group by.
416    group_by: Vec<usize>,
417    /// Aggregate expressions.
418    aggregates: Vec<AggregateExpr>,
419    /// Spill manager (None = no spilling).
420    spill_manager: Option<Arc<SpillManager>>,
421    /// Partitioned groups (used when spilling is enabled).
422    partitioned_groups: Option<PartitionedState<GroupState>>,
423    /// Non-partitioned groups (used when spilling is disabled).
424    groups: HashMap<GroupKey, GroupState>,
425    /// Global accumulator (for no GROUP BY).
426    global_state: Option<Vec<Accumulator>>,
427    /// Spill threshold (number of groups).
428    spill_threshold: usize,
429    /// Whether we've switched to partitioned mode.
430    using_partitioned: bool,
431}
432
433#[cfg(feature = "spill")]
434impl SpillableAggregatePushOperator {
435    /// Create a new spillable aggregate operator.
436    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
437        let global_state = if group_by.is_empty() {
438            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
439        } else {
440            None
441        };
442
443        Self {
444            group_by,
445            aggregates,
446            spill_manager: None,
447            partitioned_groups: None,
448            groups: HashMap::new(),
449            global_state,
450            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
451            using_partitioned: false,
452        }
453    }
454
455    /// Create a spillable aggregate operator with spilling enabled.
456    pub fn with_spilling(
457        group_by: Vec<usize>,
458        aggregates: Vec<AggregateExpr>,
459        manager: Arc<SpillManager>,
460        threshold: usize,
461    ) -> Self {
462        let global_state = if group_by.is_empty() {
463            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
464        } else {
465            None
466        };
467
468        let partitioned = PartitionedState::new(
469            Arc::clone(&manager),
470            256, // Number of partitions
471            serialize_group_state,
472            deserialize_group_state,
473        );
474
475        Self {
476            group_by,
477            aggregates,
478            spill_manager: Some(manager),
479            partitioned_groups: Some(partitioned),
480            groups: HashMap::new(),
481            global_state,
482            spill_threshold: threshold,
483            using_partitioned: true,
484        }
485    }
486
487    /// Create a simple global aggregate (no GROUP BY).
488    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
489        Self::new(Vec::new(), aggregates)
490    }
491
492    /// Sets the spill threshold.
493    pub fn with_threshold(mut self, threshold: usize) -> Self {
494        self.spill_threshold = threshold;
495        self
496    }
497
498    /// Switches to partitioned mode if needed.
499    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
500        if self.global_state.is_some() {
501            // Global aggregation doesn't need spilling
502            return Ok(());
503        }
504
505        // If using partitioned state, check if we need to spill
506        if let Some(ref mut partitioned) = self.partitioned_groups {
507            if partitioned.total_size() >= self.spill_threshold {
508                partitioned
509                    .spill_largest()
510                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
511            }
512        } else if self.groups.len() >= self.spill_threshold {
513            // Not using partitioned state yet, but reached threshold
514            // If spilling is configured, switch to partitioned mode
515            if let Some(ref manager) = self.spill_manager {
516                let mut partitioned = PartitionedState::new(
517                    Arc::clone(manager),
518                    256,
519                    serialize_group_state,
520                    deserialize_group_state,
521                );
522
523                // Move existing groups to partitioned state
524                for (_key, state) in self.groups.drain() {
525                    partitioned
526                        .insert(state.key_values.clone(), state)
527                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
528                }
529
530                self.partitioned_groups = Some(partitioned);
531                self.using_partitioned = true;
532            }
533        }
534
535        Ok(())
536    }
537}
538
539#[cfg(feature = "spill")]
540impl PushOperator for SpillableAggregatePushOperator {
541    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
542        if chunk.is_empty() {
543            return Ok(true);
544        }
545
546        for row in chunk.selected_indices() {
547            if self.group_by.is_empty() {
548                // Global aggregation - same as non-spillable
549                if let Some(ref mut accumulators) = self.global_state {
550                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
551                        if let Some(col) = expr.column {
552                            if let Some(c) = chunk.column(col)
553                                && let Some(val) = c.get_value(row)
554                            {
555                                acc.add(&val);
556                            }
557                        } else {
558                            acc.count += 1;
559                        }
560                    }
561                }
562            } else if self.using_partitioned {
563                // Use partitioned state
564                if let Some(ref mut partitioned) = self.partitioned_groups {
565                    let key_values: Vec<Value> = self
566                        .group_by
567                        .iter()
568                        .map(|&col| {
569                            chunk
570                                .column(col)
571                                .and_then(|c| c.get_value(row))
572                                .unwrap_or(Value::Null)
573                        })
574                        .collect();
575
576                    let aggregates = &self.aggregates;
577                    let state = partitioned
578                        .get_or_insert_with(key_values.clone(), || GroupState {
579                            key_values: key_values.clone(),
580                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
581                        })
582                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
583
584                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
585                        if let Some(col) = expr.column {
586                            if let Some(c) = chunk.column(col)
587                                && let Some(val) = c.get_value(row)
588                            {
589                                acc.add(&val);
590                            }
591                        } else {
592                            acc.count += 1;
593                        }
594                    }
595                }
596            } else {
597                // Use regular hash map
598                let key = GroupKey::from_row(&chunk, row, &self.group_by);
599
600                let state = self.groups.entry(key).or_insert_with(|| {
601                    let key_values: Vec<Value> = self
602                        .group_by
603                        .iter()
604                        .map(|&col| {
605                            chunk
606                                .column(col)
607                                .and_then(|c| c.get_value(row))
608                                .unwrap_or(Value::Null)
609                        })
610                        .collect();
611
612                    GroupState {
613                        key_values,
614                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
615                    }
616                });
617
618                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
619                    if let Some(col) = expr.column {
620                        if let Some(c) = chunk.column(col)
621                            && let Some(val) = c.get_value(row)
622                        {
623                            acc.add(&val);
624                        }
625                    } else {
626                        acc.count += 1;
627                    }
628                }
629            }
630        }
631
632        // Check if we need to spill
633        self.maybe_spill()?;
634
635        Ok(true)
636    }
637
638    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
639        let num_output_cols = self.group_by.len() + self.aggregates.len();
640        let mut columns: Vec<ValueVector> =
641            (0..num_output_cols).map(|_| ValueVector::new()).collect();
642
643        if self.group_by.is_empty() {
644            // Global aggregation - single row output
645            if let Some(ref mut accumulators) = self.global_state {
646                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
647                    columns[i].push(acc.finalize(expr.function));
648                }
649            }
650        } else if self.using_partitioned {
651            // Drain partitioned state
652            if let Some(ref mut partitioned) = self.partitioned_groups {
653                let groups = partitioned
654                    .drain_all()
655                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
656
657                for (_key, mut state) in groups {
658                    // Output group key columns
659                    for (i, val) in state.key_values.iter().enumerate() {
660                        columns[i].push(val.clone());
661                    }
662
663                    // Output aggregate results
664                    for (i, (acc, expr)) in state
665                        .accumulators
666                        .iter_mut()
667                        .zip(&self.aggregates)
668                        .enumerate()
669                    {
670                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
671                    }
672                }
673            }
674        } else {
675            // Group by using regular hash map - one row per group
676            for state in self.groups.values_mut() {
677                // Output group key columns
678                for (i, val) in state.key_values.iter().enumerate() {
679                    columns[i].push(val.clone());
680                }
681
682                // Output aggregate results
683                for (i, (acc, expr)) in state
684                    .accumulators
685                    .iter_mut()
686                    .zip(&self.aggregates)
687                    .enumerate()
688                {
689                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
690                }
691            }
692        }
693
694        if !columns.is_empty() && !columns[0].is_empty() {
695            let chunk = DataChunk::new(columns);
696            sink.consume(chunk)?;
697        }
698
699        Ok(())
700    }
701
702    fn preferred_chunk_size(&self) -> ChunkSizeHint {
703        ChunkSizeHint::Default
704    }
705
706    fn name(&self) -> &'static str {
707        "SpillableAggregatePush"
708    }
709}
710
711#[cfg(test)]
712mod tests {
713    use super::*;
714    use crate::execution::sink::CollectorSink;
715
716    fn create_test_chunk(values: &[i64]) -> DataChunk {
717        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
718        let vector = ValueVector::from_values(&v);
719        DataChunk::new(vec![vector])
720    }
721
722    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
723        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
724        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
725        DataChunk::new(vec![
726            ValueVector::from_values(&v1),
727            ValueVector::from_values(&v2),
728        ])
729    }
730
731    #[test]
732    fn test_global_count() {
733        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
734        let mut sink = CollectorSink::new();
735
736        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
737            .unwrap();
738        agg.finalize(&mut sink).unwrap();
739
740        let chunks = sink.into_chunks();
741        assert_eq!(chunks.len(), 1);
742        assert_eq!(
743            chunks[0].column(0).unwrap().get_value(0),
744            Some(Value::Int64(5))
745        );
746    }
747
748    #[test]
749    fn test_global_sum() {
750        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
751        let mut sink = CollectorSink::new();
752
753        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
754            .unwrap();
755        agg.finalize(&mut sink).unwrap();
756
757        let chunks = sink.into_chunks();
758        assert_eq!(
759            chunks[0].column(0).unwrap().get_value(0),
760            Some(Value::Float64(15.0))
761        );
762    }
763
764    #[test]
765    fn test_global_min_max() {
766        let mut agg =
767            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
768        let mut sink = CollectorSink::new();
769
770        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
771            .unwrap();
772        agg.finalize(&mut sink).unwrap();
773
774        let chunks = sink.into_chunks();
775        assert_eq!(
776            chunks[0].column(0).unwrap().get_value(0),
777            Some(Value::Int64(1))
778        );
779        assert_eq!(
780            chunks[0].column(1).unwrap().get_value(0),
781            Some(Value::Int64(9))
782        );
783    }
784
785    #[test]
786    fn test_group_by_sum() {
787        // Group by column 0, sum column 1
788        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
789        let mut sink = CollectorSink::new();
790
791        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
792        agg.push(
793            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
794            &mut sink,
795        )
796        .unwrap();
797        agg.finalize(&mut sink).unwrap();
798
799        let chunks = sink.into_chunks();
800        assert_eq!(chunks[0].len(), 2); // 2 groups
801    }
802
803    #[test]
804    #[cfg(feature = "spill")]
805    fn test_spillable_aggregate_no_spill() {
806        // When threshold is not reached, should work like normal aggregate
807        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
808            .with_threshold(100);
809        let mut sink = CollectorSink::new();
810
811        agg.push(
812            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
813            &mut sink,
814        )
815        .unwrap();
816        agg.finalize(&mut sink).unwrap();
817
818        let chunks = sink.into_chunks();
819        assert_eq!(chunks[0].len(), 2); // 2 groups
820    }
821
822    #[test]
823    #[cfg(feature = "spill")]
824    fn test_spillable_aggregate_with_spilling() {
825        use tempfile::TempDir;
826
827        let temp_dir = TempDir::new().unwrap();
828        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
829
830        // Set very low threshold to force spilling
831        let mut agg = SpillableAggregatePushOperator::with_spilling(
832            vec![0],
833            vec![AggregateExpr::sum(1)],
834            manager,
835            3, // Spill after 3 groups
836        );
837        let mut sink = CollectorSink::new();
838
839        // Create 10 different groups
840        for i in 0..10 {
841            let chunk = create_two_column_chunk(&[i], &[i * 10]);
842            agg.push(chunk, &mut sink).unwrap();
843        }
844        agg.finalize(&mut sink).unwrap();
845
846        let chunks = sink.into_chunks();
847        assert_eq!(chunks.len(), 1);
848        assert_eq!(chunks[0].len(), 10); // 10 groups
849
850        // Verify sums are correct
851        let mut sums: Vec<f64> = Vec::new();
852        for i in 0..chunks[0].len() {
853            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
854                sums.push(sum);
855            }
856        }
857        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
858        assert_eq!(
859            sums,
860            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
861        );
862    }
863
864    #[test]
865    #[cfg(feature = "spill")]
866    fn test_spillable_aggregate_global() {
867        // Global aggregation shouldn't be affected by spilling
868        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
869        let mut sink = CollectorSink::new();
870
871        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
872            .unwrap();
873        agg.finalize(&mut sink).unwrap();
874
875        let chunks = sink.into_chunks();
876        assert_eq!(chunks.len(), 1);
877        assert_eq!(
878            chunks[0].column(0).unwrap().get_value(0),
879            Some(Value::Int64(5))
880        );
881    }
882
883    #[test]
884    #[cfg(feature = "spill")]
885    fn test_spillable_aggregate_many_groups() {
886        use tempfile::TempDir;
887
888        let temp_dir = TempDir::new().unwrap();
889        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
890
891        let mut agg = SpillableAggregatePushOperator::with_spilling(
892            vec![0],
893            vec![AggregateExpr::count_star()],
894            manager,
895            10, // Very low threshold
896        );
897        let mut sink = CollectorSink::new();
898
899        // Create 100 different groups
900        for i in 0..100 {
901            let chunk = create_test_chunk(&[i]);
902            agg.push(chunk, &mut sink).unwrap();
903        }
904        agg.finalize(&mut sink).unwrap();
905
906        let chunks = sink.into_chunks();
907        assert_eq!(chunks.len(), 1);
908        assert_eq!(chunks[0].len(), 100); // 100 groups
909
910        // Each group should have count = 1
911        for i in 0..100 {
912            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
913                assert_eq!(count, 1);
914            }
915        }
916    }
917}