1use super::chunk::DataChunk;
44use super::operators::OperatorError;
45
46pub trait Collector: Sync {
51 type Fruit: Send;
53
54 type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
56
57 fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
59
60 fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
62}
63
64pub trait PartitionCollector: Send {
70 type Fruit: Send;
72
73 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
77
78 fn harvest(self) -> Self::Fruit;
82}
83
84#[derive(Debug, Clone, Copy, Default)]
108pub struct CountCollector;
109
110impl Collector for CountCollector {
111 type Fruit = u64;
112 type PartitionCollector = CountPartitionCollector;
113
114 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
115 CountPartitionCollector { count: 0 }
116 }
117
118 fn merge(&self, fruits: Vec<u64>) -> u64 {
119 fruits.into_iter().sum()
120 }
121}
122
123pub struct CountPartitionCollector {
125 count: u64,
126}
127
128impl PartitionCollector for CountPartitionCollector {
129 type Fruit = u64;
130
131 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
132 self.count += chunk.len() as u64;
133 Ok(())
134 }
135
136 fn harvest(self) -> u64 {
137 self.count
138 }
139}
140
141#[derive(Debug, Clone, Default)]
146pub struct MaterializeCollector;
147
148impl Collector for MaterializeCollector {
149 type Fruit = Vec<DataChunk>;
150 type PartitionCollector = MaterializePartitionCollector;
151
152 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
153 MaterializePartitionCollector { chunks: Vec::new() }
154 }
155
156 fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
157 let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
158 let mut result = Vec::with_capacity(total_chunks);
159 for fruit in &mut fruits {
160 result.append(fruit);
161 }
162 result
163 }
164}
165
166pub struct MaterializePartitionCollector {
168 chunks: Vec<DataChunk>,
169}
170
171impl PartitionCollector for MaterializePartitionCollector {
172 type Fruit = Vec<DataChunk>;
173
174 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
175 self.chunks.push(chunk.clone());
176 Ok(())
177 }
178
179 fn harvest(self) -> Vec<DataChunk> {
180 self.chunks
181 }
182}
183
184#[derive(Debug, Clone)]
189pub struct LimitCollector {
190 limit: usize,
191}
192
193impl LimitCollector {
194 #[must_use]
196 pub fn new(limit: usize) -> Self {
197 Self { limit }
198 }
199}
200
201impl Collector for LimitCollector {
202 type Fruit = (Vec<DataChunk>, usize);
203 type PartitionCollector = LimitPartitionCollector;
204
205 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
206 LimitPartitionCollector {
207 chunks: Vec::new(),
208 limit: self.limit,
209 collected: 0,
210 }
211 }
212
213 fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
214 let mut result = Vec::new();
215 let mut total = 0;
216
217 for (chunks, _) in fruits {
218 for chunk in chunks {
219 if total >= self.limit {
220 break;
221 }
222 let take = (self.limit - total).min(chunk.len());
223 if take < chunk.len() {
224 result.push(chunk.slice(0, take));
225 } else {
226 result.push(chunk);
227 }
228 total += take;
229 }
230 if total >= self.limit {
231 break;
232 }
233 }
234
235 (result, total)
236 }
237}
238
239pub struct LimitPartitionCollector {
241 chunks: Vec<DataChunk>,
242 limit: usize,
243 collected: usize,
244}
245
246impl PartitionCollector for LimitPartitionCollector {
247 type Fruit = (Vec<DataChunk>, usize);
248
249 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
250 if self.collected >= self.limit {
251 return Ok(());
252 }
253
254 let take = (self.limit - self.collected).min(chunk.len());
255 if take < chunk.len() {
256 self.chunks.push(chunk.slice(0, take));
257 } else {
258 self.chunks.push(chunk.clone());
259 }
260 self.collected += take;
261
262 Ok(())
263 }
264
265 fn harvest(self) -> (Vec<DataChunk>, usize) {
266 (self.chunks, self.collected)
267 }
268}
269
270#[derive(Debug, Clone)]
272pub struct StatsCollector {
273 column_idx: usize,
274}
275
276impl StatsCollector {
277 #[must_use]
279 pub fn new(column_idx: usize) -> Self {
280 Self { column_idx }
281 }
282}
283
284#[derive(Debug, Clone, Default)]
286pub struct CollectorStats {
287 pub count: u64,
289 pub sum: f64,
291 pub min: Option<f64>,
293 pub max: Option<f64>,
295}
296
297impl CollectorStats {
298 pub fn merge(&mut self, other: CollectorStats) {
300 self.count += other.count;
301 self.sum += other.sum;
302 self.min = match (self.min, other.min) {
303 (Some(a), Some(b)) => Some(a.min(b)),
304 (Some(v), None) | (None, Some(v)) => Some(v),
305 (None, None) => None,
306 };
307 self.max = match (self.max, other.max) {
308 (Some(a), Some(b)) => Some(a.max(b)),
309 (Some(v), None) | (None, Some(v)) => Some(v),
310 (None, None) => None,
311 };
312 }
313
314 #[must_use]
316 pub fn avg(&self) -> Option<f64> {
317 if self.count > 0 {
318 Some(self.sum / self.count as f64)
319 } else {
320 None
321 }
322 }
323}
324
325impl Collector for StatsCollector {
326 type Fruit = CollectorStats;
327 type PartitionCollector = StatsPartitionCollector;
328
329 fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
330 StatsPartitionCollector {
331 column_idx: self.column_idx,
332 stats: CollectorStats::default(),
333 }
334 }
335
336 fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
337 let mut result = CollectorStats::default();
338 for fruit in fruits {
339 result.merge(fruit);
340 }
341 result
342 }
343}
344
345pub struct StatsPartitionCollector {
347 column_idx: usize,
348 stats: CollectorStats,
349}
350
351impl PartitionCollector for StatsPartitionCollector {
352 type Fruit = CollectorStats;
353
354 fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
355 let column = chunk.column(self.column_idx).ok_or_else(|| {
356 OperatorError::ColumnNotFound(format!(
357 "column index {} out of bounds (width={})",
358 self.column_idx,
359 chunk.column_count()
360 ))
361 })?;
362
363 for i in 0..chunk.len() {
364 let val = if let Some(f) = column.get_float64(i) {
366 Some(f)
367 } else if let Some(i) = column.get_int64(i) {
368 Some(i as f64)
369 } else if let Some(value) = column.get_value(i) {
370 match value {
372 grafeo_common::types::Value::Int64(i) => Some(i as f64),
373 grafeo_common::types::Value::Float64(f) => Some(f),
374 _ => None,
375 }
376 } else {
377 None
378 };
379
380 if let Some(v) = val {
381 self.stats.count += 1;
382 self.stats.sum += v;
383 self.stats.min = Some(match self.stats.min {
384 Some(m) => m.min(v),
385 None => v,
386 });
387 self.stats.max = Some(match self.stats.max {
388 Some(m) => m.max(v),
389 None => v,
390 });
391 }
392 }
393
394 Ok(())
395 }
396
397 fn harvest(self) -> CollectorStats {
398 self.stats
399 }
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::execution::ValueVector;
406 use grafeo_common::types::Value;
407
408 fn make_test_chunk(size: usize) -> DataChunk {
409 let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
410 let column = ValueVector::from_values(&values);
411 DataChunk::new(vec![column])
412 }
413
414 #[test]
415 fn test_count_collector() {
416 let collector = CountCollector;
417
418 let mut pc = collector.for_partition(0);
419 pc.collect(&make_test_chunk(10)).unwrap();
420 pc.collect(&make_test_chunk(5)).unwrap();
421 let count1 = pc.harvest();
422
423 let mut pc2 = collector.for_partition(1);
424 pc2.collect(&make_test_chunk(7)).unwrap();
425 let count2 = pc2.harvest();
426
427 let total = collector.merge(vec![count1, count2]);
428 assert_eq!(total, 22);
429 }
430
431 #[test]
432 fn test_materialize_collector() {
433 let collector = MaterializeCollector;
434
435 let mut pc = collector.for_partition(0);
436 pc.collect(&make_test_chunk(10)).unwrap();
437 pc.collect(&make_test_chunk(5)).unwrap();
438 let chunks1 = pc.harvest();
439
440 let mut pc2 = collector.for_partition(1);
441 pc2.collect(&make_test_chunk(7)).unwrap();
442 let chunks2 = pc2.harvest();
443
444 let result = collector.merge(vec![chunks1, chunks2]);
445 assert_eq!(result.len(), 3);
446 assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
447 }
448
449 #[test]
450 fn test_limit_collector() {
451 let collector = LimitCollector::new(12);
452
453 let mut pc = collector.for_partition(0);
454 pc.collect(&make_test_chunk(10)).unwrap();
455 pc.collect(&make_test_chunk(5)).unwrap(); let result1 = pc.harvest();
457
458 let mut pc2 = collector.for_partition(1);
459 pc2.collect(&make_test_chunk(20)).unwrap();
460 let result2 = pc2.harvest();
461
462 let (chunks, total) = collector.merge(vec![result1, result2]);
463 assert_eq!(total, 12);
464
465 let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
466 assert_eq!(actual_rows, 12);
467 }
468
469 #[test]
470 fn test_stats_collector() {
471 let collector = StatsCollector::new(0);
472
473 let mut pc = collector.for_partition(0);
474
475 let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
477 let column = ValueVector::from_values(&values);
478 let chunk = DataChunk::new(vec![column]);
479
480 pc.collect(&chunk).unwrap();
481 let stats = pc.harvest();
482
483 assert_eq!(stats.count, 10);
484 assert!((stats.sum - 45.0).abs() < 0.001); assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
486 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
487 assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
488 }
489
490 #[test]
491 fn test_stats_merge() {
492 let collector = StatsCollector::new(0);
493
494 let mut pc1 = collector.for_partition(0);
496 let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
497 let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
498 pc1.collect(&chunk1).unwrap();
499
500 let mut pc2 = collector.for_partition(1);
502 let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
503 let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
504 pc2.collect(&chunk2).unwrap();
505
506 let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
507
508 assert_eq!(stats.count, 10);
509 assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
510 assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
511 }
512}