Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::operators::accumulator::{AggregateExpr, AggregateFunction};
6use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
7#[cfg(feature = "spill")]
8use crate::execution::spill::{PartitionedState, SpillManager};
9use crate::execution::vector::ValueVector;
10use grafeo_common::types::Value;
11use std::collections::HashMap;
12#[cfg(feature = "spill")]
13use std::io::{Read, Write};
14#[cfg(feature = "spill")]
15use std::sync::Arc;
16
17/// Accumulator for aggregate state.
18#[derive(Debug, Clone, Default)]
19struct Accumulator {
20    count: i64,
21    sum: f64,
22    min: Option<Value>,
23    max: Option<Value>,
24    first: Option<Value>,
25}
26
27impl Accumulator {
28    fn new() -> Self {
29        Self {
30            count: 0,
31            sum: 0.0,
32            min: None,
33            max: None,
34            first: None,
35        }
36    }
37
38    fn add(&mut self, value: &Value) {
39        // Skip nulls for aggregates
40        if matches!(value, Value::Null) {
41            return;
42        }
43
44        self.count += 1;
45
46        // Sum (for numeric types)
47        if let Some(n) = value_to_f64(value) {
48            self.sum += n;
49        }
50
51        // Min
52        if self.min.is_none() || compare_for_min(&self.min, value) {
53            self.min = Some(value.clone());
54        }
55
56        // Max
57        if self.max.is_none() || compare_for_max(&self.max, value) {
58            self.max = Some(value.clone());
59        }
60
61        // First
62        if self.first.is_none() {
63            self.first = Some(value.clone());
64        }
65    }
66
67    fn finalize(&mut self, func: AggregateFunction) -> Value {
68        match func {
69            AggregateFunction::Count | AggregateFunction::CountNonNull => Value::Int64(self.count),
70            AggregateFunction::Sum => {
71                if self.count == 0 {
72                    Value::Null
73                } else {
74                    Value::Float64(self.sum)
75                }
76            }
77            AggregateFunction::Min => self.min.take().unwrap_or(Value::Null),
78            AggregateFunction::Max => self.max.take().unwrap_or(Value::Null),
79            AggregateFunction::Avg => {
80                if self.count == 0 {
81                    Value::Null
82                } else {
83                    Value::Float64(self.sum / self.count as f64)
84                }
85            }
86            AggregateFunction::First => self.first.take().unwrap_or(Value::Null),
87            // Advanced functions not supported in push-based accumulator;
88            // these are handled by the pull-based AggregateState instead.
89            AggregateFunction::Last
90            | AggregateFunction::Collect
91            | AggregateFunction::StdDev
92            | AggregateFunction::StdDevPop
93            | AggregateFunction::Variance
94            | AggregateFunction::VariancePop
95            | AggregateFunction::PercentileDisc
96            | AggregateFunction::PercentileCont
97            | AggregateFunction::GroupConcat
98            | AggregateFunction::Sample
99            | AggregateFunction::CovarSamp
100            | AggregateFunction::CovarPop
101            | AggregateFunction::Corr
102            | AggregateFunction::RegrSlope
103            | AggregateFunction::RegrIntercept
104            | AggregateFunction::RegrR2
105            | AggregateFunction::RegrCount
106            | AggregateFunction::RegrSxx
107            | AggregateFunction::RegrSyy
108            | AggregateFunction::RegrSxy
109            | AggregateFunction::RegrAvgx
110            | AggregateFunction::RegrAvgy => Value::Null,
111        }
112    }
113}
114
115use crate::execution::operators::value_utils::{
116    is_greater_than as compare_for_max, is_less_than as compare_for_min, value_to_f64,
117};
118
119/// Hash key for grouping.
120#[derive(Debug, Clone, PartialEq, Eq, Hash)]
121struct GroupKey(Vec<u64>);
122
123impl GroupKey {
124    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
125        let hashes: Vec<u64> = group_by
126            .iter()
127            .map(|&col| {
128                chunk
129                    .column(col)
130                    .and_then(|c| c.get_value(row))
131                    .map_or(0, |v| hash_value(&v))
132            })
133            .collect();
134        Self(hashes)
135    }
136}
137
138fn hash_value(value: &Value) -> u64 {
139    use std::collections::hash_map::DefaultHasher;
140    use std::hash::{Hash, Hasher};
141
142    let mut hasher = DefaultHasher::new();
143    // Discriminant tag prevents cross-type collisions (e.g. Null vs unknown)
144    match value {
145        Value::Null => 0u8.hash(&mut hasher),
146        Value::Bool(b) => {
147            1u8.hash(&mut hasher);
148            b.hash(&mut hasher);
149        }
150        Value::Int64(i) => {
151            2u8.hash(&mut hasher);
152            i.hash(&mut hasher);
153        }
154        Value::Float64(f) => {
155            3u8.hash(&mut hasher);
156            f.to_bits().hash(&mut hasher);
157        }
158        Value::String(s) => {
159            4u8.hash(&mut hasher);
160            s.hash(&mut hasher);
161        }
162        Value::Bytes(b) => {
163            5u8.hash(&mut hasher);
164            b.hash(&mut hasher);
165        }
166        Value::Timestamp(t) => {
167            6u8.hash(&mut hasher);
168            t.hash(&mut hasher);
169        }
170        Value::Date(d) => {
171            7u8.hash(&mut hasher);
172            d.hash(&mut hasher);
173        }
174        Value::Time(t) => {
175            8u8.hash(&mut hasher);
176            t.hash(&mut hasher);
177        }
178        Value::Duration(d) => {
179            9u8.hash(&mut hasher);
180            d.hash(&mut hasher);
181        }
182        Value::ZonedDatetime(zdt) => {
183            10u8.hash(&mut hasher);
184            zdt.hash(&mut hasher);
185        }
186        Value::List(list) => {
187            11u8.hash(&mut hasher);
188            list.len().hash(&mut hasher);
189            for elem in list.iter() {
190                hash_value(elem).hash(&mut hasher);
191            }
192        }
193        Value::Map(map) => {
194            12u8.hash(&mut hasher);
195            map.len().hash(&mut hasher);
196            // BTreeMap iterates in key order, so hashing is deterministic
197            for (k, v) in map.as_ref() {
198                k.as_str().hash(&mut hasher);
199                hash_value(v).hash(&mut hasher);
200            }
201        }
202        Value::Vector(vec) => {
203            13u8.hash(&mut hasher);
204            vec.len().hash(&mut hasher);
205            for f in vec.iter() {
206                f.to_bits().hash(&mut hasher);
207            }
208        }
209        Value::Path { nodes, edges } => {
210            14u8.hash(&mut hasher);
211            nodes.len().hash(&mut hasher);
212            for n in nodes.iter() {
213                hash_value(n).hash(&mut hasher);
214            }
215            for e in edges.iter() {
216                hash_value(e).hash(&mut hasher);
217            }
218        }
219        Value::GCounter(map) => {
220            15u8.hash(&mut hasher);
221            let mut entries: Vec<_> = map.iter().collect();
222            entries.sort_by_key(|(k, _)| *k);
223            for (k, v) in entries {
224                k.hash(&mut hasher);
225                v.hash(&mut hasher);
226            }
227        }
228        Value::OnCounter { pos, neg } => {
229            16u8.hash(&mut hasher);
230            let mut pos_entries: Vec<_> = pos.iter().collect();
231            pos_entries.sort_by_key(|(k, _)| *k);
232            for (k, v) in pos_entries {
233                k.hash(&mut hasher);
234                v.hash(&mut hasher);
235            }
236            let mut neg_entries: Vec<_> = neg.iter().collect();
237            neg_entries.sort_by_key(|(k, _)| *k);
238            for (k, v) in neg_entries {
239                k.hash(&mut hasher);
240                v.hash(&mut hasher);
241            }
242        }
243        other => {
244            255u8.hash(&mut hasher);
245            std::mem::discriminant(other).hash(&mut hasher);
246        }
247    }
248    hasher.finish()
249}
250
251/// Group state with key values and accumulators.
252#[derive(Clone)]
253struct GroupState {
254    key_values: Vec<Value>,
255    accumulators: Vec<Accumulator>,
256}
257
258/// Push-based aggregate operator.
259///
260/// This is a pipeline breaker that accumulates all input, groups by key,
261/// and produces aggregated output in the finalize phase.
262pub struct AggregatePushOperator {
263    /// Columns to group by.
264    group_by: Vec<usize>,
265    /// Aggregate expressions.
266    aggregates: Vec<AggregateExpr>,
267    /// Group states by hash key.
268    groups: HashMap<GroupKey, GroupState>,
269    /// Global accumulator (for no GROUP BY).
270    global_state: Option<Vec<Accumulator>>,
271}
272
273impl AggregatePushOperator {
274    /// Create a new aggregate operator.
275    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
276        let global_state = if group_by.is_empty() {
277            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
278        } else {
279            None
280        };
281
282        Self {
283            group_by,
284            aggregates,
285            groups: HashMap::new(),
286            global_state,
287        }
288    }
289
290    /// Create a simple global aggregate (no GROUP BY).
291    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
292        Self::new(Vec::new(), aggregates)
293    }
294}
295
296impl PushOperator for AggregatePushOperator {
297    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
298        if chunk.is_empty() {
299            return Ok(true);
300        }
301
302        for row in chunk.selected_indices() {
303            if self.group_by.is_empty() {
304                // Global aggregation
305                if let Some(ref mut accumulators) = self.global_state {
306                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
307                        if let Some(col) = expr.column {
308                            if let Some(c) = chunk.column(col)
309                                && let Some(val) = c.get_value(row)
310                            {
311                                acc.add(&val);
312                            }
313                        } else {
314                            // COUNT(*)
315                            acc.count += 1;
316                        }
317                    }
318                }
319            } else {
320                // Group by aggregation
321                let key = GroupKey::from_row(&chunk, row, &self.group_by);
322
323                let state = self.groups.entry(key).or_insert_with(|| {
324                    let key_values: Vec<Value> = self
325                        .group_by
326                        .iter()
327                        .map(|&col| {
328                            chunk
329                                .column(col)
330                                .and_then(|c| c.get_value(row))
331                                .unwrap_or(Value::Null)
332                        })
333                        .collect();
334
335                    GroupState {
336                        key_values,
337                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
338                    }
339                });
340
341                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
342                    if let Some(col) = expr.column {
343                        if let Some(c) = chunk.column(col)
344                            && let Some(val) = c.get_value(row)
345                        {
346                            acc.add(&val);
347                        }
348                    } else {
349                        // COUNT(*)
350                        acc.count += 1;
351                    }
352                }
353            }
354        }
355
356        Ok(true)
357    }
358
359    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
360        let num_output_cols = self.group_by.len() + self.aggregates.len();
361        let mut columns: Vec<ValueVector> =
362            (0..num_output_cols).map(|_| ValueVector::new()).collect();
363
364        if self.group_by.is_empty() {
365            // Global aggregation - single row output
366            if let Some(ref mut accumulators) = self.global_state {
367                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
368                    columns[i].push(acc.finalize(expr.function));
369                }
370            }
371        } else {
372            // Group by - one row per group
373            for state in self.groups.values_mut() {
374                // Output group key columns
375                for (i, val) in state.key_values.iter().enumerate() {
376                    columns[i].push(val.clone());
377                }
378
379                // Output aggregate results
380                for (i, (acc, expr)) in state
381                    .accumulators
382                    .iter_mut()
383                    .zip(&self.aggregates)
384                    .enumerate()
385                {
386                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
387                }
388            }
389        }
390
391        if !columns.is_empty() && !columns[0].is_empty() {
392            let chunk = DataChunk::new(columns);
393            sink.consume(chunk)?;
394        }
395
396        Ok(())
397    }
398
399    fn preferred_chunk_size(&self) -> ChunkSizeHint {
400        ChunkSizeHint::Default
401    }
402
403    fn name(&self) -> &'static str {
404        "AggregatePush"
405    }
406}
407
408/// Default spill threshold for aggregates (number of groups).
409#[cfg(feature = "spill")]
410pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
411
412/// Serializes a GroupState to bytes.
413#[cfg(feature = "spill")]
414fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
415    use crate::execution::spill::serialize_value;
416
417    // Write key values
418    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
419    for val in &state.key_values {
420        serialize_value(val, w)?;
421    }
422
423    // Write accumulators
424    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
425    for acc in &state.accumulators {
426        w.write_all(&acc.count.to_le_bytes())?;
427        w.write_all(&acc.sum.to_bits().to_le_bytes())?;
428
429        // Min
430        let has_min = acc.min.is_some();
431        w.write_all(&[has_min as u8])?;
432        if let Some(ref v) = acc.min {
433            serialize_value(v, w)?;
434        }
435
436        // Max
437        let has_max = acc.max.is_some();
438        w.write_all(&[has_max as u8])?;
439        if let Some(ref v) = acc.max {
440            serialize_value(v, w)?;
441        }
442
443        // First
444        let has_first = acc.first.is_some();
445        w.write_all(&[has_first as u8])?;
446        if let Some(ref v) = acc.first {
447            serialize_value(v, w)?;
448        }
449    }
450
451    Ok(())
452}
453
454/// Deserializes a GroupState from bytes.
455#[cfg(feature = "spill")]
456fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
457    use crate::execution::spill::deserialize_value;
458
459    // Read key values
460    let mut len_buf = [0u8; 8];
461    r.read_exact(&mut len_buf)?;
462    let num_keys = u64::from_le_bytes(len_buf) as usize;
463
464    let mut key_values = Vec::with_capacity(num_keys);
465    for _ in 0..num_keys {
466        key_values.push(deserialize_value(r)?);
467    }
468
469    // Read accumulators
470    r.read_exact(&mut len_buf)?;
471    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
472
473    let mut accumulators = Vec::with_capacity(num_accumulators);
474    for _ in 0..num_accumulators {
475        let mut count_buf = [0u8; 8];
476        r.read_exact(&mut count_buf)?;
477        let count = i64::from_le_bytes(count_buf);
478
479        r.read_exact(&mut count_buf)?;
480        let sum = f64::from_bits(u64::from_le_bytes(count_buf));
481
482        // Min
483        let mut flag_buf = [0u8; 1];
484        r.read_exact(&mut flag_buf)?;
485        let min = if flag_buf[0] != 0 {
486            Some(deserialize_value(r)?)
487        } else {
488            None
489        };
490
491        // Max
492        r.read_exact(&mut flag_buf)?;
493        let max = if flag_buf[0] != 0 {
494            Some(deserialize_value(r)?)
495        } else {
496            None
497        };
498
499        // First
500        r.read_exact(&mut flag_buf)?;
501        let first = if flag_buf[0] != 0 {
502            Some(deserialize_value(r)?)
503        } else {
504            None
505        };
506
507        accumulators.push(Accumulator {
508            count,
509            sum,
510            min,
511            max,
512            first,
513        });
514    }
515
516    Ok(GroupState {
517        key_values,
518        accumulators,
519    })
520}
521
522/// Push-based aggregate operator with spilling support.
523///
524/// Uses partitioned hash table that can spill cold partitions to disk
525/// when memory pressure is high.
526#[cfg(feature = "spill")]
527pub struct SpillableAggregatePushOperator {
528    /// Columns to group by.
529    group_by: Vec<usize>,
530    /// Aggregate expressions.
531    aggregates: Vec<AggregateExpr>,
532    /// Spill manager (None = no spilling).
533    spill_manager: Option<Arc<SpillManager>>,
534    /// Partitioned groups (used when spilling is enabled).
535    partitioned_groups: Option<PartitionedState<GroupState>>,
536    /// Non-partitioned groups (used when spilling is disabled).
537    groups: HashMap<GroupKey, GroupState>,
538    /// Global accumulator (for no GROUP BY).
539    global_state: Option<Vec<Accumulator>>,
540    /// Spill threshold (number of groups).
541    spill_threshold: usize,
542    /// Whether we've switched to partitioned mode.
543    using_partitioned: bool,
544}
545
546#[cfg(feature = "spill")]
547impl SpillableAggregatePushOperator {
548    /// Create a new spillable aggregate operator.
549    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
550        let global_state = if group_by.is_empty() {
551            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
552        } else {
553            None
554        };
555
556        Self {
557            group_by,
558            aggregates,
559            spill_manager: None,
560            partitioned_groups: None,
561            groups: HashMap::new(),
562            global_state,
563            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
564            using_partitioned: false,
565        }
566    }
567
568    /// Create a spillable aggregate operator with spilling enabled.
569    pub fn with_spilling(
570        group_by: Vec<usize>,
571        aggregates: Vec<AggregateExpr>,
572        manager: Arc<SpillManager>,
573        threshold: usize,
574    ) -> Self {
575        let global_state = if group_by.is_empty() {
576            Some(aggregates.iter().map(|_| Accumulator::new()).collect())
577        } else {
578            None
579        };
580
581        let partitioned = PartitionedState::new(
582            Arc::clone(&manager),
583            256, // Number of partitions
584            serialize_group_state,
585            deserialize_group_state,
586        );
587
588        Self {
589            group_by,
590            aggregates,
591            spill_manager: Some(manager),
592            partitioned_groups: Some(partitioned),
593            groups: HashMap::new(),
594            global_state,
595            spill_threshold: threshold,
596            using_partitioned: true,
597        }
598    }
599
600    /// Create a simple global aggregate (no GROUP BY).
601    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
602        Self::new(Vec::new(), aggregates)
603    }
604
605    /// Sets the spill threshold.
606    pub fn with_threshold(mut self, threshold: usize) -> Self {
607        self.spill_threshold = threshold;
608        self
609    }
610
611    /// Switches to partitioned mode if needed.
612    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
613        if self.global_state.is_some() {
614            // Global aggregation doesn't need spilling
615            return Ok(());
616        }
617
618        // If using partitioned state, check if we need to spill
619        if let Some(ref mut partitioned) = self.partitioned_groups {
620            if partitioned.total_size() >= self.spill_threshold {
621                partitioned
622                    .spill_largest()
623                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
624            }
625        } else if self.groups.len() >= self.spill_threshold {
626            // Not using partitioned state yet, but reached threshold
627            // If spilling is configured, switch to partitioned mode
628            if let Some(ref manager) = self.spill_manager {
629                let mut partitioned = PartitionedState::new(
630                    Arc::clone(manager),
631                    256,
632                    serialize_group_state,
633                    deserialize_group_state,
634                );
635
636                // Move existing groups to partitioned state
637                for (_key, state) in self.groups.drain() {
638                    partitioned
639                        .insert(state.key_values.clone(), state)
640                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
641                }
642
643                self.partitioned_groups = Some(partitioned);
644                self.using_partitioned = true;
645            }
646        }
647
648        Ok(())
649    }
650}
651
652#[cfg(feature = "spill")]
653impl PushOperator for SpillableAggregatePushOperator {
654    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
655        if chunk.is_empty() {
656            return Ok(true);
657        }
658
659        for row in chunk.selected_indices() {
660            if self.group_by.is_empty() {
661                // Global aggregation - same as non-spillable
662                if let Some(ref mut accumulators) = self.global_state {
663                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
664                        if let Some(col) = expr.column {
665                            if let Some(c) = chunk.column(col)
666                                && let Some(val) = c.get_value(row)
667                            {
668                                acc.add(&val);
669                            }
670                        } else {
671                            acc.count += 1;
672                        }
673                    }
674                }
675            } else if self.using_partitioned {
676                // Use partitioned state
677                if let Some(ref mut partitioned) = self.partitioned_groups {
678                    let key_values: Vec<Value> = self
679                        .group_by
680                        .iter()
681                        .map(|&col| {
682                            chunk
683                                .column(col)
684                                .and_then(|c| c.get_value(row))
685                                .unwrap_or(Value::Null)
686                        })
687                        .collect();
688
689                    let aggregates = &self.aggregates;
690                    let state = partitioned
691                        .get_or_insert_with(key_values.clone(), || GroupState {
692                            key_values: key_values.clone(),
693                            accumulators: aggregates.iter().map(|_| Accumulator::new()).collect(),
694                        })
695                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
696
697                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
698                        if let Some(col) = expr.column {
699                            if let Some(c) = chunk.column(col)
700                                && let Some(val) = c.get_value(row)
701                            {
702                                acc.add(&val);
703                            }
704                        } else {
705                            acc.count += 1;
706                        }
707                    }
708                }
709            } else {
710                // Use regular hash map
711                let key = GroupKey::from_row(&chunk, row, &self.group_by);
712
713                let state = self.groups.entry(key).or_insert_with(|| {
714                    let key_values: Vec<Value> = self
715                        .group_by
716                        .iter()
717                        .map(|&col| {
718                            chunk
719                                .column(col)
720                                .and_then(|c| c.get_value(row))
721                                .unwrap_or(Value::Null)
722                        })
723                        .collect();
724
725                    GroupState {
726                        key_values,
727                        accumulators: self.aggregates.iter().map(|_| Accumulator::new()).collect(),
728                    }
729                });
730
731                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
732                    if let Some(col) = expr.column {
733                        if let Some(c) = chunk.column(col)
734                            && let Some(val) = c.get_value(row)
735                        {
736                            acc.add(&val);
737                        }
738                    } else {
739                        acc.count += 1;
740                    }
741                }
742            }
743        }
744
745        // Check if we need to spill
746        self.maybe_spill()?;
747
748        Ok(true)
749    }
750
751    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
752        let num_output_cols = self.group_by.len() + self.aggregates.len();
753        let mut columns: Vec<ValueVector> =
754            (0..num_output_cols).map(|_| ValueVector::new()).collect();
755
756        if self.group_by.is_empty() {
757            // Global aggregation - single row output
758            if let Some(ref mut accumulators) = self.global_state {
759                for (i, (acc, expr)) in accumulators.iter_mut().zip(&self.aggregates).enumerate() {
760                    columns[i].push(acc.finalize(expr.function));
761                }
762            }
763        } else if self.using_partitioned {
764            // Drain partitioned state
765            if let Some(ref mut partitioned) = self.partitioned_groups {
766                let groups = partitioned
767                    .drain_all()
768                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
769
770                for (_key, mut state) in groups {
771                    // Output group key columns
772                    for (i, val) in state.key_values.iter().enumerate() {
773                        columns[i].push(val.clone());
774                    }
775
776                    // Output aggregate results
777                    for (i, (acc, expr)) in state
778                        .accumulators
779                        .iter_mut()
780                        .zip(&self.aggregates)
781                        .enumerate()
782                    {
783                        columns[self.group_by.len() + i].push(acc.finalize(expr.function));
784                    }
785                }
786            }
787        } else {
788            // Group by using regular hash map - one row per group
789            for state in self.groups.values_mut() {
790                // Output group key columns
791                for (i, val) in state.key_values.iter().enumerate() {
792                    columns[i].push(val.clone());
793                }
794
795                // Output aggregate results
796                for (i, (acc, expr)) in state
797                    .accumulators
798                    .iter_mut()
799                    .zip(&self.aggregates)
800                    .enumerate()
801                {
802                    columns[self.group_by.len() + i].push(acc.finalize(expr.function));
803                }
804            }
805        }
806
807        if !columns.is_empty() && !columns[0].is_empty() {
808            let chunk = DataChunk::new(columns);
809            sink.consume(chunk)?;
810        }
811
812        Ok(())
813    }
814
815    fn preferred_chunk_size(&self) -> ChunkSizeHint {
816        ChunkSizeHint::Default
817    }
818
819    fn name(&self) -> &'static str {
820        "SpillableAggregatePush"
821    }
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827    use crate::execution::sink::CollectorSink;
828
829    fn create_test_chunk(values: &[i64]) -> DataChunk {
830        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
831        let vector = ValueVector::from_values(&v);
832        DataChunk::new(vec![vector])
833    }
834
835    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
836        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
837        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
838        DataChunk::new(vec![
839            ValueVector::from_values(&v1),
840            ValueVector::from_values(&v2),
841        ])
842    }
843
844    #[test]
845    fn test_global_count() {
846        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
847        let mut sink = CollectorSink::new();
848
849        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
850            .unwrap();
851        agg.finalize(&mut sink).unwrap();
852
853        let chunks = sink.into_chunks();
854        assert_eq!(chunks.len(), 1);
855        assert_eq!(
856            chunks[0].column(0).unwrap().get_value(0),
857            Some(Value::Int64(5))
858        );
859    }
860
861    #[test]
862    fn test_global_sum() {
863        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
864        let mut sink = CollectorSink::new();
865
866        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
867            .unwrap();
868        agg.finalize(&mut sink).unwrap();
869
870        let chunks = sink.into_chunks();
871        assert_eq!(
872            chunks[0].column(0).unwrap().get_value(0),
873            Some(Value::Float64(15.0))
874        );
875    }
876
877    #[test]
878    fn test_global_min_max() {
879        let mut agg =
880            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
881        let mut sink = CollectorSink::new();
882
883        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
884            .unwrap();
885        agg.finalize(&mut sink).unwrap();
886
887        let chunks = sink.into_chunks();
888        assert_eq!(
889            chunks[0].column(0).unwrap().get_value(0),
890            Some(Value::Int64(1))
891        );
892        assert_eq!(
893            chunks[0].column(1).unwrap().get_value(0),
894            Some(Value::Int64(9))
895        );
896    }
897
898    #[test]
899    fn test_group_by_sum() {
900        // Group by column 0, sum column 1
901        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
902        let mut sink = CollectorSink::new();
903
904        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
905        agg.push(
906            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
907            &mut sink,
908        )
909        .unwrap();
910        agg.finalize(&mut sink).unwrap();
911
912        let chunks = sink.into_chunks();
913        assert_eq!(chunks[0].len(), 2); // 2 groups
914    }
915
916    #[test]
917    #[cfg(feature = "spill")]
918    fn test_spillable_aggregate_no_spill() {
919        // When threshold is not reached, should work like normal aggregate
920        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
921            .with_threshold(100);
922        let mut sink = CollectorSink::new();
923
924        agg.push(
925            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
926            &mut sink,
927        )
928        .unwrap();
929        agg.finalize(&mut sink).unwrap();
930
931        let chunks = sink.into_chunks();
932        assert_eq!(chunks[0].len(), 2); // 2 groups
933    }
934
935    #[test]
936    #[cfg(feature = "spill")]
937    fn test_spillable_aggregate_with_spilling() {
938        use tempfile::TempDir;
939
940        let temp_dir = TempDir::new().unwrap();
941        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
942
943        // Set very low threshold to force spilling
944        let mut agg = SpillableAggregatePushOperator::with_spilling(
945            vec![0],
946            vec![AggregateExpr::sum(1)],
947            manager,
948            3, // Spill after 3 groups
949        );
950        let mut sink = CollectorSink::new();
951
952        // Create 10 different groups
953        for i in 0..10 {
954            let chunk = create_two_column_chunk(&[i], &[i * 10]);
955            agg.push(chunk, &mut sink).unwrap();
956        }
957        agg.finalize(&mut sink).unwrap();
958
959        let chunks = sink.into_chunks();
960        assert_eq!(chunks.len(), 1);
961        assert_eq!(chunks[0].len(), 10); // 10 groups
962
963        // Verify sums are correct
964        let mut sums: Vec<f64> = Vec::new();
965        for i in 0..chunks[0].len() {
966            if let Some(Value::Float64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
967                sums.push(sum);
968            }
969        }
970        sums.sort_by(|a, b| a.partial_cmp(b).unwrap());
971        assert_eq!(
972            sums,
973            vec![0.0, 10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0, 90.0]
974        );
975    }
976
977    #[test]
978    #[cfg(feature = "spill")]
979    fn test_spillable_aggregate_global() {
980        // Global aggregation shouldn't be affected by spilling
981        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
982        let mut sink = CollectorSink::new();
983
984        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
985            .unwrap();
986        agg.finalize(&mut sink).unwrap();
987
988        let chunks = sink.into_chunks();
989        assert_eq!(chunks.len(), 1);
990        assert_eq!(
991            chunks[0].column(0).unwrap().get_value(0),
992            Some(Value::Int64(5))
993        );
994    }
995
996    #[test]
997    #[cfg(feature = "spill")]
998    fn test_spillable_aggregate_many_groups() {
999        use tempfile::TempDir;
1000
1001        let temp_dir = TempDir::new().unwrap();
1002        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1003
1004        let mut agg = SpillableAggregatePushOperator::with_spilling(
1005            vec![0],
1006            vec![AggregateExpr::count_star()],
1007            manager,
1008            10, // Very low threshold
1009        );
1010        let mut sink = CollectorSink::new();
1011
1012        // Create 100 different groups
1013        for i in 0..100 {
1014            let chunk = create_test_chunk(&[i]);
1015            agg.push(chunk, &mut sink).unwrap();
1016        }
1017        agg.finalize(&mut sink).unwrap();
1018
1019        let chunks = sink.into_chunks();
1020        assert_eq!(chunks.len(), 1);
1021        assert_eq!(chunks[0].len(), 100); // 100 groups
1022
1023        // Each group should have count = 1
1024        for i in 0..100 {
1025            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1026                assert_eq!(count, 1);
1027            }
1028        }
1029    }
1030
1031    // ---------------------------------------------------------------
1032    // hash_value coverage for all Value variants
1033    // ---------------------------------------------------------------
1034
1035    #[test]
1036    fn hash_value_null() {
1037        let h = hash_value(&Value::Null);
1038        assert_ne!(h, 0); // hasher produces non-zero for Null discriminant
1039    }
1040
1041    #[test]
1042    fn hash_value_bool() {
1043        let t = hash_value(&Value::Bool(true));
1044        let f = hash_value(&Value::Bool(false));
1045        assert_ne!(t, f);
1046    }
1047
1048    #[test]
1049    fn hash_value_int64() {
1050        let a = hash_value(&Value::Int64(42));
1051        let b = hash_value(&Value::Int64(43));
1052        assert_ne!(a, b);
1053    }
1054
1055    #[test]
1056    fn hash_value_float64() {
1057        let a = hash_value(&Value::Float64(19.88));
1058        let b = hash_value(&Value::Float64(3.19));
1059        assert_ne!(a, b);
1060    }
1061
1062    #[test]
1063    fn hash_value_string() {
1064        let a = hash_value(&Value::String("hello".into()));
1065        let b = hash_value(&Value::String("world".into()));
1066        assert_ne!(a, b);
1067    }
1068
1069    #[test]
1070    fn hash_value_bytes() {
1071        let a = hash_value(&Value::Bytes(vec![1, 2, 3].into()));
1072        let b = hash_value(&Value::Bytes(vec![4, 5, 6].into()));
1073        assert_ne!(a, b);
1074    }
1075
1076    #[test]
1077    fn hash_value_list() {
1078        let a = hash_value(&Value::List(vec![Value::Int64(1), Value::Int64(2)].into()));
1079        let b = hash_value(&Value::List(vec![Value::Int64(3)].into()));
1080        assert_ne!(a, b);
1081    }
1082
1083    #[test]
1084    fn hash_value_map() {
1085        use grafeo_common::types::PropertyKey;
1086        use std::collections::BTreeMap;
1087        use std::sync::Arc;
1088        let mut map = BTreeMap::new();
1089        map.insert(PropertyKey::new("key"), Value::Int64(42));
1090        let h = hash_value(&Value::Map(Arc::new(map)));
1091        assert_ne!(h, 0);
1092    }
1093
1094    #[test]
1095    fn hash_value_vector() {
1096        let h = hash_value(&Value::Vector(vec![1.0, 2.0, 3.0].into()));
1097        assert_ne!(h, 0);
1098    }
1099
1100    #[test]
1101    fn hash_value_path() {
1102        let h = hash_value(&Value::Path {
1103            nodes: vec![Value::Int64(1), Value::Int64(2)].into(),
1104            edges: vec![Value::Int64(10)].into(),
1105        });
1106        assert_ne!(h, 0);
1107    }
1108
1109    #[test]
1110    fn hash_value_gcounter() {
1111        use std::sync::Arc;
1112        let mut map = std::collections::HashMap::new();
1113        map.insert("replica1".to_string(), 10u64);
1114        let h = hash_value(&Value::GCounter(Arc::new(map)));
1115        assert_ne!(h, 0);
1116    }
1117
1118    #[test]
1119    fn hash_value_on_counter() {
1120        use std::sync::Arc;
1121        let mut pos = std::collections::HashMap::new();
1122        pos.insert("replica1".to_string(), 10u64);
1123        let neg = std::collections::HashMap::new();
1124        let h = hash_value(&Value::OnCounter {
1125            pos: Arc::new(pos),
1126            neg: Arc::new(neg),
1127        });
1128        assert_ne!(h, 0);
1129    }
1130
1131    #[test]
1132    fn hash_value_timestamp() {
1133        use grafeo_common::types::Timestamp;
1134        let h = hash_value(&Value::Timestamp(Timestamp::from_micros(1_700_000_000_000)));
1135        assert_ne!(h, 0);
1136    }
1137
1138    #[test]
1139    fn hash_value_date() {
1140        use grafeo_common::types::Date;
1141        let h = hash_value(&Value::Date(Date::from_days(19000)));
1142        assert_ne!(h, 0);
1143    }
1144
1145    #[test]
1146    fn hash_value_time() {
1147        use grafeo_common::types::Time;
1148        let h = hash_value(&Value::Time(Time::from_hms(12, 0, 0).unwrap()));
1149        assert_ne!(h, 0);
1150    }
1151
1152    #[test]
1153    fn hash_value_duration() {
1154        use grafeo_common::types::Duration;
1155        let h = hash_value(&Value::Duration(Duration::from_days(1)));
1156        assert_ne!(h, 0);
1157    }
1158
1159    #[test]
1160    fn hash_value_zoned_datetime() {
1161        use grafeo_common::types::{Timestamp, ZonedDatetime};
1162        let zdt =
1163            ZonedDatetime::from_timestamp_offset(Timestamp::from_micros(1_700_000_000_000), 3600);
1164        let h = hash_value(&Value::ZonedDatetime(zdt));
1165        assert_ne!(h, 0);
1166    }
1167
1168    // ---------------------------------------------------------------
1169    // Accumulator finalize for advanced functions (fallback to Null)
1170    // ---------------------------------------------------------------
1171
1172    #[test]
1173    fn finalize_advanced_functions_return_null() {
1174        let advanced = [
1175            AggregateFunction::Last,
1176            AggregateFunction::Collect,
1177            AggregateFunction::StdDev,
1178            AggregateFunction::StdDevPop,
1179            AggregateFunction::Variance,
1180            AggregateFunction::VariancePop,
1181            AggregateFunction::PercentileDisc,
1182            AggregateFunction::PercentileCont,
1183            AggregateFunction::GroupConcat,
1184            AggregateFunction::Sample,
1185            AggregateFunction::CovarSamp,
1186            AggregateFunction::CovarPop,
1187            AggregateFunction::Corr,
1188            AggregateFunction::RegrSlope,
1189            AggregateFunction::RegrIntercept,
1190            AggregateFunction::RegrR2,
1191            AggregateFunction::RegrCount,
1192            AggregateFunction::RegrSxx,
1193            AggregateFunction::RegrSyy,
1194            AggregateFunction::RegrSxy,
1195            AggregateFunction::RegrAvgx,
1196            AggregateFunction::RegrAvgy,
1197        ];
1198
1199        for func in advanced {
1200            let mut acc = Accumulator::new();
1201            acc.add(&Value::Int64(42));
1202            let result = acc.finalize(func);
1203            assert_eq!(
1204                result,
1205                Value::Null,
1206                "Advanced function {func:?} should return Null in push accumulator"
1207            );
1208        }
1209    }
1210
1211    #[test]
1212    fn finalize_first_returns_first_value() {
1213        let mut acc = Accumulator::new();
1214        acc.add(&Value::Int64(10));
1215        acc.add(&Value::Int64(20));
1216        assert_eq!(acc.finalize(AggregateFunction::First), Value::Int64(10));
1217    }
1218
1219    #[test]
1220    fn finalize_avg_empty_returns_null() {
1221        let mut acc = Accumulator::new();
1222        assert_eq!(acc.finalize(AggregateFunction::Avg), Value::Null);
1223    }
1224
1225    #[test]
1226    fn finalize_sum_empty_returns_null() {
1227        let mut acc = Accumulator::new();
1228        assert_eq!(acc.finalize(AggregateFunction::Sum), Value::Null);
1229    }
1230
1231    #[test]
1232    fn finalize_min_max_empty_returns_null() {
1233        let mut acc_min = Accumulator::new();
1234        let mut acc_max = Accumulator::new();
1235        assert_eq!(acc_min.finalize(AggregateFunction::Min), Value::Null);
1236        assert_eq!(acc_max.finalize(AggregateFunction::Max), Value::Null);
1237    }
1238
1239    #[test]
1240    fn accumulator_skips_nulls() {
1241        let mut acc = Accumulator::new();
1242        acc.add(&Value::Null);
1243        acc.add(&Value::Int64(5));
1244        acc.add(&Value::Null);
1245        assert_eq!(acc.count, 1);
1246        assert_eq!(acc.finalize(AggregateFunction::Count), Value::Int64(1));
1247    }
1248
1249    #[test]
1250    fn test_empty_chunk_returns_ok() {
1251        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1252        let mut sink = CollectorSink::new();
1253        let empty = DataChunk::new(vec![ValueVector::new()]);
1254        let result = agg.push(empty, &mut sink).unwrap();
1255        assert!(result);
1256    }
1257}