Skip to main content

grafeo_core/execution/
collector.rs

1//! Generic collector trait for parallel aggregation.
2//!
3//! Collectors provide a clean separation between what data to aggregate
4//! and how to execute the aggregation in parallel. The pattern is inspired
5//! by Tantivy's collector architecture.
6//!
7//! # Pattern
8//!
9//! 1. Create partition-local collectors (one per worker thread)
10//! 2. Each collector processes its partition independently (no shared state)
11//! 3. Merge all partition results into a final result
12//!
13//! # Example
14//!
15//! ```no_run
16//! use grafeo_core::execution::collector::{Collector, PartitionCollector, CountCollector};
17//! use grafeo_core::execution::DataChunk;
18//!
19//! # fn example(partitions: Vec<Vec<DataChunk>>) -> Result<(), grafeo_core::execution::operators::OperatorError> {
20//! let collector = CountCollector;
21//!
22//! // In parallel execution:
23//! let mut partition_collectors: Vec<_> = (0..4)
24//!     .map(|id| collector.for_partition(id))
25//!     .collect();
26//!
27//! // Each partition processes its chunks
28//! for (partition, chunks) in partitions.into_iter().enumerate() {
29//!     for chunk in chunks {
30//!         partition_collectors[partition].collect(&chunk)?;
31//!     }
32//! }
33//!
34//! // Merge results
35//! let fruits: Vec<_> = partition_collectors.into_iter()
36//!     .map(|c| c.harvest())
37//!     .collect();
38//! let total = collector.merge(fruits);
39//! # Ok(())
40//! # }
41//! ```
42
43use super::chunk::DataChunk;
44use super::operators::OperatorError;
45
46/// A collector that aggregates results from parallel execution.
47///
48/// Pattern: Create partition-local collectors, process independently,
49/// then merge results. No shared mutable state during collection.
50pub trait Collector: Sync {
51    /// Final result type after merging all partitions.
52    type Fruit: Send;
53
54    /// Partition-local collector type.
55    type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
56
57    /// Creates a collector for a single partition (called per-thread).
58    fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
59
60    /// Merges results from all partitions (called once at the end).
61    fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
62}
63
64/// Per-partition collector - processes chunks locally.
65///
66/// Each partition collector is created by [`Collector::for_partition`]
67/// and processes data independently. This enables lock-free parallel
68/// execution.
69pub trait PartitionCollector: Send {
70    /// Result type produced by this partition.
71    type Fruit: Send;
72
73    /// Processes a batch of data.
74    ///
75    /// Called repeatedly with chunks from this partition.
76    ///
77    /// # Errors
78    ///
79    /// Returns `Err` if the chunk cannot be processed (e.g., type mismatch).
80    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
81
82    /// Finalizes and returns the result for this partition.
83    ///
84    /// Called once after all chunks have been processed.
85    fn harvest(self) -> Self::Fruit;
86}
87
88// ============================================================================
89// Built-in Collectors
90// ============================================================================
91
92/// Counts rows across all partitions.
93///
94/// # Example
95///
96/// ```no_run
97/// use grafeo_core::execution::collector::{Collector, PartitionCollector, CountCollector};
98/// use grafeo_core::execution::DataChunk;
99///
100/// # fn example(chunk1: DataChunk, chunk2: DataChunk) -> Result<(), grafeo_core::execution::operators::OperatorError> {
101/// let collector = CountCollector;
102/// let mut pc = collector.for_partition(0);
103///
104/// pc.collect(&chunk1)?;
105/// pc.collect(&chunk2)?;
106///
107/// let count = pc.harvest();
108/// # Ok(())
109/// # }
110/// ```
111#[derive(Debug, Clone, Copy, Default)]
112pub struct CountCollector;
113
114impl Collector for CountCollector {
115    type Fruit = u64;
116    type PartitionCollector = CountPartitionCollector;
117
118    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
119        CountPartitionCollector { count: 0 }
120    }
121
122    fn merge(&self, fruits: Vec<u64>) -> u64 {
123        fruits.into_iter().sum()
124    }
125}
126
127/// Partition-local counter.
128pub struct CountPartitionCollector {
129    count: u64,
130}
131
132impl PartitionCollector for CountPartitionCollector {
133    type Fruit = u64;
134
135    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
136        self.count += chunk.len() as u64;
137        Ok(())
138    }
139
140    fn harvest(self) -> u64 {
141        self.count
142    }
143}
144
145/// Collects all chunks (materializes the entire result).
146///
147/// Use this when you need all the data, not just an aggregate.
148/// Be careful with large datasets - this can consume significant memory.
149#[derive(Debug, Clone, Default)]
150pub struct MaterializeCollector;
151
152impl Collector for MaterializeCollector {
153    type Fruit = Vec<DataChunk>;
154    type PartitionCollector = MaterializePartitionCollector;
155
156    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
157        MaterializePartitionCollector { chunks: Vec::new() }
158    }
159
160    fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
161        let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
162        let mut result = Vec::with_capacity(total_chunks);
163        for fruit in &mut fruits {
164            result.append(fruit);
165        }
166        result
167    }
168}
169
170/// Partition-local materializer.
171pub struct MaterializePartitionCollector {
172    chunks: Vec<DataChunk>,
173}
174
175impl PartitionCollector for MaterializePartitionCollector {
176    type Fruit = Vec<DataChunk>;
177
178    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
179        self.chunks.push(chunk.clone());
180        Ok(())
181    }
182
183    fn harvest(self) -> Vec<DataChunk> {
184        self.chunks
185    }
186}
187
188/// Collects first N rows across all partitions.
189///
190/// Stops collecting once the limit is reached (per partition).
191/// Final merge ensures exactly `limit` rows are returned.
192#[derive(Debug, Clone)]
193pub struct LimitCollector {
194    limit: usize,
195}
196
197impl LimitCollector {
198    /// Creates a collector that limits output to `limit` rows.
199    #[must_use]
200    pub fn new(limit: usize) -> Self {
201        Self { limit }
202    }
203}
204
205impl Collector for LimitCollector {
206    type Fruit = (Vec<DataChunk>, usize);
207    type PartitionCollector = LimitPartitionCollector;
208
209    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
210        LimitPartitionCollector {
211            chunks: Vec::new(),
212            limit: self.limit,
213            collected: 0,
214        }
215    }
216
217    fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
218        let mut result = Vec::new();
219        let mut total = 0;
220
221        for (chunks, _) in fruits {
222            for chunk in chunks {
223                if total >= self.limit {
224                    break;
225                }
226                let take = (self.limit - total).min(chunk.len());
227                if take < chunk.len() {
228                    result.push(chunk.slice(0, take));
229                } else {
230                    result.push(chunk);
231                }
232                total += take;
233            }
234            if total >= self.limit {
235                break;
236            }
237        }
238
239        (result, total)
240    }
241}
242
243/// Partition-local limiter.
244pub struct LimitPartitionCollector {
245    chunks: Vec<DataChunk>,
246    limit: usize,
247    collected: usize,
248}
249
250impl PartitionCollector for LimitPartitionCollector {
251    type Fruit = (Vec<DataChunk>, usize);
252
253    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
254        if self.collected >= self.limit {
255            return Ok(());
256        }
257
258        let take = (self.limit - self.collected).min(chunk.len());
259        if take < chunk.len() {
260            self.chunks.push(chunk.slice(0, take));
261        } else {
262            self.chunks.push(chunk.clone());
263        }
264        self.collected += take;
265
266        Ok(())
267    }
268
269    fn harvest(self) -> (Vec<DataChunk>, usize) {
270        (self.chunks, self.collected)
271    }
272}
273
274/// Collects statistics (count, sum, min, max) for a column.
275#[derive(Debug, Clone)]
276pub struct StatsCollector {
277    column_idx: usize,
278}
279
280impl StatsCollector {
281    /// Creates a collector that computes statistics for the given column.
282    #[must_use]
283    pub fn new(column_idx: usize) -> Self {
284        Self { column_idx }
285    }
286}
287
288/// Statistics result from [`StatsCollector`].
289#[derive(Debug, Clone, Default)]
290pub struct CollectorStats {
291    /// Number of non-null values.
292    pub count: u64,
293    /// Sum of values (if numeric).
294    pub sum: f64,
295    /// Minimum value (if ordered).
296    pub min: Option<f64>,
297    /// Maximum value (if ordered).
298    pub max: Option<f64>,
299}
300
301impl CollectorStats {
302    /// Merges another stats into this one.
303    pub fn merge(&mut self, other: CollectorStats) {
304        self.count += other.count;
305        self.sum += other.sum;
306        self.min = match (self.min, other.min) {
307            (Some(a), Some(b)) => Some(a.min(b)),
308            (Some(v), None) | (None, Some(v)) => Some(v),
309            (None, None) => None,
310        };
311        self.max = match (self.max, other.max) {
312            (Some(a), Some(b)) => Some(a.max(b)),
313            (Some(v), None) | (None, Some(v)) => Some(v),
314            (None, None) => None,
315        };
316    }
317
318    /// Computes the average (mean) value.
319    #[must_use]
320    pub fn avg(&self) -> Option<f64> {
321        if self.count > 0 {
322            Some(self.sum / self.count as f64)
323        } else {
324            None
325        }
326    }
327}
328
329impl Collector for StatsCollector {
330    type Fruit = CollectorStats;
331    type PartitionCollector = StatsPartitionCollector;
332
333    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
334        StatsPartitionCollector {
335            column_idx: self.column_idx,
336            stats: CollectorStats::default(),
337        }
338    }
339
340    fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
341        let mut result = CollectorStats::default();
342        for fruit in fruits {
343            result.merge(fruit);
344        }
345        result
346    }
347}
348
349/// Partition-local stats collector.
350pub struct StatsPartitionCollector {
351    column_idx: usize,
352    stats: CollectorStats,
353}
354
355impl PartitionCollector for StatsPartitionCollector {
356    type Fruit = CollectorStats;
357
358    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
359        let column = chunk.column(self.column_idx).ok_or_else(|| {
360            OperatorError::ColumnNotFound(format!(
361                "column index {} out of bounds (width={})",
362                self.column_idx,
363                chunk.column_count()
364            ))
365        })?;
366
367        for i in 0..chunk.len() {
368            // Try typed access first (for specialized vectors), then fall back to generic
369            let val = if let Some(f) = column.get_float64(i) {
370                Some(f)
371            } else if let Some(i) = column.get_int64(i) {
372                Some(i as f64)
373            } else if let Some(value) = column.get_value(i) {
374                // Handle Generic vectors - extract numeric value
375                match value {
376                    grafeo_common::types::Value::Int64(i) => Some(i as f64),
377                    grafeo_common::types::Value::Float64(f) => Some(f),
378                    _ => None,
379                }
380            } else {
381                None
382            };
383
384            if let Some(v) = val {
385                self.stats.count += 1;
386                self.stats.sum += v;
387                self.stats.min = Some(match self.stats.min {
388                    Some(m) => m.min(v),
389                    None => v,
390                });
391                self.stats.max = Some(match self.stats.max {
392                    Some(m) => m.max(v),
393                    None => v,
394                });
395            }
396        }
397
398        Ok(())
399    }
400
401    fn harvest(self) -> CollectorStats {
402        self.stats
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use crate::execution::ValueVector;
410    use grafeo_common::types::Value;
411
412    // reason: test sizes are small, fit i64
413    #[allow(clippy::cast_possible_wrap)]
414    fn make_test_chunk(size: usize) -> DataChunk {
415        let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
416        let column = ValueVector::from_values(&values);
417        DataChunk::new(vec![column])
418    }
419
420    #[test]
421    fn test_count_collector() {
422        let collector = CountCollector;
423
424        let mut pc = collector.for_partition(0);
425        pc.collect(&make_test_chunk(10)).unwrap();
426        pc.collect(&make_test_chunk(5)).unwrap();
427        let count1 = pc.harvest();
428
429        let mut pc2 = collector.for_partition(1);
430        pc2.collect(&make_test_chunk(7)).unwrap();
431        let count2 = pc2.harvest();
432
433        let total = collector.merge(vec![count1, count2]);
434        assert_eq!(total, 22);
435    }
436
437    #[test]
438    fn test_materialize_collector() {
439        let collector = MaterializeCollector;
440
441        let mut pc = collector.for_partition(0);
442        pc.collect(&make_test_chunk(10)).unwrap();
443        pc.collect(&make_test_chunk(5)).unwrap();
444        let chunks1 = pc.harvest();
445
446        let mut pc2 = collector.for_partition(1);
447        pc2.collect(&make_test_chunk(7)).unwrap();
448        let chunks2 = pc2.harvest();
449
450        let result = collector.merge(vec![chunks1, chunks2]);
451        assert_eq!(result.len(), 3);
452        assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
453    }
454
455    #[test]
456    fn test_limit_collector() {
457        let collector = LimitCollector::new(12);
458
459        let mut pc = collector.for_partition(0);
460        pc.collect(&make_test_chunk(10)).unwrap();
461        pc.collect(&make_test_chunk(5)).unwrap(); // Only 2 more should be taken
462        let result1 = pc.harvest();
463
464        let mut pc2 = collector.for_partition(1);
465        pc2.collect(&make_test_chunk(20)).unwrap();
466        let result2 = pc2.harvest();
467
468        let (chunks, total) = collector.merge(vec![result1, result2]);
469        assert_eq!(total, 12);
470
471        let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
472        assert_eq!(actual_rows, 12);
473    }
474
475    #[test]
476    fn test_stats_collector() {
477        let collector = StatsCollector::new(0);
478
479        let mut pc = collector.for_partition(0);
480
481        // Create chunk with values 0..10
482        let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
483        let column = ValueVector::from_values(&values);
484        let chunk = DataChunk::new(vec![column]);
485
486        pc.collect(&chunk).unwrap();
487        let stats = pc.harvest();
488
489        assert_eq!(stats.count, 10);
490        assert!((stats.sum - 45.0).abs() < 0.001); // 0+1+2+...+9 = 45
491        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
492        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
493        assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
494    }
495
496    #[test]
497    fn test_stats_merge() {
498        let collector = StatsCollector::new(0);
499
500        // Partition 1: values 0..5
501        let mut pc1 = collector.for_partition(0);
502        let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
503        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
504        pc1.collect(&chunk1).unwrap();
505
506        // Partition 2: values 5..10
507        let mut pc2 = collector.for_partition(1);
508        let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
509        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
510        pc2.collect(&chunk2).unwrap();
511
512        let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
513
514        assert_eq!(stats.count, 10);
515        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
516        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
517    }
518}