use crate::record::Record;
pub type Error = Box<dyn std::error::Error + Send + Sync>;
#[derive(Debug)]
pub enum TryPush<I, T> {
NoCapacity(I),
Aggregated(T),
}
impl<I, T> TryPush<I, T> {
pub fn unwrap_input(self) -> I {
match self {
Self::NoCapacity(input) => input,
Self::Aggregated(_) => panic!("Aggregated"),
}
}
pub fn unwrap_tag(self) -> T {
match self {
Self::NoCapacity(_) => panic!("NoCapacity"),
Self::Aggregated(tag) => tag,
}
}
}
pub trait Aggregator: Send + 'static {
type Input: Send;
type Tag: Send + std::fmt::Debug;
type StatusDeaggregator: StatusDeaggregator<Tag = Self::Tag>;
fn try_push(&mut self, record: Self::Input) -> Result<TryPush<Self::Input, Self::Tag>, Error>;
fn flush(&mut self) -> Result<(Vec<Record>, Self::StatusDeaggregator), Error>;
}
pub trait StatusDeaggregator: Send + Sync + std::fmt::Debug {
type Status;
type Tag: Send;
fn deaggregate(&self, input: &[i64], tag: Self::Tag) -> Result<Self::Status, Error>;
}
pub trait AggregatorStatus {
type Status;
}
impl<T> AggregatorStatus for T
where
T: Aggregator,
{
type Status = <<Self as Aggregator>::StatusDeaggregator as StatusDeaggregator>::Status;
}
#[derive(Debug, Default)]
struct AggregatorState {
batch_size: usize,
records: Vec<Record>,
}
#[derive(Debug)]
pub struct RecordAggregator {
max_batch_size: usize,
state: AggregatorState,
}
impl Aggregator for RecordAggregator {
type Input = Record;
type Tag = usize;
type StatusDeaggregator = RecordAggregatorStatusDeaggregator;
fn try_push(&mut self, record: Self::Input) -> Result<TryPush<Self::Input, Self::Tag>, Error> {
let record_size: usize = record.approximate_size();
if self.state.batch_size + record_size > self.max_batch_size {
return Ok(TryPush::NoCapacity(record));
}
let tag = self.state.records.len();
self.state.batch_size += record_size;
self.state.records.push(record);
Ok(TryPush::Aggregated(tag))
}
fn flush(&mut self) -> Result<(Vec<Record>, Self::StatusDeaggregator), Error> {
let state = std::mem::take(&mut self.state);
Ok((state.records, RecordAggregatorStatusDeaggregator::default()))
}
}
impl RecordAggregator {
pub fn new(max_batch_size: usize) -> Self {
Self {
max_batch_size,
state: Default::default(),
}
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct RecordAggregatorStatusDeaggregator {}
impl StatusDeaggregator for RecordAggregatorStatusDeaggregator {
type Status = i64;
type Tag = usize;
fn deaggregate(&self, input: &[i64], tag: Self::Tag) -> Result<Self::Status, Error> {
Ok(input[tag])
}
}
#[cfg(test)]
mod tests {
use chrono::{TimeZone, Utc};
use super::*;
#[test]
fn test_record_aggregator() {
let r1 = Record {
key: Some(vec![0; 45]),
value: Some(vec![0; 2]),
headers: Default::default(),
timestamp: Utc.timestamp_millis_opt(1337).unwrap(),
};
let r2 = Record {
value: Some(vec![0; 34]),
..r1.clone()
};
assert!(r1.approximate_size() < r2.approximate_size());
assert!(r2.approximate_size() < r2.approximate_size() * 2);
let mut aggregator = RecordAggregator::new(r1.approximate_size() * 2);
let t1 = aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
let t2 = aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
aggregator.try_push(r1.clone()).unwrap().unwrap_input();
aggregator.try_push(r1.clone()).unwrap().unwrap_input();
let (records, deagg) = aggregator.flush().unwrap();
assert_eq!(records.len(), 2);
assert_eq!(deagg.deaggregate(&[10, 20], t1).unwrap(), 10);
assert_eq!(deagg.deaggregate(&[10, 20], t2).unwrap(), 20);
let t1 = aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
let (records, deagg) = aggregator.flush().unwrap();
assert_eq!(records.len(), 1);
assert_eq!(deagg.deaggregate(&[10], t1).unwrap(), 10);
let t1 = aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
let t2 = aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
let (records, deagg) = aggregator.flush().unwrap();
assert_eq!(records.len(), 2);
assert_eq!(deagg.deaggregate(&[10, 20], t1).unwrap(), 10);
assert_eq!(deagg.deaggregate(&[10, 20], t2).unwrap(), 20);
let (records, _deagg) = aggregator.flush().unwrap();
assert_eq!(records.len(), 0);
aggregator.try_push(r1.clone()).unwrap().unwrap_tag();
aggregator.try_push(r2.clone()).unwrap().unwrap_input();
assert_eq!(aggregator.flush().unwrap().0.len(), 1);
aggregator.try_push(r2.clone()).unwrap().unwrap_tag();
let mut aggregator = RecordAggregator::new(r1.approximate_size());
aggregator.try_push(r2).unwrap().unwrap_input();
}
#[test]
fn test_unwrap_input_ok() {
assert_eq!(TryPush::<i8, i8>::NoCapacity(42).unwrap_input(), 42,);
}
#[test]
#[should_panic(expected = "Aggregated")]
fn test_unwrap_input_panic() {
TryPush::<i8, i8>::Aggregated(42).unwrap_input();
}
#[test]
fn test_unwrap_tag_ok() {
assert_eq!(TryPush::<i8, i8>::Aggregated(42).unwrap_tag(), 42,);
}
#[test]
#[should_panic(expected = "NoCapacity")]
fn test_unwrap_tag_panic() {
TryPush::<i8, i8>::NoCapacity(42).unwrap_tag();
}
}