1use std::sync::Arc;
2use std::time::Duration;
3
4use crate::exchange::Exchange;
5
6pub type AggregationFn = Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>;
8
9pub enum CorrelationStrategy {
11 HeaderName(String),
13 Expression { expr: String, language: String },
15 #[allow(clippy::type_complexity)]
17 Fn(Arc<dyn Fn(&Exchange) -> Option<String> + Send + Sync>),
18}
19
20impl Clone for CorrelationStrategy {
21 fn clone(&self) -> Self {
22 match self {
23 CorrelationStrategy::HeaderName(h) => CorrelationStrategy::HeaderName(h.clone()),
24 CorrelationStrategy::Expression { expr, language } => CorrelationStrategy::Expression {
25 expr: expr.clone(),
26 language: language.clone(),
27 },
28 CorrelationStrategy::Fn(f) => CorrelationStrategy::Fn(Arc::clone(f)),
29 }
30 }
31}
32
33impl std::fmt::Debug for CorrelationStrategy {
34 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
35 match self {
36 CorrelationStrategy::HeaderName(h) => f.debug_tuple("HeaderName").field(h).finish(),
37 CorrelationStrategy::Expression { expr, language } => f
38 .debug_struct("Expression")
39 .field("expr", expr)
40 .field("language", language)
41 .finish(),
42 CorrelationStrategy::Fn(_) => f.write_str("Fn(..)"),
43 }
44 }
45}
46
47#[derive(Clone)]
49pub enum AggregationStrategy {
50 CollectAll,
52 Custom(AggregationFn),
54}
55
56#[derive(Clone)]
58pub enum CompletionCondition {
59 Size(usize),
61 #[allow(clippy::type_complexity)]
63 Predicate(Arc<dyn Fn(&[Exchange]) -> bool + Send + Sync>),
64 Timeout(Duration),
66}
67
68#[derive(Clone)]
71pub enum CompletionMode {
72 Single(CompletionCondition),
73 Any(Vec<CompletionCondition>),
74}
75
76#[derive(Debug, Clone, PartialEq, Eq)]
77pub enum CompletionReason {
78 Size,
79 Predicate,
80 Timeout,
81 Stop,
82}
83
84impl CompletionReason {
85 pub fn as_str(&self) -> &'static str {
86 match self {
87 CompletionReason::Size => "size",
88 CompletionReason::Predicate => "predicate",
89 CompletionReason::Timeout => "timeout",
90 CompletionReason::Stop => "stop",
91 }
92 }
93}
94
95#[derive(Clone)]
97pub struct AggregatorConfig {
98 pub header_name: String,
100 pub completion: CompletionMode,
102 pub correlation: CorrelationStrategy,
104 pub strategy: AggregationStrategy,
106 pub max_buckets: Option<usize>,
109 pub bucket_ttl: Option<Duration>,
112 pub force_completion_on_stop: bool,
114 pub discard_on_timeout: bool,
116}
117
118impl AggregatorConfig {
119 pub fn correlate_by(header: impl Into<String>) -> AggregatorConfigBuilder {
121 let header_name = header.into();
122 AggregatorConfigBuilder {
123 header_name: header_name.clone(),
124 completion: None,
125 correlation: CorrelationStrategy::HeaderName(header_name),
126 strategy: AggregationStrategy::CollectAll,
127 max_buckets: None,
128 bucket_ttl: None,
129 force_completion_on_stop: false,
130 discard_on_timeout: false,
131 }
132 }
133}
134
135pub struct AggregatorConfigBuilder {
137 header_name: String,
138 completion: Option<CompletionMode>,
139 correlation: CorrelationStrategy,
140 strategy: AggregationStrategy,
141 max_buckets: Option<usize>,
142 bucket_ttl: Option<Duration>,
143 force_completion_on_stop: bool,
144 discard_on_timeout: bool,
145}
146
147impl AggregatorConfigBuilder {
148 pub fn complete_when_size(mut self, n: usize) -> Self {
150 self.completion = Some(CompletionMode::Single(CompletionCondition::Size(n)));
151 self
152 }
153
154 pub fn complete_when<F>(mut self, predicate: F) -> Self
156 where
157 F: Fn(&[Exchange]) -> bool + Send + Sync + 'static,
158 {
159 self.completion = Some(CompletionMode::Single(CompletionCondition::Predicate(
160 Arc::new(predicate),
161 )));
162 self
163 }
164
165 pub fn complete_on_timeout(mut self, duration: Duration) -> Self {
167 self.completion = Some(CompletionMode::Single(CompletionCondition::Timeout(
168 duration,
169 )));
170 self
171 }
172
173 pub fn complete_on_size_or_timeout(mut self, size: usize, timeout: Duration) -> Self {
175 self.completion = Some(CompletionMode::Any(vec![
176 CompletionCondition::Size(size),
177 CompletionCondition::Timeout(timeout),
178 ]));
179 self
180 }
181
182 pub fn force_completion_on_stop(mut self, enabled: bool) -> Self {
184 self.force_completion_on_stop = enabled;
185 self
186 }
187
188 pub fn discard_on_timeout(mut self, enabled: bool) -> Self {
190 self.discard_on_timeout = enabled;
191 self
192 }
193
194 pub fn correlate_by(mut self, header: impl Into<String>) -> Self {
196 let header = header.into();
197 self.header_name = header.clone();
198 self.correlation = CorrelationStrategy::HeaderName(header);
199 self
200 }
201
202 pub fn strategy(mut self, strategy: AggregationStrategy) -> Self {
204 self.strategy = strategy;
205 self
206 }
207
208 pub fn max_buckets(mut self, max: usize) -> Self {
211 self.max_buckets = Some(max);
212 self
213 }
214
215 pub fn bucket_ttl(mut self, ttl: Duration) -> Self {
218 self.bucket_ttl = Some(ttl);
219 self
220 }
221
222 pub fn build(self) -> AggregatorConfig {
224 AggregatorConfig {
225 header_name: self.header_name,
226 completion: self.completion.expect("completion condition required"),
227 correlation: self.correlation,
228 strategy: self.strategy,
229 max_buckets: self.max_buckets,
230 bucket_ttl: self.bucket_ttl,
231 force_completion_on_stop: self.force_completion_on_stop,
232 discard_on_timeout: self.discard_on_timeout,
233 }
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_aggregator_config_complete_when_size() {
243 let config = AggregatorConfig::correlate_by("orderId")
244 .complete_when_size(3)
245 .build();
246 assert_eq!(config.header_name, "orderId");
247 assert!(matches!(
248 config.completion,
249 CompletionMode::Single(CompletionCondition::Size(3))
250 ));
251 assert!(matches!(config.strategy, AggregationStrategy::CollectAll));
252 }
253
254 #[test]
255 fn test_aggregator_config_complete_when_predicate() {
256 let config = AggregatorConfig::correlate_by("key")
257 .complete_when(|bucket| bucket.len() >= 2)
258 .build();
259 assert!(matches!(
260 config.completion,
261 CompletionMode::Single(CompletionCondition::Predicate(_))
262 ));
263 }
264
265 #[test]
266 fn test_aggregator_config_custom_strategy() {
267 use std::sync::Arc;
268 let f: AggregationFn = Arc::new(|acc, _next| acc);
269 let config = AggregatorConfig::correlate_by("key")
270 .complete_when_size(1)
271 .strategy(AggregationStrategy::Custom(f))
272 .build();
273 assert!(matches!(config.strategy, AggregationStrategy::Custom(_)));
274 }
275
276 #[test]
277 #[should_panic(expected = "completion condition required")]
278 fn test_aggregator_config_missing_completion_panics() {
279 AggregatorConfig::correlate_by("key").build();
280 }
281
282 #[test]
283 fn test_complete_on_size_or_timeout() {
284 let config = AggregatorConfig::correlate_by("key")
285 .complete_on_size_or_timeout(3, Duration::from_secs(5))
286 .build();
287 assert!(matches!(config.completion, CompletionMode::Any(v) if v.len() == 2));
288 }
289
290 #[test]
291 fn test_force_completion_on_stop_default() {
292 let config = AggregatorConfig::correlate_by("key")
293 .complete_when_size(1)
294 .build();
295 assert!(!config.force_completion_on_stop);
296 assert!(!config.discard_on_timeout);
297 }
298}