Skip to main content

grafeo_core/execution/
collector.rs

1//! Generic collector trait for parallel aggregation.
2//!
3//! Collectors provide a clean separation between what data to aggregate
4//! and how to execute the aggregation in parallel. The pattern is inspired
5//! by Tantivy's collector architecture.
6//!
7//! # Pattern
8//!
9//! 1. Create partition-local collectors (one per worker thread)
10//! 2. Each collector processes its partition independently (no shared state)
11//! 3. Merge all partition results into a final result
12//!
13//! # Example
14//!
15//! ```ignore
16//! use grafeo_core::execution::collector::{Collector, PartitionCollector, CountCollector};
17//!
18//! let collector = CountCollector;
19//!
20//! // In parallel execution:
21//! let mut partition_collectors: Vec<_> = (0..4)
22//!     .map(|id| collector.for_partition(id))
23//!     .collect();
24//!
25//! // Each partition processes its chunks
26//! for (partition, chunks) in partitions.into_iter().enumerate() {
27//!     for chunk in chunks {
28//!         partition_collectors[partition].collect(&chunk)?;
29//!     }
30//! }
31//!
32//! // Merge results
33//! let fruits: Vec<_> = partition_collectors.into_iter()
34//!     .map(|c| c.harvest())
35//!     .collect();
36//! let total = collector.merge(fruits);
37//! ```
38
39use super::chunk::DataChunk;
40use super::operators::OperatorError;
41
42/// A collector that aggregates results from parallel execution.
43///
44/// Pattern: Create partition-local collectors, process independently,
45/// then merge results. No shared mutable state during collection.
46pub trait Collector: Sync {
47    /// Final result type after merging all partitions.
48    type Fruit: Send;
49
50    /// Partition-local collector type.
51    type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
52
53    /// Creates a collector for a single partition (called per-thread).
54    fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
55
56    /// Merges results from all partitions (called once at the end).
57    fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
58}
59
60/// Per-partition collector - processes chunks locally.
61///
62/// Each partition collector is created by [`Collector::for_partition`]
63/// and processes data independently. This enables lock-free parallel
64/// execution.
65pub trait PartitionCollector: Send {
66    /// Result type produced by this partition.
67    type Fruit: Send;
68
69    /// Processes a batch of data.
70    ///
71    /// Called repeatedly with chunks from this partition.
72    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
73
74    /// Finalizes and returns the result for this partition.
75    ///
76    /// Called once after all chunks have been processed.
77    fn harvest(self) -> Self::Fruit;
78}
79
80// ============================================================================
81// Built-in Collectors
82// ============================================================================
83
84/// Counts rows across all partitions.
85///
86/// # Example
87///
88/// ```ignore
89/// use grafeo_core::execution::collector::{Collector, CountCollector};
90///
91/// let collector = CountCollector;
92/// let mut pc = collector.for_partition(0);
93///
94/// pc.collect(&chunk1)?;
95/// pc.collect(&chunk2)?;
96///
97/// let count = pc.harvest();
98/// ```
99#[derive(Debug, Clone, Copy, Default)]
100pub struct CountCollector;
101
102impl Collector for CountCollector {
103    type Fruit = u64;
104    type PartitionCollector = CountPartitionCollector;
105
106    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
107        CountPartitionCollector { count: 0 }
108    }
109
110    fn merge(&self, fruits: Vec<u64>) -> u64 {
111        fruits.into_iter().sum()
112    }
113}
114
115/// Partition-local counter.
116pub struct CountPartitionCollector {
117    count: u64,
118}
119
120impl PartitionCollector for CountPartitionCollector {
121    type Fruit = u64;
122
123    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
124        self.count += chunk.len() as u64;
125        Ok(())
126    }
127
128    fn harvest(self) -> u64 {
129        self.count
130    }
131}
132
133/// Collects all chunks (materializes the entire result).
134///
135/// Use this when you need all the data, not just an aggregate.
136/// Be careful with large datasets - this can consume significant memory.
137#[derive(Debug, Clone, Default)]
138pub struct MaterializeCollector;
139
140impl Collector for MaterializeCollector {
141    type Fruit = Vec<DataChunk>;
142    type PartitionCollector = MaterializePartitionCollector;
143
144    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
145        MaterializePartitionCollector { chunks: Vec::new() }
146    }
147
148    fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
149        let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
150        let mut result = Vec::with_capacity(total_chunks);
151        for fruit in &mut fruits {
152            result.append(fruit);
153        }
154        result
155    }
156}
157
158/// Partition-local materializer.
159pub struct MaterializePartitionCollector {
160    chunks: Vec<DataChunk>,
161}
162
163impl PartitionCollector for MaterializePartitionCollector {
164    type Fruit = Vec<DataChunk>;
165
166    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
167        self.chunks.push(chunk.clone());
168        Ok(())
169    }
170
171    fn harvest(self) -> Vec<DataChunk> {
172        self.chunks
173    }
174}
175
176/// Collects first N rows across all partitions.
177///
178/// Stops collecting once the limit is reached (per partition).
179/// Final merge ensures exactly `limit` rows are returned.
180#[derive(Debug, Clone)]
181pub struct LimitCollector {
182    limit: usize,
183}
184
185impl LimitCollector {
186    /// Creates a collector that limits output to `limit` rows.
187    #[must_use]
188    pub fn new(limit: usize) -> Self {
189        Self { limit }
190    }
191}
192
193impl Collector for LimitCollector {
194    type Fruit = (Vec<DataChunk>, usize);
195    type PartitionCollector = LimitPartitionCollector;
196
197    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
198        LimitPartitionCollector {
199            chunks: Vec::new(),
200            limit: self.limit,
201            collected: 0,
202        }
203    }
204
205    fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
206        let mut result = Vec::new();
207        let mut total = 0;
208
209        for (chunks, _) in fruits {
210            for chunk in chunks {
211                if total >= self.limit {
212                    break;
213                }
214                let take = (self.limit - total).min(chunk.len());
215                if take < chunk.len() {
216                    result.push(chunk.slice(0, take));
217                } else {
218                    result.push(chunk);
219                }
220                total += take;
221            }
222            if total >= self.limit {
223                break;
224            }
225        }
226
227        (result, total)
228    }
229}
230
231/// Partition-local limiter.
232pub struct LimitPartitionCollector {
233    chunks: Vec<DataChunk>,
234    limit: usize,
235    collected: usize,
236}
237
238impl PartitionCollector for LimitPartitionCollector {
239    type Fruit = (Vec<DataChunk>, usize);
240
241    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
242        if self.collected >= self.limit {
243            return Ok(());
244        }
245
246        let take = (self.limit - self.collected).min(chunk.len());
247        if take < chunk.len() {
248            self.chunks.push(chunk.slice(0, take));
249        } else {
250            self.chunks.push(chunk.clone());
251        }
252        self.collected += take;
253
254        Ok(())
255    }
256
257    fn harvest(self) -> (Vec<DataChunk>, usize) {
258        (self.chunks, self.collected)
259    }
260}
261
262/// Collects statistics (count, sum, min, max) for a column.
263#[derive(Debug, Clone)]
264pub struct StatsCollector {
265    column_idx: usize,
266}
267
268impl StatsCollector {
269    /// Creates a collector that computes statistics for the given column.
270    #[must_use]
271    pub fn new(column_idx: usize) -> Self {
272        Self { column_idx }
273    }
274}
275
276/// Statistics result from [`StatsCollector`].
277#[derive(Debug, Clone, Default)]
278pub struct CollectorStats {
279    /// Number of non-null values.
280    pub count: u64,
281    /// Sum of values (if numeric).
282    pub sum: f64,
283    /// Minimum value (if ordered).
284    pub min: Option<f64>,
285    /// Maximum value (if ordered).
286    pub max: Option<f64>,
287}
288
289impl CollectorStats {
290    /// Merges another stats into this one.
291    pub fn merge(&mut self, other: CollectorStats) {
292        self.count += other.count;
293        self.sum += other.sum;
294        self.min = match (self.min, other.min) {
295            (Some(a), Some(b)) => Some(a.min(b)),
296            (Some(v), None) | (None, Some(v)) => Some(v),
297            (None, None) => None,
298        };
299        self.max = match (self.max, other.max) {
300            (Some(a), Some(b)) => Some(a.max(b)),
301            (Some(v), None) | (None, Some(v)) => Some(v),
302            (None, None) => None,
303        };
304    }
305
306    /// Computes the average (mean) value.
307    #[must_use]
308    pub fn avg(&self) -> Option<f64> {
309        if self.count > 0 {
310            Some(self.sum / self.count as f64)
311        } else {
312            None
313        }
314    }
315}
316
317impl Collector for StatsCollector {
318    type Fruit = CollectorStats;
319    type PartitionCollector = StatsPartitionCollector;
320
321    fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
322        StatsPartitionCollector {
323            column_idx: self.column_idx,
324            stats: CollectorStats::default(),
325        }
326    }
327
328    fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
329        let mut result = CollectorStats::default();
330        for fruit in fruits {
331            result.merge(fruit);
332        }
333        result
334    }
335}
336
337/// Partition-local stats collector.
338pub struct StatsPartitionCollector {
339    column_idx: usize,
340    stats: CollectorStats,
341}
342
343impl PartitionCollector for StatsPartitionCollector {
344    type Fruit = CollectorStats;
345
346    fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
347        let column = chunk.column(self.column_idx).ok_or_else(|| {
348            OperatorError::ColumnNotFound(format!(
349                "column index {} out of bounds (width={})",
350                self.column_idx,
351                chunk.column_count()
352            ))
353        })?;
354
355        for i in 0..chunk.len() {
356            // Try typed access first (for specialized vectors), then fall back to generic
357            let val = if let Some(f) = column.get_float64(i) {
358                Some(f)
359            } else if let Some(i) = column.get_int64(i) {
360                Some(i as f64)
361            } else if let Some(value) = column.get_value(i) {
362                // Handle Generic vectors - extract numeric value
363                match value {
364                    grafeo_common::types::Value::Int64(i) => Some(i as f64),
365                    grafeo_common::types::Value::Float64(f) => Some(f),
366                    _ => None,
367                }
368            } else {
369                None
370            };
371
372            if let Some(v) = val {
373                self.stats.count += 1;
374                self.stats.sum += v;
375                self.stats.min = Some(match self.stats.min {
376                    Some(m) => m.min(v),
377                    None => v,
378                });
379                self.stats.max = Some(match self.stats.max {
380                    Some(m) => m.max(v),
381                    None => v,
382                });
383            }
384        }
385
386        Ok(())
387    }
388
389    fn harvest(self) -> CollectorStats {
390        self.stats
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::execution::ValueVector;
398    use grafeo_common::types::Value;
399
400    fn make_test_chunk(size: usize) -> DataChunk {
401        let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
402        let column = ValueVector::from_values(&values);
403        DataChunk::new(vec![column])
404    }
405
406    #[test]
407    fn test_count_collector() {
408        let collector = CountCollector;
409
410        let mut pc = collector.for_partition(0);
411        pc.collect(&make_test_chunk(10)).unwrap();
412        pc.collect(&make_test_chunk(5)).unwrap();
413        let count1 = pc.harvest();
414
415        let mut pc2 = collector.for_partition(1);
416        pc2.collect(&make_test_chunk(7)).unwrap();
417        let count2 = pc2.harvest();
418
419        let total = collector.merge(vec![count1, count2]);
420        assert_eq!(total, 22);
421    }
422
423    #[test]
424    fn test_materialize_collector() {
425        let collector = MaterializeCollector;
426
427        let mut pc = collector.for_partition(0);
428        pc.collect(&make_test_chunk(10)).unwrap();
429        pc.collect(&make_test_chunk(5)).unwrap();
430        let chunks1 = pc.harvest();
431
432        let mut pc2 = collector.for_partition(1);
433        pc2.collect(&make_test_chunk(7)).unwrap();
434        let chunks2 = pc2.harvest();
435
436        let result = collector.merge(vec![chunks1, chunks2]);
437        assert_eq!(result.len(), 3);
438        assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
439    }
440
441    #[test]
442    fn test_limit_collector() {
443        let collector = LimitCollector::new(12);
444
445        let mut pc = collector.for_partition(0);
446        pc.collect(&make_test_chunk(10)).unwrap();
447        pc.collect(&make_test_chunk(5)).unwrap(); // Only 2 more should be taken
448        let result1 = pc.harvest();
449
450        let mut pc2 = collector.for_partition(1);
451        pc2.collect(&make_test_chunk(20)).unwrap();
452        let result2 = pc2.harvest();
453
454        let (chunks, total) = collector.merge(vec![result1, result2]);
455        assert_eq!(total, 12);
456
457        let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
458        assert_eq!(actual_rows, 12);
459    }
460
461    #[test]
462    fn test_stats_collector() {
463        let collector = StatsCollector::new(0);
464
465        let mut pc = collector.for_partition(0);
466
467        // Create chunk with values 0..10
468        let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
469        let column = ValueVector::from_values(&values);
470        let chunk = DataChunk::new(vec![column]);
471
472        pc.collect(&chunk).unwrap();
473        let stats = pc.harvest();
474
475        assert_eq!(stats.count, 10);
476        assert!((stats.sum - 45.0).abs() < 0.001); // 0+1+2+...+9 = 45
477        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
478        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
479        assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
480    }
481
482    #[test]
483    fn test_stats_merge() {
484        let collector = StatsCollector::new(0);
485
486        // Partition 1: values 0..5
487        let mut pc1 = collector.for_partition(0);
488        let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
489        let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
490        pc1.collect(&chunk1).unwrap();
491
492        // Partition 2: values 5..10
493        let mut pc2 = collector.for_partition(1);
494        let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
495        let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
496        pc2.collect(&chunk2).unwrap();
497
498        let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
499
500        assert_eq!(stats.count, 10);
501        assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
502        assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
503    }
504}