use std::collections::HashMap;
use std::hash::Hash;
use std::convert::Infallible;
use crate::distribution::{Distribution, DistributionError};
pub trait PartialAggregate<Input> {
type Output;
fn add(&mut self, input: Input);
fn finish(self) -> Self::Output;
}
pub trait Mergeable {
type Error;
fn merge_from(&mut self, other: Self) -> Result<(), Self::Error>;
}
impl<T: Clone> Mergeable for Distribution<T> {
type Error = DistributionError;
fn merge_from(&mut self, other: Self) -> Result<(), Self::Error> {
self.merge(&other)
}
}
#[derive(Debug, Clone, Copy, Default, PartialEq)]
pub struct Sum<T> {
value: T,
}
impl PartialAggregate<u64> for Sum<u64> {
type Output = u64;
fn add(&mut self, input: u64) {
self.value += input;
}
fn finish(self) -> Self::Output {
self.value
}
}
#[derive(Debug, Clone, Default)]
pub struct CollectionAggregator<K, V> {
values: HashMap<K, V>,
}
impl<K, V> CollectionAggregator<K, V>
where
K: Eq + Hash,
V: Mergeable,
{
pub fn new() -> Self {
Self {
values: HashMap::new(),
}
}
pub fn add(&mut self, key: K, value: V) -> Result<(), V::Error> {
match self.values.get_mut(&key) {
Some(existing) => existing.merge_from(value)?,
None => {
self.values.insert(key, value);
}
}
Ok(())
}
pub fn into_inner(self) -> HashMap<K, V> {
self.values
}
}
impl Mergeable for u64 {
type Error = Infallible;
fn merge_from(&mut self, other: Self) -> Result<(), Self::Error> {
*self += other;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::distribution::{Bucket, Distribution};
#[test]
fn collection_aggregator_merges_values_by_key() {
let left = Distribution::<()>::new(vec![Bucket {
range: 0.0..10.0,
count: 1,
}]);
let right = Distribution::<()>::new(vec![Bucket {
range: 0.0..10.0,
count: 2,
}]);
let mut aggregator = CollectionAggregator::new();
aggregator.add("rpc", left).unwrap();
aggregator.add("rpc", right).unwrap();
assert_eq!(aggregator.into_inner()["rpc"].total_count(), 3);
}
}