use super::chunk::DataChunk;
use super::operators::OperatorError;
pub trait Collector: Sync {
type Fruit: Send;
type PartitionCollector: PartitionCollector<Fruit = Self::Fruit>;
fn for_partition(&self, partition_id: usize) -> Self::PartitionCollector;
fn merge(&self, fruits: Vec<Self::Fruit>) -> Self::Fruit;
}
pub trait PartitionCollector: Send {
type Fruit: Send;
fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError>;
fn harvest(self) -> Self::Fruit;
}
#[derive(Debug, Clone, Copy, Default)]
pub struct CountCollector;
impl Collector for CountCollector {
type Fruit = u64;
type PartitionCollector = CountPartitionCollector;
fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
CountPartitionCollector { count: 0 }
}
fn merge(&self, fruits: Vec<u64>) -> u64 {
fruits.into_iter().sum()
}
}
pub struct CountPartitionCollector {
count: u64,
}
impl PartitionCollector for CountPartitionCollector {
type Fruit = u64;
fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
self.count += chunk.len() as u64;
Ok(())
}
fn harvest(self) -> u64 {
self.count
}
}
#[derive(Debug, Clone, Default)]
pub struct MaterializeCollector;
impl Collector for MaterializeCollector {
type Fruit = Vec<DataChunk>;
type PartitionCollector = MaterializePartitionCollector;
fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
MaterializePartitionCollector { chunks: Vec::new() }
}
fn merge(&self, mut fruits: Vec<Vec<DataChunk>>) -> Vec<DataChunk> {
let total_chunks: usize = fruits.iter().map(|f| f.len()).sum();
let mut result = Vec::with_capacity(total_chunks);
for fruit in &mut fruits {
result.append(fruit);
}
result
}
}
pub struct MaterializePartitionCollector {
chunks: Vec<DataChunk>,
}
impl PartitionCollector for MaterializePartitionCollector {
type Fruit = Vec<DataChunk>;
fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
self.chunks.push(chunk.clone());
Ok(())
}
fn harvest(self) -> Vec<DataChunk> {
self.chunks
}
}
#[derive(Debug, Clone)]
pub struct LimitCollector {
limit: usize,
}
impl LimitCollector {
#[must_use]
pub fn new(limit: usize) -> Self {
Self { limit }
}
}
impl Collector for LimitCollector {
type Fruit = (Vec<DataChunk>, usize);
type PartitionCollector = LimitPartitionCollector;
fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
LimitPartitionCollector {
chunks: Vec::new(),
limit: self.limit,
collected: 0,
}
}
fn merge(&self, fruits: Vec<(Vec<DataChunk>, usize)>) -> (Vec<DataChunk>, usize) {
let mut result = Vec::new();
let mut total = 0;
for (chunks, _) in fruits {
for chunk in chunks {
if total >= self.limit {
break;
}
let take = (self.limit - total).min(chunk.len());
if take < chunk.len() {
result.push(chunk.slice(0, take));
} else {
result.push(chunk);
}
total += take;
}
if total >= self.limit {
break;
}
}
(result, total)
}
}
pub struct LimitPartitionCollector {
chunks: Vec<DataChunk>,
limit: usize,
collected: usize,
}
impl PartitionCollector for LimitPartitionCollector {
type Fruit = (Vec<DataChunk>, usize);
fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
if self.collected >= self.limit {
return Ok(());
}
let take = (self.limit - self.collected).min(chunk.len());
if take < chunk.len() {
self.chunks.push(chunk.slice(0, take));
} else {
self.chunks.push(chunk.clone());
}
self.collected += take;
Ok(())
}
fn harvest(self) -> (Vec<DataChunk>, usize) {
(self.chunks, self.collected)
}
}
#[derive(Debug, Clone)]
pub struct StatsCollector {
column_idx: usize,
}
impl StatsCollector {
#[must_use]
pub fn new(column_idx: usize) -> Self {
Self { column_idx }
}
}
#[derive(Debug, Clone, Default)]
pub struct CollectorStats {
pub count: u64,
pub sum: f64,
pub min: Option<f64>,
pub max: Option<f64>,
}
impl CollectorStats {
pub fn merge(&mut self, other: CollectorStats) {
self.count += other.count;
self.sum += other.sum;
self.min = match (self.min, other.min) {
(Some(a), Some(b)) => Some(a.min(b)),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
};
self.max = match (self.max, other.max) {
(Some(a), Some(b)) => Some(a.max(b)),
(Some(v), None) | (None, Some(v)) => Some(v),
(None, None) => None,
};
}
#[must_use]
pub fn avg(&self) -> Option<f64> {
if self.count > 0 {
Some(self.sum / self.count as f64)
} else {
None
}
}
}
impl Collector for StatsCollector {
type Fruit = CollectorStats;
type PartitionCollector = StatsPartitionCollector;
fn for_partition(&self, _partition_id: usize) -> Self::PartitionCollector {
StatsPartitionCollector {
column_idx: self.column_idx,
stats: CollectorStats::default(),
}
}
fn merge(&self, fruits: Vec<CollectorStats>) -> CollectorStats {
let mut result = CollectorStats::default();
for fruit in fruits {
result.merge(fruit);
}
result
}
}
pub struct StatsPartitionCollector {
column_idx: usize,
stats: CollectorStats,
}
impl PartitionCollector for StatsPartitionCollector {
type Fruit = CollectorStats;
fn collect(&mut self, chunk: &DataChunk) -> Result<(), OperatorError> {
let column = chunk.column(self.column_idx).ok_or_else(|| {
OperatorError::ColumnNotFound(format!(
"column index {} out of bounds (width={})",
self.column_idx,
chunk.column_count()
))
})?;
for i in 0..chunk.len() {
let val = if let Some(f) = column.get_float64(i) {
Some(f)
} else if let Some(i) = column.get_int64(i) {
Some(i as f64)
} else if let Some(value) = column.get_value(i) {
match value {
grafeo_common::types::Value::Int64(i) => Some(i as f64),
grafeo_common::types::Value::Float64(f) => Some(f),
_ => None,
}
} else {
None
};
if let Some(v) = val {
self.stats.count += 1;
self.stats.sum += v;
self.stats.min = Some(match self.stats.min {
Some(m) => m.min(v),
None => v,
});
self.stats.max = Some(match self.stats.max {
Some(m) => m.max(v),
None => v,
});
}
}
Ok(())
}
fn harvest(self) -> CollectorStats {
self.stats
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::execution::ValueVector;
use grafeo_common::types::Value;
fn make_test_chunk(size: usize) -> DataChunk {
let values: Vec<Value> = (0..size).map(|i| Value::from(i as i64)).collect();
let column = ValueVector::from_values(&values);
DataChunk::new(vec![column])
}
#[test]
fn test_count_collector() {
let collector = CountCollector;
let mut pc = collector.for_partition(0);
pc.collect(&make_test_chunk(10)).unwrap();
pc.collect(&make_test_chunk(5)).unwrap();
let count1 = pc.harvest();
let mut pc2 = collector.for_partition(1);
pc2.collect(&make_test_chunk(7)).unwrap();
let count2 = pc2.harvest();
let total = collector.merge(vec![count1, count2]);
assert_eq!(total, 22);
}
#[test]
fn test_materialize_collector() {
let collector = MaterializeCollector;
let mut pc = collector.for_partition(0);
pc.collect(&make_test_chunk(10)).unwrap();
pc.collect(&make_test_chunk(5)).unwrap();
let chunks1 = pc.harvest();
let mut pc2 = collector.for_partition(1);
pc2.collect(&make_test_chunk(7)).unwrap();
let chunks2 = pc2.harvest();
let result = collector.merge(vec![chunks1, chunks2]);
assert_eq!(result.len(), 3);
assert_eq!(result.iter().map(|c| c.len()).sum::<usize>(), 22);
}
#[test]
fn test_limit_collector() {
let collector = LimitCollector::new(12);
let mut pc = collector.for_partition(0);
pc.collect(&make_test_chunk(10)).unwrap();
pc.collect(&make_test_chunk(5)).unwrap(); let result1 = pc.harvest();
let mut pc2 = collector.for_partition(1);
pc2.collect(&make_test_chunk(20)).unwrap();
let result2 = pc2.harvest();
let (chunks, total) = collector.merge(vec![result1, result2]);
assert_eq!(total, 12);
let actual_rows: usize = chunks.iter().map(|c| c.len()).sum();
assert_eq!(actual_rows, 12);
}
#[test]
fn test_stats_collector() {
let collector = StatsCollector::new(0);
let mut pc = collector.for_partition(0);
let values: Vec<Value> = (0..10).map(|i| Value::from(i as i64)).collect();
let column = ValueVector::from_values(&values);
let chunk = DataChunk::new(vec![column]);
pc.collect(&chunk).unwrap();
let stats = pc.harvest();
assert_eq!(stats.count, 10);
assert!((stats.sum - 45.0).abs() < 0.001); assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
assert!((stats.avg().unwrap() - 4.5).abs() < 0.001);
}
#[test]
fn test_stats_merge() {
let collector = StatsCollector::new(0);
let mut pc1 = collector.for_partition(0);
let values1: Vec<Value> = (0..5).map(|i| Value::from(i as i64)).collect();
let chunk1 = DataChunk::new(vec![ValueVector::from_values(&values1)]);
pc1.collect(&chunk1).unwrap();
let mut pc2 = collector.for_partition(1);
let values2: Vec<Value> = (5..10).map(|i| Value::from(i as i64)).collect();
let chunk2 = DataChunk::new(vec![ValueVector::from_values(&values2)]);
pc2.collect(&chunk2).unwrap();
let stats = collector.merge(vec![pc1.harvest(), pc2.harvest()]);
assert_eq!(stats.count, 10);
assert!((stats.min.unwrap() - 0.0).abs() < 0.001);
assert!((stats.max.unwrap() - 9.0).abs() < 0.001);
}
}