Skip to main content

grafeo_core/execution/operators/push/
aggregate.rs

1//! Push-based aggregate operator (pipeline breaker).
2
3use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::operators::accumulator::{AggregateExpr, AggregateFunction, AggregateState};
6use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
7#[cfg(feature = "spill")]
8use crate::execution::spill::{PartitionedState, SpillManager};
9use crate::execution::vector::ValueVector;
10use grafeo_common::types::Value;
11use std::collections::HashMap;
12#[cfg(feature = "spill")]
13use std::io::{Read, Write};
14#[cfg(feature = "spill")]
15use std::sync::Arc;
16
17/// Creates a new [`AggregateState`] from an [`AggregateExpr`].
18fn state_for_expr(expr: &AggregateExpr) -> AggregateState {
19    AggregateState::new(
20        expr.function,
21        expr.distinct,
22        expr.percentile,
23        expr.separator.as_deref(),
24    )
25}
26
27/// Updates a single accumulator from a data chunk row, handling bivariate
28/// functions, `CountNonNull` null-skipping, and `COUNT(*)`.
29fn update_accumulator(
30    acc: &mut AggregateState,
31    expr: &AggregateExpr,
32    chunk: &DataChunk,
33    row: usize,
34) {
35    // Bivariate set functions (COVAR, CORR, REGR_*) need two column values
36    if expr.column2.is_some() {
37        let y_val = expr
38            .column
39            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
40        let x_val = expr
41            .column2
42            .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
43        acc.update_bivariate(y_val, x_val);
44        return;
45    }
46
47    if let Some(col) = expr.column {
48        let val = chunk.column(col).and_then(|c| c.get_value(row));
49        // CountNonNull must skip null values
50        if expr.function == AggregateFunction::CountNonNull
51            && matches!(val, None | Some(Value::Null))
52        {
53            return;
54        }
55        acc.update(val);
56    } else {
57        // COUNT(*)
58        acc.update(None);
59    }
60}
61
62/// Hash key for grouping.
63#[derive(Debug, Clone, PartialEq, Eq, Hash)]
64struct GroupKey(Vec<u64>);
65
66impl GroupKey {
67    fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
68        let hashes: Vec<u64> = group_by
69            .iter()
70            .map(|&col| {
71                chunk
72                    .column(col)
73                    .and_then(|c| c.get_value(row))
74                    .map_or(0, |v| hash_value(&v))
75            })
76            .collect();
77        Self(hashes)
78    }
79}
80
81fn hash_value(value: &Value) -> u64 {
82    use std::collections::hash_map::DefaultHasher;
83    use std::hash::{Hash, Hasher};
84
85    let mut hasher = DefaultHasher::new();
86    // Discriminant tag prevents cross-type collisions (e.g. Null vs unknown)
87    match value {
88        Value::Null => 0u8.hash(&mut hasher),
89        Value::Bool(b) => {
90            1u8.hash(&mut hasher);
91            b.hash(&mut hasher);
92        }
93        Value::Int64(i) => {
94            2u8.hash(&mut hasher);
95            i.hash(&mut hasher);
96        }
97        Value::Float64(f) => {
98            3u8.hash(&mut hasher);
99            f.to_bits().hash(&mut hasher);
100        }
101        Value::String(s) => {
102            4u8.hash(&mut hasher);
103            s.hash(&mut hasher);
104        }
105        Value::Bytes(b) => {
106            5u8.hash(&mut hasher);
107            b.hash(&mut hasher);
108        }
109        Value::Timestamp(t) => {
110            6u8.hash(&mut hasher);
111            t.hash(&mut hasher);
112        }
113        Value::Date(d) => {
114            7u8.hash(&mut hasher);
115            d.hash(&mut hasher);
116        }
117        Value::Time(t) => {
118            8u8.hash(&mut hasher);
119            t.hash(&mut hasher);
120        }
121        Value::Duration(d) => {
122            9u8.hash(&mut hasher);
123            d.hash(&mut hasher);
124        }
125        Value::ZonedDatetime(zdt) => {
126            10u8.hash(&mut hasher);
127            zdt.hash(&mut hasher);
128        }
129        Value::List(list) => {
130            11u8.hash(&mut hasher);
131            list.len().hash(&mut hasher);
132            for elem in list.iter() {
133                hash_value(elem).hash(&mut hasher);
134            }
135        }
136        Value::Map(map) => {
137            12u8.hash(&mut hasher);
138            map.len().hash(&mut hasher);
139            // BTreeMap iterates in key order, so hashing is deterministic
140            for (k, v) in map.as_ref() {
141                k.as_str().hash(&mut hasher);
142                hash_value(v).hash(&mut hasher);
143            }
144        }
145        Value::Vector(vec) => {
146            13u8.hash(&mut hasher);
147            vec.len().hash(&mut hasher);
148            for f in vec.iter() {
149                f.to_bits().hash(&mut hasher);
150            }
151        }
152        Value::Path { nodes, edges } => {
153            14u8.hash(&mut hasher);
154            nodes.len().hash(&mut hasher);
155            for n in nodes.iter() {
156                hash_value(n).hash(&mut hasher);
157            }
158            for e in edges.iter() {
159                hash_value(e).hash(&mut hasher);
160            }
161        }
162        Value::GCounter(map) => {
163            15u8.hash(&mut hasher);
164            let mut entries: Vec<_> = map.iter().collect();
165            entries.sort_by_key(|(k, _)| *k);
166            for (k, v) in entries {
167                k.hash(&mut hasher);
168                v.hash(&mut hasher);
169            }
170        }
171        Value::OnCounter { pos, neg } => {
172            16u8.hash(&mut hasher);
173            let mut pos_entries: Vec<_> = pos.iter().collect();
174            pos_entries.sort_by_key(|(k, _)| *k);
175            for (k, v) in pos_entries {
176                k.hash(&mut hasher);
177                v.hash(&mut hasher);
178            }
179            let mut neg_entries: Vec<_> = neg.iter().collect();
180            neg_entries.sort_by_key(|(k, _)| *k);
181            for (k, v) in neg_entries {
182                k.hash(&mut hasher);
183                v.hash(&mut hasher);
184            }
185        }
186        other => {
187            255u8.hash(&mut hasher);
188            std::mem::discriminant(other).hash(&mut hasher);
189        }
190    }
191    hasher.finish()
192}
193
194/// Group state with key values and accumulators.
195#[derive(Clone)]
196struct GroupState {
197    key_values: Vec<Value>,
198    accumulators: Vec<AggregateState>,
199}
200
201/// Push-based aggregate operator.
202///
203/// This is a pipeline breaker that accumulates all input, groups by key,
204/// and produces aggregated output in the finalize phase.
205pub struct AggregatePushOperator {
206    /// Columns to group by.
207    group_by: Vec<usize>,
208    /// Aggregate expressions.
209    aggregates: Vec<AggregateExpr>,
210    /// Group states by hash key.
211    groups: HashMap<GroupKey, GroupState>,
212    /// Global accumulator (for no GROUP BY).
213    global_state: Option<Vec<AggregateState>>,
214}
215
216impl AggregatePushOperator {
217    /// Create a new aggregate operator.
218    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
219        let global_state = if group_by.is_empty() {
220            Some(aggregates.iter().map(state_for_expr).collect())
221        } else {
222            None
223        };
224
225        Self {
226            group_by,
227            aggregates,
228            groups: HashMap::new(),
229            global_state,
230        }
231    }
232
233    /// Create a simple global aggregate (no GROUP BY).
234    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
235        Self::new(Vec::new(), aggregates)
236    }
237}
238
239impl PushOperator for AggregatePushOperator {
240    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
241        if chunk.is_empty() {
242            return Ok(true);
243        }
244
245        for row in chunk.selected_indices() {
246            if self.group_by.is_empty() {
247                // Global aggregation
248                if let Some(ref mut accumulators) = self.global_state {
249                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
250                        update_accumulator(acc, expr, &chunk, row);
251                    }
252                }
253            } else {
254                // Group by aggregation
255                let key = GroupKey::from_row(&chunk, row, &self.group_by);
256
257                let state = self.groups.entry(key).or_insert_with(|| {
258                    let key_values: Vec<Value> = self
259                        .group_by
260                        .iter()
261                        .map(|&col| {
262                            chunk
263                                .column(col)
264                                .and_then(|c| c.get_value(row))
265                                .unwrap_or(Value::Null)
266                        })
267                        .collect();
268
269                    GroupState {
270                        key_values,
271                        accumulators: self.aggregates.iter().map(state_for_expr).collect(),
272                    }
273                });
274
275                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
276                    update_accumulator(acc, expr, &chunk, row);
277                }
278            }
279        }
280
281        Ok(true)
282    }
283
284    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
285        let num_output_cols = self.group_by.len() + self.aggregates.len();
286        let mut columns: Vec<ValueVector> =
287            (0..num_output_cols).map(|_| ValueVector::new()).collect();
288
289        if self.group_by.is_empty() {
290            // Global aggregation - single row output
291            if let Some(ref accumulators) = self.global_state {
292                for (i, acc) in accumulators.iter().enumerate() {
293                    columns[i].push(acc.finalize());
294                }
295            }
296        } else {
297            // Group by - one row per group
298            for state in self.groups.values() {
299                // Output group key columns
300                for (i, val) in state.key_values.iter().enumerate() {
301                    columns[i].push(val.clone());
302                }
303
304                // Output aggregate results
305                for (i, acc) in state.accumulators.iter().enumerate() {
306                    columns[self.group_by.len() + i].push(acc.finalize());
307                }
308            }
309        }
310
311        if !columns.is_empty() && !columns[0].is_empty() {
312            let chunk = DataChunk::new(columns);
313            sink.consume(chunk)?;
314        }
315
316        Ok(())
317    }
318
319    fn preferred_chunk_size(&self) -> ChunkSizeHint {
320        ChunkSizeHint::Default
321    }
322
323    fn name(&self) -> &'static str {
324        "AggregatePush"
325    }
326}
327
328/// Default spill threshold for aggregates (number of groups).
329#[cfg(feature = "spill")]
330pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
331
332/// Tag bytes for aggregate state variants used during spill serialization.
333///
334/// Each tag identifies both the aggregate function AND how to reconstruct
335/// the accumulator state so it can continue receiving updates after reload.
336#[cfg(feature = "spill")]
337mod spill_tag {
338    pub const COUNT: u8 = 0;
339    pub const SUM_INT: u8 = 1;
340    pub const SUM_FLOAT: u8 = 2;
341    pub const AVG: u8 = 3;
342    pub const MIN: u8 = 4;
343    pub const MAX: u8 = 5;
344    pub const FIRST: u8 = 6;
345    pub const LAST: u8 = 7;
346    pub const COLLECT: u8 = 8;
347    /// Fallback: stores finalized value only, cannot resume accumulation.
348    pub const FINALIZED: u8 = 255;
349}
350
351/// Serializes a `GroupState` to bytes.
352///
353/// Each accumulator is serialized with a tag byte indicating the state variant
354/// followed by the internal fields needed to reconstruct a resumable state.
355/// For complex variants (StdDev, percentiles, bivariate, etc.) the finalized
356/// value is stored instead, since those are rare in spill scenarios.
357#[cfg(feature = "spill")]
358fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
359    use crate::execution::spill::serialize_value;
360
361    // Write key values
362    w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
363    for val in &state.key_values {
364        serialize_value(val, w)?;
365    }
366
367    // Write accumulators with tag bytes
368    w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
369    for acc in &state.accumulators {
370        match acc {
371            AggregateState::Count(n) => {
372                w.write_all(&[spill_tag::COUNT])?;
373                w.write_all(&n.to_le_bytes())?;
374            }
375            AggregateState::SumInt(sum, count) => {
376                w.write_all(&[spill_tag::SUM_INT])?;
377                w.write_all(&sum.to_le_bytes())?;
378                w.write_all(&count.to_le_bytes())?;
379            }
380            AggregateState::SumFloat(sum, _comp, count) => {
381                w.write_all(&[spill_tag::SUM_FLOAT])?;
382                w.write_all(&sum.to_le_bytes())?;
383                w.write_all(&count.to_le_bytes())?;
384            }
385            AggregateState::Avg(sum, count) => {
386                w.write_all(&[spill_tag::AVG])?;
387                w.write_all(&sum.to_le_bytes())?;
388                w.write_all(&count.to_le_bytes())?;
389            }
390            // DISTINCT variants track a HashSet that can't be serialized compactly.
391            // Serialize as finalized to avoid dropping distinct semantics.
392            AggregateState::CountDistinct(..)
393            | AggregateState::SumIntDistinct(..)
394            | AggregateState::SumFloatDistinct(..)
395            | AggregateState::AvgDistinct(..)
396            | AggregateState::CollectDistinct(..)
397            | AggregateState::GroupConcatDistinct(..) => {
398                w.write_all(&[spill_tag::FINALIZED])?;
399                serialize_value(&acc.finalize(), w)?;
400            }
401            AggregateState::Min(val) => {
402                w.write_all(&[spill_tag::MIN])?;
403                serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
404            }
405            AggregateState::Max(val) => {
406                w.write_all(&[spill_tag::MAX])?;
407                serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
408            }
409            AggregateState::First(val) => {
410                w.write_all(&[spill_tag::FIRST])?;
411                serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
412            }
413            AggregateState::Last(val) => {
414                w.write_all(&[spill_tag::LAST])?;
415                serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
416            }
417            AggregateState::Collect(list) => {
418                w.write_all(&[spill_tag::COLLECT])?;
419                w.write_all(&(list.len() as u64).to_le_bytes())?;
420                for val in list {
421                    serialize_value(val, w)?;
422                }
423            }
424            // Complex states: serialize finalized value as fallback
425            _ => {
426                w.write_all(&[spill_tag::FINALIZED])?;
427                serialize_value(&acc.finalize(), w)?;
428            }
429        }
430    }
431
432    Ok(())
433}
434
435/// Deserializes a `GroupState` from bytes.
436///
437/// Reconstructs the correct `AggregateState` variant from the tag byte so that
438/// reloaded groups can continue accumulating rows. Common variants (Count,
439/// SumInt, SumFloat, Avg, Min, Max, First, Last, Collect) are fully resumable.
440/// Rare/complex variants fall back to `Frozen(val)`.
441#[cfg(feature = "spill")]
442fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
443    use crate::execution::spill::deserialize_value;
444
445    // Read key values
446    let mut len_buf = [0u8; 8];
447    r.read_exact(&mut len_buf)?;
448    let num_keys = u64::from_le_bytes(len_buf) as usize;
449
450    let mut key_values = Vec::with_capacity(num_keys);
451    for _ in 0..num_keys {
452        key_values.push(deserialize_value(r)?);
453    }
454
455    // Read accumulators with tag-based reconstruction
456    r.read_exact(&mut len_buf)?;
457    let num_accumulators = u64::from_le_bytes(len_buf) as usize;
458
459    let mut accumulators = Vec::with_capacity(num_accumulators);
460    for _ in 0..num_accumulators {
461        let mut tag = [0u8; 1];
462        r.read_exact(&mut tag)?;
463
464        let state = match tag[0] {
465            spill_tag::COUNT => {
466                let mut buf = [0u8; 8];
467                r.read_exact(&mut buf)?;
468                AggregateState::Count(i64::from_le_bytes(buf))
469            }
470            spill_tag::SUM_INT => {
471                let mut buf = [0u8; 8];
472                r.read_exact(&mut buf)?;
473                let sum = i64::from_le_bytes(buf);
474                r.read_exact(&mut buf)?;
475                let count = i64::from_le_bytes(buf);
476                AggregateState::SumInt(sum, count)
477            }
478            spill_tag::SUM_FLOAT => {
479                let mut buf = [0u8; 8];
480                r.read_exact(&mut buf)?;
481                let sum = f64::from_le_bytes(buf);
482                r.read_exact(&mut buf)?;
483                let count = i64::from_le_bytes(buf);
484                // Reset Kahan compensation to zero; minor precision loss is acceptable
485                AggregateState::SumFloat(sum, 0.0, count)
486            }
487            spill_tag::AVG => {
488                let mut buf = [0u8; 8];
489                r.read_exact(&mut buf)?;
490                let sum = f64::from_le_bytes(buf);
491                r.read_exact(&mut buf)?;
492                let count = i64::from_le_bytes(buf);
493                AggregateState::Avg(sum, count)
494            }
495            spill_tag::MIN => {
496                let val = deserialize_value(r)?;
497                let opt = if matches!(val, Value::Null) {
498                    None
499                } else {
500                    Some(val)
501                };
502                AggregateState::Min(opt)
503            }
504            spill_tag::MAX => {
505                let val = deserialize_value(r)?;
506                let opt = if matches!(val, Value::Null) {
507                    None
508                } else {
509                    Some(val)
510                };
511                AggregateState::Max(opt)
512            }
513            spill_tag::FIRST => {
514                let val = deserialize_value(r)?;
515                let opt = if matches!(val, Value::Null) {
516                    None
517                } else {
518                    Some(val)
519                };
520                AggregateState::First(opt)
521            }
522            spill_tag::LAST => {
523                let val = deserialize_value(r)?;
524                let opt = if matches!(val, Value::Null) {
525                    None
526                } else {
527                    Some(val)
528                };
529                AggregateState::Last(opt)
530            }
531            spill_tag::COLLECT => {
532                let mut buf = [0u8; 8];
533                r.read_exact(&mut buf)?;
534                let len = u64::from_le_bytes(buf) as usize;
535                let mut list = Vec::with_capacity(len);
536                for _ in 0..len {
537                    list.push(deserialize_value(r)?);
538                }
539                AggregateState::Collect(list)
540            }
541            _ => {
542                let val = deserialize_value(r)?;
543                AggregateState::Frozen(val)
544            }
545        };
546
547        accumulators.push(state);
548    }
549
550    Ok(GroupState {
551        key_values,
552        accumulators,
553    })
554}
555
556/// Push-based aggregate operator with spilling support.
557///
558/// Uses partitioned hash table that can spill cold partitions to disk
559/// when memory pressure is high.
560#[cfg(feature = "spill")]
561pub struct SpillableAggregatePushOperator {
562    /// Columns to group by.
563    group_by: Vec<usize>,
564    /// Aggregate expressions.
565    aggregates: Vec<AggregateExpr>,
566    /// Spill manager (None = no spilling).
567    spill_manager: Option<Arc<SpillManager>>,
568    /// Partitioned groups (used when spilling is enabled).
569    partitioned_groups: Option<PartitionedState<GroupState>>,
570    /// Non-partitioned groups (used when spilling is disabled).
571    groups: HashMap<GroupKey, GroupState>,
572    /// Global accumulator (for no GROUP BY).
573    global_state: Option<Vec<AggregateState>>,
574    /// Spill threshold (number of groups).
575    spill_threshold: usize,
576    /// Whether we've switched to partitioned mode.
577    using_partitioned: bool,
578}
579
580#[cfg(feature = "spill")]
581impl SpillableAggregatePushOperator {
582    /// Create a new spillable aggregate operator.
583    pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
584        let global_state = if group_by.is_empty() {
585            Some(aggregates.iter().map(state_for_expr).collect())
586        } else {
587            None
588        };
589
590        Self {
591            group_by,
592            aggregates,
593            spill_manager: None,
594            partitioned_groups: None,
595            groups: HashMap::new(),
596            global_state,
597            spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
598            using_partitioned: false,
599        }
600    }
601
602    /// Create a spillable aggregate operator with spilling enabled.
603    pub fn with_spilling(
604        group_by: Vec<usize>,
605        aggregates: Vec<AggregateExpr>,
606        manager: Arc<SpillManager>,
607        threshold: usize,
608    ) -> Self {
609        let global_state = if group_by.is_empty() {
610            Some(aggregates.iter().map(state_for_expr).collect())
611        } else {
612            None
613        };
614
615        let partitioned = PartitionedState::new(
616            Arc::clone(&manager),
617            256, // Number of partitions
618            serialize_group_state,
619            deserialize_group_state,
620        );
621
622        Self {
623            group_by,
624            aggregates,
625            spill_manager: Some(manager),
626            partitioned_groups: Some(partitioned),
627            groups: HashMap::new(),
628            global_state,
629            spill_threshold: threshold,
630            using_partitioned: true,
631        }
632    }
633
634    /// Create a simple global aggregate (no GROUP BY).
635    pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
636        Self::new(Vec::new(), aggregates)
637    }
638
639    /// Sets the spill threshold.
640    pub fn with_threshold(mut self, threshold: usize) -> Self {
641        self.spill_threshold = threshold;
642        self
643    }
644
645    /// Switches to partitioned mode if needed.
646    fn maybe_spill(&mut self) -> Result<(), OperatorError> {
647        if self.global_state.is_some() {
648            // Global aggregation doesn't need spilling
649            return Ok(());
650        }
651
652        // If using partitioned state, check if we need to spill
653        if let Some(ref mut partitioned) = self.partitioned_groups {
654            if partitioned.total_size() >= self.spill_threshold {
655                partitioned
656                    .spill_largest()
657                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
658            }
659        } else if self.groups.len() >= self.spill_threshold {
660            // Not using partitioned state yet, but reached threshold
661            // If spilling is configured, switch to partitioned mode
662            if let Some(ref manager) = self.spill_manager {
663                let mut partitioned = PartitionedState::new(
664                    Arc::clone(manager),
665                    256,
666                    serialize_group_state,
667                    deserialize_group_state,
668                );
669
670                // Move existing groups to partitioned state
671                for (_key, state) in self.groups.drain() {
672                    partitioned
673                        .insert(state.key_values.clone(), state)
674                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
675                }
676
677                self.partitioned_groups = Some(partitioned);
678                self.using_partitioned = true;
679            }
680        }
681
682        Ok(())
683    }
684}
685
686#[cfg(feature = "spill")]
687impl PushOperator for SpillableAggregatePushOperator {
688    fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
689        if chunk.is_empty() {
690            return Ok(true);
691        }
692
693        for row in chunk.selected_indices() {
694            if self.group_by.is_empty() {
695                // Global aggregation - same as non-spillable
696                if let Some(ref mut accumulators) = self.global_state {
697                    for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
698                        update_accumulator(acc, expr, &chunk, row);
699                    }
700                }
701            } else if self.using_partitioned {
702                // Use partitioned state
703                if let Some(ref mut partitioned) = self.partitioned_groups {
704                    let key_values: Vec<Value> = self
705                        .group_by
706                        .iter()
707                        .map(|&col| {
708                            chunk
709                                .column(col)
710                                .and_then(|c| c.get_value(row))
711                                .unwrap_or(Value::Null)
712                        })
713                        .collect();
714
715                    let aggregates = &self.aggregates;
716                    let state = partitioned
717                        .get_or_insert_with(key_values.clone(), || GroupState {
718                            key_values: key_values.clone(),
719                            accumulators: aggregates.iter().map(state_for_expr).collect(),
720                        })
721                        .map_err(|e| OperatorError::Execution(e.to_string()))?;
722
723                    for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
724                        update_accumulator(acc, expr, &chunk, row);
725                    }
726                }
727            } else {
728                // Use regular hash map
729                let key = GroupKey::from_row(&chunk, row, &self.group_by);
730
731                let state = self.groups.entry(key).or_insert_with(|| {
732                    let key_values: Vec<Value> = self
733                        .group_by
734                        .iter()
735                        .map(|&col| {
736                            chunk
737                                .column(col)
738                                .and_then(|c| c.get_value(row))
739                                .unwrap_or(Value::Null)
740                        })
741                        .collect();
742
743                    GroupState {
744                        key_values,
745                        accumulators: self.aggregates.iter().map(state_for_expr).collect(),
746                    }
747                });
748
749                for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
750                    update_accumulator(acc, expr, &chunk, row);
751                }
752            }
753        }
754
755        // Check if we need to spill
756        self.maybe_spill()?;
757
758        Ok(true)
759    }
760
761    fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
762        let num_output_cols = self.group_by.len() + self.aggregates.len();
763        let mut columns: Vec<ValueVector> =
764            (0..num_output_cols).map(|_| ValueVector::new()).collect();
765
766        if self.group_by.is_empty() {
767            // Global aggregation - single row output
768            if let Some(ref accumulators) = self.global_state {
769                for (i, acc) in accumulators.iter().enumerate() {
770                    columns[i].push(acc.finalize());
771                }
772            }
773        } else if self.using_partitioned {
774            // Drain partitioned state
775            if let Some(ref mut partitioned) = self.partitioned_groups {
776                let groups = partitioned
777                    .drain_all()
778                    .map_err(|e| OperatorError::Execution(e.to_string()))?;
779
780                for (_key, state) in groups {
781                    // Output group key columns
782                    for (i, val) in state.key_values.iter().enumerate() {
783                        columns[i].push(val.clone());
784                    }
785
786                    // Output aggregate results
787                    for (i, acc) in state.accumulators.iter().enumerate() {
788                        columns[self.group_by.len() + i].push(acc.finalize());
789                    }
790                }
791            }
792        } else {
793            // Group by using regular hash map - one row per group
794            for state in self.groups.values() {
795                // Output group key columns
796                for (i, val) in state.key_values.iter().enumerate() {
797                    columns[i].push(val.clone());
798                }
799
800                // Output aggregate results
801                for (i, acc) in state.accumulators.iter().enumerate() {
802                    columns[self.group_by.len() + i].push(acc.finalize());
803                }
804            }
805        }
806
807        if !columns.is_empty() && !columns[0].is_empty() {
808            let chunk = DataChunk::new(columns);
809            sink.consume(chunk)?;
810        }
811
812        Ok(())
813    }
814
815    fn preferred_chunk_size(&self) -> ChunkSizeHint {
816        ChunkSizeHint::Default
817    }
818
819    fn name(&self) -> &'static str {
820        "SpillableAggregatePush"
821    }
822}
823
824#[cfg(test)]
825mod tests {
826    use super::*;
827    use crate::execution::operators::accumulator::AggregateFunction;
828    use crate::execution::sink::CollectorSink;
829
830    fn create_test_chunk(values: &[i64]) -> DataChunk {
831        let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
832        let vector = ValueVector::from_values(&v);
833        DataChunk::new(vec![vector])
834    }
835
836    fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
837        let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
838        let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
839        DataChunk::new(vec![
840            ValueVector::from_values(&v1),
841            ValueVector::from_values(&v2),
842        ])
843    }
844
845    #[test]
846    fn test_global_count() {
847        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
848        let mut sink = CollectorSink::new();
849
850        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
851            .unwrap();
852        agg.finalize(&mut sink).unwrap();
853
854        let chunks = sink.into_chunks();
855        assert_eq!(chunks.len(), 1);
856        assert_eq!(
857            chunks[0].column(0).unwrap().get_value(0),
858            Some(Value::Int64(5))
859        );
860    }
861
862    #[test]
863    fn test_global_sum() {
864        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
865        let mut sink = CollectorSink::new();
866
867        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
868            .unwrap();
869        agg.finalize(&mut sink).unwrap();
870
871        let chunks = sink.into_chunks();
872        // AggregateState preserves integer type for SUM of integers
873        assert_eq!(
874            chunks[0].column(0).unwrap().get_value(0),
875            Some(Value::Int64(15))
876        );
877    }
878
879    #[test]
880    fn test_global_min_max() {
881        let mut agg =
882            AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
883        let mut sink = CollectorSink::new();
884
885        agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
886            .unwrap();
887        agg.finalize(&mut sink).unwrap();
888
889        let chunks = sink.into_chunks();
890        assert_eq!(
891            chunks[0].column(0).unwrap().get_value(0),
892            Some(Value::Int64(1))
893        );
894        assert_eq!(
895            chunks[0].column(1).unwrap().get_value(0),
896            Some(Value::Int64(9))
897        );
898    }
899
900    #[test]
901    fn test_group_by_sum() {
902        // Group by column 0, sum column 1
903        let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
904        let mut sink = CollectorSink::new();
905
906        // Group 1: 10, 20 (sum=30), Group 2: 30, 40 (sum=70)
907        agg.push(
908            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
909            &mut sink,
910        )
911        .unwrap();
912        agg.finalize(&mut sink).unwrap();
913
914        let chunks = sink.into_chunks();
915        assert_eq!(chunks[0].len(), 2); // 2 groups
916    }
917
918    #[test]
919    #[cfg(feature = "spill")]
920    fn test_spillable_aggregate_no_spill() {
921        // When threshold is not reached, should work like normal aggregate
922        let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
923            .with_threshold(100);
924        let mut sink = CollectorSink::new();
925
926        agg.push(
927            create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
928            &mut sink,
929        )
930        .unwrap();
931        agg.finalize(&mut sink).unwrap();
932
933        let chunks = sink.into_chunks();
934        assert_eq!(chunks[0].len(), 2); // 2 groups
935    }
936
937    #[test]
938    #[cfg(feature = "spill")]
939    fn test_spillable_aggregate_with_spilling() {
940        use tempfile::TempDir;
941
942        let temp_dir = TempDir::new().unwrap();
943        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
944
945        // Set very low threshold to force spilling
946        let mut agg = SpillableAggregatePushOperator::with_spilling(
947            vec![0],
948            vec![AggregateExpr::sum(1)],
949            manager,
950            3, // Spill after 3 groups
951        );
952        let mut sink = CollectorSink::new();
953
954        // Create 10 different groups
955        for i in 0..10 {
956            let chunk = create_two_column_chunk(&[i], &[i * 10]);
957            agg.push(chunk, &mut sink).unwrap();
958        }
959        agg.finalize(&mut sink).unwrap();
960
961        let chunks = sink.into_chunks();
962        assert_eq!(chunks.len(), 1);
963        assert_eq!(chunks[0].len(), 10); // 10 groups
964
965        // Verify sums are correct (AggregateState preserves Int64 for integer sums)
966        let mut sums: Vec<i64> = Vec::new();
967        for i in 0..chunks[0].len() {
968            if let Some(Value::Int64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
969                sums.push(sum);
970            }
971        }
972        sums.sort_unstable();
973        assert_eq!(sums, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
974    }
975
976    #[test]
977    #[cfg(feature = "spill")]
978    fn test_spillable_aggregate_global() {
979        // Global aggregation shouldn't be affected by spilling
980        let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
981        let mut sink = CollectorSink::new();
982
983        agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
984            .unwrap();
985        agg.finalize(&mut sink).unwrap();
986
987        let chunks = sink.into_chunks();
988        assert_eq!(chunks.len(), 1);
989        assert_eq!(
990            chunks[0].column(0).unwrap().get_value(0),
991            Some(Value::Int64(5))
992        );
993    }
994
995    #[test]
996    #[cfg(feature = "spill")]
997    fn test_spillable_aggregate_many_groups() {
998        use tempfile::TempDir;
999
1000        let temp_dir = TempDir::new().unwrap();
1001        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1002
1003        let mut agg = SpillableAggregatePushOperator::with_spilling(
1004            vec![0],
1005            vec![AggregateExpr::count_star()],
1006            manager,
1007            10, // Very low threshold
1008        );
1009        let mut sink = CollectorSink::new();
1010
1011        // Create 100 different groups
1012        for i in 0..100 {
1013            let chunk = create_test_chunk(&[i]);
1014            agg.push(chunk, &mut sink).unwrap();
1015        }
1016        agg.finalize(&mut sink).unwrap();
1017
1018        let chunks = sink.into_chunks();
1019        assert_eq!(chunks.len(), 1);
1020        assert_eq!(chunks[0].len(), 100); // 100 groups
1021
1022        // Each group should have count = 1
1023        for i in 0..100 {
1024            if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1025                assert_eq!(count, 1);
1026            }
1027        }
1028    }
1029
1030    // ---------------------------------------------------------------
1031    // hash_value coverage for all Value variants
1032    // ---------------------------------------------------------------
1033
1034    #[test]
1035    fn hash_value_null() {
1036        let h = hash_value(&Value::Null);
1037        assert_ne!(h, 0); // hasher produces non-zero for Null discriminant
1038    }
1039
1040    #[test]
1041    fn hash_value_bool() {
1042        let t = hash_value(&Value::Bool(true));
1043        let f = hash_value(&Value::Bool(false));
1044        assert_ne!(t, f);
1045    }
1046
1047    #[test]
1048    fn hash_value_int64() {
1049        let a = hash_value(&Value::Int64(42));
1050        let b = hash_value(&Value::Int64(43));
1051        assert_ne!(a, b);
1052    }
1053
1054    #[test]
1055    fn hash_value_float64() {
1056        let a = hash_value(&Value::Float64(19.88));
1057        let b = hash_value(&Value::Float64(3.19));
1058        assert_ne!(a, b);
1059    }
1060
1061    #[test]
1062    fn hash_value_string() {
1063        let a = hash_value(&Value::String("hello".into()));
1064        let b = hash_value(&Value::String("world".into()));
1065        assert_ne!(a, b);
1066    }
1067
1068    #[test]
1069    fn hash_value_bytes() {
1070        let a = hash_value(&Value::Bytes(vec![1, 2, 3].into()));
1071        let b = hash_value(&Value::Bytes(vec![4, 5, 6].into()));
1072        assert_ne!(a, b);
1073    }
1074
1075    #[test]
1076    fn hash_value_list() {
1077        let a = hash_value(&Value::List(vec![Value::Int64(1), Value::Int64(2)].into()));
1078        let b = hash_value(&Value::List(vec![Value::Int64(3)].into()));
1079        assert_ne!(a, b);
1080    }
1081
1082    #[test]
1083    fn hash_value_map() {
1084        use grafeo_common::types::PropertyKey;
1085        use std::collections::BTreeMap;
1086        use std::sync::Arc;
1087        let mut map = BTreeMap::new();
1088        map.insert(PropertyKey::new("key"), Value::Int64(42));
1089        let h = hash_value(&Value::Map(Arc::new(map)));
1090        assert_ne!(h, 0);
1091    }
1092
1093    #[test]
1094    fn hash_value_vector() {
1095        let h = hash_value(&Value::Vector(vec![1.0, 2.0, 3.0].into()));
1096        assert_ne!(h, 0);
1097    }
1098
1099    #[test]
1100    fn hash_value_path() {
1101        let h = hash_value(&Value::Path {
1102            nodes: vec![Value::Int64(1), Value::Int64(2)].into(),
1103            edges: vec![Value::Int64(10)].into(),
1104        });
1105        assert_ne!(h, 0);
1106    }
1107
1108    #[test]
1109    fn hash_value_gcounter() {
1110        use std::sync::Arc;
1111        let mut map = std::collections::HashMap::new();
1112        map.insert("replica1".to_string(), 10u64);
1113        let h = hash_value(&Value::GCounter(Arc::new(map)));
1114        assert_ne!(h, 0);
1115    }
1116
1117    #[test]
1118    fn hash_value_on_counter() {
1119        use std::sync::Arc;
1120        let mut pos = std::collections::HashMap::new();
1121        pos.insert("replica1".to_string(), 10u64);
1122        let neg = std::collections::HashMap::new();
1123        let h = hash_value(&Value::OnCounter {
1124            pos: Arc::new(pos),
1125            neg: Arc::new(neg),
1126        });
1127        assert_ne!(h, 0);
1128    }
1129
1130    #[test]
1131    fn hash_value_timestamp() {
1132        use grafeo_common::types::Timestamp;
1133        let h = hash_value(&Value::Timestamp(Timestamp::from_micros(1_700_000_000_000)));
1134        assert_ne!(h, 0);
1135    }
1136
1137    #[test]
1138    fn hash_value_date() {
1139        use grafeo_common::types::Date;
1140        let h = hash_value(&Value::Date(Date::from_days(19000)));
1141        assert_ne!(h, 0);
1142    }
1143
1144    #[test]
1145    fn hash_value_time() {
1146        use grafeo_common::types::Time;
1147        let h = hash_value(&Value::Time(Time::from_hms(12, 0, 0).unwrap()));
1148        assert_ne!(h, 0);
1149    }
1150
1151    #[test]
1152    fn hash_value_duration() {
1153        use grafeo_common::types::Duration;
1154        let h = hash_value(&Value::Duration(Duration::from_days(1)));
1155        assert_ne!(h, 0);
1156    }
1157
1158    #[test]
1159    fn hash_value_zoned_datetime() {
1160        use grafeo_common::types::{Timestamp, ZonedDatetime};
1161        let zdt =
1162            ZonedDatetime::from_timestamp_offset(Timestamp::from_micros(1_700_000_000_000), 3600);
1163        let h = hash_value(&Value::ZonedDatetime(zdt));
1164        assert_ne!(h, 0);
1165    }
1166
1167    // ---------------------------------------------------------------
1168    // AggregateState in push context: advanced functions now work
1169    // ---------------------------------------------------------------
1170
1171    #[test]
1172    fn aggregate_state_last_returns_last_value() {
1173        let mut state = AggregateState::new(AggregateFunction::Last, false, None, None);
1174        state.update(Some(Value::Int64(10)));
1175        state.update(Some(Value::Int64(20)));
1176        assert_eq!(state.finalize(), Value::Int64(20));
1177    }
1178
1179    #[test]
1180    fn aggregate_state_collect_returns_list() {
1181        let mut state = AggregateState::new(AggregateFunction::Collect, false, None, None);
1182        state.update(Some(Value::Int64(1)));
1183        state.update(Some(Value::Int64(2)));
1184        assert_eq!(
1185            state.finalize(),
1186            Value::List(vec![Value::Int64(1), Value::Int64(2)].into())
1187        );
1188    }
1189
1190    #[test]
1191    fn aggregate_state_stdev_returns_value() {
1192        let mut state = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1193        state.update(Some(Value::Float64(2.0)));
1194        state.update(Some(Value::Float64(4.0)));
1195        state.update(Some(Value::Float64(6.0)));
1196        let result = state.finalize();
1197        assert!(matches!(result, Value::Float64(_)));
1198    }
1199
1200    #[test]
1201    fn aggregate_state_first_returns_first_value() {
1202        let mut state = AggregateState::new(AggregateFunction::First, false, None, None);
1203        state.update(Some(Value::Int64(10)));
1204        state.update(Some(Value::Int64(20)));
1205        assert_eq!(state.finalize(), Value::Int64(10));
1206    }
1207
1208    #[test]
1209    fn aggregate_state_avg_empty_returns_null() {
1210        let state = AggregateState::new(AggregateFunction::Avg, false, None, None);
1211        assert_eq!(state.finalize(), Value::Null);
1212    }
1213
1214    #[test]
1215    fn aggregate_state_sum_empty_returns_null() {
1216        let state = AggregateState::new(AggregateFunction::Sum, false, None, None);
1217        assert_eq!(state.finalize(), Value::Null);
1218    }
1219
1220    #[test]
1221    fn aggregate_state_min_max_empty_returns_null() {
1222        let min = AggregateState::new(AggregateFunction::Min, false, None, None);
1223        let max = AggregateState::new(AggregateFunction::Max, false, None, None);
1224        assert_eq!(min.finalize(), Value::Null);
1225        assert_eq!(max.finalize(), Value::Null);
1226    }
1227
1228    #[test]
1229    fn aggregate_state_count_non_null_skips_nulls() {
1230        // CountNonNull maps to the Count(0) state variant, which increments
1231        // unconditionally. Callers (both push and pull operators) must filter
1232        // null values before calling update. This test verifies the expected
1233        // contract: only non-null values are fed to the accumulator.
1234        let mut state = AggregateState::new(AggregateFunction::CountNonNull, false, None, None);
1235        // Simulate what the operator should do: skip nulls, update only non-nulls
1236        // (Value::Null is skipped, Value::Int64(5) is the only non-null)
1237        state.update(Some(Value::Int64(5)));
1238        assert_eq!(state.finalize(), Value::Int64(1));
1239    }
1240
1241    #[test]
1242    fn test_empty_chunk_returns_ok() {
1243        let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1244        let mut sink = CollectorSink::new();
1245        let empty = DataChunk::new(vec![ValueVector::new()]);
1246        let result = agg.push(empty, &mut sink).unwrap();
1247        assert!(result);
1248    }
1249
1250    // ---------------------------------------------------------------
1251    // Spill serialization round-trip tests
1252    // ---------------------------------------------------------------
1253
1254    #[test]
1255    #[cfg(feature = "spill")]
1256    fn spill_roundtrip_count() {
1257        let state = GroupState {
1258            key_values: vec![Value::String("grp".into())],
1259            accumulators: vec![AggregateState::Count(42)],
1260        };
1261        let mut buf = Vec::new();
1262        serialize_group_state(&state, &mut buf).unwrap();
1263        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1264        assert_eq!(restored.key_values, vec![Value::String("grp".into())]);
1265        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(42));
1266    }
1267
1268    #[test]
1269    #[cfg(feature = "spill")]
1270    fn spill_roundtrip_sum_int() {
1271        let state = GroupState {
1272            key_values: vec![Value::Int64(1)],
1273            accumulators: vec![AggregateState::SumInt(100, 5)],
1274        };
1275        let mut buf = Vec::new();
1276        serialize_group_state(&state, &mut buf).unwrap();
1277        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1278        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(100));
1279    }
1280
1281    #[test]
1282    #[cfg(feature = "spill")]
1283    fn spill_roundtrip_sum_float() {
1284        let state = GroupState {
1285            key_values: vec![Value::Int64(1)],
1286            accumulators: vec![AggregateState::SumFloat(3.125, 0.0, 2)],
1287        };
1288        let mut buf = Vec::new();
1289        serialize_group_state(&state, &mut buf).unwrap();
1290        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1291        assert_eq!(restored.accumulators[0].finalize(), Value::Float64(3.125));
1292    }
1293
1294    #[test]
1295    #[cfg(feature = "spill")]
1296    fn spill_roundtrip_avg() {
1297        let state = GroupState {
1298            key_values: vec![Value::Int64(1)],
1299            accumulators: vec![AggregateState::Avg(30.0, 3)],
1300        };
1301        let mut buf = Vec::new();
1302        serialize_group_state(&state, &mut buf).unwrap();
1303        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1304        assert_eq!(restored.accumulators[0].finalize(), Value::Float64(10.0));
1305    }
1306
1307    #[test]
1308    #[cfg(feature = "spill")]
1309    fn spill_roundtrip_min() {
1310        let state = GroupState {
1311            key_values: vec![Value::Int64(1)],
1312            accumulators: vec![AggregateState::Min(Some(Value::Int64(7)))],
1313        };
1314        let mut buf = Vec::new();
1315        serialize_group_state(&state, &mut buf).unwrap();
1316        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1317        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(7));
1318    }
1319
1320    #[test]
1321    #[cfg(feature = "spill")]
1322    fn spill_roundtrip_min_none() {
1323        let state = GroupState {
1324            key_values: vec![Value::Int64(1)],
1325            accumulators: vec![AggregateState::Min(None)],
1326        };
1327        let mut buf = Vec::new();
1328        serialize_group_state(&state, &mut buf).unwrap();
1329        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1330        assert_eq!(restored.accumulators[0].finalize(), Value::Null);
1331    }
1332
1333    #[test]
1334    #[cfg(feature = "spill")]
1335    fn spill_roundtrip_max() {
1336        let state = GroupState {
1337            key_values: vec![Value::Int64(1)],
1338            accumulators: vec![AggregateState::Max(Some(Value::Int64(99)))],
1339        };
1340        let mut buf = Vec::new();
1341        serialize_group_state(&state, &mut buf).unwrap();
1342        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1343        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(99));
1344    }
1345
1346    #[test]
1347    #[cfg(feature = "spill")]
1348    fn spill_roundtrip_first() {
1349        let state = GroupState {
1350            key_values: vec![Value::Int64(1)],
1351            accumulators: vec![AggregateState::First(Some(Value::String("hello".into())))],
1352        };
1353        let mut buf = Vec::new();
1354        serialize_group_state(&state, &mut buf).unwrap();
1355        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1356        assert_eq!(
1357            restored.accumulators[0].finalize(),
1358            Value::String("hello".into())
1359        );
1360    }
1361
1362    #[test]
1363    #[cfg(feature = "spill")]
1364    fn spill_roundtrip_last() {
1365        let state = GroupState {
1366            key_values: vec![Value::Int64(1)],
1367            accumulators: vec![AggregateState::Last(Some(Value::Float64(2.75)))],
1368        };
1369        let mut buf = Vec::new();
1370        serialize_group_state(&state, &mut buf).unwrap();
1371        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1372        assert_eq!(restored.accumulators[0].finalize(), Value::Float64(2.75));
1373    }
1374
1375    #[test]
1376    #[cfg(feature = "spill")]
1377    fn spill_roundtrip_collect() {
1378        let state = GroupState {
1379            key_values: vec![Value::Int64(1)],
1380            accumulators: vec![AggregateState::Collect(vec![
1381                Value::Int64(10),
1382                Value::Int64(20),
1383                Value::Int64(30),
1384            ])],
1385        };
1386        let mut buf = Vec::new();
1387        serialize_group_state(&state, &mut buf).unwrap();
1388        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1389        assert_eq!(
1390            restored.accumulators[0].finalize(),
1391            Value::List(vec![Value::Int64(10), Value::Int64(20), Value::Int64(30)].into())
1392        );
1393    }
1394
1395    #[test]
1396    #[cfg(feature = "spill")]
1397    fn spill_roundtrip_all_variants_combined() {
1398        // A single GroupState with every common accumulator type
1399        let state = GroupState {
1400            key_values: vec![Value::String("combined".into()), Value::Int64(42)],
1401            accumulators: vec![
1402                AggregateState::Count(10),
1403                AggregateState::SumInt(50, 5),
1404                AggregateState::SumFloat(7.5, 0.0, 3),
1405                AggregateState::Avg(20.0, 4),
1406                AggregateState::Min(Some(Value::Int64(1))),
1407                AggregateState::Max(Some(Value::Int64(99))),
1408                AggregateState::First(Some(Value::String("first".into()))),
1409                AggregateState::Last(Some(Value::String("last".into()))),
1410                AggregateState::Collect(vec![Value::Int64(1), Value::Int64(2)]),
1411            ],
1412        };
1413        let mut buf = Vec::new();
1414        serialize_group_state(&state, &mut buf).unwrap();
1415        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1416
1417        assert_eq!(restored.key_values.len(), 2);
1418        assert_eq!(restored.key_values[0], Value::String("combined".into()));
1419        assert_eq!(restored.key_values[1], Value::Int64(42));
1420        assert_eq!(restored.accumulators.len(), 9);
1421
1422        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(10));
1423        assert_eq!(restored.accumulators[1].finalize(), Value::Int64(50));
1424        assert_eq!(restored.accumulators[2].finalize(), Value::Float64(7.5));
1425        assert_eq!(restored.accumulators[3].finalize(), Value::Float64(5.0));
1426        assert_eq!(restored.accumulators[4].finalize(), Value::Int64(1));
1427        assert_eq!(restored.accumulators[5].finalize(), Value::Int64(99));
1428        assert_eq!(
1429            restored.accumulators[6].finalize(),
1430            Value::String("first".into())
1431        );
1432        assert_eq!(
1433            restored.accumulators[7].finalize(),
1434            Value::String("last".into())
1435        );
1436        assert_eq!(
1437            restored.accumulators[8].finalize(),
1438            Value::List(vec![Value::Int64(1), Value::Int64(2)].into())
1439        );
1440    }
1441
1442    // ---------------------------------------------------------------
1443    // DISTINCT variants serialize as FINALIZED
1444    // ---------------------------------------------------------------
1445
1446    #[test]
1447    #[cfg(feature = "spill")]
1448    fn spill_roundtrip_count_distinct() {
1449        use crate::execution::operators::accumulator::HashableValue;
1450        use std::collections::HashSet;
1451
1452        let mut seen = HashSet::new();
1453        seen.insert(HashableValue::from(Value::Int64(1)));
1454        seen.insert(HashableValue::from(Value::Int64(2)));
1455        seen.insert(HashableValue::from(Value::Int64(3)));
1456        let state = GroupState {
1457            key_values: vec![Value::Int64(1)],
1458            accumulators: vec![AggregateState::CountDistinct(3, seen)],
1459        };
1460        let mut buf = Vec::new();
1461        serialize_group_state(&state, &mut buf).unwrap();
1462        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1463        // DISTINCT serializes as FINALIZED, deserialized as Frozen(val)
1464        assert_eq!(restored.accumulators[0].finalize(), Value::Int64(3));
1465    }
1466
1467    #[test]
1468    #[cfg(feature = "spill")]
1469    fn spill_roundtrip_avg_distinct() {
1470        use crate::execution::operators::accumulator::HashableValue;
1471        use std::collections::HashSet;
1472
1473        let mut seen = HashSet::new();
1474        seen.insert(HashableValue::from(Value::Float64(2.0)));
1475        seen.insert(HashableValue::from(Value::Float64(4.0)));
1476        let state = GroupState {
1477            key_values: vec![Value::Int64(1)],
1478            accumulators: vec![AggregateState::AvgDistinct(6.0, 2, seen)],
1479        };
1480        let mut buf = Vec::new();
1481        serialize_group_state(&state, &mut buf).unwrap();
1482        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1483        assert_eq!(restored.accumulators[0].finalize(), Value::Float64(3.0));
1484    }
1485
1486    #[test]
1487    #[cfg(feature = "spill")]
1488    fn spill_roundtrip_collect_distinct() {
1489        use crate::execution::operators::accumulator::HashableValue;
1490        use std::collections::HashSet;
1491
1492        let mut seen = HashSet::new();
1493        seen.insert(HashableValue::from(Value::Int64(10)));
1494        seen.insert(HashableValue::from(Value::Int64(20)));
1495        let state = GroupState {
1496            key_values: vec![Value::Int64(1)],
1497            accumulators: vec![AggregateState::CollectDistinct(
1498                vec![Value::Int64(10), Value::Int64(20)],
1499                seen,
1500            )],
1501        };
1502        let mut buf = Vec::new();
1503        serialize_group_state(&state, &mut buf).unwrap();
1504        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1505        // CollectDistinct finalizes to a List, deserialized via FINALIZED fallback
1506        let result = restored.accumulators[0].finalize();
1507        assert!(matches!(result, Value::List(_)));
1508    }
1509
1510    // ---------------------------------------------------------------
1511    // Complex variants (FINALIZED fallback)
1512    // ---------------------------------------------------------------
1513
1514    #[test]
1515    #[cfg(feature = "spill")]
1516    fn spill_roundtrip_stddev() {
1517        // Build a StdDev state by feeding values
1518        let mut acc = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1519        acc.update(Some(Value::Float64(2.0)));
1520        acc.update(Some(Value::Float64(4.0)));
1521        acc.update(Some(Value::Float64(6.0)));
1522        let expected = acc.finalize();
1523
1524        let state = GroupState {
1525            key_values: vec![Value::Int64(1)],
1526            accumulators: vec![acc],
1527        };
1528        let mut buf = Vec::new();
1529        serialize_group_state(&state, &mut buf).unwrap();
1530        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1531        // Complex variant stored as FINALIZED, restored as Frozen(val)
1532        assert_eq!(restored.accumulators[0].finalize(), expected);
1533    }
1534
1535    #[test]
1536    #[cfg(feature = "spill")]
1537    fn spill_roundtrip_percentile_disc() {
1538        let state = GroupState {
1539            key_values: vec![Value::Int64(1)],
1540            accumulators: vec![AggregateState::PercentileDisc {
1541                values: vec![1.0, 2.0, 3.0, 4.0, 5.0],
1542                percentile: 0.5,
1543            }],
1544        };
1545        let expected = state.accumulators[0].finalize();
1546        let mut buf = Vec::new();
1547        serialize_group_state(&state, &mut buf).unwrap();
1548        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1549        assert_eq!(restored.accumulators[0].finalize(), expected);
1550    }
1551
1552    #[test]
1553    #[cfg(feature = "spill")]
1554    fn spill_roundtrip_group_concat() {
1555        let state = GroupState {
1556            key_values: vec![Value::Int64(1)],
1557            accumulators: vec![AggregateState::GroupConcat(
1558                vec!["alix".to_string(), "gus".to_string(), "vincent".to_string()],
1559                ", ".to_string(),
1560            )],
1561        };
1562        let expected = state.accumulators[0].finalize();
1563        let mut buf = Vec::new();
1564        serialize_group_state(&state, &mut buf).unwrap();
1565        let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1566        assert_eq!(restored.accumulators[0].finalize(), expected);
1567    }
1568
1569    // ---------------------------------------------------------------
1570    // SpillableAggregatePushOperator with Collect
1571    // ---------------------------------------------------------------
1572
1573    #[test]
1574    #[cfg(feature = "spill")]
1575    fn test_spillable_aggregate_collect() {
1576        use tempfile::TempDir;
1577
1578        let temp_dir = TempDir::new().unwrap();
1579        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1580
1581        let mut agg = SpillableAggregatePushOperator::with_spilling(
1582            vec![0],
1583            vec![AggregateExpr::collect(1)],
1584            manager,
1585            3, // Spill after 3 groups
1586        );
1587        let mut sink = CollectorSink::new();
1588
1589        // Create groups: group 1 collects [10, 20], group 2 collects [30, 40]
1590        agg.push(
1591            create_two_column_chunk(&[1, 2, 1, 2], &[10, 30, 20, 40]),
1592            &mut sink,
1593        )
1594        .unwrap();
1595        // Add more groups to trigger spilling
1596        for i in 3..10 {
1597            agg.push(create_two_column_chunk(&[i], &[i * 10]), &mut sink)
1598                .unwrap();
1599        }
1600        agg.finalize(&mut sink).unwrap();
1601
1602        let chunks = sink.into_chunks();
1603        assert_eq!(chunks.len(), 1);
1604        assert_eq!(chunks[0].len(), 9); // 9 groups
1605
1606        // Find group 1 and verify its collected list
1607        let mut found_group1 = false;
1608        for row in 0..chunks[0].len() {
1609            if let Some(Value::Int64(1)) = chunks[0].column(0).unwrap().get_value(row) {
1610                let collected = chunks[0].column(1).unwrap().get_value(row).unwrap();
1611                if let Value::List(list) = collected {
1612                    assert_eq!(list.len(), 2);
1613                    assert!(list.contains(&Value::Int64(10)));
1614                    assert!(list.contains(&Value::Int64(20)));
1615                    found_group1 = true;
1616                }
1617            }
1618        }
1619        assert!(found_group1, "Group 1 with collected values not found");
1620    }
1621
1622    // ---------------------------------------------------------------
1623    // SpillableAggregatePushOperator with Min/Max
1624    // ---------------------------------------------------------------
1625
1626    #[test]
1627    #[cfg(feature = "spill")]
1628    fn test_spillable_aggregate_min_max() {
1629        use tempfile::TempDir;
1630
1631        let temp_dir = TempDir::new().unwrap();
1632        let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1633
1634        let mut agg = SpillableAggregatePushOperator::with_spilling(
1635            vec![0],
1636            vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1637            manager,
1638            3, // Spill after 3 groups
1639        );
1640        let mut sink = CollectorSink::new();
1641
1642        // Group 1: values 50, 10, 30 => min=10, max=50
1643        // Group 2: values 20, 40 => min=20, max=40
1644        agg.push(
1645            create_two_column_chunk(&[1, 2, 1, 2, 1], &[50, 20, 10, 40, 30]),
1646            &mut sink,
1647        )
1648        .unwrap();
1649
1650        // Add more groups to trigger spilling
1651        for i in 3..10 {
1652            agg.push(create_two_column_chunk(&[i], &[i * 10]), &mut sink)
1653                .unwrap();
1654        }
1655        agg.finalize(&mut sink).unwrap();
1656
1657        let chunks = sink.into_chunks();
1658        assert_eq!(chunks.len(), 1);
1659        assert_eq!(chunks[0].len(), 9); // 9 groups
1660
1661        // Verify group 1: min=10, max=50
1662        let mut found_group1 = false;
1663        for row in 0..chunks[0].len() {
1664            if let Some(Value::Int64(1)) = chunks[0].column(0).unwrap().get_value(row) {
1665                assert_eq!(
1666                    chunks[0].column(1).unwrap().get_value(row),
1667                    Some(Value::Int64(10))
1668                );
1669                assert_eq!(
1670                    chunks[0].column(2).unwrap().get_value(row),
1671                    Some(Value::Int64(50))
1672                );
1673                found_group1 = true;
1674            }
1675        }
1676        assert!(found_group1, "Group 1 with min/max not found");
1677
1678        // Verify group 2: min=20, max=40
1679        let mut found_group2 = false;
1680        for row in 0..chunks[0].len() {
1681            if let Some(Value::Int64(2)) = chunks[0].column(0).unwrap().get_value(row) {
1682                assert_eq!(
1683                    chunks[0].column(1).unwrap().get_value(row),
1684                    Some(Value::Int64(20))
1685                );
1686                assert_eq!(
1687                    chunks[0].column(2).unwrap().get_value(row),
1688                    Some(Value::Int64(40))
1689                );
1690                found_group2 = true;
1691            }
1692        }
1693        assert!(found_group2, "Group 2 with min/max not found");
1694    }
1695
1696    #[test]
1697    #[cfg(feature = "spill")]
1698    fn spill_finalized_frozen_ignores_further_updates() {
1699        let mut acc = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1700        acc.update(Some(Value::Float64(2.0)));
1701        acc.update(Some(Value::Float64(4.0)));
1702        acc.update(Some(Value::Float64(6.0)));
1703        let expected = acc.finalize();
1704
1705        let state = GroupState {
1706            key_values: vec![Value::Int64(1)],
1707            accumulators: vec![acc],
1708        };
1709        let mut buf = Vec::new();
1710        serialize_group_state(&state, &mut buf).unwrap();
1711        let mut restored = deserialize_group_state(&mut &buf[..]).unwrap();
1712
1713        assert!(matches!(
1714            restored.accumulators[0],
1715            AggregateState::Frozen(_)
1716        ));
1717
1718        restored.accumulators[0].update(Some(Value::Float64(100.0)));
1719        restored.accumulators[0].update(Some(Value::Float64(200.0)));
1720
1721        assert_eq!(restored.accumulators[0].finalize(), expected);
1722    }
1723}