1use super::chunk::DataChunk;
44use super::operators::OperatorError;
45
46pub trait Collector: Sync {
51 type Fruit: Send;
53
54 type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
56
57 fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
59
60 fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
62}
63
64pub trait PartitionCollector: Send {
70 type Fruit: Send;
72
73 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
81
82 fn harvest(self) -> Self::Fruit;
86}
87
88#[derive(Debug, Clone, Copy, Default)]
112pub struct CountCollector;
113
114impl Collector for CountCollector {
115 type Fruit = u64;
116 type PartitionCollector = CountPartitionCollector;
117
118 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
119 CountPartitionCollector { count: 0 }
120 }
121
122 fn merge(&self, fruits: Vec<u64>) -> u64 {
123 fruits.into_iter().sum()
124 }
125}
126
127pub struct CountPartitionCollector {
129 count: u64,
130}
131
132impl PartitionCollector for CountPartitionCollector {
133 type Fruit = u64;
134
135 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
136 self.count += chunk.len() as u64;
137 Ok(())
138 }
139
140 fn harvest(self) -> u64 {
141 self.count
142 }
143}
144
145#[derive(Debug, Clone, Default)]
150pub struct MaterializeCollector;
151
152impl Collector for MaterializeCollector {
153 type Fruit = Vec<DataChunk>;
154 type PartitionCollector = MaterializePartitionCollector;
155
156 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
157 MaterializePartitionCollector { chunks: Vec::new() }
158 }
159
160 fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
161 let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
162 let mut result = Vec::with_capacity(total_chunks);
163 for fruit in &mut fruits {
164 result.append(fruit);
165 }
166 result
167 }
168}
169
170pub struct MaterializePartitionCollector {
172 chunks: Vec<DataChunk>,
173}
174
175impl PartitionCollector for MaterializePartitionCollector {
176 type Fruit = Vec<DataChunk>;
177
178 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
179 self.chunks.push(chunk.clone());
180 Ok(())
181 }
182
183 fn harvest(self) -> Vec<DataChunk> {
184 self.chunks
185 }
186}
187
188#[derive(Debug, Clone)]
193pub struct LimitCollector {
194 limit: usize,
195}
196
197impl LimitCollector {
198 #[must_use]
200 pub fn new(limit: usize) -> Self {
201 Self { limit }
202 }
203}
204
205impl Collector for LimitCollector {
206 type Fruit = (Vec<DataChunk>, usize);
207 type PartitionCollector = LimitPartitionCollector;
208
209 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
210 LimitPartitionCollector {
211 chunks: Vec::new(),
212 limit: self.limit,
213 collected: 0,
214 }
215 }
216
217 fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
218 let mut result = Vec::new();
219 let mut total = 0;
220
221 for (chunks, _) in fruits {
222 for chunk in chunks {
223 if total >= self.limit {
224 break;
225 }
226 let take = (self.limit - total).min(chunk.len());
227 if take < chunk.len() {
228 result.push(chunk.slice(0, take));
229 } else {
230 result.push(chunk);
231 }
232 total += take;
233 }
234 if total >= self.limit {
235 break;
236 }
237 }
238
239 (result, total)
240 }
241}
242
243pub struct LimitPartitionCollector {
245 chunks: Vec<DataChunk>,
246 limit: usize,
247 collected: usize,
248}
249
250impl PartitionCollector for LimitPartitionCollector {
251 type Fruit = (Vec<DataChunk>, usize);
252
253 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
254 if self.collected >= self.limit {
255 return Ok(());
256 }
257
258 let take = (self.limit - self.collected).min(chunk.len());
259 if take < chunk.len() {
260 self.chunks.push(chunk.slice(0, take));
261 } else {
262 self.chunks.push(chunk.clone());
263 }
264 self.collected += take;
265
266 Ok(())
267 }
268
269 fn harvest(self) -> (Vec<DataChunk>, usize) {
270 (self.chunks, self.collected)
271 }
272}
273
274#[derive(Debug, Clone)]
276pub struct StatsCollector {
277 column_idx: usize,
278}
279
280impl StatsCollector {
281 #[must_use]
283 pub fn new(column_idx: usize) -> Self {
284 Self { column_idx }
285 }
286}
287
288#[derive(Debug, Clone, Default)]
290pub struct CollectorStats {
291 pub count: u64,
293 pub sum: f64,
295 pub min: Option<f64>,
297 pub max: Option<f64>,
299}
300
301impl CollectorStats {
302 pub fn merge(&mut self, other: CollectorStats) {
304 self.count += other.count;
305 self.sum += other.sum;
306 self.min = match (self.min, other.min) {
307 (Some(a), Some(b)) => Some(a.min(b)),
308 (Some(v), None) | (None, Some(v)) => Some(v),
309 (None, None) => None,
310 };
311 self.max = match (self.max, other.max) {
312 (Some(a), Some(b)) => Some(a.max(b)),
313 (Some(v), None) | (None, Some(v)) => Some(v),
314 (None, None) => None,
315 };
316 }
317
318 #[must_use]
320 pub fn avg(&self) -> Option<f64> {
321 if self.count > 0 {
322 Some(self.sum / self.count as f64)
323 } else {
324 None
325 }
326 }
327}
328
329impl Collector for StatsCollector {
330 type Fruit = CollectorStats;
331 type PartitionCollector = StatsPartitionCollector;
332
333 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
334 StatsPartitionCollector {
335 column_idx: self.column_idx,
336 stats: CollectorStats::default(),
337 }
338 }
339
340 fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
341 let mut result = CollectorStats::default();
342 for fruit in fruits {
343 result.merge(fruit);
344 }
345 result
346 }
347}
348
349pub struct StatsPartitionCollector {
351 column_idx: usize,
352 stats: CollectorStats,
353}
354
355impl PartitionCollector for StatsPartitionCollector {
356 type Fruit = CollectorStats;
357
358 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
359 let column = chunk.column(self.column_idx).ok_or_else(|| {
360 OperatorError::ColumnNotFound(format!(
361 "column index {} out of bounds (width={})",
362 self.column_idx,
363 chunk.column_count()
364 ))
365 })?;
366
367 for i in 0..chunk.len() {
368 let val = if let Some(f) = column.get_float64(i) {
370 Some(f)
371 } else if let Some(i) = column.get_int64(i) {
372 Some(i as f64)
373 } else if let Some(value) = column.get_value(i) {
374 match value {
376 grafeo_common::types::Value::Int64(i) => Some(i as f64),
377 grafeo_common::types::Value::Float64(f) => Some(f),
378 _ => None,
379 }
380 } else {
381 None
382 };
383
384 if let Some(v) = val {
385 self.stats.count += 1;
386 self.stats.sum += v;
387 self.stats.min = Some(match self.stats.min {
388 Some(m) => m.min(v),
389 None => v,
390 });
391 self.stats.max = Some(match self.stats.max {
392 Some(m) => m.max(v),
393 None => v,
394 });
395 }
396 }
397
398 Ok(())
399 }
400
401 fn harvest(self) -> CollectorStats {
402 self.stats
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::execution::ValueVector;
410 use grafeo_common::types::Value;
411
412 fn make_test_chunk(size: usize) -> DataChunk {
413 let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
414 let column = ValueVector::from_values(&values);
415 DataChunk::new(vec![column])
416 }
417
418 #[test]
419 fn test_count_collector() {
420 let collector = CountCollector;
421
422 let mut pc = collector.for_partition(0);
423 pc.collect(&make_test_chunk(10)).unwrap();
424 pc.collect(&make_test_chunk(5)).unwrap();
425 let count1 = pc.harvest();
426
427 let mut pc2 = collector.for_partition(1);
428 pc2.collect(&make_test_chunk(7)).unwrap();
429 let count2 = pc2.harvest();
430
431 let total = collector.merge(vec![count1, count2]);
432 assert_eq!(total, 22);
433 }
434
435 #[test]
436 fn test_materialize_collector() {
437 let collector = MaterializeCollector;
438
439 let mut pc = collector.for_partition(0);
440 pc.collect(&make_test_chunk(10)).unwrap();
441 pc.collect(&make_test_chunk(5)).unwrap();
442 let chunks1 = pc.harvest();
443
444 let mut pc2 = collector.for_partition(1);
445 pc2.collect(&make_test_chunk(7)).unwrap();
446 let chunks2 = pc2.harvest();
447
448 let result = collector.merge(vec![chunks1, chunks2]);
449 assert_eq!(result.len(), 3);
450 assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
451 }
452
453 #[test]
454 fn test_limit_collector() {
455 let collector = LimitCollector::new(12);
456
457 let mut pc = collector.for_partition(0);
458 pc.collect(&make_test_chunk(10)).unwrap();
459 pc.collect(&make_test_chunk(5)).unwrap(); let result1 = pc.harvest();
461
462 let mut pc2 = collector.for_partition(1);
463 pc2.collect(&make_test_chunk(20)).unwrap();
464 let result2 = pc2.harvest();
465
466 let (chunks, total) = collector.merge(vec![result1, result2]);
467 assert_eq!(total, 12);
468
469 let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
470 assert_eq!(actual_rows, 12);
471 }
472
473 #[test]
474 fn test_stats_collector() {
475 let collector = StatsCollector::new(0);
476
477 let mut pc = collector.for_partition(0);
478
479 let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
481 let column = ValueVector::from_values(&values);
482 let chunk = DataChunk::new(vec![column]);
483
484 pc.collect(&chunk).unwrap();
485 let stats = pc.harvest();
486
487 assert_eq!(stats.count, 10);
488 assert!((stats.sum - 45.0).abs() < 0.001); assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
490 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
491 assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
492 }
493
494 #[test]
495 fn test_stats_merge() {
496 let collector = StatsCollector::new(0);
497
498 let mut pc1 = collector.for_partition(0);
500 let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
501 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
502 pc1.collect(&chunk1).unwrap();
503
504 let mut pc2 = collector.for_partition(1);
506 let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
507 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
508 pc2.collect(&chunk2).unwrap();
509
510 let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
511
512 assert_eq!(stats.count, 10);
513 assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
514 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
515 }
516}