1use crate::execution::chunk::DataChunk;
4use crate::execution::operators::OperatorError;
5use crate::execution::operators::accumulator::{AggregateExpr, AggregateFunction, AggregateState};
6use crate::execution::pipeline::{ChunkSizeHint, PushOperator, Sink};
7#[cfg(feature = "spill")]
8use crate::execution::spill::{PartitionedState, SpillManager};
9use crate::execution::vector::ValueVector;
10use grafeo_common::types::Value;
11use std::collections::HashMap;
12#[cfg(feature = "spill")]
13use std::io::{Read, Write};
14#[cfg(feature = "spill")]
15use std::sync::Arc;
16
17fn state_for_expr(expr: &AggregateExpr) -> AggregateState {
19 AggregateState::new(
20 expr.function,
21 expr.distinct,
22 expr.percentile,
23 expr.separator.as_deref(),
24 )
25}
26
27fn update_accumulator(
30 acc: &mut AggregateState,
31 expr: &AggregateExpr,
32 chunk: &DataChunk,
33 row: usize,
34) {
35 if expr.column2.is_some() {
37 let y_val = expr
38 .column
39 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
40 let x_val = expr
41 .column2
42 .and_then(|col| chunk.column(col).and_then(|c| c.get_value(row)));
43 acc.update_bivariate(y_val, x_val);
44 return;
45 }
46
47 if let Some(col) = expr.column {
48 let val = chunk.column(col).and_then(|c| c.get_value(row));
49 if expr.function == AggregateFunction::CountNonNull
51 && matches!(val, None | Some(Value::Null))
52 {
53 return;
54 }
55 acc.update(val);
56 } else {
57 acc.update(None);
59 }
60}
61
62#[derive(Debug, Clone, PartialEq, Eq, Hash)]
64struct GroupKey(Vec<u64>);
65
66impl GroupKey {
67 fn from_row(chunk: &DataChunk, row: usize, group_by: &[usize]) -> Self {
68 let hashes: Vec<u64> = group_by
69 .iter()
70 .map(|&col| {
71 chunk
72 .column(col)
73 .and_then(|c| c.get_value(row))
74 .map_or(0, |v| hash_value(&v))
75 })
76 .collect();
77 Self(hashes)
78 }
79}
80
81fn hash_value(value: &Value) -> u64 {
82 use std::collections::hash_map::DefaultHasher;
83 use std::hash::{Hash, Hasher};
84
85 let mut hasher = DefaultHasher::new();
86 match value {
88 Value::Null => 0u8.hash(&mut hasher),
89 Value::Bool(b) => {
90 1u8.hash(&mut hasher);
91 b.hash(&mut hasher);
92 }
93 Value::Int64(i) => {
94 2u8.hash(&mut hasher);
95 i.hash(&mut hasher);
96 }
97 Value::Float64(f) => {
98 3u8.hash(&mut hasher);
99 f.to_bits().hash(&mut hasher);
100 }
101 Value::String(s) => {
102 4u8.hash(&mut hasher);
103 s.hash(&mut hasher);
104 }
105 Value::Bytes(b) => {
106 5u8.hash(&mut hasher);
107 b.hash(&mut hasher);
108 }
109 Value::Timestamp(t) => {
110 6u8.hash(&mut hasher);
111 t.hash(&mut hasher);
112 }
113 Value::Date(d) => {
114 7u8.hash(&mut hasher);
115 d.hash(&mut hasher);
116 }
117 Value::Time(t) => {
118 8u8.hash(&mut hasher);
119 t.hash(&mut hasher);
120 }
121 Value::Duration(d) => {
122 9u8.hash(&mut hasher);
123 d.hash(&mut hasher);
124 }
125 Value::ZonedDatetime(zdt) => {
126 10u8.hash(&mut hasher);
127 zdt.hash(&mut hasher);
128 }
129 Value::List(list) => {
130 11u8.hash(&mut hasher);
131 list.len().hash(&mut hasher);
132 for elem in list.iter() {
133 hash_value(elem).hash(&mut hasher);
134 }
135 }
136 Value::Map(map) => {
137 12u8.hash(&mut hasher);
138 map.len().hash(&mut hasher);
139 for (k, v) in map.as_ref() {
141 k.as_str().hash(&mut hasher);
142 hash_value(v).hash(&mut hasher);
143 }
144 }
145 Value::Vector(vec) => {
146 13u8.hash(&mut hasher);
147 vec.len().hash(&mut hasher);
148 for f in vec.iter() {
149 f.to_bits().hash(&mut hasher);
150 }
151 }
152 Value::Path { nodes, edges } => {
153 14u8.hash(&mut hasher);
154 nodes.len().hash(&mut hasher);
155 for n in nodes.iter() {
156 hash_value(n).hash(&mut hasher);
157 }
158 for e in edges.iter() {
159 hash_value(e).hash(&mut hasher);
160 }
161 }
162 Value::GCounter(map) => {
163 15u8.hash(&mut hasher);
164 let mut entries: Vec<_> = map.iter().collect();
165 entries.sort_by_key(|(k, _)| *k);
166 for (k, v) in entries {
167 k.hash(&mut hasher);
168 v.hash(&mut hasher);
169 }
170 }
171 Value::OnCounter { pos, neg } => {
172 16u8.hash(&mut hasher);
173 let mut pos_entries: Vec<_> = pos.iter().collect();
174 pos_entries.sort_by_key(|(k, _)| *k);
175 for (k, v) in pos_entries {
176 k.hash(&mut hasher);
177 v.hash(&mut hasher);
178 }
179 let mut neg_entries: Vec<_> = neg.iter().collect();
180 neg_entries.sort_by_key(|(k, _)| *k);
181 for (k, v) in neg_entries {
182 k.hash(&mut hasher);
183 v.hash(&mut hasher);
184 }
185 }
186 other => {
187 255u8.hash(&mut hasher);
188 std::mem::discriminant(other).hash(&mut hasher);
189 }
190 }
191 hasher.finish()
192}
193
194#[derive(Clone)]
196struct GroupState {
197 key_values: Vec<Value>,
198 accumulators: Vec<AggregateState>,
199}
200
201pub struct AggregatePushOperator {
206 group_by: Vec<usize>,
208 aggregates: Vec<AggregateExpr>,
210 groups: HashMap<GroupKey, GroupState>,
212 global_state: Option<Vec<AggregateState>>,
214}
215
216impl AggregatePushOperator {
217 pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
219 let global_state = if group_by.is_empty() {
220 Some(aggregates.iter().map(state_for_expr).collect())
221 } else {
222 None
223 };
224
225 Self {
226 group_by,
227 aggregates,
228 groups: HashMap::new(),
229 global_state,
230 }
231 }
232
233 pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
235 Self::new(Vec::new(), aggregates)
236 }
237}
238
239impl PushOperator for AggregatePushOperator {
240 fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
241 if chunk.is_empty() {
242 return Ok(true);
243 }
244
245 for row in chunk.selected_indices() {
246 if self.group_by.is_empty() {
247 if let Some(ref mut accumulators) = self.global_state {
249 for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
250 update_accumulator(acc, expr, &chunk, row);
251 }
252 }
253 } else {
254 let key = GroupKey::from_row(&chunk, row, &self.group_by);
256
257 let state = self.groups.entry(key).or_insert_with(|| {
258 let key_values: Vec<Value> = self
259 .group_by
260 .iter()
261 .map(|&col| {
262 chunk
263 .column(col)
264 .and_then(|c| c.get_value(row))
265 .unwrap_or(Value::Null)
266 })
267 .collect();
268
269 GroupState {
270 key_values,
271 accumulators: self.aggregates.iter().map(state_for_expr).collect(),
272 }
273 });
274
275 for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
276 update_accumulator(acc, expr, &chunk, row);
277 }
278 }
279 }
280
281 Ok(true)
282 }
283
284 fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
285 let num_output_cols = self.group_by.len() + self.aggregates.len();
286 let mut columns: Vec<ValueVector> =
287 (0..num_output_cols).map(|_| ValueVector::new()).collect();
288
289 if self.group_by.is_empty() {
290 if let Some(ref accumulators) = self.global_state {
292 for (i, acc) in accumulators.iter().enumerate() {
293 columns[i].push(acc.finalize());
294 }
295 }
296 } else {
297 for state in self.groups.values() {
299 for (i, val) in state.key_values.iter().enumerate() {
301 columns[i].push(val.clone());
302 }
303
304 for (i, acc) in state.accumulators.iter().enumerate() {
306 columns[self.group_by.len() + i].push(acc.finalize());
307 }
308 }
309 }
310
311 if !columns.is_empty() && !columns[0].is_empty() {
312 let chunk = DataChunk::new(columns);
313 sink.consume(chunk)?;
314 }
315
316 Ok(())
317 }
318
319 fn preferred_chunk_size(&self) -> ChunkSizeHint {
320 ChunkSizeHint::Default
321 }
322
323 fn name(&self) -> &'static str {
324 "AggregatePush"
325 }
326}
327
328#[cfg(feature = "spill")]
330pub const DEFAULT_AGGREGATE_SPILL_THRESHOLD: usize = 50_000;
331
332#[cfg(feature = "spill")]
337const AGGREGATE_MIN_BUFFER_GROUPS: usize = 500;
338
339#[cfg(feature = "spill")]
344mod spill_tag {
345 pub const COUNT: u8 = 0;
346 pub const SUM_INT: u8 = 1;
347 pub const SUM_FLOAT: u8 = 2;
348 pub const AVG: u8 = 3;
349 pub const MIN: u8 = 4;
350 pub const MAX: u8 = 5;
351 pub const FIRST: u8 = 6;
352 pub const LAST: u8 = 7;
353 pub const COLLECT: u8 = 8;
354 pub const FINALIZED: u8 = 255;
356}
357
358#[cfg(feature = "spill")]
365fn serialize_group_state(state: &GroupState, w: &mut dyn Write) -> std::io::Result<()> {
366 use crate::execution::spill::serialize_value;
367
368 w.write_all(&(state.key_values.len() as u64).to_le_bytes())?;
370 for val in &state.key_values {
371 serialize_value(val, w)?;
372 }
373
374 w.write_all(&(state.accumulators.len() as u64).to_le_bytes())?;
376 for acc in &state.accumulators {
377 match acc {
378 AggregateState::Count(n) => {
379 w.write_all(&[spill_tag::COUNT])?;
380 w.write_all(&n.to_le_bytes())?;
381 }
382 AggregateState::SumInt(sum, count) => {
383 w.write_all(&[spill_tag::SUM_INT])?;
384 w.write_all(&sum.to_le_bytes())?;
385 w.write_all(&count.to_le_bytes())?;
386 }
387 AggregateState::SumFloat(sum, _comp, count) => {
388 w.write_all(&[spill_tag::SUM_FLOAT])?;
389 w.write_all(&sum.to_le_bytes())?;
390 w.write_all(&count.to_le_bytes())?;
391 }
392 AggregateState::Avg(sum, count) => {
393 w.write_all(&[spill_tag::AVG])?;
394 w.write_all(&sum.to_le_bytes())?;
395 w.write_all(&count.to_le_bytes())?;
396 }
397 AggregateState::CountDistinct(..)
400 | AggregateState::SumIntDistinct(..)
401 | AggregateState::SumFloatDistinct(..)
402 | AggregateState::AvgDistinct(..)
403 | AggregateState::CollectDistinct(..)
404 | AggregateState::GroupConcatDistinct(..) => {
405 w.write_all(&[spill_tag::FINALIZED])?;
406 serialize_value(&acc.finalize(), w)?;
407 }
408 AggregateState::Min(val) => {
409 w.write_all(&[spill_tag::MIN])?;
410 serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
411 }
412 AggregateState::Max(val) => {
413 w.write_all(&[spill_tag::MAX])?;
414 serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
415 }
416 AggregateState::First(val) => {
417 w.write_all(&[spill_tag::FIRST])?;
418 serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
419 }
420 AggregateState::Last(val) => {
421 w.write_all(&[spill_tag::LAST])?;
422 serialize_value(&val.clone().unwrap_or(Value::Null), w)?;
423 }
424 AggregateState::Collect(list) => {
425 w.write_all(&[spill_tag::COLLECT])?;
426 w.write_all(&(list.len() as u64).to_le_bytes())?;
427 for val in list {
428 serialize_value(val, w)?;
429 }
430 }
431 _ => {
433 w.write_all(&[spill_tag::FINALIZED])?;
434 serialize_value(&acc.finalize(), w)?;
435 }
436 }
437 }
438
439 Ok(())
440}
441
442#[cfg(feature = "spill")]
449fn deserialize_group_state(r: &mut dyn Read) -> std::io::Result<GroupState> {
450 use crate::execution::spill::deserialize_value;
451
452 let mut len_buf = [0u8; 8];
454 r.read_exact(&mut len_buf)?;
455 #[allow(clippy::cast_possible_truncation)]
457 let num_keys = u64::from_le_bytes(len_buf) as usize;
458
459 let mut key_values = Vec::with_capacity(num_keys);
460 for _ in 0..num_keys {
461 key_values.push(deserialize_value(r)?);
462 }
463
464 r.read_exact(&mut len_buf)?;
466 #[allow(clippy::cast_possible_truncation)]
468 let num_accumulators = u64::from_le_bytes(len_buf) as usize;
469
470 let mut accumulators = Vec::with_capacity(num_accumulators);
471 for _ in 0..num_accumulators {
472 let mut tag = [0u8; 1];
473 r.read_exact(&mut tag)?;
474
475 let state = match tag[0] {
476 spill_tag::COUNT => {
477 let mut buf = [0u8; 8];
478 r.read_exact(&mut buf)?;
479 AggregateState::Count(i64::from_le_bytes(buf))
480 }
481 spill_tag::SUM_INT => {
482 let mut buf = [0u8; 8];
483 r.read_exact(&mut buf)?;
484 let sum = i64::from_le_bytes(buf);
485 r.read_exact(&mut buf)?;
486 let count = i64::from_le_bytes(buf);
487 AggregateState::SumInt(sum, count)
488 }
489 spill_tag::SUM_FLOAT => {
490 let mut buf = [0u8; 8];
491 r.read_exact(&mut buf)?;
492 let sum = f64::from_le_bytes(buf);
493 r.read_exact(&mut buf)?;
494 let count = i64::from_le_bytes(buf);
495 AggregateState::SumFloat(sum, 0.0, count)
497 }
498 spill_tag::AVG => {
499 let mut buf = [0u8; 8];
500 r.read_exact(&mut buf)?;
501 let sum = f64::from_le_bytes(buf);
502 r.read_exact(&mut buf)?;
503 let count = i64::from_le_bytes(buf);
504 AggregateState::Avg(sum, count)
505 }
506 spill_tag::MIN => {
507 let val = deserialize_value(r)?;
508 let opt = if matches!(val, Value::Null) {
509 None
510 } else {
511 Some(val)
512 };
513 AggregateState::Min(opt)
514 }
515 spill_tag::MAX => {
516 let val = deserialize_value(r)?;
517 let opt = if matches!(val, Value::Null) {
518 None
519 } else {
520 Some(val)
521 };
522 AggregateState::Max(opt)
523 }
524 spill_tag::FIRST => {
525 let val = deserialize_value(r)?;
526 let opt = if matches!(val, Value::Null) {
527 None
528 } else {
529 Some(val)
530 };
531 AggregateState::First(opt)
532 }
533 spill_tag::LAST => {
534 let val = deserialize_value(r)?;
535 let opt = if matches!(val, Value::Null) {
536 None
537 } else {
538 Some(val)
539 };
540 AggregateState::Last(opt)
541 }
542 spill_tag::COLLECT => {
543 let mut buf = [0u8; 8];
544 r.read_exact(&mut buf)?;
545 #[allow(clippy::cast_possible_truncation)]
547 let len = u64::from_le_bytes(buf) as usize;
548 let mut list = Vec::with_capacity(len);
549 for _ in 0..len {
550 list.push(deserialize_value(r)?);
551 }
552 AggregateState::Collect(list)
553 }
554 _ => {
555 let val = deserialize_value(r)?;
556 AggregateState::Frozen(val)
557 }
558 };
559
560 accumulators.push(state);
561 }
562
563 Ok(GroupState {
564 key_values,
565 accumulators,
566 })
567}
568
569#[cfg(feature = "spill")]
583pub struct SpillableAggregatePushOperator {
584 group_by: Vec<usize>,
586 aggregates: Vec<AggregateExpr>,
588 spill_manager: Option<Arc<SpillManager>>,
590 partitioned_groups: Option<PartitionedState<GroupState>>,
592 groups: HashMap<GroupKey, GroupState>,
594 global_state: Option<Vec<AggregateState>>,
596 spill_threshold: usize,
598 using_partitioned: bool,
600 memory_ctx: Option<crate::execution::memory::OperatorMemoryContext>,
602 spill_state: Option<std::sync::Arc<super::spill_state::OperatorSpillState>>,
604 estimated_bytes: usize,
606}
607
608#[cfg(feature = "spill")]
609impl SpillableAggregatePushOperator {
610 pub fn new(group_by: Vec<usize>, aggregates: Vec<AggregateExpr>) -> Self {
612 let global_state = if group_by.is_empty() {
613 Some(aggregates.iter().map(state_for_expr).collect())
614 } else {
615 None
616 };
617
618 Self {
619 group_by,
620 aggregates,
621 spill_manager: None,
622 partitioned_groups: None,
623 groups: HashMap::new(),
624 global_state,
625 spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
626 using_partitioned: false,
627 memory_ctx: None,
628 spill_state: None,
629 estimated_bytes: 0,
630 }
631 }
632
633 pub fn with_spilling(
635 group_by: Vec<usize>,
636 aggregates: Vec<AggregateExpr>,
637 manager: Arc<SpillManager>,
638 threshold: usize,
639 ) -> Self {
640 let global_state = if group_by.is_empty() {
641 Some(aggregates.iter().map(state_for_expr).collect())
642 } else {
643 None
644 };
645
646 let partitioned = PartitionedState::new(
647 Arc::clone(&manager),
648 256, serialize_group_state,
650 deserialize_group_state,
651 );
652
653 Self {
654 group_by,
655 aggregates,
656 spill_manager: Some(manager),
657 partitioned_groups: Some(partitioned),
658 groups: HashMap::new(),
659 global_state,
660 spill_threshold: threshold,
661 using_partitioned: true,
662 memory_ctx: None,
663 spill_state: None,
664 estimated_bytes: 0,
665 }
666 }
667
668 pub fn with_memory_context(
673 group_by: Vec<usize>,
674 aggregates: Vec<AggregateExpr>,
675 ctx: crate::execution::memory::OperatorMemoryContext,
676 ) -> Self {
677 use super::spill_state::{OperatorConsumerAdapter, OperatorSpillState};
678
679 let global_state = if group_by.is_empty() {
680 Some(aggregates.iter().map(state_for_expr).collect())
681 } else {
682 None
683 };
684
685 let state = std::sync::Arc::new(OperatorSpillState::new(
686 "SpillableAggregatePush".to_string(),
687 ));
688 let adapter =
689 std::sync::Arc::new(OperatorConsumerAdapter::new(std::sync::Arc::clone(&state)));
690 ctx.register_consumer(adapter);
691
692 let partitioned = PartitionedState::new(
694 std::sync::Arc::clone(ctx.spill_manager()),
695 256,
696 serialize_group_state,
697 deserialize_group_state,
698 );
699
700 Self {
701 group_by,
702 aggregates,
703 spill_manager: None,
704 partitioned_groups: Some(partitioned),
705 groups: HashMap::new(),
706 global_state,
707 spill_threshold: DEFAULT_AGGREGATE_SPILL_THRESHOLD,
708 using_partitioned: true,
709 memory_ctx: Some(ctx),
710 spill_state: Some(state),
711 estimated_bytes: 0,
712 }
713 }
714
715 pub fn global(aggregates: Vec<AggregateExpr>) -> Self {
717 Self::new(Vec::new(), aggregates)
718 }
719
720 pub fn with_threshold(mut self, threshold: usize) -> Self {
722 self.spill_threshold = threshold;
723 self
724 }
725
726 fn maybe_spill(&mut self) -> Result<(), OperatorError> {
728 if self.global_state.is_some() {
729 return Ok(());
731 }
732
733 if self.spill_state.is_some() {
734 self.maybe_spill_memory_aware()
736 } else {
737 self.maybe_spill_row_count()
739 }
740 }
741
742 fn maybe_spill_memory_aware(&mut self) -> Result<(), OperatorError> {
744 let should_spill = if let Some(ref state) = self.spill_state {
745 let eviction = state.take_eviction_request().is_some();
746 let pressure = self.memory_ctx.as_ref().map_or(false, |c| c.should_spill());
747
748 let group_count = if let Some(ref partitioned) = self.partitioned_groups {
750 partitioned.total_size()
751 } else {
752 self.groups.len()
753 };
754 let above_minimum = group_count >= AGGREGATE_MIN_BUFFER_GROUPS;
755
756 (eviction || pressure) && above_minimum
757 } else {
758 false
759 };
760
761 if should_spill && let Some(ref mut partitioned) = self.partitioned_groups {
762 partitioned
763 .spill_largest()
764 .map_err(|e| OperatorError::Execution(e.to_string()))?;
765 }
766
767 Ok(())
768 }
769
770 fn maybe_spill_row_count(&mut self) -> Result<(), OperatorError> {
772 if let Some(ref mut partitioned) = self.partitioned_groups {
774 if partitioned.total_size() >= self.spill_threshold {
775 partitioned
776 .spill_largest()
777 .map_err(|e| OperatorError::Execution(e.to_string()))?;
778 }
779 } else if self.groups.len() >= self.spill_threshold {
780 if let Some(ref manager) = self.spill_manager {
783 let mut partitioned = PartitionedState::new(
784 Arc::clone(manager),
785 256,
786 serialize_group_state,
787 deserialize_group_state,
788 );
789
790 for (_key, state) in self.groups.drain() {
792 partitioned
793 .insert(state.key_values.clone(), state)
794 .map_err(|e| OperatorError::Execution(e.to_string()))?;
795 }
796
797 self.partitioned_groups = Some(partitioned);
798 self.using_partitioned = true;
799 }
800 }
801
802 Ok(())
803 }
804
805 fn unregister_consumer(&self) {
807 if let (Some(ctx), Some(state)) = (&self.memory_ctx, &self.spill_state) {
808 ctx.unregister_consumer(state.name());
809 }
810 }
811}
812
813#[cfg(feature = "spill")]
814impl PushOperator for SpillableAggregatePushOperator {
815 fn push(&mut self, chunk: DataChunk, _sink: &mut dyn Sink) -> Result<bool, OperatorError> {
816 if chunk.is_empty() {
817 return Ok(true);
818 }
819
820 for row in chunk.selected_indices() {
821 if self.group_by.is_empty() {
822 if let Some(ref mut accumulators) = self.global_state {
824 for (acc, expr) in accumulators.iter_mut().zip(&self.aggregates) {
825 update_accumulator(acc, expr, &chunk, row);
826 }
827 }
828 } else if self.using_partitioned {
829 if let Some(ref mut partitioned) = self.partitioned_groups {
831 let key_values: Vec<Value> = self
832 .group_by
833 .iter()
834 .map(|&col| {
835 chunk
836 .column(col)
837 .and_then(|c| c.get_value(row))
838 .unwrap_or(Value::Null)
839 })
840 .collect();
841
842 let aggregates = &self.aggregates;
843 let state = partitioned
844 .get_or_insert_with(key_values.clone(), || GroupState {
845 key_values: key_values.clone(),
846 accumulators: aggregates.iter().map(state_for_expr).collect(),
847 })
848 .map_err(|e| OperatorError::Execution(e.to_string()))?;
849
850 for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
851 update_accumulator(acc, expr, &chunk, row);
852 }
853 }
854 } else {
855 let key = GroupKey::from_row(&chunk, row, &self.group_by);
857
858 let state = self.groups.entry(key).or_insert_with(|| {
859 let key_values: Vec<Value> = self
860 .group_by
861 .iter()
862 .map(|&col| {
863 chunk
864 .column(col)
865 .and_then(|c| c.get_value(row))
866 .unwrap_or(Value::Null)
867 })
868 .collect();
869
870 GroupState {
871 key_values,
872 accumulators: self.aggregates.iter().map(state_for_expr).collect(),
873 }
874 });
875
876 for (acc, expr) in state.accumulators.iter_mut().zip(&self.aggregates) {
877 update_accumulator(acc, expr, &chunk, row);
878 }
879 }
880 }
881
882 if let Some(ref spill_state) = self.spill_state {
884 let group_count = if self.using_partitioned {
887 self.partitioned_groups
888 .as_ref()
889 .map_or(0, |p| p.total_size())
890 } else {
891 self.groups.len()
892 };
893 let key_size = self.group_by.len() * std::mem::size_of::<Value>();
894 let acc_size = self.aggregates.len() * 64; self.estimated_bytes = group_count * (key_size + acc_size + 48);
896 spill_state.set_usage(self.estimated_bytes);
897 }
898
899 self.maybe_spill()?;
901
902 Ok(true)
903 }
904
905 fn finalize(&mut self, sink: &mut dyn Sink) -> Result<(), OperatorError> {
906 let num_output_cols = self.group_by.len() + self.aggregates.len();
907 let mut columns: Vec<ValueVector> =
908 (0..num_output_cols).map(|_| ValueVector::new()).collect();
909
910 if self.group_by.is_empty() {
911 if let Some(ref accumulators) = self.global_state {
913 for (i, acc) in accumulators.iter().enumerate() {
914 columns[i].push(acc.finalize());
915 }
916 }
917 } else if self.using_partitioned {
918 if let Some(ref mut partitioned) = self.partitioned_groups {
920 let groups = partitioned
921 .drain_all()
922 .map_err(|e| OperatorError::Execution(e.to_string()))?;
923
924 for (_key, state) in groups {
925 for (i, val) in state.key_values.iter().enumerate() {
927 columns[i].push(val.clone());
928 }
929
930 for (i, acc) in state.accumulators.iter().enumerate() {
932 columns[self.group_by.len() + i].push(acc.finalize());
933 }
934 }
935 }
936 } else {
937 for state in self.groups.values() {
939 for (i, val) in state.key_values.iter().enumerate() {
941 columns[i].push(val.clone());
942 }
943
944 for (i, acc) in state.accumulators.iter().enumerate() {
946 columns[self.group_by.len() + i].push(acc.finalize());
947 }
948 }
949 }
950
951 self.unregister_consumer();
953
954 if !columns.is_empty() && !columns[0].is_empty() {
955 let chunk = DataChunk::new(columns);
956 sink.consume(chunk)?;
957 }
958
959 Ok(())
960 }
961
962 fn preferred_chunk_size(&self) -> ChunkSizeHint {
963 ChunkSizeHint::Default
964 }
965
966 fn name(&self) -> &'static str {
967 "SpillableAggregatePush"
968 }
969}
970
971#[cfg(test)]
972mod tests {
973 use super::*;
974 use crate::execution::operators::accumulator::AggregateFunction;
975 use crate::execution::sink::CollectorSink;
976
977 fn create_test_chunk(values: &[i64]) -> DataChunk {
978 let v: Vec<Value> = values.iter().map(|&i| Value::Int64(i)).collect();
979 let vector = ValueVector::from_values(&v);
980 DataChunk::new(vec![vector])
981 }
982
983 fn create_two_column_chunk(col1: &[i64], col2: &[i64]) -> DataChunk {
984 let v1: Vec<Value> = col1.iter().map(|&i| Value::Int64(i)).collect();
985 let v2: Vec<Value> = col2.iter().map(|&i| Value::Int64(i)).collect();
986 DataChunk::new(vec![
987 ValueVector::from_values(&v1),
988 ValueVector::from_values(&v2),
989 ])
990 }
991
992 #[test]
993 fn test_global_count() {
994 let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
995 let mut sink = CollectorSink::new();
996
997 agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
998 .unwrap();
999 agg.finalize(&mut sink).unwrap();
1000
1001 let chunks = sink.into_chunks();
1002 assert_eq!(chunks.len(), 1);
1003 assert_eq!(
1004 chunks[0].column(0).unwrap().get_value(0),
1005 Some(Value::Int64(5))
1006 );
1007 }
1008
1009 #[test]
1010 fn test_global_sum() {
1011 let mut agg = AggregatePushOperator::global(vec![AggregateExpr::sum(0)]);
1012 let mut sink = CollectorSink::new();
1013
1014 agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
1015 .unwrap();
1016 agg.finalize(&mut sink).unwrap();
1017
1018 let chunks = sink.into_chunks();
1019 assert_eq!(
1021 chunks[0].column(0).unwrap().get_value(0),
1022 Some(Value::Int64(15))
1023 );
1024 }
1025
1026 #[test]
1027 fn test_global_min_max() {
1028 let mut agg =
1029 AggregatePushOperator::global(vec![AggregateExpr::min(0), AggregateExpr::max(0)]);
1030 let mut sink = CollectorSink::new();
1031
1032 agg.push(create_test_chunk(&[3, 1, 4, 1, 5, 9, 2, 6]), &mut sink)
1033 .unwrap();
1034 agg.finalize(&mut sink).unwrap();
1035
1036 let chunks = sink.into_chunks();
1037 assert_eq!(
1038 chunks[0].column(0).unwrap().get_value(0),
1039 Some(Value::Int64(1))
1040 );
1041 assert_eq!(
1042 chunks[0].column(1).unwrap().get_value(0),
1043 Some(Value::Int64(9))
1044 );
1045 }
1046
1047 #[test]
1048 fn test_group_by_sum() {
1049 let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
1051 let mut sink = CollectorSink::new();
1052
1053 agg.push(
1055 create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
1056 &mut sink,
1057 )
1058 .unwrap();
1059 agg.finalize(&mut sink).unwrap();
1060
1061 let chunks = sink.into_chunks();
1062 assert_eq!(chunks[0].len(), 2); }
1064
1065 #[test]
1066 #[cfg(feature = "spill")]
1067 fn test_spillable_aggregate_no_spill() {
1068 let mut agg = SpillableAggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)])
1070 .with_threshold(100);
1071 let mut sink = CollectorSink::new();
1072
1073 agg.push(
1074 create_two_column_chunk(&[1, 1, 2, 2], &[10, 20, 30, 40]),
1075 &mut sink,
1076 )
1077 .unwrap();
1078 agg.finalize(&mut sink).unwrap();
1079
1080 let chunks = sink.into_chunks();
1081 assert_eq!(chunks[0].len(), 2); }
1083
1084 #[test]
1085 #[cfg(feature = "spill")]
1086 fn test_spillable_aggregate_with_spilling() {
1087 use tempfile::TempDir;
1088
1089 let temp_dir = TempDir::new().unwrap();
1090 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1091
1092 let mut agg = SpillableAggregatePushOperator::with_spilling(
1094 vec![0],
1095 vec![AggregateExpr::sum(1)],
1096 manager,
1097 3, );
1099 let mut sink = CollectorSink::new();
1100
1101 for i in 0..10 {
1103 let chunk = create_two_column_chunk(&[i], &[i * 10]);
1104 agg.push(chunk, &mut sink).unwrap();
1105 }
1106 agg.finalize(&mut sink).unwrap();
1107
1108 let chunks = sink.into_chunks();
1109 assert_eq!(chunks.len(), 1);
1110 assert_eq!(chunks[0].len(), 10); let mut sums: Vec<i64> = Vec::new();
1114 for i in 0..chunks[0].len() {
1115 if let Some(Value::Int64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
1116 sums.push(sum);
1117 }
1118 }
1119 sums.sort_unstable();
1120 assert_eq!(sums, vec![0, 10, 20, 30, 40, 50, 60, 70, 80, 90]);
1121 }
1122
1123 #[test]
1124 #[cfg(feature = "spill")]
1125 fn test_spillable_aggregate_global() {
1126 let mut agg = SpillableAggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1128 let mut sink = CollectorSink::new();
1129
1130 agg.push(create_test_chunk(&[1, 2, 3, 4, 5]), &mut sink)
1131 .unwrap();
1132 agg.finalize(&mut sink).unwrap();
1133
1134 let chunks = sink.into_chunks();
1135 assert_eq!(chunks.len(), 1);
1136 assert_eq!(
1137 chunks[0].column(0).unwrap().get_value(0),
1138 Some(Value::Int64(5))
1139 );
1140 }
1141
1142 #[test]
1143 #[cfg(feature = "spill")]
1144 fn test_spillable_aggregate_many_groups() {
1145 use tempfile::TempDir;
1146
1147 let temp_dir = TempDir::new().unwrap();
1148 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1149
1150 let mut agg = SpillableAggregatePushOperator::with_spilling(
1151 vec![0],
1152 vec![AggregateExpr::count_star()],
1153 manager,
1154 10, );
1156 let mut sink = CollectorSink::new();
1157
1158 for i in 0..100 {
1160 let chunk = create_test_chunk(&[i]);
1161 agg.push(chunk, &mut sink).unwrap();
1162 }
1163 agg.finalize(&mut sink).unwrap();
1164
1165 let chunks = sink.into_chunks();
1166 assert_eq!(chunks.len(), 1);
1167 assert_eq!(chunks[0].len(), 100); for i in 0..100 {
1171 if let Some(Value::Int64(count)) = chunks[0].column(1).unwrap().get_value(i) {
1172 assert_eq!(count, 1);
1173 }
1174 }
1175 }
1176
1177 #[test]
1182 fn hash_value_null() {
1183 let h = hash_value(&Value::Null);
1184 assert_ne!(h, 0); }
1186
1187 #[test]
1188 fn hash_value_bool() {
1189 let t = hash_value(&Value::Bool(true));
1190 let f = hash_value(&Value::Bool(false));
1191 assert_ne!(t, f);
1192 }
1193
1194 #[test]
1195 fn hash_value_int64() {
1196 let a = hash_value(&Value::Int64(42));
1197 let b = hash_value(&Value::Int64(43));
1198 assert_ne!(a, b);
1199 }
1200
1201 #[test]
1202 fn hash_value_float64() {
1203 let a = hash_value(&Value::Float64(19.88));
1204 let b = hash_value(&Value::Float64(3.19));
1205 assert_ne!(a, b);
1206 }
1207
1208 #[test]
1209 fn hash_value_string() {
1210 let a = hash_value(&Value::String("hello".into()));
1211 let b = hash_value(&Value::String("world".into()));
1212 assert_ne!(a, b);
1213 }
1214
1215 #[test]
1216 fn hash_value_bytes() {
1217 let a = hash_value(&Value::Bytes(vec![1, 2, 3].into()));
1218 let b = hash_value(&Value::Bytes(vec![4, 5, 6].into()));
1219 assert_ne!(a, b);
1220 }
1221
1222 #[test]
1223 fn hash_value_list() {
1224 let a = hash_value(&Value::List(vec![Value::Int64(1), Value::Int64(2)].into()));
1225 let b = hash_value(&Value::List(vec![Value::Int64(3)].into()));
1226 assert_ne!(a, b);
1227 }
1228
1229 #[test]
1230 fn hash_value_map() {
1231 use grafeo_common::types::PropertyKey;
1232 use std::collections::BTreeMap;
1233 use std::sync::Arc;
1234 let mut map = BTreeMap::new();
1235 map.insert(PropertyKey::new("key"), Value::Int64(42));
1236 let h = hash_value(&Value::Map(Arc::new(map)));
1237 assert_ne!(h, 0);
1238 }
1239
1240 #[test]
1241 fn hash_value_vector() {
1242 let h = hash_value(&Value::Vector(vec![1.0, 2.0, 3.0].into()));
1243 assert_ne!(h, 0);
1244 }
1245
1246 #[test]
1247 fn hash_value_path() {
1248 let h = hash_value(&Value::Path {
1249 nodes: vec![Value::Int64(1), Value::Int64(2)].into(),
1250 edges: vec![Value::Int64(10)].into(),
1251 });
1252 assert_ne!(h, 0);
1253 }
1254
1255 #[test]
1256 fn hash_value_gcounter() {
1257 use std::sync::Arc;
1258 let mut map = std::collections::HashMap::new();
1259 map.insert("replica1".to_string(), 10u64);
1260 let h = hash_value(&Value::GCounter(Arc::new(map)));
1261 assert_ne!(h, 0);
1262 }
1263
1264 #[test]
1265 fn hash_value_on_counter() {
1266 use std::sync::Arc;
1267 let mut pos = std::collections::HashMap::new();
1268 pos.insert("replica1".to_string(), 10u64);
1269 let neg = std::collections::HashMap::new();
1270 let h = hash_value(&Value::OnCounter {
1271 pos: Arc::new(pos),
1272 neg: Arc::new(neg),
1273 });
1274 assert_ne!(h, 0);
1275 }
1276
1277 #[test]
1278 fn hash_value_timestamp() {
1279 use grafeo_common::types::Timestamp;
1280 let h = hash_value(&Value::Timestamp(Timestamp::from_micros(1_700_000_000_000)));
1281 assert_ne!(h, 0);
1282 }
1283
1284 #[test]
1285 fn hash_value_date() {
1286 use grafeo_common::types::Date;
1287 let h = hash_value(&Value::Date(Date::from_days(19000)));
1288 assert_ne!(h, 0);
1289 }
1290
1291 #[test]
1292 fn hash_value_time() {
1293 use grafeo_common::types::Time;
1294 let h = hash_value(&Value::Time(Time::from_hms(12, 0, 0).unwrap()));
1295 assert_ne!(h, 0);
1296 }
1297
1298 #[test]
1299 fn hash_value_duration() {
1300 use grafeo_common::types::Duration;
1301 let h = hash_value(&Value::Duration(Duration::from_days(1)));
1302 assert_ne!(h, 0);
1303 }
1304
1305 #[test]
1306 fn hash_value_zoned_datetime() {
1307 use grafeo_common::types::{Timestamp, ZonedDatetime};
1308 let zdt =
1309 ZonedDatetime::from_timestamp_offset(Timestamp::from_micros(1_700_000_000_000), 3600);
1310 let h = hash_value(&Value::ZonedDatetime(zdt));
1311 assert_ne!(h, 0);
1312 }
1313
1314 #[test]
1319 fn aggregate_state_last_returns_last_value() {
1320 let mut state = AggregateState::new(AggregateFunction::Last, false, None, None);
1321 state.update(Some(Value::Int64(10)));
1322 state.update(Some(Value::Int64(20)));
1323 assert_eq!(state.finalize(), Value::Int64(20));
1324 }
1325
1326 #[test]
1327 fn aggregate_state_collect_returns_list() {
1328 let mut state = AggregateState::new(AggregateFunction::Collect, false, None, None);
1329 state.update(Some(Value::Int64(1)));
1330 state.update(Some(Value::Int64(2)));
1331 assert_eq!(
1332 state.finalize(),
1333 Value::List(vec![Value::Int64(1), Value::Int64(2)].into())
1334 );
1335 }
1336
1337 #[test]
1338 fn aggregate_state_stdev_returns_value() {
1339 let mut state = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1340 state.update(Some(Value::Float64(2.0)));
1341 state.update(Some(Value::Float64(4.0)));
1342 state.update(Some(Value::Float64(6.0)));
1343 let result = state.finalize();
1344 assert!(matches!(result, Value::Float64(_)));
1345 }
1346
1347 #[test]
1348 fn aggregate_state_first_returns_first_value() {
1349 let mut state = AggregateState::new(AggregateFunction::First, false, None, None);
1350 state.update(Some(Value::Int64(10)));
1351 state.update(Some(Value::Int64(20)));
1352 assert_eq!(state.finalize(), Value::Int64(10));
1353 }
1354
1355 #[test]
1356 fn aggregate_state_avg_empty_returns_null() {
1357 let state = AggregateState::new(AggregateFunction::Avg, false, None, None);
1358 assert_eq!(state.finalize(), Value::Null);
1359 }
1360
1361 #[test]
1362 fn aggregate_state_sum_empty_returns_null() {
1363 let state = AggregateState::new(AggregateFunction::Sum, false, None, None);
1364 assert_eq!(state.finalize(), Value::Null);
1365 }
1366
1367 #[test]
1368 fn aggregate_state_min_max_empty_returns_null() {
1369 let min = AggregateState::new(AggregateFunction::Min, false, None, None);
1370 let max = AggregateState::new(AggregateFunction::Max, false, None, None);
1371 assert_eq!(min.finalize(), Value::Null);
1372 assert_eq!(max.finalize(), Value::Null);
1373 }
1374
1375 #[test]
1376 fn aggregate_state_count_non_null_skips_nulls() {
1377 let mut state = AggregateState::new(AggregateFunction::CountNonNull, false, None, None);
1382 state.update(Some(Value::Int64(5)));
1385 assert_eq!(state.finalize(), Value::Int64(1));
1386 }
1387
1388 #[test]
1389 fn test_empty_chunk_returns_ok() {
1390 let mut agg = AggregatePushOperator::global(vec![AggregateExpr::count_star()]);
1391 let mut sink = CollectorSink::new();
1392 let empty = DataChunk::new(vec![ValueVector::new()]);
1393 let result = agg.push(empty, &mut sink).unwrap();
1394 assert!(result);
1395 }
1396
1397 #[test]
1402 #[cfg(feature = "spill")]
1403 fn spill_roundtrip_count() {
1404 let state = GroupState {
1405 key_values: vec![Value::String("grp".into())],
1406 accumulators: vec![AggregateState::Count(42)],
1407 };
1408 let mut buf = Vec::new();
1409 serialize_group_state(&state, &mut buf).unwrap();
1410 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1411 assert_eq!(restored.key_values, vec![Value::String("grp".into())]);
1412 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(42));
1413 }
1414
1415 #[test]
1416 #[cfg(feature = "spill")]
1417 fn spill_roundtrip_sum_int() {
1418 let state = GroupState {
1419 key_values: vec![Value::Int64(1)],
1420 accumulators: vec![AggregateState::SumInt(100, 5)],
1421 };
1422 let mut buf = Vec::new();
1423 serialize_group_state(&state, &mut buf).unwrap();
1424 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1425 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(100));
1426 }
1427
1428 #[test]
1429 #[cfg(feature = "spill")]
1430 fn spill_roundtrip_sum_float() {
1431 let state = GroupState {
1432 key_values: vec![Value::Int64(1)],
1433 accumulators: vec![AggregateState::SumFloat(3.125, 0.0, 2)],
1434 };
1435 let mut buf = Vec::new();
1436 serialize_group_state(&state, &mut buf).unwrap();
1437 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1438 assert_eq!(restored.accumulators[0].finalize(), Value::Float64(3.125));
1439 }
1440
1441 #[test]
1442 #[cfg(feature = "spill")]
1443 fn spill_roundtrip_avg() {
1444 let state = GroupState {
1445 key_values: vec![Value::Int64(1)],
1446 accumulators: vec![AggregateState::Avg(30.0, 3)],
1447 };
1448 let mut buf = Vec::new();
1449 serialize_group_state(&state, &mut buf).unwrap();
1450 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1451 assert_eq!(restored.accumulators[0].finalize(), Value::Float64(10.0));
1452 }
1453
1454 #[test]
1455 #[cfg(feature = "spill")]
1456 fn spill_roundtrip_min() {
1457 let state = GroupState {
1458 key_values: vec![Value::Int64(1)],
1459 accumulators: vec![AggregateState::Min(Some(Value::Int64(7)))],
1460 };
1461 let mut buf = Vec::new();
1462 serialize_group_state(&state, &mut buf).unwrap();
1463 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1464 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(7));
1465 }
1466
1467 #[test]
1468 #[cfg(feature = "spill")]
1469 fn spill_roundtrip_min_none() {
1470 let state = GroupState {
1471 key_values: vec![Value::Int64(1)],
1472 accumulators: vec![AggregateState::Min(None)],
1473 };
1474 let mut buf = Vec::new();
1475 serialize_group_state(&state, &mut buf).unwrap();
1476 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1477 assert_eq!(restored.accumulators[0].finalize(), Value::Null);
1478 }
1479
1480 #[test]
1481 #[cfg(feature = "spill")]
1482 fn spill_roundtrip_max() {
1483 let state = GroupState {
1484 key_values: vec![Value::Int64(1)],
1485 accumulators: vec![AggregateState::Max(Some(Value::Int64(99)))],
1486 };
1487 let mut buf = Vec::new();
1488 serialize_group_state(&state, &mut buf).unwrap();
1489 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1490 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(99));
1491 }
1492
1493 #[test]
1494 #[cfg(feature = "spill")]
1495 fn spill_roundtrip_first() {
1496 let state = GroupState {
1497 key_values: vec![Value::Int64(1)],
1498 accumulators: vec![AggregateState::First(Some(Value::String("hello".into())))],
1499 };
1500 let mut buf = Vec::new();
1501 serialize_group_state(&state, &mut buf).unwrap();
1502 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1503 assert_eq!(
1504 restored.accumulators[0].finalize(),
1505 Value::String("hello".into())
1506 );
1507 }
1508
1509 #[test]
1510 #[cfg(feature = "spill")]
1511 fn spill_roundtrip_last() {
1512 let state = GroupState {
1513 key_values: vec![Value::Int64(1)],
1514 accumulators: vec![AggregateState::Last(Some(Value::Float64(2.75)))],
1515 };
1516 let mut buf = Vec::new();
1517 serialize_group_state(&state, &mut buf).unwrap();
1518 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1519 assert_eq!(restored.accumulators[0].finalize(), Value::Float64(2.75));
1520 }
1521
1522 #[test]
1523 #[cfg(feature = "spill")]
1524 fn spill_roundtrip_collect() {
1525 let state = GroupState {
1526 key_values: vec![Value::Int64(1)],
1527 accumulators: vec![AggregateState::Collect(vec![
1528 Value::Int64(10),
1529 Value::Int64(20),
1530 Value::Int64(30),
1531 ])],
1532 };
1533 let mut buf = Vec::new();
1534 serialize_group_state(&state, &mut buf).unwrap();
1535 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1536 assert_eq!(
1537 restored.accumulators[0].finalize(),
1538 Value::List(vec![Value::Int64(10), Value::Int64(20), Value::Int64(30)].into())
1539 );
1540 }
1541
1542 #[test]
1543 #[cfg(feature = "spill")]
1544 fn spill_roundtrip_all_variants_combined() {
1545 let state = GroupState {
1547 key_values: vec![Value::String("combined".into()), Value::Int64(42)],
1548 accumulators: vec![
1549 AggregateState::Count(10),
1550 AggregateState::SumInt(50, 5),
1551 AggregateState::SumFloat(7.5, 0.0, 3),
1552 AggregateState::Avg(20.0, 4),
1553 AggregateState::Min(Some(Value::Int64(1))),
1554 AggregateState::Max(Some(Value::Int64(99))),
1555 AggregateState::First(Some(Value::String("first".into()))),
1556 AggregateState::Last(Some(Value::String("last".into()))),
1557 AggregateState::Collect(vec![Value::Int64(1), Value::Int64(2)]),
1558 ],
1559 };
1560 let mut buf = Vec::new();
1561 serialize_group_state(&state, &mut buf).unwrap();
1562 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1563
1564 assert_eq!(restored.key_values.len(), 2);
1565 assert_eq!(restored.key_values[0], Value::String("combined".into()));
1566 assert_eq!(restored.key_values[1], Value::Int64(42));
1567 assert_eq!(restored.accumulators.len(), 9);
1568
1569 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(10));
1570 assert_eq!(restored.accumulators[1].finalize(), Value::Int64(50));
1571 assert_eq!(restored.accumulators[2].finalize(), Value::Float64(7.5));
1572 assert_eq!(restored.accumulators[3].finalize(), Value::Float64(5.0));
1573 assert_eq!(restored.accumulators[4].finalize(), Value::Int64(1));
1574 assert_eq!(restored.accumulators[5].finalize(), Value::Int64(99));
1575 assert_eq!(
1576 restored.accumulators[6].finalize(),
1577 Value::String("first".into())
1578 );
1579 assert_eq!(
1580 restored.accumulators[7].finalize(),
1581 Value::String("last".into())
1582 );
1583 assert_eq!(
1584 restored.accumulators[8].finalize(),
1585 Value::List(vec![Value::Int64(1), Value::Int64(2)].into())
1586 );
1587 }
1588
1589 #[test]
1594 #[cfg(feature = "spill")]
1595 fn spill_roundtrip_count_distinct() {
1596 use crate::execution::operators::accumulator::HashableValue;
1597 use std::collections::HashSet;
1598
1599 let mut seen = HashSet::new();
1600 seen.insert(HashableValue::from(Value::Int64(1)));
1601 seen.insert(HashableValue::from(Value::Int64(2)));
1602 seen.insert(HashableValue::from(Value::Int64(3)));
1603 let state = GroupState {
1604 key_values: vec![Value::Int64(1)],
1605 accumulators: vec![AggregateState::CountDistinct(3, seen)],
1606 };
1607 let mut buf = Vec::new();
1608 serialize_group_state(&state, &mut buf).unwrap();
1609 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1610 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(3));
1612 }
1613
1614 #[test]
1615 #[cfg(feature = "spill")]
1616 fn spill_roundtrip_avg_distinct() {
1617 use crate::execution::operators::accumulator::HashableValue;
1618 use std::collections::HashSet;
1619
1620 let mut seen = HashSet::new();
1621 seen.insert(HashableValue::from(Value::Float64(2.0)));
1622 seen.insert(HashableValue::from(Value::Float64(4.0)));
1623 let state = GroupState {
1624 key_values: vec![Value::Int64(1)],
1625 accumulators: vec![AggregateState::AvgDistinct(6.0, 2, seen)],
1626 };
1627 let mut buf = Vec::new();
1628 serialize_group_state(&state, &mut buf).unwrap();
1629 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1630 assert_eq!(restored.accumulators[0].finalize(), Value::Float64(3.0));
1631 }
1632
1633 #[test]
1634 #[cfg(feature = "spill")]
1635 fn spill_roundtrip_collect_distinct() {
1636 use crate::execution::operators::accumulator::HashableValue;
1637 use std::collections::HashSet;
1638
1639 let mut seen = HashSet::new();
1640 seen.insert(HashableValue::from(Value::Int64(10)));
1641 seen.insert(HashableValue::from(Value::Int64(20)));
1642 let state = GroupState {
1643 key_values: vec![Value::Int64(1)],
1644 accumulators: vec![AggregateState::CollectDistinct(
1645 vec![Value::Int64(10), Value::Int64(20)],
1646 seen,
1647 )],
1648 };
1649 let mut buf = Vec::new();
1650 serialize_group_state(&state, &mut buf).unwrap();
1651 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1652 let result = restored.accumulators[0].finalize();
1654 assert!(matches!(result, Value::List(_)));
1655 }
1656
1657 #[test]
1662 #[cfg(feature = "spill")]
1663 fn spill_roundtrip_stddev() {
1664 let mut acc = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1666 acc.update(Some(Value::Float64(2.0)));
1667 acc.update(Some(Value::Float64(4.0)));
1668 acc.update(Some(Value::Float64(6.0)));
1669 let expected = acc.finalize();
1670
1671 let state = GroupState {
1672 key_values: vec![Value::Int64(1)],
1673 accumulators: vec![acc],
1674 };
1675 let mut buf = Vec::new();
1676 serialize_group_state(&state, &mut buf).unwrap();
1677 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1678 assert_eq!(restored.accumulators[0].finalize(), expected);
1680 }
1681
1682 #[test]
1683 #[cfg(feature = "spill")]
1684 fn spill_roundtrip_percentile_disc() {
1685 let state = GroupState {
1686 key_values: vec![Value::Int64(1)],
1687 accumulators: vec![AggregateState::PercentileDisc {
1688 values: vec![1.0, 2.0, 3.0, 4.0, 5.0],
1689 percentile: 0.5,
1690 }],
1691 };
1692 let expected = state.accumulators[0].finalize();
1693 let mut buf = Vec::new();
1694 serialize_group_state(&state, &mut buf).unwrap();
1695 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1696 assert_eq!(restored.accumulators[0].finalize(), expected);
1697 }
1698
1699 #[test]
1700 #[cfg(feature = "spill")]
1701 fn spill_roundtrip_group_concat() {
1702 let state = GroupState {
1703 key_values: vec![Value::Int64(1)],
1704 accumulators: vec![AggregateState::GroupConcat(
1705 vec!["alix".to_string(), "gus".to_string(), "vincent".to_string()],
1706 ", ".to_string(),
1707 )],
1708 };
1709 let expected = state.accumulators[0].finalize();
1710 let mut buf = Vec::new();
1711 serialize_group_state(&state, &mut buf).unwrap();
1712 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1713 assert_eq!(restored.accumulators[0].finalize(), expected);
1714 }
1715
1716 #[test]
1721 #[cfg(feature = "spill")]
1722 fn test_spillable_aggregate_collect() {
1723 use tempfile::TempDir;
1724
1725 let temp_dir = TempDir::new().unwrap();
1726 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1727
1728 let mut agg = SpillableAggregatePushOperator::with_spilling(
1729 vec![0],
1730 vec![AggregateExpr::collect(1)],
1731 manager,
1732 3, );
1734 let mut sink = CollectorSink::new();
1735
1736 agg.push(
1738 create_two_column_chunk(&[1, 2, 1, 2], &[10, 30, 20, 40]),
1739 &mut sink,
1740 )
1741 .unwrap();
1742 for i in 3..10 {
1744 agg.push(create_two_column_chunk(&[i], &[i * 10]), &mut sink)
1745 .unwrap();
1746 }
1747 agg.finalize(&mut sink).unwrap();
1748
1749 let chunks = sink.into_chunks();
1750 assert_eq!(chunks.len(), 1);
1751 assert_eq!(chunks[0].len(), 9); let mut found_group1 = false;
1755 for row in 0..chunks[0].len() {
1756 if let Some(Value::Int64(1)) = chunks[0].column(0).unwrap().get_value(row) {
1757 let collected = chunks[0].column(1).unwrap().get_value(row).unwrap();
1758 if let Value::List(list) = collected {
1759 assert_eq!(list.len(), 2);
1760 assert!(list.contains(&Value::Int64(10)));
1761 assert!(list.contains(&Value::Int64(20)));
1762 found_group1 = true;
1763 }
1764 }
1765 }
1766 assert!(found_group1, "Group 1 with collected values not found");
1767 }
1768
1769 #[test]
1774 #[cfg(feature = "spill")]
1775 fn test_spillable_aggregate_min_max() {
1776 use tempfile::TempDir;
1777
1778 let temp_dir = TempDir::new().unwrap();
1779 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1780
1781 let mut agg = SpillableAggregatePushOperator::with_spilling(
1782 vec![0],
1783 vec![AggregateExpr::min(1), AggregateExpr::max(1)],
1784 manager,
1785 3, );
1787 let mut sink = CollectorSink::new();
1788
1789 agg.push(
1792 create_two_column_chunk(&[1, 2, 1, 2, 1], &[50, 20, 10, 40, 30]),
1793 &mut sink,
1794 )
1795 .unwrap();
1796
1797 for i in 3..10 {
1799 agg.push(create_two_column_chunk(&[i], &[i * 10]), &mut sink)
1800 .unwrap();
1801 }
1802 agg.finalize(&mut sink).unwrap();
1803
1804 let chunks = sink.into_chunks();
1805 assert_eq!(chunks.len(), 1);
1806 assert_eq!(chunks[0].len(), 9); let mut found_group1 = false;
1810 for row in 0..chunks[0].len() {
1811 if let Some(Value::Int64(1)) = chunks[0].column(0).unwrap().get_value(row) {
1812 assert_eq!(
1813 chunks[0].column(1).unwrap().get_value(row),
1814 Some(Value::Int64(10))
1815 );
1816 assert_eq!(
1817 chunks[0].column(2).unwrap().get_value(row),
1818 Some(Value::Int64(50))
1819 );
1820 found_group1 = true;
1821 }
1822 }
1823 assert!(found_group1, "Group 1 with min/max not found");
1824
1825 let mut found_group2 = false;
1827 for row in 0..chunks[0].len() {
1828 if let Some(Value::Int64(2)) = chunks[0].column(0).unwrap().get_value(row) {
1829 assert_eq!(
1830 chunks[0].column(1).unwrap().get_value(row),
1831 Some(Value::Int64(20))
1832 );
1833 assert_eq!(
1834 chunks[0].column(2).unwrap().get_value(row),
1835 Some(Value::Int64(40))
1836 );
1837 found_group2 = true;
1838 }
1839 }
1840 assert!(found_group2, "Group 2 with min/max not found");
1841 }
1842
1843 #[test]
1848 fn test_aggregate_count_non_null() {
1849 let expr = AggregateExpr::count(0);
1851 let mut agg = AggregatePushOperator::global(vec![expr]);
1852 let mut sink = CollectorSink::new();
1853
1854 let mut col = ValueVector::new();
1856 col.push(Value::Int64(10)); col.push(Value::Null);
1858 col.push(Value::Int64(30)); col.push(Value::Null);
1860 col.push(Value::Int64(50)); let chunk = DataChunk::new(vec![col]);
1862
1863 agg.push(chunk, &mut sink).unwrap();
1864 agg.finalize(&mut sink).unwrap();
1865
1866 let chunks = sink.into_chunks();
1867 assert_eq!(chunks.len(), 1);
1868 assert_eq!(
1870 chunks[0].column(0).unwrap().get_value(0),
1871 Some(Value::Int64(3))
1872 );
1873 }
1874
1875 #[test]
1876 fn test_grouped_aggregate_empty_groups() {
1877 let mut agg = AggregatePushOperator::new(vec![0], vec![AggregateExpr::sum(1)]);
1879 let mut sink = CollectorSink::new();
1880
1881 let empty = DataChunk::new(vec![ValueVector::new(), ValueVector::new()]);
1883 agg.push(empty, &mut sink).unwrap();
1884 agg.finalize(&mut sink).unwrap();
1885
1886 let chunks = sink.into_chunks();
1887 assert!(chunks.is_empty());
1889 }
1890
1891 #[test]
1892 #[cfg(feature = "spill")]
1893 fn test_spillable_aggregate_threshold_transition() {
1894 use tempfile::TempDir;
1898
1899 let temp_dir = TempDir::new().unwrap();
1900 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
1901
1902 let mut agg = SpillableAggregatePushOperator::with_spilling(
1904 vec![0],
1905 vec![AggregateExpr::count_star()],
1906 manager,
1907 2, );
1909 let mut sink = CollectorSink::new();
1910
1911 for i in 0..5 {
1913 agg.push(create_test_chunk(&[i]), &mut sink).unwrap();
1914 }
1915 agg.finalize(&mut sink).unwrap();
1916
1917 let chunks = sink.into_chunks();
1918 assert_eq!(chunks.len(), 1);
1919 assert_eq!(chunks[0].len(), 5);
1920 }
1921
1922 #[test]
1923 #[cfg(feature = "spill")]
1924 fn spill_finalized_frozen_ignores_further_updates() {
1925 let mut acc = AggregateState::new(AggregateFunction::StdDev, false, None, None);
1926 acc.update(Some(Value::Float64(2.0)));
1927 acc.update(Some(Value::Float64(4.0)));
1928 acc.update(Some(Value::Float64(6.0)));
1929 let expected = acc.finalize();
1930
1931 let state = GroupState {
1932 key_values: vec![Value::Int64(1)],
1933 accumulators: vec![acc],
1934 };
1935 let mut buf = Vec::new();
1936 serialize_group_state(&state, &mut buf).unwrap();
1937 let mut restored = deserialize_group_state(&mut &buf[..]).unwrap();
1938
1939 assert!(matches!(
1940 restored.accumulators[0],
1941 AggregateState::Frozen(_)
1942 ));
1943
1944 restored.accumulators[0].update(Some(Value::Float64(100.0)));
1945 restored.accumulators[0].update(Some(Value::Float64(200.0)));
1946
1947 assert_eq!(restored.accumulators[0].finalize(), expected);
1948 }
1949
1950 #[test]
1955 #[cfg(feature = "spill")]
1956 fn test_serialize_deserialize_sum_state() {
1957 let state = GroupState {
1959 key_values: vec![Value::String("Alix".into())],
1960 accumulators: vec![
1961 AggregateState::SumInt(42, 3),
1962 AggregateState::SumFloat(2.72, 0.001, 2),
1963 ],
1964 };
1965 let mut buf = Vec::new();
1966 serialize_group_state(&state, &mut buf).unwrap();
1967 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1968
1969 assert_eq!(restored.key_values, vec![Value::String("Alix".into())]);
1970 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(42));
1971 assert_eq!(restored.accumulators[1].finalize(), Value::Float64(2.72));
1972 }
1973
1974 #[test]
1975 #[cfg(feature = "spill")]
1976 fn test_serialize_deserialize_avg_state() {
1977 let state = GroupState {
1979 key_values: vec![Value::String("Gus".into())],
1980 accumulators: vec![AggregateState::Avg(30.0, 6)],
1981 };
1982 let mut buf = Vec::new();
1983 serialize_group_state(&state, &mut buf).unwrap();
1984 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
1985
1986 assert_eq!(restored.key_values, vec![Value::String("Gus".into())]);
1987 assert_eq!(restored.accumulators[0].finalize(), Value::Float64(5.0));
1988 }
1989
1990 #[test]
1991 #[cfg(feature = "spill")]
1992 fn test_serialize_deserialize_count_state() {
1993 let state = GroupState {
1994 key_values: vec![Value::String("Vincent".into())],
1995 accumulators: vec![AggregateState::Count(17)],
1996 };
1997 let mut buf = Vec::new();
1998 serialize_group_state(&state, &mut buf).unwrap();
1999 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
2000
2001 assert_eq!(restored.key_values, vec![Value::String("Vincent".into())]);
2002 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(17));
2003 }
2004
2005 #[test]
2006 #[cfg(feature = "spill")]
2007 fn test_serialize_deserialize_min_max_state() {
2008 let state = GroupState {
2010 key_values: vec![Value::String("Jules".into())],
2011 accumulators: vec![
2012 AggregateState::Min(Some(Value::String("Amsterdam".into()))),
2013 AggregateState::Max(Some(Value::Float64(99.9))),
2014 AggregateState::Min(None),
2015 AggregateState::Max(None),
2016 ],
2017 };
2018 let mut buf = Vec::new();
2019 serialize_group_state(&state, &mut buf).unwrap();
2020 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
2021
2022 assert_eq!(
2023 restored.accumulators[0].finalize(),
2024 Value::String("Amsterdam".into())
2025 );
2026 assert_eq!(restored.accumulators[1].finalize(), Value::Float64(99.9));
2027 assert_eq!(restored.accumulators[2].finalize(), Value::Null);
2029 assert_eq!(restored.accumulators[3].finalize(), Value::Null);
2030 }
2031
2032 #[test]
2033 #[cfg(feature = "spill")]
2034 fn test_serialize_deserialize_collect_state() {
2035 let state = GroupState {
2037 key_values: vec![Value::String("Mia".into())],
2038 accumulators: vec![AggregateState::Collect(vec![
2039 Value::Int64(1),
2040 Value::String("Berlin".into()),
2041 Value::Float64(2.5),
2042 Value::Bool(true),
2043 ])],
2044 };
2045 let mut buf = Vec::new();
2046 serialize_group_state(&state, &mut buf).unwrap();
2047 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
2048
2049 let result = restored.accumulators[0].finalize();
2050 if let Value::List(list) = result {
2051 assert_eq!(list.len(), 4);
2052 assert_eq!(list[0], Value::Int64(1));
2053 assert_eq!(list[1], Value::String("Berlin".into()));
2054 assert_eq!(list[2], Value::Float64(2.5));
2055 assert_eq!(list[3], Value::Bool(true));
2056 } else {
2057 panic!("expected List, got {result:?}");
2058 }
2059 }
2060
2061 #[test]
2062 #[cfg(feature = "spill")]
2063 fn test_serialize_deserialize_count_distinct() {
2064 use crate::execution::operators::accumulator::HashableValue;
2065 use std::collections::HashSet;
2066
2067 let mut seen = HashSet::new();
2068 seen.insert(HashableValue::from(Value::String("Paris".into())));
2069 seen.insert(HashableValue::from(Value::String("Prague".into())));
2070 seen.insert(HashableValue::from(Value::String("Barcelona".into())));
2071 let state = GroupState {
2072 key_values: vec![Value::String("Butch".into())],
2073 accumulators: vec![AggregateState::CountDistinct(3, seen)],
2074 };
2075 let mut buf = Vec::new();
2076 serialize_group_state(&state, &mut buf).unwrap();
2077 let restored = deserialize_group_state(&mut &buf[..]).unwrap();
2078
2079 assert_eq!(restored.accumulators[0].finalize(), Value::Int64(3));
2081 assert!(
2082 matches!(restored.accumulators[0], AggregateState::Frozen(_)),
2083 "DISTINCT should be deserialized as Frozen"
2084 );
2085 }
2086
2087 #[test]
2092 fn test_global_aggregate_empty_input() {
2093 let mut agg = AggregatePushOperator::global(vec![
2095 AggregateExpr::count_star(),
2096 AggregateExpr::sum(0),
2097 AggregateExpr::min(0),
2098 AggregateExpr::max(0),
2099 ]);
2100 let mut sink = CollectorSink::new();
2101
2102 agg.finalize(&mut sink).unwrap();
2104
2105 let chunks = sink.into_chunks();
2106 assert_eq!(chunks.len(), 1);
2107 assert_eq!(
2109 chunks[0].column(0).unwrap().get_value(0),
2110 Some(Value::Int64(0))
2111 );
2112 assert_eq!(chunks[0].column(1).unwrap().get_value(0), Some(Value::Null));
2114 assert_eq!(chunks[0].column(2).unwrap().get_value(0), Some(Value::Null));
2116 assert_eq!(chunks[0].column(3).unwrap().get_value(0), Some(Value::Null));
2118 }
2119
2120 #[test]
2125 #[cfg(feature = "spill")]
2126 fn test_spillable_aggregate_memory_pressure() {
2127 use tempfile::TempDir;
2128
2129 let temp_dir = TempDir::new().unwrap();
2130 let manager = Arc::new(SpillManager::new(temp_dir.path()).unwrap());
2131
2132 let mut agg = SpillableAggregatePushOperator::with_spilling(
2134 vec![0],
2135 vec![AggregateExpr::sum(1)],
2136 Arc::clone(&manager),
2137 2,
2138 );
2139 let mut sink = CollectorSink::new();
2140
2141 for i in 0..20 {
2143 let chunk = create_two_column_chunk(&[i], &[i * 5]);
2144 agg.push(chunk, &mut sink).unwrap();
2145 }
2146
2147 assert!(
2149 manager.active_file_count() > 0,
2150 "expected spill files to be created under memory pressure"
2151 );
2152
2153 agg.finalize(&mut sink).unwrap();
2154
2155 let chunks = sink.into_chunks();
2156 assert_eq!(chunks.len(), 1);
2157 assert_eq!(chunks[0].len(), 20);
2158
2159 let mut sums: Vec<i64> = Vec::new();
2161 for i in 0..chunks[0].len() {
2162 if let Some(Value::Int64(sum)) = chunks[0].column(1).unwrap().get_value(i) {
2163 sums.push(sum);
2164 }
2165 }
2166 sums.sort_unstable();
2167 let expected: Vec<i64> = (0..20).map(|i| i * 5).collect();
2168 assert_eq!(sums, expected);
2169 }
2170}