use std::sync::Arc;
use std::time::Duration;
use crate::exchange::Exchange;
pub type AggregationFn = Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>;
#[derive(Clone)]
pub enum AggregationStrategy {
CollectAll,
Custom(AggregationFn),
}
#[derive(Clone)]
pub enum CompletionCondition {
Size(usize),
#[allow(clippy::type_complexity)]
Predicate(Arc<dyn Fn(&[Exchange]) -> bool + Send + Sync>),
}
#[derive(Clone)]
pub struct AggregatorConfig {
pub header_name: String,
pub completion: CompletionCondition,
pub strategy: AggregationStrategy,
pub max_buckets: Option<usize>,
pub bucket_ttl: Option<Duration>,
}
impl AggregatorConfig {
pub fn correlate_by(header: impl Into<String>) -> AggregatorConfigBuilder {
AggregatorConfigBuilder {
header_name: header.into(),
completion: None,
strategy: AggregationStrategy::CollectAll,
max_buckets: None,
bucket_ttl: None,
}
}
}
pub struct AggregatorConfigBuilder {
header_name: String,
completion: Option<CompletionCondition>,
strategy: AggregationStrategy,
max_buckets: Option<usize>,
bucket_ttl: Option<Duration>,
}
impl AggregatorConfigBuilder {
pub fn complete_when_size(mut self, n: usize) -> Self {
self.completion = Some(CompletionCondition::Size(n));
self
}
pub fn complete_when<F>(mut self, predicate: F) -> Self
where
F: Fn(&[Exchange]) -> bool + Send + Sync + 'static,
{
self.completion = Some(CompletionCondition::Predicate(Arc::new(predicate)));
self
}
pub fn strategy(mut self, strategy: AggregationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn max_buckets(mut self, max: usize) -> Self {
self.max_buckets = Some(max);
self
}
pub fn bucket_ttl(mut self, ttl: Duration) -> Self {
self.bucket_ttl = Some(ttl);
self
}
pub fn build(self) -> AggregatorConfig {
AggregatorConfig {
header_name: self.header_name,
completion: self.completion.expect("completion condition required"),
strategy: self.strategy,
max_buckets: self.max_buckets,
bucket_ttl: self.bucket_ttl,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregator_config_complete_when_size() {
let config = AggregatorConfig::correlate_by("orderId")
.complete_when_size(3)
.build();
assert_eq!(config.header_name, "orderId");
matches!(config.completion, CompletionCondition::Size(3));
matches!(config.strategy, AggregationStrategy::CollectAll);
}
#[test]
fn test_aggregator_config_complete_when_predicate() {
let config = AggregatorConfig::correlate_by("key")
.complete_when(|bucket| bucket.len() >= 2)
.build();
matches!(config.completion, CompletionCondition::Predicate(_));
}
#[test]
fn test_aggregator_config_custom_strategy() {
use std::sync::Arc;
let f: AggregationFn = Arc::new(|acc, _next| acc);
let config = AggregatorConfig::correlate_by("key")
.complete_when_size(1)
.strategy(AggregationStrategy::Custom(f))
.build();
matches!(config.strategy, AggregationStrategy::Custom(_));
}
#[test]
#[should_panic(expected = "completion condition required")]
fn test_aggregator_config_missing_completion_panics() {
AggregatorConfig::correlate_by("key").build();
}
}