1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11 CamelError,
12 aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13 body::Body,
14 exchange::Exchange,
15 message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22struct Bucket {
24 exchanges: Vec<Exchange>,
25 #[allow(dead_code)]
26 created_at: Instant,
27 last_updated: Instant,
28}
29
30impl Bucket {
31 fn new() -> Self {
32 let now = Instant::now();
33 Self {
34 exchanges: Vec::new(),
35 created_at: now,
36 last_updated: now,
37 }
38 }
39
40 fn push(&mut self, exchange: Exchange) {
41 self.exchanges.push(exchange);
42 self.last_updated = Instant::now();
43 }
44
45 fn len(&self) -> usize {
46 self.exchanges.len()
47 }
48
49 fn is_expired(&self, ttl: Duration) -> bool {
50 Instant::now().duration_since(self.last_updated) >= ttl
51 }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56 config: AggregatorConfig,
57 buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61 pub fn new(config: AggregatorConfig) -> Self {
62 Self {
63 config,
64 buckets: Arc::new(Mutex::new(HashMap::new())),
65 }
66 }
67}
68
69impl Service<Exchange> for AggregatorService {
70 type Response = Exchange;
71 type Error = CamelError;
72 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn call(&mut self, exchange: Exchange) -> Self::Future {
79 let config = self.config.clone();
80 let buckets = Arc::clone(&self.buckets);
81
82 Box::pin(async move {
83 let key_value = exchange
85 .input
86 .headers
87 .get(&config.header_name)
88 .cloned()
89 .ok_or_else(|| {
90 CamelError::ProcessorError(format!(
91 "Aggregator: missing correlation key header '{}'",
92 config.header_name
93 ))
94 })?;
95
96 let key_str = serde_json::to_string(&key_value)
98 .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100 let completed_bucket = {
102 let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104 if let Some(ttl) = config.bucket_ttl {
106 guard.retain(|_, bucket| !bucket.is_expired(ttl));
107 }
108
109 if let Some(max) = config.max_buckets
111 && !guard.contains_key(&key_str)
112 && guard.len() >= max
113 {
114 tracing::warn!(
115 max_buckets = max,
116 correlation_key = %key_str,
117 "Aggregator reached max buckets limit, rejecting new correlation key"
118 );
119 return Err(CamelError::ProcessorError(format!(
120 "Aggregator reached maximum {} buckets",
121 max
122 )));
123 }
124
125 let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126 bucket.push(exchange);
127
128 let is_complete = match &config.completion {
129 CompletionCondition::Size(n) => bucket.len() >= *n,
130 CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131 };
132
133 if is_complete {
134 guard.remove(&key_str).map(|b| b.exchanges)
135 } else {
136 None
137 }
138 }; match completed_bucket {
142 Some(exchanges) => {
143 let size = exchanges.len();
144 let mut result = aggregate(exchanges, &config.strategy)?;
145 result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146 result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147 Ok(result)
148 }
149 None => {
150 let mut pending = Exchange::new(Message {
151 headers: Default::default(),
152 body: Body::Empty,
153 });
154 pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155 Ok(pending)
156 }
157 }
158 })
159 }
160}
161
162fn aggregate(
163 exchanges: Vec<Exchange>,
164 strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166 match strategy {
167 AggregationStrategy::CollectAll => {
168 let bodies: Vec<serde_json::Value> = exchanges
169 .into_iter()
170 .map(|e| match e.input.body {
171 Body::Json(v) => v,
172 Body::Text(s) => serde_json::Value::String(s),
173 Body::Xml(s) => serde_json::Value::String(s),
174 Body::Bytes(b) => {
175 serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
176 }
177 Body::Empty => serde_json::Value::Null,
178 Body::Stream(s) => serde_json::json!({
179 "_stream": {
180 "origin": s.metadata.origin,
181 "placeholder": true,
182 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
183 }
184 }),
185 })
186 .collect();
187 Ok(Exchange::new(Message {
188 headers: Default::default(),
189 body: Body::Json(serde_json::Value::Array(bodies)),
190 }))
191 }
192 AggregationStrategy::Custom(f) => {
193 let mut iter = exchanges.into_iter();
194 let first = iter.next().ok_or_else(|| {
195 CamelError::ProcessorError("Aggregator: empty bucket".to_string())
196 })?;
197 Ok(iter.fold(first, |acc, next| f(acc, next)))
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use super::*;
205 use camel_api::{
206 aggregator::{AggregationStrategy, AggregatorConfig},
207 body::Body,
208 exchange::Exchange,
209 message::Message,
210 };
211 use tower::ServiceExt;
212
213 fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
214 let mut msg = Message {
215 headers: Default::default(),
216 body: Body::Text(body.to_string()),
217 };
218 msg.headers
219 .insert(header.to_string(), serde_json::json!(value));
220 Exchange::new(msg)
221 }
222
223 fn config_size(n: usize) -> AggregatorConfig {
224 AggregatorConfig::correlate_by("orderId")
225 .complete_when_size(n)
226 .build()
227 }
228
229 #[tokio::test]
230 async fn test_pending_exchange_not_yet_complete() {
231 let mut svc = AggregatorService::new(config_size(3));
232 let ex = make_exchange("orderId", "A", "first");
233 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
234 assert!(matches!(result.input.body, Body::Empty));
235 assert_eq!(
236 result.property(CAMEL_AGGREGATOR_PENDING),
237 Some(&serde_json::json!(true))
238 );
239 }
240
241 #[tokio::test]
242 async fn test_completes_on_size() {
243 let mut svc = AggregatorService::new(config_size(3));
244 for _ in 0..2 {
245 let ex = make_exchange("orderId", "A", "item");
246 let r = svc.ready().await.unwrap().call(ex).await.unwrap();
247 assert!(matches!(r.input.body, Body::Empty));
248 }
249 let ex = make_exchange("orderId", "A", "last");
250 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
251 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
252 assert_eq!(
253 result.property(CAMEL_AGGREGATED_SIZE),
254 Some(&serde_json::json!(3u64))
255 );
256 }
257
258 #[tokio::test]
259 async fn test_collect_all_produces_json_array() {
260 let mut svc = AggregatorService::new(config_size(2));
261 svc.ready()
262 .await
263 .unwrap()
264 .call(make_exchange("orderId", "A", "alpha"))
265 .await
266 .unwrap();
267 let result = svc
268 .ready()
269 .await
270 .unwrap()
271 .call(make_exchange("orderId", "A", "beta"))
272 .await
273 .unwrap();
274 let Body::Json(v) = &result.input.body else {
275 panic!("expected Body::Json")
276 };
277 let arr = v.as_array().unwrap();
278 assert_eq!(arr.len(), 2);
279 assert_eq!(arr[0], serde_json::json!("alpha"));
280 assert_eq!(arr[1], serde_json::json!("beta"));
281 }
282
283 #[tokio::test]
284 async fn test_two_keys_independent_buckets() {
285 let mut svc = AggregatorService::new(config_size(3));
287 svc.ready()
288 .await
289 .unwrap()
290 .call(make_exchange("orderId", "A", "a1"))
291 .await
292 .unwrap();
293 svc.ready()
294 .await
295 .unwrap()
296 .call(make_exchange("orderId", "B", "b1"))
297 .await
298 .unwrap();
299 svc.ready()
300 .await
301 .unwrap()
302 .call(make_exchange("orderId", "A", "a2"))
303 .await
304 .unwrap();
305 let ra = svc
307 .ready()
308 .await
309 .unwrap()
310 .call(make_exchange("orderId", "A", "a3"))
311 .await
312 .unwrap();
313 assert!(matches!(ra.input.body, Body::Json(_)));
315 let rb = svc
317 .ready()
318 .await
319 .unwrap()
320 .call(make_exchange("orderId", "B", "b_check"))
321 .await
322 .unwrap();
323 assert!(matches!(rb.input.body, Body::Empty));
324 }
325
326 #[tokio::test]
327 async fn test_bucket_resets_after_completion() {
328 let mut svc = AggregatorService::new(config_size(2));
329 svc.ready()
330 .await
331 .unwrap()
332 .call(make_exchange("orderId", "A", "x"))
333 .await
334 .unwrap();
335 svc.ready()
336 .await
337 .unwrap()
338 .call(make_exchange("orderId", "A", "x"))
339 .await
340 .unwrap(); let r = svc
343 .ready()
344 .await
345 .unwrap()
346 .call(make_exchange("orderId", "A", "new"))
347 .await
348 .unwrap();
349 assert!(matches!(r.input.body, Body::Empty)); }
351
352 #[tokio::test]
353 async fn test_completion_size_1_emits_immediately() {
354 let mut svc = AggregatorService::new(config_size(1));
355 let ex = make_exchange("orderId", "A", "solo");
356 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
357 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
358 }
359
360 #[tokio::test]
361 async fn test_custom_aggregation_strategy() {
362 use camel_api::aggregator::AggregationFn;
363 use std::sync::Arc;
364
365 let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
366 let combined = format!(
367 "{}+{}",
368 acc.input.body.as_text().unwrap_or(""),
369 next.input.body.as_text().unwrap_or("")
370 );
371 acc.input.body = Body::Text(combined);
372 acc
373 });
374 let config = AggregatorConfig::correlate_by("key")
375 .complete_when_size(2)
376 .strategy(AggregationStrategy::Custom(f))
377 .build();
378 let mut svc = AggregatorService::new(config);
379 svc.ready()
380 .await
381 .unwrap()
382 .call(make_exchange("key", "X", "hello"))
383 .await
384 .unwrap();
385 let result = svc
386 .ready()
387 .await
388 .unwrap()
389 .call(make_exchange("key", "X", "world"))
390 .await
391 .unwrap();
392 assert_eq!(result.input.body.as_text(), Some("hello+world"));
393 }
394
395 #[tokio::test]
396 async fn test_completion_predicate() {
397 let config = AggregatorConfig::correlate_by("key")
398 .complete_when(|bucket| {
399 bucket
400 .iter()
401 .any(|e| e.input.body.as_text() == Some("DONE"))
402 })
403 .build();
404 let mut svc = AggregatorService::new(config);
405 svc.ready()
406 .await
407 .unwrap()
408 .call(make_exchange("key", "K", "first"))
409 .await
410 .unwrap();
411 svc.ready()
412 .await
413 .unwrap()
414 .call(make_exchange("key", "K", "second"))
415 .await
416 .unwrap();
417 let result = svc
418 .ready()
419 .await
420 .unwrap()
421 .call(make_exchange("key", "K", "DONE"))
422 .await
423 .unwrap();
424 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
425 }
426
427 #[tokio::test]
428 async fn test_missing_header_returns_error() {
429 let mut svc = AggregatorService::new(config_size(2));
430 let msg = Message {
431 headers: Default::default(),
432 body: Body::Text("no key".into()),
433 };
434 let ex = Exchange::new(msg);
435 let result = svc.ready().await.unwrap().call(ex).await;
436 assert!(result.is_err());
437 assert!(matches!(
438 result.unwrap_err(),
439 camel_api::CamelError::ProcessorError(_)
440 ));
441 }
442
443 #[tokio::test]
444 async fn test_cloned_service_shares_state() {
445 let svc1 = AggregatorService::new(config_size(2));
446 let mut svc2 = svc1.clone();
447 svc1.clone()
449 .ready()
450 .await
451 .unwrap()
452 .call(make_exchange("orderId", "A", "from-svc1"))
453 .await
454 .unwrap();
455 let result = svc2
457 .ready()
458 .await
459 .unwrap()
460 .call(make_exchange("orderId", "A", "from-svc2"))
461 .await
462 .unwrap();
463 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
464 }
465
466 #[tokio::test]
467 async fn test_camel_aggregated_key_property_set() {
468 let mut svc = AggregatorService::new(config_size(1));
469 let ex = make_exchange("orderId", "ORDER-42", "body");
470 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
471 assert_eq!(
472 result.property(CAMEL_AGGREGATED_KEY),
473 Some(&serde_json::json!("ORDER-42"))
474 );
475 }
476
477 #[tokio::test]
478 async fn test_aggregator_enforces_max_buckets() {
479 let config = AggregatorConfig::correlate_by("orderId")
480 .complete_when_size(2)
481 .max_buckets(3)
482 .build();
483
484 let mut svc = AggregatorService::new(config);
485
486 for i in 0..3 {
488 let ex = make_exchange("orderId", &format!("key-{}", i), "body");
489 let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
490 }
491
492 let ex = make_exchange("orderId", "key-4", "body");
494 let result = svc.ready().await.unwrap().call(ex).await;
495
496 assert!(result.is_err(), "Should reject when max buckets reached");
497 let err = result.unwrap_err().to_string();
498 assert!(
499 err.contains("maximum"),
500 "Error message should contain 'maximum': {}",
501 err
502 );
503 }
504
505 #[tokio::test]
506 async fn test_max_buckets_allows_existing_key() {
507 let config = AggregatorConfig::correlate_by("orderId")
508 .complete_when_size(5) .max_buckets(2)
510 .build();
511
512 let mut svc = AggregatorService::new(config);
513
514 let ex1 = make_exchange("orderId", "key-A", "body1");
516 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
517 let ex2 = make_exchange("orderId", "key-B", "body2");
518 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
519
520 let ex3 = make_exchange("orderId", "key-A", "body3");
522 let result = svc.ready().await.unwrap().call(ex3).await;
523 assert!(
524 result.is_ok(),
525 "Should allow adding to existing bucket even at max limit"
526 );
527 }
528
529 #[tokio::test]
530 async fn test_bucket_ttl_eviction() {
531 let config = AggregatorConfig::correlate_by("orderId")
532 .complete_when_size(10) .bucket_ttl(Duration::from_millis(50))
534 .build();
535
536 let mut svc = AggregatorService::new(config);
537
538 let ex1 = make_exchange("orderId", "key-A", "body1");
540 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
541
542 tokio::time::sleep(Duration::from_millis(100)).await;
544
545 let ex2 = make_exchange("orderId", "key-B", "body2");
547 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
548
549 let ex3 = make_exchange("orderId", "key-A", "body3");
552 let result = svc.ready().await.unwrap().call(ex3).await;
553 assert!(result.is_ok(), "Should be able to recreate evicted bucket");
554 }
555
556 #[tokio::test]
557 async fn test_aggregate_stream_bodies_creates_valid_json() {
558 use bytes::Bytes;
559 use camel_api::{Body, StreamBody, StreamMetadata};
560 use futures::stream;
561 use tokio::sync::Mutex;
562
563 let chunks = vec![Ok(Bytes::from("test"))];
564 let stream_body = StreamBody {
565 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
566 metadata: StreamMetadata {
567 origin: Some("file:///test.txt".to_string()),
568 ..Default::default()
569 },
570 };
571
572 let ex1 = Exchange::new(Message {
573 headers: Default::default(),
574 body: Body::Stream(stream_body),
575 });
576
577 let exchanges = vec![ex1];
578 let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
579
580 let exchange = result.expect("Expected Ok result");
581 assert!(
582 matches!(exchange.input.body, Body::Json(_)),
583 "Expected Json body"
584 );
585
586 if let Body::Json(value) = exchange.input.body {
587 let json_str = serde_json::to_string(&value).unwrap();
588 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
589
590 assert!(parsed.is_array(), "Result should be an array");
591 let arr = parsed.as_array().unwrap();
592 assert!(arr[0].is_object(), "First element should be an object");
593 assert!(
594 arr[0]["_stream"].is_object(),
595 "Should contain _stream object"
596 );
597 assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
598 assert_eq!(
599 arr[0]["_stream"]["placeholder"], true,
600 "placeholder flag should be true"
601 );
602 }
603 }
604
605 #[tokio::test]
606 async fn test_aggregate_stream_bodies_with_none_origin() {
607 use bytes::Bytes;
608 use camel_api::{Body, StreamBody, StreamMetadata};
609 use futures::stream;
610 use tokio::sync::Mutex;
611
612 let chunks = vec![Ok(Bytes::from("test"))];
613 let stream_body = StreamBody {
614 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
615 metadata: StreamMetadata {
616 origin: None,
617 ..Default::default()
618 },
619 };
620
621 let ex1 = Exchange::new(Message {
622 headers: Default::default(),
623 body: Body::Stream(stream_body),
624 });
625
626 let exchanges = vec![ex1];
627 let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
628
629 let exchange = result.expect("Expected Ok result");
630 assert!(
631 matches!(exchange.input.body, Body::Json(_)),
632 "Expected Json body"
633 );
634
635 if let Body::Json(value) = exchange.input.body {
636 let json_str = serde_json::to_string(&value).unwrap();
637 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
638
639 assert!(parsed.is_array(), "Result should be an array");
640 let arr = parsed.as_array().unwrap();
641 assert!(arr[0].is_object(), "First element should be an object");
642 assert!(
643 arr[0]["_stream"].is_object(),
644 "Should contain _stream object"
645 );
646 assert_eq!(
647 arr[0]["_stream"]["origin"],
648 serde_json::Value::Null,
649 "origin should be null when None"
650 );
651 assert_eq!(
652 arr[0]["_stream"]["placeholder"], true,
653 "placeholder flag should be true"
654 );
655 }
656 }
657}