use std::sync::Arc;
use std::time::Duration;
use crate::exchange::Exchange;
pub type AggregationFn = Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>;
pub enum CorrelationStrategy {
HeaderName(String),
Expression { expr: String, language: String },
#[allow(clippy::type_complexity)]
Fn(Arc<dyn Fn(&Exchange) -> Option<String> + Send + Sync>),
}
impl Clone for CorrelationStrategy {
fn clone(&self) -> Self {
match self {
CorrelationStrategy::HeaderName(h) => CorrelationStrategy::HeaderName(h.clone()),
CorrelationStrategy::Expression { expr, language } => CorrelationStrategy::Expression {
expr: expr.clone(),
language: language.clone(),
},
CorrelationStrategy::Fn(f) => CorrelationStrategy::Fn(Arc::clone(f)),
}
}
}
impl std::fmt::Debug for CorrelationStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CorrelationStrategy::HeaderName(h) => f.debug_tuple("HeaderName").field(h).finish(),
CorrelationStrategy::Expression { expr, language } => f
.debug_struct("Expression")
.field("expr", expr)
.field("language", language)
.finish(),
CorrelationStrategy::Fn(_) => f.write_str("Fn(..)"),
}
}
}
#[derive(Clone)]
pub enum AggregationStrategy {
CollectAll,
Custom(AggregationFn),
}
#[derive(Clone)]
pub enum CompletionCondition {
Size(usize),
#[allow(clippy::type_complexity)]
Predicate(Arc<dyn Fn(&[Exchange]) -> bool + Send + Sync>),
Timeout(Duration),
}
#[derive(Clone)]
pub enum CompletionMode {
Single(CompletionCondition),
Any(Vec<CompletionCondition>),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CompletionReason {
Size,
Predicate,
Timeout,
Stop,
}
impl CompletionReason {
pub fn as_str(&self) -> &'static str {
match self {
CompletionReason::Size => "size",
CompletionReason::Predicate => "predicate",
CompletionReason::Timeout => "timeout",
CompletionReason::Stop => "stop",
}
}
}
#[derive(Clone)]
pub struct AggregatorConfig {
pub header_name: String,
pub completion: CompletionMode,
pub correlation: CorrelationStrategy,
pub strategy: AggregationStrategy,
pub max_buckets: Option<usize>,
pub bucket_ttl: Option<Duration>,
pub force_completion_on_stop: bool,
pub discard_on_timeout: bool,
}
impl AggregatorConfig {
pub fn correlate_by(header: impl Into<String>) -> AggregatorConfigBuilder {
let header_name = header.into();
AggregatorConfigBuilder {
header_name: header_name.clone(),
completion: None,
correlation: CorrelationStrategy::HeaderName(header_name),
strategy: AggregationStrategy::CollectAll,
max_buckets: None,
bucket_ttl: None,
force_completion_on_stop: false,
discard_on_timeout: false,
}
}
}
pub struct AggregatorConfigBuilder {
header_name: String,
completion: Option<CompletionMode>,
correlation: CorrelationStrategy,
strategy: AggregationStrategy,
max_buckets: Option<usize>,
bucket_ttl: Option<Duration>,
force_completion_on_stop: bool,
discard_on_timeout: bool,
}
impl AggregatorConfigBuilder {
pub fn complete_when_size(mut self, n: usize) -> Self {
self.completion = Some(CompletionMode::Single(CompletionCondition::Size(n)));
self
}
pub fn complete_when<F>(mut self, predicate: F) -> Self
where
F: Fn(&[Exchange]) -> bool + Send + Sync + 'static,
{
self.completion = Some(CompletionMode::Single(CompletionCondition::Predicate(
Arc::new(predicate),
)));
self
}
pub fn complete_on_timeout(mut self, duration: Duration) -> Self {
self.completion = Some(CompletionMode::Single(CompletionCondition::Timeout(
duration,
)));
self
}
pub fn complete_on_size_or_timeout(mut self, size: usize, timeout: Duration) -> Self {
self.completion = Some(CompletionMode::Any(vec![
CompletionCondition::Size(size),
CompletionCondition::Timeout(timeout),
]));
self
}
pub fn force_completion_on_stop(mut self, enabled: bool) -> Self {
self.force_completion_on_stop = enabled;
self
}
pub fn discard_on_timeout(mut self, enabled: bool) -> Self {
self.discard_on_timeout = enabled;
self
}
pub fn correlate_by(mut self, header: impl Into<String>) -> Self {
let header = header.into();
self.header_name = header.clone();
self.correlation = CorrelationStrategy::HeaderName(header);
self
}
pub fn strategy(mut self, strategy: AggregationStrategy) -> Self {
self.strategy = strategy;
self
}
pub fn max_buckets(mut self, max: usize) -> Self {
self.max_buckets = Some(max);
self
}
pub fn bucket_ttl(mut self, ttl: Duration) -> Self {
self.bucket_ttl = Some(ttl);
self
}
pub fn build(self) -> AggregatorConfig {
AggregatorConfig {
header_name: self.header_name,
completion: self.completion.expect("completion condition required"),
correlation: self.correlation,
strategy: self.strategy,
max_buckets: self.max_buckets,
bucket_ttl: self.bucket_ttl,
force_completion_on_stop: self.force_completion_on_stop,
discard_on_timeout: self.discard_on_timeout,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_aggregator_config_complete_when_size() {
let config = AggregatorConfig::correlate_by("orderId")
.complete_when_size(3)
.build();
assert_eq!(config.header_name, "orderId");
assert!(matches!(
config.completion,
CompletionMode::Single(CompletionCondition::Size(3))
));
assert!(matches!(config.strategy, AggregationStrategy::CollectAll));
}
#[test]
fn test_aggregator_config_complete_when_predicate() {
let config = AggregatorConfig::correlate_by("key")
.complete_when(|bucket| bucket.len() >= 2)
.build();
assert!(matches!(
config.completion,
CompletionMode::Single(CompletionCondition::Predicate(_))
));
}
#[test]
fn test_aggregator_config_custom_strategy() {
use std::sync::Arc;
let f: AggregationFn = Arc::new(|acc, _next| acc);
let config = AggregatorConfig::correlate_by("key")
.complete_when_size(1)
.strategy(AggregationStrategy::Custom(f))
.build();
assert!(matches!(config.strategy, AggregationStrategy::Custom(_)));
}
#[test]
#[should_panic(expected = "completion condition required")]
fn test_aggregator_config_missing_completion_panics() {
AggregatorConfig::correlate_by("key").build();
}
#[test]
fn test_complete_on_size_or_timeout() {
let config = AggregatorConfig::correlate_by("key")
.complete_on_size_or_timeout(3, Duration::from_secs(5))
.build();
assert!(matches!(config.completion, CompletionMode::Any(v) if v.len() == 2));
}
#[test]
fn test_force_completion_on_stop_default() {
let config = AggregatorConfig::correlate_by("key")
.complete_when_size(1)
.build();
assert!(!config.force_completion_on_stop);
assert!(!config.discard_on_timeout);
}
}