1use std::sync::Arc;
2use std::time::Duration;
3
4use crate::exchange::Exchange;
5
6pub type AggregationFn = Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>;
8
9#[derive(Clone)]
11pub enum AggregationStrategy {
12 CollectAll,
14 Custom(AggregationFn),
16}
17
18#[derive(Clone)]
20pub enum CompletionCondition {
21 Size(usize),
23 #[allow(clippy::type_complexity)]
25 Predicate(Arc<dyn Fn(&[Exchange]) -> bool + Send + Sync>),
26}
27
28#[derive(Clone)]
30pub struct AggregatorConfig {
31 pub header_name: String,
33 pub completion: CompletionCondition,
35 pub strategy: AggregationStrategy,
37 pub max_buckets: Option<usize>,
40 pub bucket_ttl: Option<Duration>,
43}
44
45impl AggregatorConfig {
46 pub fn correlate_by(header: impl Into<String>) -> AggregatorConfigBuilder {
48 AggregatorConfigBuilder {
49 header_name: header.into(),
50 completion: None,
51 strategy: AggregationStrategy::CollectAll,
52 max_buckets: None,
53 bucket_ttl: None,
54 }
55 }
56}
57
58pub struct AggregatorConfigBuilder {
60 header_name: String,
61 completion: Option<CompletionCondition>,
62 strategy: AggregationStrategy,
63 max_buckets: Option<usize>,
64 bucket_ttl: Option<Duration>,
65}
66
67impl AggregatorConfigBuilder {
68 pub fn complete_when_size(mut self, n: usize) -> Self {
70 self.completion = Some(CompletionCondition::Size(n));
71 self
72 }
73
74 pub fn complete_when<F>(mut self, predicate: F) -> Self
76 where
77 F: Fn(&[Exchange]) -> bool + Send + Sync + 'static,
78 {
79 self.completion = Some(CompletionCondition::Predicate(Arc::new(predicate)));
80 self
81 }
82
83 pub fn strategy(mut self, strategy: AggregationStrategy) -> Self {
85 self.strategy = strategy;
86 self
87 }
88
89 pub fn max_buckets(mut self, max: usize) -> Self {
92 self.max_buckets = Some(max);
93 self
94 }
95
96 pub fn bucket_ttl(mut self, ttl: Duration) -> Self {
99 self.bucket_ttl = Some(ttl);
100 self
101 }
102
103 pub fn build(self) -> AggregatorConfig {
105 AggregatorConfig {
106 header_name: self.header_name,
107 completion: self.completion.expect("completion condition required"),
108 strategy: self.strategy,
109 max_buckets: self.max_buckets,
110 bucket_ttl: self.bucket_ttl,
111 }
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118
119 #[test]
120 fn test_aggregator_config_complete_when_size() {
121 let config = AggregatorConfig::correlate_by("orderId")
122 .complete_when_size(3)
123 .build();
124 assert_eq!(config.header_name, "orderId");
125 matches!(config.completion, CompletionCondition::Size(3));
126 matches!(config.strategy, AggregationStrategy::CollectAll);
127 }
128
129 #[test]
130 fn test_aggregator_config_complete_when_predicate() {
131 let config = AggregatorConfig::correlate_by("key")
132 .complete_when(|bucket| bucket.len() >= 2)
133 .build();
134 matches!(config.completion, CompletionCondition::Predicate(_));
135 }
136
137 #[test]
138 fn test_aggregator_config_custom_strategy() {
139 use std::sync::Arc;
140 let f: AggregationFn = Arc::new(|acc, _next| acc);
141 let config = AggregatorConfig::correlate_by("key")
142 .complete_when_size(1)
143 .strategy(AggregationStrategy::Custom(f))
144 .build();
145 matches!(config.strategy, AggregationStrategy::Custom(_));
146 }
147
148 #[test]
149 #[should_panic(expected = "completion condition required")]
150 fn test_aggregator_config_missing_completion_panics() {
151 AggregatorConfig::correlate_by("key").build();
152 }
153}