Skip to main content

camel_processor/
aggregator.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11    CamelError,
12    aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13    body::Body,
14    exchange::Exchange,
15    message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22/// Internal bucket structure with timestamp tracking for TTL eviction.
23struct Bucket {
24    exchanges: Vec<Exchange>,
25    #[allow(dead_code)]
26    created_at: Instant,
27    last_updated: Instant,
28}
29
30impl Bucket {
31    fn new() -> Self {
32        let now = Instant::now();
33        Self {
34            exchanges: Vec::new(),
35            created_at: now,
36            last_updated: now,
37        }
38    }
39
40    fn push(&mut self, exchange: Exchange) {
41        self.exchanges.push(exchange);
42        self.last_updated = Instant::now();
43    }
44
45    fn len(&self) -> usize {
46        self.exchanges.len()
47    }
48
49    fn is_expired(&self, ttl: Duration) -> bool {
50        Instant::now().duration_since(self.last_updated) >= ttl
51    }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56    config: AggregatorConfig,
57    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61    pub fn new(config: AggregatorConfig) -> Self {
62        Self {
63            config,
64            buckets: Arc::new(Mutex::new(HashMap::new())),
65        }
66    }
67}
68
69impl Service<Exchange> for AggregatorService {
70    type Response = Exchange;
71    type Error = CamelError;
72    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn call(&mut self, exchange: Exchange) -> Self::Future {
79        let config = self.config.clone();
80        let buckets = Arc::clone(&self.buckets);
81
82        Box::pin(async move {
83            // 1. Extract correlation key value from header
84            let key_value = exchange
85                .input
86                .headers
87                .get(&config.header_name)
88                .cloned()
89                .ok_or_else(|| {
90                    CamelError::ProcessorError(format!(
91                        "Aggregator: missing correlation key header '{}'",
92                        config.header_name
93                    ))
94                })?;
95
96            // Serialize to String for use as HashMap key
97            let key_str = serde_json::to_string(&key_value)
98                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100            // 2. Insert into bucket and check completion (lock scope)
101            let completed_bucket = {
102                let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104                // Evict expired buckets if TTL is configured
105                if let Some(ttl) = config.bucket_ttl {
106                    guard.retain(|_, bucket| !bucket.is_expired(ttl));
107                }
108
109                // Enforce max buckets limit - reject new correlation keys if at limit
110                if let Some(max) = config.max_buckets
111                    && !guard.contains_key(&key_str)
112                    && guard.len() >= max
113                {
114                    tracing::warn!(
115                        max_buckets = max,
116                        correlation_key = %key_str,
117                        "Aggregator reached max buckets limit, rejecting new correlation key"
118                    );
119                    return Err(CamelError::ProcessorError(format!(
120                        "Aggregator reached maximum {} buckets",
121                        max
122                    )));
123                }
124
125                let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126                bucket.push(exchange);
127
128                let is_complete = match &config.completion {
129                    CompletionCondition::Size(n) => bucket.len() >= *n,
130                    CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131                };
132
133                if is_complete {
134                    guard.remove(&key_str).map(|b| b.exchanges)
135                } else {
136                    None
137                }
138            }; // Mutex released here
139
140            // 3. Emit aggregated exchange or return pending placeholder
141            match completed_bucket {
142                Some(exchanges) => {
143                    let size = exchanges.len();
144                    let mut result = aggregate(exchanges, &config.strategy)?;
145                    result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146                    result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147                    Ok(result)
148                }
149                None => {
150                    let mut pending = Exchange::new(Message {
151                        headers: Default::default(),
152                        body: Body::Empty,
153                    });
154                    pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155                    Ok(pending)
156                }
157            }
158        })
159    }
160}
161
162fn aggregate(
163    exchanges: Vec<Exchange>,
164    strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166    match strategy {
167        AggregationStrategy::CollectAll => {
168            let bodies: Vec<serde_json::Value> = exchanges
169                .into_iter()
170                .map(|e| match e.input.body {
171                    Body::Json(v) => v,
172                    Body::Text(s) => serde_json::Value::String(s),
173                    Body::Xml(s) => serde_json::Value::String(s),
174                    Body::Bytes(b) => {
175                        serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
176                    }
177                    Body::Empty => serde_json::Value::Null,
178                    Body::Stream(s) => serde_json::json!({
179                        "_stream": {
180                            "origin": s.metadata.origin,
181                            "placeholder": true,
182                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
183                        }
184                    }),
185                })
186                .collect();
187            Ok(Exchange::new(Message {
188                headers: Default::default(),
189                body: Body::Json(serde_json::Value::Array(bodies)),
190            }))
191        }
192        AggregationStrategy::Custom(f) => {
193            let mut iter = exchanges.into_iter();
194            let first = iter.next().ok_or_else(|| {
195                CamelError::ProcessorError("Aggregator: empty bucket".to_string())
196            })?;
197            Ok(iter.fold(first, |acc, next| f(acc, next)))
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use super::*;
205    use camel_api::{
206        aggregator::{AggregationStrategy, AggregatorConfig},
207        body::Body,
208        exchange::Exchange,
209        message::Message,
210    };
211    use tower::ServiceExt;
212
213    fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
214        let mut msg = Message {
215            headers: Default::default(),
216            body: Body::Text(body.to_string()),
217        };
218        msg.headers
219            .insert(header.to_string(), serde_json::json!(value));
220        Exchange::new(msg)
221    }
222
223    fn config_size(n: usize) -> AggregatorConfig {
224        AggregatorConfig::correlate_by("orderId")
225            .complete_when_size(n)
226            .build()
227    }
228
229    #[tokio::test]
230    async fn test_pending_exchange_not_yet_complete() {
231        let mut svc = AggregatorService::new(config_size(3));
232        let ex = make_exchange("orderId", "A", "first");
233        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
234        assert!(matches!(result.input.body, Body::Empty));
235        assert_eq!(
236            result.property(CAMEL_AGGREGATOR_PENDING),
237            Some(&serde_json::json!(true))
238        );
239    }
240
241    #[tokio::test]
242    async fn test_completes_on_size() {
243        let mut svc = AggregatorService::new(config_size(3));
244        for _ in 0..2 {
245            let ex = make_exchange("orderId", "A", "item");
246            let r = svc.ready().await.unwrap().call(ex).await.unwrap();
247            assert!(matches!(r.input.body, Body::Empty));
248        }
249        let ex = make_exchange("orderId", "A", "last");
250        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
251        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
252        assert_eq!(
253            result.property(CAMEL_AGGREGATED_SIZE),
254            Some(&serde_json::json!(3u64))
255        );
256    }
257
258    #[tokio::test]
259    async fn test_collect_all_produces_json_array() {
260        let mut svc = AggregatorService::new(config_size(2));
261        svc.ready()
262            .await
263            .unwrap()
264            .call(make_exchange("orderId", "A", "alpha"))
265            .await
266            .unwrap();
267        let result = svc
268            .ready()
269            .await
270            .unwrap()
271            .call(make_exchange("orderId", "A", "beta"))
272            .await
273            .unwrap();
274        let Body::Json(v) = &result.input.body else {
275            panic!("expected Body::Json")
276        };
277        let arr = v.as_array().unwrap();
278        assert_eq!(arr.len(), 2);
279        assert_eq!(arr[0], serde_json::json!("alpha"));
280        assert_eq!(arr[1], serde_json::json!("beta"));
281    }
282
283    #[tokio::test]
284    async fn test_two_keys_independent_buckets() {
285        // completionSize=3 so we can test that A and B accumulate independently.
286        let mut svc = AggregatorService::new(config_size(3));
287        svc.ready()
288            .await
289            .unwrap()
290            .call(make_exchange("orderId", "A", "a1"))
291            .await
292            .unwrap();
293        svc.ready()
294            .await
295            .unwrap()
296            .call(make_exchange("orderId", "B", "b1"))
297            .await
298            .unwrap();
299        svc.ready()
300            .await
301            .unwrap()
302            .call(make_exchange("orderId", "A", "a2"))
303            .await
304            .unwrap();
305        // A has 2 items, B has 1 item — neither complete yet
306        let ra = svc
307            .ready()
308            .await
309            .unwrap()
310            .call(make_exchange("orderId", "A", "a3"))
311            .await
312            .unwrap();
313        // A now has 3 → completes
314        assert!(matches!(ra.input.body, Body::Json(_)));
315        // B only has 1 → still pending
316        let rb = svc
317            .ready()
318            .await
319            .unwrap()
320            .call(make_exchange("orderId", "B", "b_check"))
321            .await
322            .unwrap();
323        assert!(matches!(rb.input.body, Body::Empty));
324    }
325
326    #[tokio::test]
327    async fn test_bucket_resets_after_completion() {
328        let mut svc = AggregatorService::new(config_size(2));
329        svc.ready()
330            .await
331            .unwrap()
332            .call(make_exchange("orderId", "A", "x"))
333            .await
334            .unwrap();
335        svc.ready()
336            .await
337            .unwrap()
338            .call(make_exchange("orderId", "A", "x"))
339            .await
340            .unwrap(); // completes
341        // New bucket starts
342        let r = svc
343            .ready()
344            .await
345            .unwrap()
346            .call(make_exchange("orderId", "A", "new"))
347            .await
348            .unwrap();
349        assert!(matches!(r.input.body, Body::Empty)); // pending again
350    }
351
352    #[tokio::test]
353    async fn test_completion_size_1_emits_immediately() {
354        let mut svc = AggregatorService::new(config_size(1));
355        let ex = make_exchange("orderId", "A", "solo");
356        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
357        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
358    }
359
360    #[tokio::test]
361    async fn test_custom_aggregation_strategy() {
362        use camel_api::aggregator::AggregationFn;
363        use std::sync::Arc;
364
365        let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
366            let combined = format!(
367                "{}+{}",
368                acc.input.body.as_text().unwrap_or(""),
369                next.input.body.as_text().unwrap_or("")
370            );
371            acc.input.body = Body::Text(combined);
372            acc
373        });
374        let config = AggregatorConfig::correlate_by("key")
375            .complete_when_size(2)
376            .strategy(AggregationStrategy::Custom(f))
377            .build();
378        let mut svc = AggregatorService::new(config);
379        svc.ready()
380            .await
381            .unwrap()
382            .call(make_exchange("key", "X", "hello"))
383            .await
384            .unwrap();
385        let result = svc
386            .ready()
387            .await
388            .unwrap()
389            .call(make_exchange("key", "X", "world"))
390            .await
391            .unwrap();
392        assert_eq!(result.input.body.as_text(), Some("hello+world"));
393    }
394
395    #[tokio::test]
396    async fn test_completion_predicate() {
397        let config = AggregatorConfig::correlate_by("key")
398            .complete_when(|bucket| {
399                bucket
400                    .iter()
401                    .any(|e| e.input.body.as_text() == Some("DONE"))
402            })
403            .build();
404        let mut svc = AggregatorService::new(config);
405        svc.ready()
406            .await
407            .unwrap()
408            .call(make_exchange("key", "K", "first"))
409            .await
410            .unwrap();
411        svc.ready()
412            .await
413            .unwrap()
414            .call(make_exchange("key", "K", "second"))
415            .await
416            .unwrap();
417        let result = svc
418            .ready()
419            .await
420            .unwrap()
421            .call(make_exchange("key", "K", "DONE"))
422            .await
423            .unwrap();
424        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
425    }
426
427    #[tokio::test]
428    async fn test_missing_header_returns_error() {
429        let mut svc = AggregatorService::new(config_size(2));
430        let msg = Message {
431            headers: Default::default(),
432            body: Body::Text("no key".into()),
433        };
434        let ex = Exchange::new(msg);
435        let result = svc.ready().await.unwrap().call(ex).await;
436        assert!(result.is_err());
437        assert!(matches!(
438            result.unwrap_err(),
439            camel_api::CamelError::ProcessorError(_)
440        ));
441    }
442
443    #[tokio::test]
444    async fn test_cloned_service_shares_state() {
445        let svc1 = AggregatorService::new(config_size(2));
446        let mut svc2 = svc1.clone();
447        // send first exchange via svc1
448        svc1.clone()
449            .ready()
450            .await
451            .unwrap()
452            .call(make_exchange("orderId", "A", "from-svc1"))
453            .await
454            .unwrap();
455        // send second exchange via svc2 — should complete because same Arc<Mutex>
456        let result = svc2
457            .ready()
458            .await
459            .unwrap()
460            .call(make_exchange("orderId", "A", "from-svc2"))
461            .await
462            .unwrap();
463        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
464    }
465
466    #[tokio::test]
467    async fn test_camel_aggregated_key_property_set() {
468        let mut svc = AggregatorService::new(config_size(1));
469        let ex = make_exchange("orderId", "ORDER-42", "body");
470        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
471        assert_eq!(
472            result.property(CAMEL_AGGREGATED_KEY),
473            Some(&serde_json::json!("ORDER-42"))
474        );
475    }
476
477    #[tokio::test]
478    async fn test_aggregator_enforces_max_buckets() {
479        let config = AggregatorConfig::correlate_by("orderId")
480            .complete_when_size(2)
481            .max_buckets(3)
482            .build();
483
484        let mut svc = AggregatorService::new(config);
485
486        // Create 3 different correlation keys (fills limit)
487        for i in 0..3 {
488            let ex = make_exchange("orderId", &format!("key-{}", i), "body");
489            let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
490        }
491
492        // 4th key should be rejected
493        let ex = make_exchange("orderId", "key-4", "body");
494        let result = svc.ready().await.unwrap().call(ex).await;
495
496        assert!(result.is_err(), "Should reject when max buckets reached");
497        let err = result.unwrap_err().to_string();
498        assert!(
499            err.contains("maximum"),
500            "Error message should contain 'maximum': {}",
501            err
502        );
503    }
504
505    #[tokio::test]
506    async fn test_max_buckets_allows_existing_key() {
507        let config = AggregatorConfig::correlate_by("orderId")
508            .complete_when_size(5) // Large size so bucket doesn't complete
509            .max_buckets(2)
510            .build();
511
512        let mut svc = AggregatorService::new(config);
513
514        // Create 2 different correlation keys (fills limit)
515        let ex1 = make_exchange("orderId", "key-A", "body1");
516        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
517        let ex2 = make_exchange("orderId", "key-B", "body2");
518        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
519
520        // Should still allow adding to existing key
521        let ex3 = make_exchange("orderId", "key-A", "body3");
522        let result = svc.ready().await.unwrap().call(ex3).await;
523        assert!(
524            result.is_ok(),
525            "Should allow adding to existing bucket even at max limit"
526        );
527    }
528
529    #[tokio::test]
530    async fn test_bucket_ttl_eviction() {
531        let config = AggregatorConfig::correlate_by("orderId")
532            .complete_when_size(10) // Large size so bucket doesn't complete normally
533            .bucket_ttl(Duration::from_millis(50))
534            .build();
535
536        let mut svc = AggregatorService::new(config);
537
538        // Create a bucket
539        let ex1 = make_exchange("orderId", "key-A", "body1");
540        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
541
542        // Wait for TTL to expire
543        tokio::time::sleep(Duration::from_millis(100)).await;
544
545        // Create a new bucket - this should trigger eviction of the old one
546        let ex2 = make_exchange("orderId", "key-B", "body2");
547        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
548
549        // The expired bucket should have been evicted, so we should be able to
550        // add a new key-A bucket again
551        let ex3 = make_exchange("orderId", "key-A", "body3");
552        let result = svc.ready().await.unwrap().call(ex3).await;
553        assert!(result.is_ok(), "Should be able to recreate evicted bucket");
554    }
555
556    #[tokio::test]
557    async fn test_aggregate_stream_bodies_creates_valid_json() {
558        use bytes::Bytes;
559        use camel_api::{Body, StreamBody, StreamMetadata};
560        use futures::stream;
561        use tokio::sync::Mutex;
562
563        let chunks = vec![Ok(Bytes::from("test"))];
564        let stream_body = StreamBody {
565            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
566            metadata: StreamMetadata {
567                origin: Some("file:///test.txt".to_string()),
568                ..Default::default()
569            },
570        };
571
572        let ex1 = Exchange::new(Message {
573            headers: Default::default(),
574            body: Body::Stream(stream_body),
575        });
576
577        let exchanges = vec![ex1];
578        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
579
580        let exchange = result.expect("Expected Ok result");
581        assert!(
582            matches!(exchange.input.body, Body::Json(_)),
583            "Expected Json body"
584        );
585
586        if let Body::Json(value) = exchange.input.body {
587            let json_str = serde_json::to_string(&value).unwrap();
588            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
589
590            assert!(parsed.is_array(), "Result should be an array");
591            let arr = parsed.as_array().unwrap();
592            assert!(arr[0].is_object(), "First element should be an object");
593            assert!(
594                arr[0]["_stream"].is_object(),
595                "Should contain _stream object"
596            );
597            assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
598            assert_eq!(
599                arr[0]["_stream"]["placeholder"], true,
600                "placeholder flag should be true"
601            );
602        }
603    }
604
605    #[tokio::test]
606    async fn test_aggregate_stream_bodies_with_none_origin() {
607        use bytes::Bytes;
608        use camel_api::{Body, StreamBody, StreamMetadata};
609        use futures::stream;
610        use tokio::sync::Mutex;
611
612        let chunks = vec![Ok(Bytes::from("test"))];
613        let stream_body = StreamBody {
614            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
615            metadata: StreamMetadata {
616                origin: None,
617                ..Default::default()
618            },
619        };
620
621        let ex1 = Exchange::new(Message {
622            headers: Default::default(),
623            body: Body::Stream(stream_body),
624        });
625
626        let exchanges = vec![ex1];
627        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
628
629        let exchange = result.expect("Expected Ok result");
630        assert!(
631            matches!(exchange.input.body, Body::Json(_)),
632            "Expected Json body"
633        );
634
635        if let Body::Json(value) = exchange.input.body {
636            let json_str = serde_json::to_string(&value).unwrap();
637            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
638
639            assert!(parsed.is_array(), "Result should be an array");
640            let arr = parsed.as_array().unwrap();
641            assert!(arr[0].is_object(), "First element should be an object");
642            assert!(
643                arr[0]["_stream"].is_object(),
644                "Should contain _stream object"
645            );
646            assert_eq!(
647                arr[0]["_stream"]["origin"],
648                serde_json::Value::Null,
649                "origin should be null when None"
650            );
651            assert_eq!(
652                arr[0]["_stream"]["placeholder"], true,
653                "placeholder flag should be true"
654            );
655        }
656    }
657}