1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11 CamelError,
12 aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13 body::Body,
14 exchange::Exchange,
15 message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22struct Bucket {
24 exchanges: Vec<Exchange>,
25 #[allow(dead_code)]
26 created_at: Instant,
27 last_updated: Instant,
28}
29
30impl Bucket {
31 fn new() -> Self {
32 let now = Instant::now();
33 Self {
34 exchanges: Vec::new(),
35 created_at: now,
36 last_updated: now,
37 }
38 }
39
40 fn push(&mut self, exchange: Exchange) {
41 self.exchanges.push(exchange);
42 self.last_updated = Instant::now();
43 }
44
45 fn len(&self) -> usize {
46 self.exchanges.len()
47 }
48
49 fn is_expired(&self, ttl: Duration) -> bool {
50 Instant::now().duration_since(self.last_updated) >= ttl
51 }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56 config: AggregatorConfig,
57 buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61 pub fn new(config: AggregatorConfig) -> Self {
62 Self {
63 config,
64 buckets: Arc::new(Mutex::new(HashMap::new())),
65 }
66 }
67}
68
69impl Service<Exchange> for AggregatorService {
70 type Response = Exchange;
71 type Error = CamelError;
72 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75 Poll::Ready(Ok(()))
76 }
77
78 fn call(&mut self, exchange: Exchange) -> Self::Future {
79 let config = self.config.clone();
80 let buckets = Arc::clone(&self.buckets);
81
82 Box::pin(async move {
83 let key_value = exchange
85 .input
86 .headers
87 .get(&config.header_name)
88 .cloned()
89 .ok_or_else(|| {
90 CamelError::ProcessorError(format!(
91 "Aggregator: missing correlation key header '{}'",
92 config.header_name
93 ))
94 })?;
95
96 let key_str = serde_json::to_string(&key_value)
98 .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100 let completed_bucket = {
102 let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104 if let Some(ttl) = config.bucket_ttl {
106 guard.retain(|_, bucket| !bucket.is_expired(ttl));
107 }
108
109 if let Some(max) = config.max_buckets
111 && !guard.contains_key(&key_str)
112 && guard.len() >= max
113 {
114 tracing::warn!(
115 max_buckets = max,
116 correlation_key = %key_str,
117 "Aggregator reached max buckets limit, rejecting new correlation key"
118 );
119 return Err(CamelError::ProcessorError(format!(
120 "Aggregator reached maximum {} buckets",
121 max
122 )));
123 }
124
125 let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126 bucket.push(exchange);
127
128 let is_complete = match &config.completion {
129 CompletionCondition::Size(n) => bucket.len() >= *n,
130 CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131 };
132
133 if is_complete {
134 guard.remove(&key_str).map(|b| b.exchanges)
135 } else {
136 None
137 }
138 }; match completed_bucket {
142 Some(exchanges) => {
143 let size = exchanges.len();
144 let mut result = aggregate(exchanges, &config.strategy)?;
145 result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146 result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147 Ok(result)
148 }
149 None => {
150 let mut pending = Exchange::new(Message {
151 headers: Default::default(),
152 body: Body::Empty,
153 });
154 pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155 Ok(pending)
156 }
157 }
158 })
159 }
160}
161
162fn aggregate(
163 exchanges: Vec<Exchange>,
164 strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166 match strategy {
167 AggregationStrategy::CollectAll => {
168 let bodies: Vec<serde_json::Value> = exchanges
169 .into_iter()
170 .map(|e| match e.input.body {
171 Body::Json(v) => v,
172 Body::Text(s) => serde_json::Value::String(s),
173 Body::Bytes(b) => {
174 serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
175 }
176 Body::Empty => serde_json::Value::Null,
177 Body::Stream(s) => serde_json::json!({
178 "_stream": {
179 "origin": s.metadata.origin,
180 "placeholder": true,
181 "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
182 }
183 }),
184 })
185 .collect();
186 Ok(Exchange::new(Message {
187 headers: Default::default(),
188 body: Body::Json(serde_json::Value::Array(bodies)),
189 }))
190 }
191 AggregationStrategy::Custom(f) => {
192 let mut iter = exchanges.into_iter();
193 let first = iter.next().ok_or_else(|| {
194 CamelError::ProcessorError("Aggregator: empty bucket".to_string())
195 })?;
196 Ok(iter.fold(first, |acc, next| f(acc, next)))
197 }
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use camel_api::{
205 aggregator::{AggregationStrategy, AggregatorConfig},
206 body::Body,
207 exchange::Exchange,
208 message::Message,
209 };
210 use tower::ServiceExt;
211
212 fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
213 let mut msg = Message {
214 headers: Default::default(),
215 body: Body::Text(body.to_string()),
216 };
217 msg.headers
218 .insert(header.to_string(), serde_json::json!(value));
219 Exchange::new(msg)
220 }
221
222 fn config_size(n: usize) -> AggregatorConfig {
223 AggregatorConfig::correlate_by("orderId")
224 .complete_when_size(n)
225 .build()
226 }
227
228 #[tokio::test]
229 async fn test_pending_exchange_not_yet_complete() {
230 let mut svc = AggregatorService::new(config_size(3));
231 let ex = make_exchange("orderId", "A", "first");
232 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
233 assert!(matches!(result.input.body, Body::Empty));
234 assert_eq!(
235 result.property(CAMEL_AGGREGATOR_PENDING),
236 Some(&serde_json::json!(true))
237 );
238 }
239
240 #[tokio::test]
241 async fn test_completes_on_size() {
242 let mut svc = AggregatorService::new(config_size(3));
243 for _ in 0..2 {
244 let ex = make_exchange("orderId", "A", "item");
245 let r = svc.ready().await.unwrap().call(ex).await.unwrap();
246 assert!(matches!(r.input.body, Body::Empty));
247 }
248 let ex = make_exchange("orderId", "A", "last");
249 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
250 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
251 assert_eq!(
252 result.property(CAMEL_AGGREGATED_SIZE),
253 Some(&serde_json::json!(3u64))
254 );
255 }
256
257 #[tokio::test]
258 async fn test_collect_all_produces_json_array() {
259 let mut svc = AggregatorService::new(config_size(2));
260 svc.ready()
261 .await
262 .unwrap()
263 .call(make_exchange("orderId", "A", "alpha"))
264 .await
265 .unwrap();
266 let result = svc
267 .ready()
268 .await
269 .unwrap()
270 .call(make_exchange("orderId", "A", "beta"))
271 .await
272 .unwrap();
273 let Body::Json(v) = &result.input.body else {
274 panic!("expected Body::Json")
275 };
276 let arr = v.as_array().unwrap();
277 assert_eq!(arr.len(), 2);
278 assert_eq!(arr[0], serde_json::json!("alpha"));
279 assert_eq!(arr[1], serde_json::json!("beta"));
280 }
281
282 #[tokio::test]
283 async fn test_two_keys_independent_buckets() {
284 let mut svc = AggregatorService::new(config_size(3));
286 svc.ready()
287 .await
288 .unwrap()
289 .call(make_exchange("orderId", "A", "a1"))
290 .await
291 .unwrap();
292 svc.ready()
293 .await
294 .unwrap()
295 .call(make_exchange("orderId", "B", "b1"))
296 .await
297 .unwrap();
298 svc.ready()
299 .await
300 .unwrap()
301 .call(make_exchange("orderId", "A", "a2"))
302 .await
303 .unwrap();
304 let ra = svc
306 .ready()
307 .await
308 .unwrap()
309 .call(make_exchange("orderId", "A", "a3"))
310 .await
311 .unwrap();
312 assert!(matches!(ra.input.body, Body::Json(_)));
314 let rb = svc
316 .ready()
317 .await
318 .unwrap()
319 .call(make_exchange("orderId", "B", "b_check"))
320 .await
321 .unwrap();
322 assert!(matches!(rb.input.body, Body::Empty));
323 }
324
325 #[tokio::test]
326 async fn test_bucket_resets_after_completion() {
327 let mut svc = AggregatorService::new(config_size(2));
328 svc.ready()
329 .await
330 .unwrap()
331 .call(make_exchange("orderId", "A", "x"))
332 .await
333 .unwrap();
334 svc.ready()
335 .await
336 .unwrap()
337 .call(make_exchange("orderId", "A", "x"))
338 .await
339 .unwrap(); let r = svc
342 .ready()
343 .await
344 .unwrap()
345 .call(make_exchange("orderId", "A", "new"))
346 .await
347 .unwrap();
348 assert!(matches!(r.input.body, Body::Empty)); }
350
351 #[tokio::test]
352 async fn test_completion_size_1_emits_immediately() {
353 let mut svc = AggregatorService::new(config_size(1));
354 let ex = make_exchange("orderId", "A", "solo");
355 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
356 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
357 }
358
359 #[tokio::test]
360 async fn test_custom_aggregation_strategy() {
361 use camel_api::aggregator::AggregationFn;
362 use std::sync::Arc;
363
364 let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
365 let combined = format!(
366 "{}+{}",
367 acc.input.body.as_text().unwrap_or(""),
368 next.input.body.as_text().unwrap_or("")
369 );
370 acc.input.body = Body::Text(combined);
371 acc
372 });
373 let config = AggregatorConfig::correlate_by("key")
374 .complete_when_size(2)
375 .strategy(AggregationStrategy::Custom(f))
376 .build();
377 let mut svc = AggregatorService::new(config);
378 svc.ready()
379 .await
380 .unwrap()
381 .call(make_exchange("key", "X", "hello"))
382 .await
383 .unwrap();
384 let result = svc
385 .ready()
386 .await
387 .unwrap()
388 .call(make_exchange("key", "X", "world"))
389 .await
390 .unwrap();
391 assert_eq!(result.input.body.as_text(), Some("hello+world"));
392 }
393
394 #[tokio::test]
395 async fn test_completion_predicate() {
396 let config = AggregatorConfig::correlate_by("key")
397 .complete_when(|bucket| {
398 bucket
399 .iter()
400 .any(|e| e.input.body.as_text() == Some("DONE"))
401 })
402 .build();
403 let mut svc = AggregatorService::new(config);
404 svc.ready()
405 .await
406 .unwrap()
407 .call(make_exchange("key", "K", "first"))
408 .await
409 .unwrap();
410 svc.ready()
411 .await
412 .unwrap()
413 .call(make_exchange("key", "K", "second"))
414 .await
415 .unwrap();
416 let result = svc
417 .ready()
418 .await
419 .unwrap()
420 .call(make_exchange("key", "K", "DONE"))
421 .await
422 .unwrap();
423 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
424 }
425
426 #[tokio::test]
427 async fn test_missing_header_returns_error() {
428 let mut svc = AggregatorService::new(config_size(2));
429 let msg = Message {
430 headers: Default::default(),
431 body: Body::Text("no key".into()),
432 };
433 let ex = Exchange::new(msg);
434 let result = svc.ready().await.unwrap().call(ex).await;
435 assert!(result.is_err());
436 assert!(matches!(
437 result.unwrap_err(),
438 camel_api::CamelError::ProcessorError(_)
439 ));
440 }
441
442 #[tokio::test]
443 async fn test_cloned_service_shares_state() {
444 let svc1 = AggregatorService::new(config_size(2));
445 let mut svc2 = svc1.clone();
446 svc1.clone()
448 .ready()
449 .await
450 .unwrap()
451 .call(make_exchange("orderId", "A", "from-svc1"))
452 .await
453 .unwrap();
454 let result = svc2
456 .ready()
457 .await
458 .unwrap()
459 .call(make_exchange("orderId", "A", "from-svc2"))
460 .await
461 .unwrap();
462 assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
463 }
464
465 #[tokio::test]
466 async fn test_camel_aggregated_key_property_set() {
467 let mut svc = AggregatorService::new(config_size(1));
468 let ex = make_exchange("orderId", "ORDER-42", "body");
469 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
470 assert_eq!(
471 result.property(CAMEL_AGGREGATED_KEY),
472 Some(&serde_json::json!("ORDER-42"))
473 );
474 }
475
476 #[tokio::test]
477 async fn test_aggregator_enforces_max_buckets() {
478 let config = AggregatorConfig::correlate_by("orderId")
479 .complete_when_size(2)
480 .max_buckets(3)
481 .build();
482
483 let mut svc = AggregatorService::new(config);
484
485 for i in 0..3 {
487 let ex = make_exchange("orderId", &format!("key-{}", i), "body");
488 let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
489 }
490
491 let ex = make_exchange("orderId", "key-4", "body");
493 let result = svc.ready().await.unwrap().call(ex).await;
494
495 assert!(result.is_err(), "Should reject when max buckets reached");
496 let err = result.unwrap_err().to_string();
497 assert!(
498 err.contains("maximum"),
499 "Error message should contain 'maximum': {}",
500 err
501 );
502 }
503
504 #[tokio::test]
505 async fn test_max_buckets_allows_existing_key() {
506 let config = AggregatorConfig::correlate_by("orderId")
507 .complete_when_size(5) .max_buckets(2)
509 .build();
510
511 let mut svc = AggregatorService::new(config);
512
513 let ex1 = make_exchange("orderId", "key-A", "body1");
515 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
516 let ex2 = make_exchange("orderId", "key-B", "body2");
517 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
518
519 let ex3 = make_exchange("orderId", "key-A", "body3");
521 let result = svc.ready().await.unwrap().call(ex3).await;
522 assert!(
523 result.is_ok(),
524 "Should allow adding to existing bucket even at max limit"
525 );
526 }
527
528 #[tokio::test]
529 async fn test_bucket_ttl_eviction() {
530 let config = AggregatorConfig::correlate_by("orderId")
531 .complete_when_size(10) .bucket_ttl(Duration::from_millis(50))
533 .build();
534
535 let mut svc = AggregatorService::new(config);
536
537 let ex1 = make_exchange("orderId", "key-A", "body1");
539 let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
540
541 tokio::time::sleep(Duration::from_millis(100)).await;
543
544 let ex2 = make_exchange("orderId", "key-B", "body2");
546 let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
547
548 let ex3 = make_exchange("orderId", "key-A", "body3");
551 let result = svc.ready().await.unwrap().call(ex3).await;
552 assert!(result.is_ok(), "Should be able to recreate evicted bucket");
553 }
554
555 #[tokio::test]
556 async fn test_aggregate_stream_bodies_creates_valid_json() {
557 use bytes::Bytes;
558 use camel_api::{Body, StreamBody, StreamMetadata};
559 use futures::stream;
560 use tokio::sync::Mutex;
561
562 let chunks = vec![Ok(Bytes::from("test"))];
563 let stream_body = StreamBody {
564 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
565 metadata: StreamMetadata {
566 origin: Some("file:///test.txt".to_string()),
567 ..Default::default()
568 },
569 };
570
571 let ex1 = Exchange::new(Message {
572 headers: Default::default(),
573 body: Body::Stream(stream_body),
574 });
575
576 let exchanges = vec![ex1];
577 let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
578
579 let exchange = result.expect("Expected Ok result");
580 assert!(
581 matches!(exchange.input.body, Body::Json(_)),
582 "Expected Json body"
583 );
584
585 if let Body::Json(value) = exchange.input.body {
586 let json_str = serde_json::to_string(&value).unwrap();
587 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
588
589 assert!(parsed.is_array(), "Result should be an array");
590 let arr = parsed.as_array().unwrap();
591 assert!(arr[0].is_object(), "First element should be an object");
592 assert!(
593 arr[0]["_stream"].is_object(),
594 "Should contain _stream object"
595 );
596 assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
597 assert_eq!(
598 arr[0]["_stream"]["placeholder"], true,
599 "placeholder flag should be true"
600 );
601 }
602 }
603
604 #[tokio::test]
605 async fn test_aggregate_stream_bodies_with_none_origin() {
606 use bytes::Bytes;
607 use camel_api::{Body, StreamBody, StreamMetadata};
608 use futures::stream;
609 use tokio::sync::Mutex;
610
611 let chunks = vec![Ok(Bytes::from("test"))];
612 let stream_body = StreamBody {
613 stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
614 metadata: StreamMetadata {
615 origin: None,
616 ..Default::default()
617 },
618 };
619
620 let ex1 = Exchange::new(Message {
621 headers: Default::default(),
622 body: Body::Stream(stream_body),
623 });
624
625 let exchanges = vec![ex1];
626 let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
627
628 let exchange = result.expect("Expected Ok result");
629 assert!(
630 matches!(exchange.input.body, Body::Json(_)),
631 "Expected Json body"
632 );
633
634 if let Body::Json(value) = exchange.input.body {
635 let json_str = serde_json::to_string(&value).unwrap();
636 let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
637
638 assert!(parsed.is_array(), "Result should be an array");
639 let arr = parsed.as_array().unwrap();
640 assert!(arr[0].is_object(), "First element should be an object");
641 assert!(
642 arr[0]["_stream"].is_object(),
643 "Should contain _stream object"
644 );
645 assert_eq!(
646 arr[0]["_stream"]["origin"],
647 serde_json::Value::Null,
648 "origin should be null when None"
649 );
650 assert_eq!(
651 arr[0]["_stream"]["placeholder"], true,
652 "placeholder flag should be true"
653 );
654 }
655 }
656}