1use super::chunk::DataChunk;
40use super::operators::OperatorError;
41
42pub trait Collector: Sync {
47 type Fruit: Send;
49
50 type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
52
53 fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
55
56 fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
58}
59
60pub trait PartitionCollector: Send {
66 type Fruit: Send;
68
69 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
73
74 fn harvest(self) -> Self::Fruit;
78}
79
80#[derive(Debug, Clone, Copy, Default)]
100pub struct CountCollector;
101
102impl Collector for CountCollector {
103 type Fruit = u64;
104 type PartitionCollector = CountPartitionCollector;
105
106 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
107 CountPartitionCollector { count: 0 }
108 }
109
110 fn merge(&self, fruits: Vec<u64>) -> u64 {
111 fruits.into_iter().sum()
112 }
113}
114
115pub struct CountPartitionCollector {
117 count: u64,
118}
119
120impl PartitionCollector for CountPartitionCollector {
121 type Fruit = u64;
122
123 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
124 self.count += chunk.len() as u64;
125 Ok(())
126 }
127
128 fn harvest(self) -> u64 {
129 self.count
130 }
131}
132
133#[derive(Debug, Clone, Default)]
138pub struct MaterializeCollector;
139
140impl Collector for MaterializeCollector {
141 type Fruit = Vec<DataChunk>;
142 type PartitionCollector = MaterializePartitionCollector;
143
144 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
145 MaterializePartitionCollector { chunks: Vec::new() }
146 }
147
148 fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
149 let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
150 let mut result = Vec::with_capacity(total_chunks);
151 for fruit in &mut fruits {
152 result.append(fruit);
153 }
154 result
155 }
156}
157
158pub struct MaterializePartitionCollector {
160 chunks: Vec<DataChunk>,
161}
162
163impl PartitionCollector for MaterializePartitionCollector {
164 type Fruit = Vec<DataChunk>;
165
166 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
167 self.chunks.push(chunk.clone());
168 Ok(())
169 }
170
171 fn harvest(self) -> Vec<DataChunk> {
172 self.chunks
173 }
174}
175
176#[derive(Debug, Clone)]
181pub struct LimitCollector {
182 limit: usize,
183}
184
185impl LimitCollector {
186 #[must_use]
188 pub fn new(limit: usize) -> Self {
189 Self { limit }
190 }
191}
192
193impl Collector for LimitCollector {
194 type Fruit = (Vec<DataChunk>, usize);
195 type PartitionCollector = LimitPartitionCollector;
196
197 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
198 LimitPartitionCollector {
199 chunks: Vec::new(),
200 limit: self.limit,
201 collected: 0,
202 }
203 }
204
205 fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
206 let mut result = Vec::new();
207 let mut total = 0;
208
209 for (chunks, _) in fruits {
210 for chunk in chunks {
211 if total >= self.limit {
212 break;
213 }
214 let take = (self.limit - total).min(chunk.len());
215 if take < chunk.len() {
216 result.push(chunk.slice(0, take));
217 } else {
218 result.push(chunk);
219 }
220 total += take;
221 }
222 if total >= self.limit {
223 break;
224 }
225 }
226
227 (result, total)
228 }
229}
230
231pub struct LimitPartitionCollector {
233 chunks: Vec<DataChunk>,
234 limit: usize,
235 collected: usize,
236}
237
238impl PartitionCollector for LimitPartitionCollector {
239 type Fruit = (Vec<DataChunk>, usize);
240
241 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
242 if self.collected >= self.limit {
243 return Ok(());
244 }
245
246 let take = (self.limit - self.collected).min(chunk.len());
247 if take < chunk.len() {
248 self.chunks.push(chunk.slice(0, take));
249 } else {
250 self.chunks.push(chunk.clone());
251 }
252 self.collected += take;
253
254 Ok(())
255 }
256
257 fn harvest(self) -> (Vec<DataChunk>, usize) {
258 (self.chunks, self.collected)
259 }
260}
261
262#[derive(Debug, Clone)]
264pub struct StatsCollector {
265 column_idx: usize,
266}
267
268impl StatsCollector {
269 #[must_use]
271 pub fn new(column_idx: usize) -> Self {
272 Self { column_idx }
273 }
274}
275
276#[derive(Debug, Clone, Default)]
278pub struct CollectorStats {
279 pub count: u64,
281 pub sum: f64,
283 pub min: Option<f64>,
285 pub max: Option<f64>,
287}
288
289impl CollectorStats {
290 pub fn merge(&mut self, other: CollectorStats) {
292 self.count += other.count;
293 self.sum += other.sum;
294 self.min = match (self.min, other.min) {
295 (Some(a), Some(b)) => Some(a.min(b)),
296 (Some(v), None) | (None, Some(v)) => Some(v),
297 (None, None) => None,
298 };
299 self.max = match (self.max, other.max) {
300 (Some(a), Some(b)) => Some(a.max(b)),
301 (Some(v), None) | (None, Some(v)) => Some(v),
302 (None, None) => None,
303 };
304 }
305
306 #[must_use]
308 pub fn avg(&self) -> Option<f64> {
309 if self.count > 0 {
310 Some(self.sum / self.count as f64)
311 } else {
312 None
313 }
314 }
315}
316
317impl Collector for StatsCollector {
318 type Fruit = CollectorStats;
319 type PartitionCollector = StatsPartitionCollector;
320
321 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
322 StatsPartitionCollector {
323 column_idx: self.column_idx,
324 stats: CollectorStats::default(),
325 }
326 }
327
328 fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
329 let mut result = CollectorStats::default();
330 for fruit in fruits {
331 result.merge(fruit);
332 }
333 result
334 }
335}
336
337pub struct StatsPartitionCollector {
339 column_idx: usize,
340 stats: CollectorStats,
341}
342
343impl PartitionCollector for StatsPartitionCollector {
344 type Fruit = CollectorStats;
345
346 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
347 let column = chunk.column(self.column_idx).ok_or_else(|| {
348 OperatorError::ColumnNotFound(format!(
349 "column index {} out of bounds (width={})",
350 self.column_idx,
351 chunk.column_count()
352 ))
353 })?;
354
355 for i in 0..chunk.len() {
356 let val = if let Some(f) = column.get_float64(i) {
358 Some(f)
359 } else if let Some(i) = column.get_int64(i) {
360 Some(i as f64)
361 } else if let Some(value) = column.get_value(i) {
362 match value {
364 grafeo_common::types::Value::Int64(i) => Some(i as f64),
365 grafeo_common::types::Value::Float64(f) => Some(f),
366 _ => None,
367 }
368 } else {
369 None
370 };
371
372 if let Some(v) = val {
373 self.stats.count += 1;
374 self.stats.sum += v;
375 self.stats.min = Some(match self.stats.min {
376 Some(m) => m.min(v),
377 None => v,
378 });
379 self.stats.max = Some(match self.stats.max {
380 Some(m) => m.max(v),
381 None => v,
382 });
383 }
384 }
385
386 Ok(())
387 }
388
389 fn harvest(self) -> CollectorStats {
390 self.stats
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::execution::ValueVector;
398 use grafeo_common::types::Value;
399
400 fn make_test_chunk(size: usize) -> DataChunk {
401 let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
402 let column = ValueVector::from_values(&values);
403 DataChunk::new(vec![column])
404 }
405
406 #[test]
407 fn test_count_collector() {
408 let collector = CountCollector;
409
410 let mut pc = collector.for_partition(0);
411 pc.collect(&make_test_chunk(10)).unwrap();
412 pc.collect(&make_test_chunk(5)).unwrap();
413 let count1 = pc.harvest();
414
415 let mut pc2 = collector.for_partition(1);
416 pc2.collect(&make_test_chunk(7)).unwrap();
417 let count2 = pc2.harvest();
418
419 let total = collector.merge(vec![count1, count2]);
420 assert_eq!(total, 22);
421 }
422
423 #[test]
424 fn test_materialize_collector() {
425 let collector = MaterializeCollector;
426
427 let mut pc = collector.for_partition(0);
428 pc.collect(&make_test_chunk(10)).unwrap();
429 pc.collect(&make_test_chunk(5)).unwrap();
430 let chunks1 = pc.harvest();
431
432 let mut pc2 = collector.for_partition(1);
433 pc2.collect(&make_test_chunk(7)).unwrap();
434 let chunks2 = pc2.harvest();
435
436 let result = collector.merge(vec![chunks1, chunks2]);
437 assert_eq!(result.len(), 3);
438 assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
439 }
440
441 #[test]
442 fn test_limit_collector() {
443 let collector = LimitCollector::new(12);
444
445 let mut pc = collector.for_partition(0);
446 pc.collect(&make_test_chunk(10)).unwrap();
447 pc.collect(&make_test_chunk(5)).unwrap(); let result1 = pc.harvest();
449
450 let mut pc2 = collector.for_partition(1);
451 pc2.collect(&make_test_chunk(20)).unwrap();
452 let result2 = pc2.harvest();
453
454 let (chunks, total) = collector.merge(vec![result1, result2]);
455 assert_eq!(total, 12);
456
457 let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
458 assert_eq!(actual_rows, 12);
459 }
460
461 #[test]
462 fn test_stats_collector() {
463 let collector = StatsCollector::new(0);
464
465 let mut pc = collector.for_partition(0);
466
467 let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
469 let column = ValueVector::from_values(&values);
470 let chunk = DataChunk::new(vec![column]);
471
472 pc.collect(&chunk).unwrap();
473 let stats = pc.harvest();
474
475 assert_eq!(stats.count, 10);
476 assert!((stats.sum - 45.0).abs() < 0.001); assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
478 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
479 assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
480 }
481
482 #[test]
483 fn test_stats_merge() {
484 let collector = StatsCollector::new(0);
485
486 let mut pc1 = collector.for_partition(0);
488 let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
489 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
490 pc1.collect(&chunk1).unwrap();
491
492 let mut pc2 = collector.for_partition(1);
494 let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
495 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
496 pc2.collect(&chunk2).unwrap();
497
498 let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
499
500 assert_eq!(stats.count, 10);
501 assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
502 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
503 }
504}