Skip to main content

grafeo_core/execution/
collector.rs

1//! Generic collector trait for parallel aggregation.
2//!
3//! Collectors provide a clean separation between what data to aggregate
4//! and how to execute the aggregation in parallel. The pattern is inspired
5//! by Tantivy's collector architecture.
6//!
7//! # Pattern
8//!
9//! 1. Create partition-local collectors (one per worker thread)
10//! 2. Each collector processes its partition independently (no shared state)
11//! 3. Merge all partition results into a final result
12//!
13//! # Example
14//!
15//! ```no_run
16//! use grafeo_core::execution::collector::{Collector, PartitionCollector, CountCollector};
17//! use grafeo_core::execution::DataChunk;
18//!
19//! # fn example(partitions: Vec<Vec<DataChunk>>) -> Result<(), grafeo_core::execution::operators::OperatorError> {
20//! let collector = CountCollector;
21//!
22//! // In parallel execution:
23//! let mut partition_collectors: Vec<_> = (0..4)
24//!     .map(|id| collector.for_partition(id))
25//!     .collect();
26//!
27//! // Each partition processes its chunks
28//! for (partition, chunks) in partitions.into_iter().enumerate() {
29//!     for chunk in chunks {
30//!         partition_collectors[partition].collect(&chunk)?;
31//!     }
32//! }
33//!
34//! // Merge results
35//! let fruits: Vec<_> = partition_collectors.into_iter()
36//!     .map(|c| c.harvest())
37//!     .collect();
38//! let total = collector.merge(fruits);
39//! # Ok(())
40//! # }
41//! ```
42
43use super::chunk::DataChunk;
44use super::operators::OperatorError;
45
46/// A collector that aggregates results from parallel execution.
47///
48/// Pattern: Create partition-local collectors, process independently,
49/// then merge results. No shared mutable state during collection.
50pub trait Collector: Sync {
51    /// Final result type after merging all partitions.
52    type Fruit: Send;
53
54    /// Partition-local collector type.
55    type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
56
57    /// Creates a collector for a single partition (called per-thread).
58    fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
59
60    /// Merges results from all partitions (called once at the end).
61    fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
62}
63
64/// Per-partition collector - processes chunks locally.
65///
66/// Each partition collector is created by [`Collector::for_partition`]
67/// and processes data independently. This enables lock-free parallel
68/// execution.
69pub trait PartitionCollector: Send {
70    /// Result type produced by this partition.
71    type Fruit: Send;
72
73    /// Processes a batch of data.
74    ///
75    /// Called repeatedly with chunks from this partition.
76    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
77
78    /// Finalizes and returns the result for this partition.
79    ///
80    /// Called once after all chunks have been processed.
81    fn harvest(self) -> Self::Fruit;
82}
83
84// ============================================================================
85// Built-in Collectors
86// ============================================================================
87
88/// Counts rows across all partitions.
89///
90/// # Example
91///
92/// ```no_run
93/// use grafeo_core::execution::collector::{Collector, PartitionCollector, CountCollector};
94/// use grafeo_core::execution::DataChunk;
95///
96/// # fn example(chunk1: DataChunk, chunk2: DataChunk) -> Result<(), grafeo_core::execution::operators::OperatorError> {
97/// let collector = CountCollector;
98/// let mut pc = collector.for_partition(0);
99///
100/// pc.collect(&chunk1)?;
101/// pc.collect(&chunk2)?;
102///
103/// let count = pc.harvest();
104/// # Ok(())
105/// # }
106/// ```
107#[derive(Debug, Clone, Copy, Default)]
108pub struct CountCollector;
109
110impl Collector for CountCollector {
111    type Fruit = u64;
112    type PartitionCollector = CountPartitionCollector;
113
114    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
115        CountPartitionCollector { count: 0 }
116    }
117
118    fn merge(&self, fruits: Vec<u64>) -> u64 {
119        fruits.into_iter().sum()
120    }
121}
122
123/// Partition-local counter.
124pub struct CountPartitionCollector {
125    count: u64,
126}
127
128impl PartitionCollector for CountPartitionCollector {
129    type Fruit = u64;
130
131    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
132        self.count += chunk.len() as u64;
133        Ok(())
134    }
135
136    fn harvest(self) -> u64 {
137        self.count
138    }
139}
140
141/// Collects all chunks (materializes the entire result).
142///
143/// Use this when you need all the data, not just an aggregate.
144/// Be careful with large datasets - this can consume significant memory.
145#[derive(Debug, Clone, Default)]
146pub struct MaterializeCollector;
147
148impl Collector for MaterializeCollector {
149    type Fruit = Vec<DataChunk>;
150    type PartitionCollector = MaterializePartitionCollector;
151
152    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
153        MaterializePartitionCollector { chunks: Vec::new() }
154    }
155
156    fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
157        let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
158        let mut result = Vec::with_capacity(total_chunks);
159        for fruit in &mut fruits {
160            result.append(fruit);
161        }
162        result
163    }
164}
165
166/// Partition-local materializer.
167pub struct MaterializePartitionCollector {
168    chunks: Vec<DataChunk>,
169}
170
171impl PartitionCollector for MaterializePartitionCollector {
172    type Fruit = Vec<DataChunk>;
173
174    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
175        self.chunks.push(chunk.clone());
176        Ok(())
177    }
178
179    fn harvest(self) -> Vec<DataChunk> {
180        self.chunks
181    }
182}
183
184/// Collects first N rows across all partitions.
185///
186/// Stops collecting once the limit is reached (per partition).
187/// Final merge ensures exactly `limit` rows are returned.
188#[derive(Debug, Clone)]
189pub struct LimitCollector {
190    limit: usize,
191}
192
193impl LimitCollector {
194    /// Creates a collector that limits output to `limit` rows.
195    #[must_use]
196    pub fn new(limit: usize) -> Self {
197        Self { limit }
198    }
199}
200
201impl Collector for LimitCollector {
202    type Fruit = (Vec<DataChunk>, usize);
203    type PartitionCollector = LimitPartitionCollector;
204
205    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
206        LimitPartitionCollector {
207            chunks: Vec::new(),
208            limit: self.limit,
209            collected: 0,
210        }
211    }
212
213    fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
214        let mut result = Vec::new();
215        let mut total = 0;
216
217        for (chunks, _) in fruits {
218            for chunk in chunks {
219                if total >= self.limit {
220                    break;
221                }
222                let take = (self.limit - total).min(chunk.len());
223                if take < chunk.len() {
224                    result.push(chunk.slice(0, take));
225                } else {
226                    result.push(chunk);
227                }
228                total += take;
229            }
230            if total >= self.limit {
231                break;
232            }
233        }
234
235        (result, total)
236    }
237}
238
239/// Partition-local limiter.
240pub struct LimitPartitionCollector {
241    chunks: Vec<DataChunk>,
242    limit: usize,
243    collected: usize,
244}
245
246impl PartitionCollector for LimitPartitionCollector {
247    type Fruit = (Vec<DataChunk>, usize);
248
249    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
250        if self.collected >= self.limit {
251            return Ok(());
252        }
253
254        let take = (self.limit - self.collected).min(chunk.len());
255        if take < chunk.len() {
256            self.chunks.push(chunk.slice(0, take));
257        } else {
258            self.chunks.push(chunk.clone());
259        }
260        self.collected += take;
261
262        Ok(())
263    }
264
265    fn harvest(self) -> (Vec<DataChunk>, usize) {
266        (self.chunks, self.collected)
267    }
268}
269
270/// Collects statistics (count, sum, min, max) for a column.
271#[derive(Debug, Clone)]
272pub struct StatsCollector {
273    column_idx: usize,
274}
275
276impl StatsCollector {
277    /// Creates a collector that computes statistics for the given column.
278    #[must_use]
279    pub fn new(column_idx: usize) -> Self {
280        Self { column_idx }
281    }
282}
283
284/// Statistics result from [`StatsCollector`].
285#[derive(Debug, Clone, Default)]
286pub struct CollectorStats {
287    /// Number of non-null values.
288    pub count: u64,
289    /// Sum of values (if numeric).
290    pub sum: f64,
291    /// Minimum value (if ordered).
292    pub min: Option<f64>,
293    /// Maximum value (if ordered).
294    pub max: Option<f64>,
295}
296
297impl CollectorStats {
298    /// Merges another stats into this one.
299    pub fn merge(&mut self, other: CollectorStats) {
300        self.count += other.count;
301        self.sum += other.sum;
302        self.min = match (self.min, other.min) {
303            (Some(a), Some(b)) => Some(a.min(b)),
304            (Some(v), None) | (None, Some(v)) => Some(v),
305            (None, None) => None,
306        };
307        self.max = match (self.max, other.max) {
308            (Some(a), Some(b)) => Some(a.max(b)),
309            (Some(v), None) | (None, Some(v)) => Some(v),
310            (None, None) => None,
311        };
312    }
313
314    /// Computes the average (mean) value.
315    #[must_use]
316    pub fn avg(&self) -> Option<f64> {
317        if self.count > 0 {
318            Some(self.sum / self.count as f64)
319        } else {
320            None
321        }
322    }
323}
324
325impl Collector for StatsCollector {
326    type Fruit = CollectorStats;
327    type PartitionCollector = StatsPartitionCollector;
328
329    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
330        StatsPartitionCollector {
331            column_idx: self.column_idx,
332            stats: CollectorStats::default(),
333        }
334    }
335
336    fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
337        let mut result = CollectorStats::default();
338        for fruit in fruits {
339            result.merge(fruit);
340        }
341        result
342    }
343}
344
345/// Partition-local stats collector.
346pub struct StatsPartitionCollector {
347    column_idx: usize,
348    stats: CollectorStats,
349}
350
351impl PartitionCollector for StatsPartitionCollector {
352    type Fruit = CollectorStats;
353
354    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
355        let column = chunk.column(self.column_idx).ok_or_else(|| {
356            OperatorError::ColumnNotFound(format!(
357                "column index {} out of bounds (width={})",
358                self.column_idx,
359                chunk.column_count()
360            ))
361        })?;
362
363        for i in 0..chunk.len() {
364            // Try typed access first (for specialized vectors), then fall back to generic
365            let val = if let Some(f) = column.get_float64(i) {
366                Some(f)
367            } else if let Some(i) = column.get_int64(i) {
368                Some(i as f64)
369            } else if let Some(value) = column.get_value(i) {
370                // Handle Generic vectors - extract numeric value
371                match value {
372                    grafeo_common::types::Value::Int64(i) => Some(i as f64),
373                    grafeo_common::types::Value::Float64(f) => Some(f),
374                    _ => None,
375                }
376            } else {
377                None
378            };
379
380            if let Some(v) = val {
381                self.stats.count += 1;
382                self.stats.sum += v;
383                self.stats.min = Some(match self.stats.min {
384                    Some(m) => m.min(v),
385                    None => v,
386                });
387                self.stats.max = Some(match self.stats.max {
388                    Some(m) => m.max(v),
389                    None => v,
390                });
391            }
392        }
393
394        Ok(())
395    }
396
397    fn harvest(self) -> CollectorStats {
398        self.stats
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::execution::ValueVector;
406    use grafeo_common::types::Value;
407
408    fn make_test_chunk(size: usize) -> DataChunk {
409        let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
410        let column = ValueVector::from_values(&values);
411        DataChunk::new(vec![column])
412    }
413
414    #[test]
415    fn test_count_collector() {
416        let collector = CountCollector;
417
418        let mut pc = collector.for_partition(0);
419        pc.collect(&make_test_chunk(10)).unwrap();
420        pc.collect(&make_test_chunk(5)).unwrap();
421        let count1 = pc.harvest();
422
423        let mut pc2 = collector.for_partition(1);
424        pc2.collect(&make_test_chunk(7)).unwrap();
425        let count2 = pc2.harvest();
426
427        let total = collector.merge(vec![count1, count2]);
428        assert_eq!(total, 22);
429    }
430
431    #[test]
432    fn test_materialize_collector() {
433        let collector = MaterializeCollector;
434
435        let mut pc = collector.for_partition(0);
436        pc.collect(&make_test_chunk(10)).unwrap();
437        pc.collect(&make_test_chunk(5)).unwrap();
438        let chunks1 = pc.harvest();
439
440        let mut pc2 = collector.for_partition(1);
441        pc2.collect(&make_test_chunk(7)).unwrap();
442        let chunks2 = pc2.harvest();
443
444        let result = collector.merge(vec![chunks1, chunks2]);
445        assert_eq!(result.len(), 3);
446        assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
447    }
448
449    #[test]
450    fn test_limit_collector() {
451        let collector = LimitCollector::new(12);
452
453        let mut pc = collector.for_partition(0);
454        pc.collect(&make_test_chunk(10)).unwrap();
455        pc.collect(&make_test_chunk(5)).unwrap(); // Only 2 more should be taken
456        let result1 = pc.harvest();
457
458        let mut pc2 = collector.for_partition(1);
459        pc2.collect(&make_test_chunk(20)).unwrap();
460        let result2 = pc2.harvest();
461
462        let (chunks, total) = collector.merge(vec![result1, result2]);
463        assert_eq!(total, 12);
464
465        let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
466        assert_eq!(actual_rows, 12);
467    }
468
469    #[test]
470    fn test_stats_collector() {
471        let collector = StatsCollector::new(0);
472
473        let mut pc = collector.for_partition(0);
474
475        // Create chunk with values 0..10
476        let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
477        let column = ValueVector::from_values(&values);
478        let chunk = DataChunk::new(vec![column]);
479
480        pc.collect(&chunk).unwrap();
481        let stats = pc.harvest();
482
483        assert_eq!(stats.count, 10);
484        assert!((stats.sum - 45.0).abs() < 0.001); // 0+1+2+...+9 = 45
485        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
486        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
487        assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
488    }
489
490    #[test]
491    fn test_stats_merge() {
492        let collector = StatsCollector::new(0);
493
494        // Partition 1: values 0..5
495        let mut pc1 = collector.for_partition(0);
496        let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
497        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
498        pc1.collect(&chunk1).unwrap();
499
500        // Partition 2: values 5..10
501        let mut pc2 = collector.for_partition(1);
502        let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
503        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
504        pc2.collect(&chunk2).unwrap();
505
506        let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
507
508        assert_eq!(stats.count, 10);
509        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
510        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
511    }
512}