1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11 CamelError,
12 aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13 body::Body,
14 exchange::Exchange,
15 message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22struct Bucket {
24 exchanges: Vec<Exchange>,
25 #[allow(dead_code)]
26 created_at: Instant,
27 last_updated: Instant,
28}
29
30impl Bucket {
31 fn new() -> Self {
32 let now = Instant::now();
33 Self {
34 exchanges: Vec::new(),
35 created_at: now,
36 last_updated: now,
37 }
38 }
39
40 fn push(&mut self, exchange: Exchange) {
41 self.exchanges.push(exchange);
42 self.last_updated = Instant::now();
43 }
44
45 fn len(&self) -> usize {
46 self.exchanges.len()
47 }
48
49 fn is_expired(&self, ttl: Duration) -> bool {
50 Instant::now().duration_since(self.last_updated) >= ttl
51 }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56 config: AggregatorConfig,
57 buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61 pub fn new(config: AggregatorConfig) -> Self {
62 Self {
63 config,
64 buckets: Arc::new(Mutex::new(HashMap::new())),
65 }
66 }
67}
68
69impl Service<Exchange> for AggregatorService {
70 type Response = Exchange;
71 type Error = CamelError;
72 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn call(&mut self, exchange: Exchange) -> Self::Future {
79 let config = self.config.clone();
80 let buckets = Arc::clone(&self.buckets);
81
82 Box::pin(async move {
83 let key_value = exchange
85 .input
86 .headers
87 .get(&config.header_name)
88 .cloned()
89 .ok_or_else(|| {
90 CamelError::ProcessorError(format!(
91 "Aggregator: missing correlation key header '{}'",
92 config.header_name
93 ))
94 })?;
95
96 let key_str = serde_json::to_string(&key_value)
98 .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100 let completed_bucket = {
102 let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104 if let Some(ttl) = config.bucket_ttl {
106 guard.retain(|_, bucket| !bucket.is_expired(ttl));
107 }
108
109 if let Some(max) = config.max_buckets
111 && !guard.contains_key(&key_str)
112 && guard.len() >= max
113 {
114 tracing::warn!(
115 max_buckets = max,
116 correlation_key = %key_str,
117 "Aggregator reached max buckets limit, rejecting new correlation key"
118 );
119 return Err(CamelError::ProcessorError(format!(
120 "Aggregator reached maximum {} buckets",
121 max
122 )));
123 }
124
125 let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126 bucket.push(exchange);
127
128 let is_complete = match &config.completion {
129 CompletionCondition::Size(n) => bucket.len() >= *n,
130 CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131 };
132
133 if is_complete {
134 guard.remove(&key_str).map(|b| b.exchanges)
135 } else {
136 None
137 }
138 }; match completed_bucket {
142 Some(exchanges) => {
143 let size = exchanges.len();
144 let mut result = aggregate(exchanges, &config.strategy)?;
145 result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146 result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147 Ok(result)
148 }
149 None => {
150 let mut pending = Exchange::new(Message {
151 headers: Default::default(),
152 body: Body::Empty,
153 });
154 pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155 Ok(pending)
156 }
157 }
158 })
159 }
160}
161
162fn aggregate(
163 exchanges: Vec<Exchange>,
164 strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166 match strategy {
167 AggregationStrategy::CollectAll => {
168 let bodies: Vec<serde_json::Value> = exchanges
169 .into_iter()
170 .map(|e| match e.input.body {
171 Body::Json(v) => v,
172 Body::Text(s) => serde_json::Value::String(s),
173 Body::Bytes(b) => {
174 serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
175 }
176 Body::Empty => serde_json::Value::Null,
177 })
178 .collect();
179 Ok(Exchange::new(Message {
180 headers: Default::default(),
181 body: Body::Json(serde_json::Value::Array(bodies)),
182 }))
183 }
184 AggregationStrategy::Custom(f) => {
185 let mut iter = exchanges.into_iter();
186 let first = iter.next().ok_or_else(|| {
187 CamelError::ProcessorError("Aggregator: empty bucket".to_string())
188 })?;
189 Ok(iter.fold(first, |acc, next| f(acc, next)))
190 }
191 }
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use camel_api::{
198 aggregator::{AggregationStrategy, AggregatorConfig},
199 body::Body,
200 exchange::Exchange,
201 message::Message,
202 };
203 use tower::ServiceExt;
204
205 fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
206 let mut msg = Message {
207 headers: Default::default(),
208 body: Body::Text(body.to_string()),
209 };
210 msg.headers
211 .insert(header.to_string(), serde_json::json!(value));
212 Exchange::new(msg)
213 }
214
215 fn config_size(n: usize) -> AggregatorConfig {
216 AggregatorConfig::correlate_by("orderId")
217 .complete_when_size(n)
218 .build()
219 }
220
221 #[tokio::test]
222 async fn test_pending_exchange_not_yet_complete() {
223 let mut svc = AggregatorService::new(config_size(3));
224 let ex = make_exchange("orderId", "A", "first");
225 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
226 assert!(matches!(result.input.body, Body::Empty));
227 assert_eq!(
228 result.property(CAMEL_AGGREGATOR_PENDING),
229 Some(&serde_json::json!(true))
230 );
231 }
232
233 #[tokio::test]
234 async fn test_completes_on_size() {
235 let mut svc = AggregatorService::new(config_size(3));
236 for _ in 0..2 {
237 let ex = make_exchange("orderId", "A", "item");
238 let r = svc.ready().await.unwrap().call(ex).await.unwrap();
239 assert!(matches!(r.input.body, Body::Empty));
240 }
241 let ex = make_exchange("orderId", "A", "last");
242 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
243 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
244 assert_eq!(
245 result.property(CAMEL_AGGREGATED_SIZE),
246 Some(&serde_json::json!(3u64))
247 );
248 }
249
250 #[tokio::test]
251 async fn test_collect_all_produces_json_array() {
252 let mut svc = AggregatorService::new(config_size(2));
253 svc.ready()
254 .await
255 .unwrap()
256 .call(make_exchange("orderId", "A", "alpha"))
257 .await
258 .unwrap();
259 let result = svc
260 .ready()
261 .await
262 .unwrap()
263 .call(make_exchange("orderId", "A", "beta"))
264 .await
265 .unwrap();
266 let Body::Json(v) = &result.input.body else {
267 panic!("expected Body::Json")
268 };
269 let arr = v.as_array().unwrap();
270 assert_eq!(arr.len(), 2);
271 assert_eq!(arr[0], serde_json::json!("alpha"));
272 assert_eq!(arr[1], serde_json::json!("beta"));
273 }
274
275 #[tokio::test]
276 async fn test_two_keys_independent_buckets() {
277 let mut svc = AggregatorService::new(config_size(3));
279 svc.ready()
280 .await
281 .unwrap()
282 .call(make_exchange("orderId", "A", "a1"))
283 .await
284 .unwrap();
285 svc.ready()
286 .await
287 .unwrap()
288 .call(make_exchange("orderId", "B", "b1"))
289 .await
290 .unwrap();
291 svc.ready()
292 .await
293 .unwrap()
294 .call(make_exchange("orderId", "A", "a2"))
295 .await
296 .unwrap();
297 let ra = svc
299 .ready()
300 .await
301 .unwrap()
302 .call(make_exchange("orderId", "A", "a3"))
303 .await
304 .unwrap();
305 assert!(matches!(ra.input.body, Body::Json(_)));
307 let rb = svc
309 .ready()
310 .await
311 .unwrap()
312 .call(make_exchange("orderId", "B", "b_check"))
313 .await
314 .unwrap();
315 assert!(matches!(rb.input.body, Body::Empty));
316 }
317
318 #[tokio::test]
319 async fn test_bucket_resets_after_completion() {
320 let mut svc = AggregatorService::new(config_size(2));
321 svc.ready()
322 .await
323 .unwrap()
324 .call(make_exchange("orderId", "A", "x"))
325 .await
326 .unwrap();
327 svc.ready()
328 .await
329 .unwrap()
330 .call(make_exchange("orderId", "A", "x"))
331 .await
332 .unwrap(); let r = svc
335 .ready()
336 .await
337 .unwrap()
338 .call(make_exchange("orderId", "A", "new"))
339 .await
340 .unwrap();
341 assert!(matches!(r.input.body, Body::Empty)); }
343
344 #[tokio::test]
345 async fn test_completion_size_1_emits_immediately() {
346 let mut svc = AggregatorService::new(config_size(1));
347 let ex = make_exchange("orderId", "A", "solo");
348 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
349 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
350 }
351
352 #[tokio::test]
353 async fn test_custom_aggregation_strategy() {
354 use camel_api::aggregator::AggregationFn;
355 use std::sync::Arc;
356
357 let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
358 let combined = format!(
359 "{}+{}",
360 acc.input.body.as_text().unwrap_or(""),
361 next.input.body.as_text().unwrap_or("")
362 );
363 acc.input.body = Body::Text(combined);
364 acc
365 });
366 let config = AggregatorConfig::correlate_by("key")
367 .complete_when_size(2)
368 .strategy(AggregationStrategy::Custom(f))
369 .build();
370 let mut svc = AggregatorService::new(config);
371 svc.ready()
372 .await
373 .unwrap()
374 .call(make_exchange("key", "X", "hello"))
375 .await
376 .unwrap();
377 let result = svc
378 .ready()
379 .await
380 .unwrap()
381 .call(make_exchange("key", "X", "world"))
382 .await
383 .unwrap();
384 assert_eq!(result.input.body.as_text(), Some("hello+world"));
385 }
386
387 #[tokio::test]
388 async fn test_completion_predicate() {
389 let config = AggregatorConfig::correlate_by("key")
390 .complete_when(|bucket| {
391 bucket
392 .iter()
393 .any(|e| e.input.body.as_text() == Some("DONE"))
394 })
395 .build();
396 let mut svc = AggregatorService::new(config);
397 svc.ready()
398 .await
399 .unwrap()
400 .call(make_exchange("key", "K", "first"))
401 .await
402 .unwrap();
403 svc.ready()
404 .await
405 .unwrap()
406 .call(make_exchange("key", "K", "second"))
407 .await
408 .unwrap();
409 let result = svc
410 .ready()
411 .await
412 .unwrap()
413 .call(make_exchange("key", "K", "DONE"))
414 .await
415 .unwrap();
416 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
417 }
418
419 #[tokio::test]
420 async fn test_missing_header_returns_error() {
421 let mut svc = AggregatorService::new(config_size(2));
422 let msg = Message {
423 headers: Default::default(),
424 body: Body::Text("no key".into()),
425 };
426 let ex = Exchange::new(msg);
427 let result = svc.ready().await.unwrap().call(ex).await;
428 assert!(result.is_err());
429 assert!(matches!(
430 result.unwrap_err(),
431 camel_api::CamelError::ProcessorError(_)
432 ));
433 }
434
435 #[tokio::test]
436 async fn test_cloned_service_shares_state() {
437 let svc1 = AggregatorService::new(config_size(2));
438 let mut svc2 = svc1.clone();
439 svc1.clone()
441 .ready()
442 .await
443 .unwrap()
444 .call(make_exchange("orderId", "A", "from-svc1"))
445 .await
446 .unwrap();
447 let result = svc2
449 .ready()
450 .await
451 .unwrap()
452 .call(make_exchange("orderId", "A", "from-svc2"))
453 .await
454 .unwrap();
455 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
456 }
457
458 #[tokio::test]
459 async fn test_camel_aggregated_key_property_set() {
460 let mut svc = AggregatorService::new(config_size(1));
461 let ex = make_exchange("orderId", "ORDER-42", "body");
462 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
463 assert_eq!(
464 result.property(CAMEL_AGGREGATED_KEY),
465 Some(&serde_json::json!("ORDER-42"))
466 );
467 }
468
469 #[tokio::test]
470 async fn test_aggregator_enforces_max_buckets() {
471 let config = AggregatorConfig::correlate_by("orderId")
472 .complete_when_size(2)
473 .max_buckets(3)
474 .build();
475
476 let mut svc = AggregatorService::new(config);
477
478 for i in 0..3 {
480 let ex = make_exchange("orderId", &format!("key-{}", i), "body");
481 let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
482 }
483
484 let ex = make_exchange("orderId", "key-4", "body");
486 let result = svc.ready().await.unwrap().call(ex).await;
487
488 assert!(result.is_err(), "Should reject when max buckets reached");
489 let err = result.unwrap_err().to_string();
490 assert!(
491 err.contains("maximum"),
492 "Error message should contain 'maximum': {}",
493 err
494 );
495 }
496
497 #[tokio::test]
498 async fn test_max_buckets_allows_existing_key() {
499 let config = AggregatorConfig::correlate_by("orderId")
500 .complete_when_size(5) .max_buckets(2)
502 .build();
503
504 let mut svc = AggregatorService::new(config);
505
506 let ex1 = make_exchange("orderId", "key-A", "body1");
508 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
509 let ex2 = make_exchange("orderId", "key-B", "body2");
510 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
511
512 let ex3 = make_exchange("orderId", "key-A", "body3");
514 let result = svc.ready().await.unwrap().call(ex3).await;
515 assert!(
516 result.is_ok(),
517 "Should allow adding to existing bucket even at max limit"
518 );
519 }
520
521 #[tokio::test]
522 async fn test_bucket_ttl_eviction() {
523 let config = AggregatorConfig::correlate_by("orderId")
524 .complete_when_size(10) .bucket_ttl(Duration::from_millis(50))
526 .build();
527
528 let mut svc = AggregatorService::new(config);
529
530 let ex1 = make_exchange("orderId", "key-A", "body1");
532 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
533
534 tokio::time::sleep(Duration::from_millis(100)).await;
536
537 let ex2 = make_exchange("orderId", "key-B", "body2");
539 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
540
541 let ex3 = make_exchange("orderId", "key-A", "body3");
544 let result = svc.ready().await.unwrap().call(ex3).await;
545 assert!(result.is_ok(), "Should be able to recreate evicted bucket");
546 }
547}