1use std::sync::Arc;
2use std::time::Duration;
3
4use crate::error::CamelError;
5use crate::exchange::Exchange;
6
7pub type AggregationFn = Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync>;
9
10pub enum CorrelationStrategy {
12 HeaderName(String),
14 Expression { expr: String, language: String },
16 #[allow(clippy::type_complexity)]
18 Fn(Arc<dyn Fn(&Exchange) -> Option<String> + Send + Sync>),
19}
20
21impl Clone for CorrelationStrategy {
22 fn clone(&self) -> Self {
23 match self {
24 CorrelationStrategy::HeaderName(h) => CorrelationStrategy::HeaderName(h.clone()),
25 CorrelationStrategy::Expression { expr, language } => CorrelationStrategy::Expression {
26 expr: expr.clone(),
27 language: language.clone(),
28 },
29 CorrelationStrategy::Fn(f) => CorrelationStrategy::Fn(Arc::clone(f)),
30 }
31 }
32}
33
34impl std::fmt::Debug for CorrelationStrategy {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 match self {
37 CorrelationStrategy::HeaderName(h) => f.debug_tuple("HeaderName").field(h).finish(),
38 CorrelationStrategy::Expression { expr, language } => f
39 .debug_struct("Expression")
40 .field("expr", expr)
41 .field("language", language)
42 .finish(),
43 CorrelationStrategy::Fn(_) => f.write_str("Fn(..)"),
44 }
45 }
46}
47
48#[derive(Clone)]
50pub enum AggregationStrategy {
51 CollectAll,
53 Custom(AggregationFn),
55}
56
57#[derive(Clone)]
59pub enum CompletionCondition {
60 Size(usize),
62 #[allow(clippy::type_complexity)]
64 Predicate(Arc<dyn Fn(&[Exchange]) -> bool + Send + Sync>),
65 Timeout(Duration),
67}
68
69#[derive(Clone)]
72pub enum CompletionMode {
73 Single(CompletionCondition),
74 Any(Vec<CompletionCondition>),
75}
76
77#[derive(Debug, Clone, PartialEq, Eq)]
78pub enum CompletionReason {
79 Size,
80 Predicate,
81 Timeout,
82 Stop,
83}
84
85impl CompletionReason {
86 pub fn as_str(&self) -> &'static str {
87 match self {
88 CompletionReason::Size => "size",
89 CompletionReason::Predicate => "predicate",
90 CompletionReason::Timeout => "timeout",
91 CompletionReason::Stop => "stop",
92 }
93 }
94}
95
96#[derive(Clone)]
98pub struct AggregatorConfig {
99 pub header_name: String,
101 pub completion: CompletionMode,
103 pub correlation: CorrelationStrategy,
105 pub strategy: AggregationStrategy,
107 pub max_buckets: Option<usize>,
110 pub bucket_ttl: Option<Duration>,
113 pub force_completion_on_stop: bool,
115 pub discard_on_timeout: bool,
117}
118
119impl AggregatorConfig {
120 pub fn correlate_by(header: impl Into<String>) -> AggregatorConfigBuilder {
122 let header_name = header.into();
123 AggregatorConfigBuilder {
124 header_name: header_name.clone(),
125 completion: None,
126 correlation: CorrelationStrategy::HeaderName(header_name),
127 strategy: AggregationStrategy::CollectAll,
128 max_buckets: Some(10_000),
134 bucket_ttl: Some(Duration::from_secs(300)),
135 force_completion_on_stop: false,
136 discard_on_timeout: false,
137 }
138 }
139}
140
141pub struct AggregatorConfigBuilder {
143 header_name: String,
144 completion: Option<CompletionMode>,
145 correlation: CorrelationStrategy,
146 strategy: AggregationStrategy,
147 max_buckets: Option<usize>,
148 bucket_ttl: Option<Duration>,
149 force_completion_on_stop: bool,
150 discard_on_timeout: bool,
151}
152
153impl AggregatorConfigBuilder {
154 pub fn complete_when_size(mut self, n: usize) -> Self {
156 self.completion = Some(CompletionMode::Single(CompletionCondition::Size(n)));
157 self
158 }
159
160 pub fn complete_when<F>(mut self, predicate: F) -> Self
162 where
163 F: Fn(&[Exchange]) -> bool + Send + Sync + 'static,
164 {
165 self.completion = Some(CompletionMode::Single(CompletionCondition::Predicate(
166 Arc::new(predicate),
167 )));
168 self
169 }
170
171 pub fn complete_on_timeout(mut self, duration: Duration) -> Self {
173 self.completion = Some(CompletionMode::Single(CompletionCondition::Timeout(
174 duration,
175 )));
176 self
177 }
178
179 pub fn complete_on_size_or_timeout(mut self, size: usize, timeout: Duration) -> Self {
181 self.completion = Some(CompletionMode::Any(vec![
182 CompletionCondition::Size(size),
183 CompletionCondition::Timeout(timeout),
184 ]));
185 self
186 }
187
188 pub fn force_completion_on_stop(mut self, enabled: bool) -> Self {
190 self.force_completion_on_stop = enabled;
191 self
192 }
193
194 pub fn discard_on_timeout(mut self, enabled: bool) -> Self {
196 self.discard_on_timeout = enabled;
197 self
198 }
199
200 pub fn correlate_by(mut self, header: impl Into<String>) -> Self {
202 let header = header.into();
203 self.header_name = header.clone();
204 self.correlation = CorrelationStrategy::HeaderName(header);
205 self
206 }
207
208 pub fn strategy(mut self, strategy: AggregationStrategy) -> Self {
210 self.strategy = strategy;
211 self
212 }
213
214 pub fn max_buckets(mut self, max: usize) -> Self {
217 self.max_buckets = Some(max);
218 self
219 }
220
221 pub fn bucket_ttl(mut self, ttl: Duration) -> Self {
224 self.bucket_ttl = Some(ttl);
225 self
226 }
227
228 pub fn try_build(self) -> Result<AggregatorConfig, CamelError> {
229 let completion = self.completion.ok_or_else(|| {
242 CamelError::Config(
243 "AggregatorMissingCompletionBound: a completion condition \
244 (complete_when_size, complete_when, complete_on_timeout, \
245 or complete_on_size_or_timeout) is required"
246 .into(),
247 )
248 })?;
249 Ok(AggregatorConfig {
250 header_name: self.header_name,
251 completion,
252 correlation: self.correlation,
253 strategy: self.strategy,
254 max_buckets: self.max_buckets,
255 bucket_ttl: self.bucket_ttl,
256 force_completion_on_stop: self.force_completion_on_stop,
257 discard_on_timeout: self.discard_on_timeout,
258 })
259 }
260
261 pub fn build(self) -> Result<AggregatorConfig, CamelError> {
263 self.try_build()
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn test_aggregator_config_complete_when_size() {
273 let config = AggregatorConfig::correlate_by("orderId")
274 .complete_when_size(3)
275 .build()
276 .unwrap();
277 assert_eq!(config.header_name, "orderId");
278 assert!(matches!(
279 config.completion,
280 CompletionMode::Single(CompletionCondition::Size(3))
281 ));
282 assert!(matches!(config.strategy, AggregationStrategy::CollectAll));
283 }
284
285 #[test]
286 fn test_aggregator_config_complete_when_predicate() {
287 let config = AggregatorConfig::correlate_by("key")
288 .complete_when(|bucket| bucket.len() >= 2)
289 .build()
290 .unwrap();
291 assert!(matches!(
292 config.completion,
293 CompletionMode::Single(CompletionCondition::Predicate(_))
294 ));
295 }
296
297 #[test]
298 fn test_aggregator_config_custom_strategy() {
299 use std::sync::Arc;
300 let f: AggregationFn = Arc::new(|acc, _next| acc);
301 let config = AggregatorConfig::correlate_by("key")
302 .complete_when_size(1)
303 .strategy(AggregationStrategy::Custom(f))
304 .build()
305 .unwrap();
306 assert!(matches!(config.strategy, AggregationStrategy::Custom(_)));
307 }
308
309 #[test]
310 fn test_aggregator_config_missing_completion_returns_err() {
311 let result = AggregatorConfig::correlate_by("key").build();
312 let err = match result {
313 Err(e) => e,
314 Ok(_) => panic!("expected error, got Ok"),
315 };
316 assert!(
317 err.to_string().contains("completion"),
318 "error message should mention 'completion': {err}"
319 );
320 }
321
322 #[test]
323 fn test_complete_on_size_or_timeout() {
324 let config = AggregatorConfig::correlate_by("key")
325 .complete_on_size_or_timeout(3, Duration::from_secs(5))
326 .build()
327 .unwrap();
328 assert!(matches!(config.completion, CompletionMode::Any(v) if v.len() == 2));
329 }
330
331 #[test]
332 fn test_force_completion_on_stop_default() {
333 let config = AggregatorConfig::correlate_by("key")
334 .complete_when_size(1)
335 .build()
336 .unwrap();
337 assert!(!config.force_completion_on_stop);
338 assert!(!config.discard_on_timeout);
339 }
340
341 #[test]
342 fn test_builder_sets_timeout_and_flags_and_limits() {
343 let config = AggregatorConfig::correlate_by("key")
344 .complete_on_timeout(Duration::from_secs(2))
345 .max_buckets(7)
346 .bucket_ttl(Duration::from_secs(10))
347 .force_completion_on_stop(true)
348 .discard_on_timeout(true)
349 .build()
350 .unwrap();
351
352 assert!(matches!(
353 config.completion,
354 CompletionMode::Single(CompletionCondition::Timeout(d)) if d == Duration::from_secs(2)
355 ));
356 assert_eq!(config.max_buckets, Some(7));
357 assert_eq!(config.bucket_ttl, Some(Duration::from_secs(10)));
358 assert!(config.force_completion_on_stop);
359 assert!(config.discard_on_timeout);
360 }
361
362 #[test]
363 fn test_builder_correlate_by_overrides_header_and_strategy() {
364 let config = AggregatorConfig::correlate_by("original")
365 .correlate_by("override")
366 .complete_when_size(1)
367 .build()
368 .unwrap();
369
370 assert_eq!(config.header_name, "override");
371 assert!(matches!(
372 config.correlation,
373 CorrelationStrategy::HeaderName(ref h) if h == "override"
374 ));
375 }
376
377 #[test]
378 fn test_completion_reason_as_str_all_variants() {
379 assert_eq!(CompletionReason::Size.as_str(), "size");
380 assert_eq!(CompletionReason::Predicate.as_str(), "predicate");
381 assert_eq!(CompletionReason::Timeout.as_str(), "timeout");
382 assert_eq!(CompletionReason::Stop.as_str(), "stop");
383 }
384
385 #[test]
386 fn test_correlation_strategy_clone_and_debug() {
387 let strategy = CorrelationStrategy::Expression {
388 expr: "${header.orderId}".to_string(),
389 language: "simple".to_string(),
390 };
391 let cloned = strategy.clone();
392 assert!(matches!(
393 cloned,
394 CorrelationStrategy::Expression { ref expr, ref language }
395 if expr == "${header.orderId}" && language == "simple"
396 ));
397
398 let f = CorrelationStrategy::Fn(Arc::new(|_| Some("k".to_string())));
399 assert_eq!(format!("{:?}", f), "Fn(..)");
400 }
401
402 #[test]
403 fn test_complete_on_size_or_timeout_contains_both_conditions() {
404 let config = AggregatorConfig::correlate_by("k")
405 .complete_on_size_or_timeout(4, Duration::from_millis(250))
406 .build()
407 .unwrap();
408
409 match config.completion {
410 CompletionMode::Any(conditions) => {
411 assert!(matches!(conditions[0], CompletionCondition::Size(4)));
412 assert!(matches!(
413 conditions[1],
414 CompletionCondition::Timeout(d) if d == Duration::from_millis(250)
415 ));
416 }
417 _ => panic!("expected CompletionMode::Any"),
418 }
419 }
420
421 #[test]
422 fn test_correlation_strategy_fn_clone_shares_same_arc() {
423 let f: Arc<dyn Fn(&Exchange) -> Option<String> + Send + Sync> =
424 Arc::new(|_| Some("shared".to_string()));
425 let strategy = CorrelationStrategy::Fn(f.clone());
426 let cloned = strategy.clone();
427
428 match cloned {
429 CorrelationStrategy::Fn(cloned_fn) => assert!(Arc::ptr_eq(&f, &cloned_fn)),
430 _ => panic!("expected fn strategy"),
431 }
432 }
433
434 #[test]
435 fn test_builder_correlate_by_overrides_previous() {
436 let config = AggregatorConfig::correlate_by("first")
437 .correlate_by("second")
438 .complete_when_size(2)
439 .build()
440 .unwrap();
441
442 assert_eq!(config.header_name, "second");
443 assert!(
444 matches!(config.correlation, CorrelationStrategy::HeaderName(ref h) if h == "second")
445 );
446 }
447
448 #[test]
449 fn test_aggregator_try_build_missing_completion_returns_error() {
450 let result = AggregatorConfig::correlate_by("key").try_build();
451 assert!(result.is_err());
452 }
453
454 #[test]
459 fn test_default_max_buckets_is_10000() {
460 let cfg = AggregatorConfig::correlate_by("k")
461 .complete_when_size(1)
462 .build()
463 .unwrap();
464 assert_eq!(cfg.max_buckets, Some(10_000));
465 }
466
467 #[test]
470 fn test_default_bucket_ttl_is_300s() {
471 let cfg = AggregatorConfig::correlate_by("k")
472 .complete_when_size(1)
473 .build()
474 .unwrap();
475 assert_eq!(cfg.bucket_ttl, Some(Duration::from_secs(300)));
476 }
477
478 #[test]
482 fn test_explicit_max_buckets_zero_is_accepted_at_build() {
483 let cfg = AggregatorConfig::correlate_by("k")
484 .complete_when_size(1)
485 .max_buckets(0)
486 .build()
487 .unwrap();
488 assert_eq!(cfg.max_buckets, Some(0));
489 }
490
491 #[test]
495 fn test_aggregator_rejects_no_completion_bound() {
496 let err = match AggregatorConfig::correlate_by("k").try_build() {
499 Err(e) => e,
500 Ok(_) => panic!("expected error, got Ok"),
501 };
502 let msg = err.to_string();
503 assert!(
504 msg.contains("completion"),
505 "expected error mentioning 'completion', got: {msg}"
506 );
507 }
508}