Skip to main content

graphos_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
6use crate::execution::spill::{PartitionedState, SpillManager};
7use crate::execution::vector::ValueVector;
8use graphos_common::types::Value;
9use std::collections::HashMap;
10use std::io::{Read, Write};
11use std::sync::Arc;
12
13/// Aggregation function type.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum AggregateFunction {
16    /// Count rows or non-null values.
17    Count,
18    /// Sum of values.
19    Sum,
20    /// Minimum value.
21    Min,
22    /// Maximum value.
23    Max,
24    /// Average value.
25    Avg,
26    /// First value in group.
27    First,
28}
29
30/// Aggregate expression.
31#[derive(Debug, Clone)]
32pub struct AggregateExpr {
33    /// The aggregate function.
34    pub function: AggregateFunction,
35    /// Column index to aggregate (None for COUNT(*)).
36    pub column: Option<usize>,
37    /// Whether DISTINCT applies.
38    pub distinct: bool,
39}
40
41impl AggregateExpr {
42    /// Create a COUNT(*) expression.
43    pub fn count_star() -> Self {
44        Self {
45            function: AggregateFunction::Count,
46            column: None,
47            distinct: false,
48        }
49    }
50
51    /// Create a COUNT(column) expression.
52    pub fn count(column: usize) -> Self {
53        Self {
54            function: AggregateFunction::Count,
55            column: Some(column),
56            distinct: false,
57        }
58    }
59
60    /// Create a SUM(column) expression.
61    pub fn sum(column: usize) -> Self {
62        Self {
63            function: AggregateFunction::Sum,
64            column: Some(column),
65            distinct: false,
66        }
67    }
68
69    /// Create a MIN(column) expression.
70    pub fn min(column: usize) -> Self {
71        Self {
72            function: AggregateFunction::Min,
73            column: Some(column),
74            distinct: false,
75        }
76    }
77
78    /// Create a MAX(column) expression.
79    pub fn max(column: usize) -> Self {
80        Self {
81            function: AggregateFunction::Max,
82            column: Some(column),
83            distinct: false,
84        }
85    }
86
87    /// Create an AVG(column) expression.
88    pub fn avg(column: usize) -> Self {
89        Self {
90            function: AggregateFunction::Avg,
91            column: Some(column),
92            distinct: false,
93        }
94    }
95}
96
97/// Accumulator for aggregate state.
98#[derive(Debug, Clone, Default)]
99struct Accumulator {
100    count: i64,
101    sum: f64,
102    min: Option<Value>,
103    max: Option<Value>,
104    first: Option<Value>,
105}
106
107impl Accumulator {
108    fn new() -> Self {
109        Self {
110            count: 0,
111            sum: 0.0,
112            min: None,
113            max: None,
114            first: None,
115        }
116    }
117
118    fn add(&mut self, value: &Value) {
119        // Skip nulls for aggregates
120        if matches!(value, Value::Null) {
121            return;
122        }
123
124        self.count += 1;
125
126        // Sum (for numeric types)
127        if let Some(n) = value_to_f64(value) {
128            self.sum += n;
129        }
130
131        // Min
132        if self.min.is_none() || compare_for_min(&self.min, value) {
133            self.min = Some(value.clone());
134        }
135
136        // Max
137        if self.max.is_none() || compare_for_max(&self.max, value) {
138            self.max = Some(value.clone());
139        }
140
141        // First
142        if self.first.is_none() {
143            self.first = Some(value.clone());
144        }
145    }
146
147    fn finalize(&self, func: AggregateFunction) -> Value {
148        match func {
149            AggregateFunction::Count => Value::Int64(self.count),
150            AggregateFunction::Sum => {
151                if self.count == 0 {
152                    Value::Null
153                } else {
154                    Value::Float64(self.sum)
155                }
156            }
157            AggregateFunction::Min => self.min.clone().unwrap_or(Value::Null),
158            AggregateFunction::Max => self.max.clone().unwrap_or(Value::Null),
159            AggregateFunction::Avg => {
160                if self.count == 0 {
161                    Value::Null
162                } else {
163                    Value::Float64(self.sum / self.count as f64)
164                }
165            }
166            AggregateFunction::First => self.first.clone().unwrap_or(Value::Null),
167        }
168    }
169}
170
171fn value_to_f64(value: &Value) -> Option<f64> {
172    match value {
173        Value::Int64(i) => Some(*i as f64),
174        Value::Float64(f) => Some(*f),
175        _ => None,
176    }
177}
178
179fn compare_for_min(current: &Option<Value>, new: &Value) -> bool {
180    match (current, new) {
181        (None, _) => true,
182        (Some(Value::Int64(a)), Value::Int64(b)) => b < a,
183        (Some(Value::Float64(a)), Value::Float64(b)) => b < a,
184        (Some(Value::String(a)), Value::String(b)) => b < a,
185        _ => false,
186    }
187}
188
189fn compare_for_max(current: &Option<Value>, new: &Value) -> bool {
190    match (current, new) {
191        (None, _) => true,
192        (Some(Value::Int64(a)), Value::Int64(b)) => b > a,
193        (Some(Value::Float64(a)), Value::Float64(b)) => b > a,
194        (Some(Value::String(a)), Value::String(b)) => b > a,
195        _ => false,
196    }
197}
198
199/// Hash key for grouping.
200#[derive(Debug, Clone, PartialEq, Eq, Hash)]
201struct GroupKey(Vec<u64>);
202
203impl GroupKey {
204    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
205        let hashes: Vec<u64> = group_by
206            .iter()
207            .map(|&col| {
208                chunk
209                    .column(col)
210                    .and_then(|c| c.get_value(row))
211                    .map(|v| hash_value(&v))
212                    .unwrap_or(0)
213            })
214            .collect();
215        Self(hashes)
216    }
217}
218
219fn hash_value(value: &Value) -> u64 {
220    use std::collections::hash_map::DefaultHasher;
221    use std::hash::{Hash, Hasher};
222
223    let mut hasher = DefaultHasher::new();
224    match value {
225        Value::Null => 0u8.hash(&mut hasher),
226        Value::Bool(b) => b.hash(&mut hasher),
227        Value::Int64(i) => i.hash(&mut hasher),
228        Value::Float64(f) => f.to_bits().hash(&mut hasher),
229        Value::String(s) => s.hash(&mut hasher),
230        _ => 0u8.hash(&mut hasher),
231    }
232    hasher.finish()
233}
234
235/// Group state with key values and accumulators.
236#[derive(Clone)]
237struct GroupState {
238    key_values: Vec<Value>,
239    accumulators: Vec<Accumulator>,
240}
241
242/// Push-based aggregate operator.
243///
244/// This is a pipeline breaker that accumulates all input, groups by key,
245/// and produces aggregated output in the finalize phase.
246pub struct AggregatePushOperator {
247    /// Columns to group by.
248    group_by: Vec<usize>,
249    /// Aggregate expressions.
250    aggregates: Vec<AggregateExpr>,
251    /// Group states by hash key.
252    groups: HashMap<GroupKey, GroupState>,
253    /// Global accumulator (for no GROUP BY).
254    global_state: Option<Vec<Accumulator>>,
255}
256
257impl AggregatePushOperator {
258    /// Create a new aggregate operator.
259    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
260        let global_state = if group_by.is_empty() {
261            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
262        } else {
263            None
264        };
265
266        Self {
267            group_by,
268            aggregates,
269            groups: HashMap::new(),
270            global_state,
271        }
272    }
273
274    /// Create a simple global aggregate (no GROUP BY).
275    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
276        Self::new(Vec::new(), aggregates)
277    }
278}
279
280impl PushOperator for AggregatePushOperator {
281    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
282        if chunk.is_empty() {
283            return Ok(true);
284        }
285
286        for row in chunk.selected_indices() {
287            if self.group_by.is_empty() {
288                // Global aggregation
289                if let Some(ref mut accumulators) = self.global_state {
290                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
291                        if let Some(col) = expr.column {
292                            if let Some(c) = chunk.column(col) {
293                                if let Some(val) = c.get_value(row) {
294                                    acc.add(&val);
295                                }
296                            }
297                        } else {
298                            // COUNT(*)
299                            acc.count += 1;
300                        }
301                    }
302                }
303            } else {
304                // Group by aggregation
305                let key = GroupKey::from_row(&chunk, row, &self.group_by);
306
307                let state = self.groups.entry(key).or_insert_with(|| {
308                    let key_values: Vec<Value> = self
309                        .group_by
310                        .iter()
311                        .map(|&col| {
312                            chunk
313                                .column(col)
314                                .and_then(|c| c.get_value(row))
315                                .unwrap_or(Value::Null)
316                        })
317                        .collect();
318
319                    GroupState {
320                        key_values,
321                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
322                    }
323                });
324
325                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
326                    if let Some(col) = expr.column {
327                        if let Some(c) = chunk.column(col) {
328                            if let Some(val) = c.get_value(row) {
329                                acc.add(&val);
330                            }
331                        }
332                    } else {
333                        // COUNT(*)
334                        acc.count += 1;
335                    }
336                }
337            }
338        }
339
340        Ok(true)
341    }
342
343    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
344        let num_output_cols = self.group_by.len() + self.aggregates.len();
345        let mut columns: Vec<ValueVector> =
346            (0..num_output_cols).map(|_| ValueVector::new()).collect();
347
348        if self.group_by.is_empty() {
349            // Global aggregation - single row output
350            if let Some(ref accumulators) = self.global_state {
351                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
352                    columns[i].push(acc.finalize(expr.function));
353                }
354            }
355        } else {
356            // Group by - one row per group
357            for state in self.groups.values() {
358                // Output group key columns
359                for (i, val) in state.key_values.iter().enumerate() {
360                    columns[i].push(val.clone());
361                }
362
363                // Output aggregate results
364                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
365                {
366                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
367                }
368            }
369        }
370
371        if !columns.is_empty() && !columns[0].is_empty() {
372            let chunk = DataChunk::new(columns);
373            sink.consume(chunk)?;
374        }
375
376        Ok(())
377    }
378
379    fn preferred_chunk_size(&self) -> ChunkSizeHint {
380        ChunkSizeHint::Default
381    }
382
383    fn name(&self) -> &'static str {
384        "AggregatePush"
385    }
386}
387
388/// Default spill threshold for aggregates (number of groups).
389pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
390
391/// Serializes a GroupState to bytes.
392fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
393    use crate::execution::spill::serialize_value;
394
395    // Write key values
396    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
397    for val in &state.key_values {
398        serialize_value(val, w)?;
399    }
400
401    // Write accumulators
402    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
403    for acc in &state.accumulators {
404        w.write_all(&acc.count.to_le_bytes())?;
405        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
406
407        // Min
408        let has_min = acc.min.is_some();
409        w.write_all(&[has_min as u8])?;
410        if let Some(ref v) = acc.min {
411            serialize_value(v, w)?;
412        }
413
414        // Max
415        let has_max = acc.max.is_some();
416        w.write_all(&[has_max as u8])?;
417        if let Some(ref v) = acc.max {
418            serialize_value(v, w)?;
419        }
420
421        // First
422        let has_first = acc.first.is_some();
423        w.write_all(&[has_first as u8])?;
424        if let Some(ref v) = acc.first {
425            serialize_value(v, w)?;
426        }
427    }
428
429    Ok(())
430}
431
432/// Deserializes a GroupState from bytes.
433fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
434    use crate::execution::spill::deserialize_value;
435
436    // Read key values
437    let mut len_buf = [0u8; 8];
438    r.read_exact(&mut len_buf)?;
439    let num_keys = u64::from_le_bytes(len_buf) as usize;
440
441    let mut key_values = Vec::with_capacity(num_keys);
442    for _ in 0..num_keys {
443        key_values.push(deserialize_value(r)?);
444    }
445
446    // Read accumulators
447    r.read_exact(&mut len_buf)?;
448    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
449
450    let mut accumulators = Vec::with_capacity(num_accumulators);
451    for _ in 0..num_accumulators {
452        let mut count_buf = [0u8; 8];
453        r.read_exact(&mut count_buf)?;
454        let count = i64::from_le_bytes(count_buf);
455
456        r.read_exact(&mut count_buf)?;
457        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
458
459        // Min
460        let mut flag_buf = [0u8; 1];
461        r.read_exact(&mut flag_buf)?;
462        let min = if flag_buf[0] != 0 {
463            Some(deserialize_value(r)?)
464        } else {
465            None
466        };
467
468        // Max
469        r.read_exact(&mut flag_buf)?;
470        let max = if flag_buf[0] != 0 {
471            Some(deserialize_value(r)?)
472        } else {
473            None
474        };
475
476        // First
477        r.read_exact(&mut flag_buf)?;
478        let first = if flag_buf[0] != 0 {
479            Some(deserialize_value(r)?)
480        } else {
481            None
482        };
483
484        accumulators.push(Accumulator {
485            count,
486            sum,
487            min,
488            max,
489            first,
490        });
491    }
492
493    Ok(GroupState {
494        key_values,
495        accumulators,
496    })
497}
498
499/// Push-based aggregate operator with spilling support.
500///
501/// Uses partitioned hash table that can spill cold partitions to disk
502/// when memory pressure is high.
503pub struct SpillableAggregatePushOperator {
504    /// Columns to group by.
505    group_by: Vec<usize>,
506    /// Aggregate expressions.
507    aggregates: Vec<AggregateExpr>,
508    /// Spill manager (None = no spilling).
509    spill_manager: Option<Arc<SpillManager>>,
510    /// Partitioned groups (used when spilling is enabled).
511    partitioned_groups: Option<PartitionedState<GroupState>>,
512    /// Non-partitioned groups (used when spilling is disabled).
513    groups: HashMap<GroupKey, GroupState>,
514    /// Global accumulator (for no GROUP BY).
515    global_state: Option<Vec<Accumulator>>,
516    /// Spill threshold (number of groups).
517    spill_threshold: usize,
518    /// Whether we've switched to partitioned mode.
519    using_partitioned: bool,
520}
521
522impl SpillableAggregatePushOperator {
523    /// Create a new spillable aggregate operator.
524    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
525        let global_state = if group_by.is_empty() {
526            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
527        } else {
528            None
529        };
530
531        Self {
532            group_by,
533            aggregates,
534            spill_manager: None,
535            partitioned_groups: None,
536            groups: HashMap::new(),
537            global_state,
538            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
539            using_partitioned: false,
540        }
541    }
542
543    /// Create a spillable aggregate operator with spilling enabled.
544    pub fn with_spilling(
545        group_by: Vec<usize>,
546        aggregates: Vec<AggregateExpr>,
547        manager: Arc<SpillManager>,
548        threshold: usize,
549    ) -> Self {
550        let global_state = if group_by.is_empty() {
551            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
552        } else {
553            None
554        };
555
556        let partitioned = PartitionedState::new(
557            Arc::clone(&manager),
558            256, // Number of partitions
559            serialize_group_state,
560            deserialize_group_state,
561        );
562
563        Self {
564            group_by,
565            aggregates,
566            spill_manager: Some(manager),
567            partitioned_groups: Some(partitioned),
568            groups: HashMap::new(),
569            global_state,
570            spill_threshold: threshold,
571            using_partitioned: true,
572        }
573    }
574
575    /// Create a simple global aggregate (no GROUP BY).
576    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
577        Self::new(Vec::new(), aggregates)
578    }
579
580    /// Sets the spill threshold.
581    pub fn with_threshold(mut self, threshold: usize) -> Self {
582        self.spill_threshold = threshold;
583        self
584    }
585
586    /// Switches to partitioned mode if needed.
587    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
588        if self.global_state.is_some() {
589            // Global aggregation doesn't need spilling
590            return Ok(());
591        }
592
593        // If using partitioned state, check if we need to spill
594        if let Some(ref mut partitioned) = self.partitioned_groups {
595            if partitioned.total_size() >= self.spill_threshold {
596                partitioned
597                    .spill_largest()
598                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
599            }
600        } else if self.groups.len() >= self.spill_threshold {
601            // Not using partitioned state yet, but reached threshold
602            // If spilling is configured, switch to partitioned mode
603            if let Some(ref manager) = self.spill_manager {
604                let mut partitioned = PartitionedState::new(
605                    Arc::clone(manager),
606                    256,
607                    serialize_group_state,
608                    deserialize_group_state,
609                );
610
611                // Move existing groups to partitioned state
612                for (_key, state) in self.groups.drain() {
613                    partitioned
614                        .insert(state.key_values.clone(), state)
615                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
616                }
617
618                self.partitioned_groups = Some(partitioned);
619                self.using_partitioned = true;
620            }
621        }
622
623        Ok(())
624    }
625}
626
627impl PushOperator for SpillableAggregatePushOperator {
628    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
629        if chunk.is_empty() {
630            return Ok(true);
631        }
632
633        for row in chunk.selected_indices() {
634            if self.group_by.is_empty() {
635                // Global aggregation - same as non-spillable
636                if let Some(ref mut accumulators) = self.global_state {
637                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
638                        if let Some(col) = expr.column {
639                            if let Some(c) = chunk.column(col) {
640                                if let Some(val) = c.get_value(row) {
641                                    acc.add(&val);
642                                }
643                            }
644                        } else {
645                            acc.count += 1;
646                        }
647                    }
648                }
649            } else if self.using_partitioned {
650                // Use partitioned state
651                if let Some(ref mut partitioned) = self.partitioned_groups {
652                    let key_values: Vec<Value> = self
653                        .group_by
654                        .iter()
655                        .map(|&col| {
656                            chunk
657                                .column(col)
658                                .and_then(|c| c.get_value(row))
659                                .unwrap_or(Value::Null)
660                        })
661                        .collect();
662
663                    let aggregates = &self.aggregates;
664                    let state = partitioned
665                        .get_or_insert_with(key_values.clone(), || GroupState {
666                            key_values: key_values.clone(),
667                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
668                        })
669                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
670
671                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
672                        if let Some(col) = expr.column {
673                            if let Some(c) = chunk.column(col) {
674                                if let Some(val) = c.get_value(row) {
675                                    acc.add(&val);
676                                }
677                            }
678                        } else {
679                            acc.count += 1;
680                        }
681                    }
682                }
683            } else {
684                // Use regular hash map
685                let key = GroupKey::from_row(&chunk, row, &self.group_by);
686
687                let state = self.groups.entry(key).or_insert_with(|| {
688                    let key_values: Vec<Value> = self
689                        .group_by
690                        .iter()
691                        .map(|&col| {
692                            chunk
693                                .column(col)
694                                .and_then(|c| c.get_value(row))
695                                .unwrap_or(Value::Null)
696                        })
697                        .collect();
698
699                    GroupState {
700                        key_values,
701                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
702                    }
703                });
704
705                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
706                    if let Some(col) = expr.column {
707                        if let Some(c) = chunk.column(col) {
708                            if let Some(val) = c.get_value(row) {
709                                acc.add(&val);
710                            }
711                        }
712                    } else {
713                        acc.count += 1;
714                    }
715                }
716            }
717        }
718
719        // Check if we need to spill
720        self.maybe_spill()?;
721
722        Ok(true)
723    }
724
725    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
726        let num_output_cols = self.group_by.len() + self.aggregates.len();
727        let mut columns: Vec<ValueVector> =
728            (0..num_output_cols).map(|_| ValueVector::new()).collect();
729
730        if self.group_by.is_empty() {
731            // Global aggregation - single row output
732            if let Some(ref accumulators) = self.global_state {
733                for (i, (acc, expr)) in accumulators.iter().zip(&self.aggregates).enumerate() {
734                    columns[i].push(acc.finalize(expr.function));
735                }
736            }
737        } else if self.using_partitioned {
738            // Drain partitioned state
739            if let Some(ref mut partitioned) = self.partitioned_groups {
740                let groups = partitioned
741                    .drain_all()
742                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
743
744                for (_key, state) in groups {
745                    // Output group key columns
746                    for (i, val) in state.key_values.iter().enumerate() {
747                        columns[i].push(val.clone());
748                    }
749
750                    // Output aggregate results
751                    for (i, (acc, expr)) in
752                        state.accumulators.iter().zip(&self.aggregates).enumerate()
753                    {
754                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
755                    }
756                }
757            }
758        } else {
759            // Group by using regular hash map - one row per group
760            for state in self.groups.values() {
761                // Output group key columns
762                for (i, val) in state.key_values.iter().enumerate() {
763                    columns[i].push(val.clone());
764                }
765
766                // Output aggregate results
767                for (i, (acc, expr)) in state.accumulators.iter().zip(&self.aggregates).enumerate()
768                {
769                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
770                }
771            }
772        }
773
774        if !columns.is_empty() && !columns[0].is_empty() {
775            let chunk = DataChunk::new(columns);
776            sink.consume(chunk)?;
777        }
778
779        Ok(())
780    }
781
782    fn preferred_chunk_size(&self) -> ChunkSizeHint {
783        ChunkSizeHint::Default
784    }
785
786    fn name(&self) -> &'static str {
787        "SpillableAggregatePush"
788    }
789}
790
791#[cfg(test)]
792mod tests {
793    use super::*;
794    use crate::execution::sink::CollectorSink;
795
796    fn create_test_chunk(values: &[i64]) -> DataChunk {
797        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
798        let vector = ValueVector::from_values(&v);
799        DataChunk::new(vec![vector])
800    }
801
802    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
803        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
804        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
805        DataChunk::new(vec![
806            ValueVector::from_values(&v1),
807            ValueVector::from_values(&v2),
808        ])
809    }
810
811    #[test]
812    fn test_global_count() {
813        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
814        let mut sink = CollectorSink::new();
815
816        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
817            .unwrap();
818        agg.finalize(&mut sink).unwrap();
819
820        let chunks = sink.into_chunks();
821        assert_eq!(chunks.len(), 1);
822        assert_eq!(
823            chunks[0].column(0).unwrap().get_value(0),
824            Some(Value::Int64(5))
825        );
826    }
827
828    #[test]
829    fn test_global_sum() {
830        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
831        let mut sink = CollectorSink::new();
832
833        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
834            .unwrap();
835        agg.finalize(&mut sink).unwrap();
836
837        let chunks = sink.into_chunks();
838        assert_eq!(
839            chunks[0].column(0).unwrap().get_value(0),
840            Some(Value::Float64(15.0))
841        );
842    }
843
844    #[test]
845    fn test_global_min_max() {
846        let mut agg =
847            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
848        let mut sink = CollectorSink::new();
849
850        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
851            .unwrap();
852        agg.finalize(&mut sink).unwrap();
853
854        let chunks = sink.into_chunks();
855        assert_eq!(
856            chunks[0].column(0).unwrap().get_value(0),
857            Some(Value::Int64(1))
858        );
859        assert_eq!(
860            chunks[0].column(1).unwrap().get_value(0),
861            Some(Value::Int64(9))
862        );
863    }
864
865    #[test]
866    fn test_group_by_sum() {
867        // Group by column 0, sum column 1
868        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
869        let mut sink = CollectorSink::new();
870
871        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
872        agg.push(
873            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
874            &mut sink,
875        )
876        .unwrap();
877        agg.finalize(&mut sink).unwrap();
878
879        let chunks = sink.into_chunks();
880        assert_eq!(chunks[0].len(), 2); // 2 groups
881    }
882
883    #[test]
884    fn test_spillable_aggregate_no_spill() {
885        // When threshold is not reached, should work like normal aggregate
886        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
887            .with_threshold(100);
888        let mut sink = CollectorSink::new();
889
890        agg.push(
891            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
892            &mut sink,
893        )
894        .unwrap();
895        agg.finalize(&mut sink).unwrap();
896
897        let chunks = sink.into_chunks();
898        assert_eq!(chunks[0].len(), 2); // 2 groups
899    }
900
901    #[test]
902    fn test_spillable_aggregate_with_spilling() {
903        use tempfile::TempDir;
904
905        let temp_dir = TempDir::new().unwrap();
906        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
907
908        // Set very low threshold to force spilling
909        let mut agg = SpillableAggregatePushOperator::with_spilling(
910            vec![0],
911            vec![AggregateExpr::sum(1)],
912            manager,
913            3, // Spill after 3 groups
914        );
915        let mut sink = CollectorSink::new();
916
917        // Create 10 different groups
918        for i in 0..10 {
919            let chunk = create_two_column_chunk(&[i], &[i * 10]);
920            agg.push(chunk, &mut sink).unwrap();
921        }
922        agg.finalize(&mut sink).unwrap();
923
924        let chunks = sink.into_chunks();
925        assert_eq!(chunks.len(), 1);
926        assert_eq!(chunks[0].len(), 10); // 10 groups
927
928        // Verify sums are correct
929        let mut sums: Vec<f64> = Vec::new();
930        for i in 0..chunks[0].len() {
931            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
932                sums.push(sum);
933            }
934        }
935        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
936        assert_eq!(
937            sums,
938            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
939        );
940    }
941
942    #[test]
943    fn test_spillable_aggregate_global() {
944        // Global aggregation shouldn't be affected by spilling
945        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
946        let mut sink = CollectorSink::new();
947
948        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
949            .unwrap();
950        agg.finalize(&mut sink).unwrap();
951
952        let chunks = sink.into_chunks();
953        assert_eq!(chunks.len(), 1);
954        assert_eq!(
955            chunks[0].column(0).unwrap().get_value(0),
956            Some(Value::Int64(5))
957        );
958    }
959
960    #[test]
961    fn test_spillable_aggregate_many_groups() {
962        use tempfile::TempDir;
963
964        let temp_dir = TempDir::new().unwrap();
965        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
966
967        let mut agg = SpillableAggregatePushOperator::with_spilling(
968            vec![0],
969            vec![AggregateExpr::count_star()],
970            manager,
971            10, // Very low threshold
972        );
973        let mut sink = CollectorSink::new();
974
975        // Create 100 different groups
976        for i in 0..100 {
977            let chunk = create_test_chunk(&[i]);
978            agg.push(chunk, &mut sink).unwrap();
979        }
980        agg.finalize(&mut sink).unwrap();
981
982        let chunks = sink.into_chunks();
983        assert_eq!(chunks.len(), 1);
984        assert_eq!(chunks[0].len(), 100); // 100 groups
985
986        // Each group should have count = 1
987        for i in 0..100 {
988            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
989                assert_eq!(count, 1);
990            }
991        }
992    }
993}