Skip to main content

camel_processor/
aggregator.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11    CamelError,
12    aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13    body::Body,
14    exchange::Exchange,
15    message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22/// Internal bucket structure with timestamp tracking for TTL eviction.
23struct Bucket {
24    exchanges: Vec<Exchange>,
25    #[allow(dead_code)]
26    created_at: Instant,
27    last_updated: Instant,
28}
29
30impl Bucket {
31    fn new() -> Self {
32        let now = Instant::now();
33        Self {
34            exchanges: Vec::new(),
35            created_at: now,
36            last_updated: now,
37        }
38    }
39
40    fn push(&mut self, exchange: Exchange) {
41        self.exchanges.push(exchange);
42        self.last_updated = Instant::now();
43    }
44
45    fn len(&self) -> usize {
46        self.exchanges.len()
47    }
48
49    fn is_expired(&self, ttl: Duration) -> bool {
50        Instant::now().duration_since(self.last_updated) >= ttl
51    }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56    config: AggregatorConfig,
57    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61    pub fn new(config: AggregatorConfig) -> Self {
62        Self {
63            config,
64            buckets: Arc::new(Mutex::new(HashMap::new())),
65        }
66    }
67}
68
69impl Service<Exchange> for AggregatorService {
70    type Response = Exchange;
71    type Error = CamelError;
72    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn call(&mut self, exchange: Exchange) -> Self::Future {
79        let config = self.config.clone();
80        let buckets = Arc::clone(&self.buckets);
81
82        Box::pin(async move {
83            // 1. Extract correlation key value from header
84            let key_value = exchange
85                .input
86                .headers
87                .get(&config.header_name)
88                .cloned()
89                .ok_or_else(|| {
90                    CamelError::ProcessorError(format!(
91                        "Aggregator: missing correlation key header '{}'",
92                        config.header_name
93                    ))
94                })?;
95
96            // Serialize to String for use as HashMap key
97            let key_str = serde_json::to_string(&key_value)
98                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100            // 2. Insert into bucket and check completion (lock scope)
101            let completed_bucket = {
102                let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104                // Evict expired buckets if TTL is configured
105                if let Some(ttl) = config.bucket_ttl {
106                    guard.retain(|_, bucket| !bucket.is_expired(ttl));
107                }
108
109                // Enforce max buckets limit - reject new correlation keys if at limit
110                if let Some(max) = config.max_buckets
111                    && !guard.contains_key(&key_str)
112                    && guard.len() >= max
113                {
114                    tracing::warn!(
115                        max_buckets = max,
116                        correlation_key = %key_str,
117                        "Aggregator reached max buckets limit, rejecting new correlation key"
118                    );
119                    return Err(CamelError::ProcessorError(format!(
120                        "Aggregator reached maximum {} buckets",
121                        max
122                    )));
123                }
124
125                let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126                bucket.push(exchange);
127
128                let is_complete = match &config.completion {
129                    CompletionCondition::Size(n) => bucket.len() >= *n,
130                    CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131                };
132
133                if is_complete {
134                    guard.remove(&key_str).map(|b| b.exchanges)
135                } else {
136                    None
137                }
138            }; // Mutex released here
139
140            // 3. Emit aggregated exchange or return pending placeholder
141            match completed_bucket {
142                Some(exchanges) => {
143                    let size = exchanges.len();
144                    let mut result = aggregate(exchanges, &config.strategy)?;
145                    result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146                    result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147                    Ok(result)
148                }
149                None => {
150                    let mut pending = Exchange::new(Message {
151                        headers: Default::default(),
152                        body: Body::Empty,
153                    });
154                    pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155                    Ok(pending)
156                }
157            }
158        })
159    }
160}
161
162fn aggregate(
163    exchanges: Vec<Exchange>,
164    strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166    match strategy {
167        AggregationStrategy::CollectAll => {
168            let bodies: Vec<serde_json::Value> = exchanges
169                .into_iter()
170                .map(|e| match e.input.body {
171                    Body::Json(v) => v,
172                    Body::Text(s) => serde_json::Value::String(s),
173                    Body::Bytes(b) => {
174                        serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
175                    }
176                    Body::Empty => serde_json::Value::Null,
177                    Body::Stream(s) => serde_json::json!({
178                        "_stream": {
179                            "origin": s.metadata.origin,
180                            "placeholder": true,
181                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
182                        }
183                    }),
184                })
185                .collect();
186            Ok(Exchange::new(Message {
187                headers: Default::default(),
188                body: Body::Json(serde_json::Value::Array(bodies)),
189            }))
190        }
191        AggregationStrategy::Custom(f) => {
192            let mut iter = exchanges.into_iter();
193            let first = iter.next().ok_or_else(|| {
194                CamelError::ProcessorError("Aggregator: empty bucket".to_string())
195            })?;
196            Ok(iter.fold(first, |acc, next| f(acc, next)))
197        }
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use camel_api::{
205        aggregator::{AggregationStrategy, AggregatorConfig},
206        body::Body,
207        exchange::Exchange,
208        message::Message,
209    };
210    use tower::ServiceExt;
211
212    fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
213        let mut msg = Message {
214            headers: Default::default(),
215            body: Body::Text(body.to_string()),
216        };
217        msg.headers
218            .insert(header.to_string(), serde_json::json!(value));
219        Exchange::new(msg)
220    }
221
222    fn config_size(n: usize) -> AggregatorConfig {
223        AggregatorConfig::correlate_by("orderId")
224            .complete_when_size(n)
225            .build()
226    }
227
228    #[tokio::test]
229    async fn test_pending_exchange_not_yet_complete() {
230        let mut svc = AggregatorService::new(config_size(3));
231        let ex = make_exchange("orderId", "A", "first");
232        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
233        assert!(matches!(result.input.body, Body::Empty));
234        assert_eq!(
235            result.property(CAMEL_AGGREGATOR_PENDING),
236            Some(&serde_json::json!(true))
237        );
238    }
239
240    #[tokio::test]
241    async fn test_completes_on_size() {
242        let mut svc = AggregatorService::new(config_size(3));
243        for _ in 0..2 {
244            let ex = make_exchange("orderId", "A", "item");
245            let r = svc.ready().await.unwrap().call(ex).await.unwrap();
246            assert!(matches!(r.input.body, Body::Empty));
247        }
248        let ex = make_exchange("orderId", "A", "last");
249        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
250        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
251        assert_eq!(
252            result.property(CAMEL_AGGREGATED_SIZE),
253            Some(&serde_json::json!(3u64))
254        );
255    }
256
257    #[tokio::test]
258    async fn test_collect_all_produces_json_array() {
259        let mut svc = AggregatorService::new(config_size(2));
260        svc.ready()
261            .await
262            .unwrap()
263            .call(make_exchange("orderId", "A", "alpha"))
264            .await
265            .unwrap();
266        let result = svc
267            .ready()
268            .await
269            .unwrap()
270            .call(make_exchange("orderId", "A", "beta"))
271            .await
272            .unwrap();
273        let Body::Json(v) = &result.input.body else {
274            panic!("expected Body::Json")
275        };
276        let arr = v.as_array().unwrap();
277        assert_eq!(arr.len(), 2);
278        assert_eq!(arr[0], serde_json::json!("alpha"));
279        assert_eq!(arr[1], serde_json::json!("beta"));
280    }
281
282    #[tokio::test]
283    async fn test_two_keys_independent_buckets() {
284        // completionSize=3 so we can test that A and B accumulate independently.
285        let mut svc = AggregatorService::new(config_size(3));
286        svc.ready()
287            .await
288            .unwrap()
289            .call(make_exchange("orderId", "A", "a1"))
290            .await
291            .unwrap();
292        svc.ready()
293            .await
294            .unwrap()
295            .call(make_exchange("orderId", "B", "b1"))
296            .await
297            .unwrap();
298        svc.ready()
299            .await
300            .unwrap()
301            .call(make_exchange("orderId", "A", "a2"))
302            .await
303            .unwrap();
304        // A has 2 items, B has 1 item — neither complete yet
305        let ra = svc
306            .ready()
307            .await
308            .unwrap()
309            .call(make_exchange("orderId", "A", "a3"))
310            .await
311            .unwrap();
312        // A now has 3 → completes
313        assert!(matches!(ra.input.body, Body::Json(_)));
314        // B only has 1 → still pending
315        let rb = svc
316            .ready()
317            .await
318            .unwrap()
319            .call(make_exchange("orderId", "B", "b_check"))
320            .await
321            .unwrap();
322        assert!(matches!(rb.input.body, Body::Empty));
323    }
324
325    #[tokio::test]
326    async fn test_bucket_resets_after_completion() {
327        let mut svc = AggregatorService::new(config_size(2));
328        svc.ready()
329            .await
330            .unwrap()
331            .call(make_exchange("orderId", "A", "x"))
332            .await
333            .unwrap();
334        svc.ready()
335            .await
336            .unwrap()
337            .call(make_exchange("orderId", "A", "x"))
338            .await
339            .unwrap(); // completes
340        // New bucket starts
341        let r = svc
342            .ready()
343            .await
344            .unwrap()
345            .call(make_exchange("orderId", "A", "new"))
346            .await
347            .unwrap();
348        assert!(matches!(r.input.body, Body::Empty)); // pending again
349    }
350
351    #[tokio::test]
352    async fn test_completion_size_1_emits_immediately() {
353        let mut svc = AggregatorService::new(config_size(1));
354        let ex = make_exchange("orderId", "A", "solo");
355        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
356        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
357    }
358
359    #[tokio::test]
360    async fn test_custom_aggregation_strategy() {
361        use camel_api::aggregator::AggregationFn;
362        use std::sync::Arc;
363
364        let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
365            let combined = format!(
366                "{}+{}",
367                acc.input.body.as_text().unwrap_or(""),
368                next.input.body.as_text().unwrap_or("")
369            );
370            acc.input.body = Body::Text(combined);
371            acc
372        });
373        let config = AggregatorConfig::correlate_by("key")
374            .complete_when_size(2)
375            .strategy(AggregationStrategy::Custom(f))
376            .build();
377        let mut svc = AggregatorService::new(config);
378        svc.ready()
379            .await
380            .unwrap()
381            .call(make_exchange("key", "X", "hello"))
382            .await
383            .unwrap();
384        let result = svc
385            .ready()
386            .await
387            .unwrap()
388            .call(make_exchange("key", "X", "world"))
389            .await
390            .unwrap();
391        assert_eq!(result.input.body.as_text(), Some("hello+world"));
392    }
393
394    #[tokio::test]
395    async fn test_completion_predicate() {
396        let config = AggregatorConfig::correlate_by("key")
397            .complete_when(|bucket| {
398                bucket
399                    .iter()
400                    .any(|e| e.input.body.as_text() == Some("DONE"))
401            })
402            .build();
403        let mut svc = AggregatorService::new(config);
404        svc.ready()
405            .await
406            .unwrap()
407            .call(make_exchange("key", "K", "first"))
408            .await
409            .unwrap();
410        svc.ready()
411            .await
412            .unwrap()
413            .call(make_exchange("key", "K", "second"))
414            .await
415            .unwrap();
416        let result = svc
417            .ready()
418            .await
419            .unwrap()
420            .call(make_exchange("key", "K", "DONE"))
421            .await
422            .unwrap();
423        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
424    }
425
426    #[tokio::test]
427    async fn test_missing_header_returns_error() {
428        let mut svc = AggregatorService::new(config_size(2));
429        let msg = Message {
430            headers: Default::default(),
431            body: Body::Text("no key".into()),
432        };
433        let ex = Exchange::new(msg);
434        let result = svc.ready().await.unwrap().call(ex).await;
435        assert!(result.is_err());
436        assert!(matches!(
437            result.unwrap_err(),
438            camel_api::CamelError::ProcessorError(_)
439        ));
440    }
441
442    #[tokio::test]
443    async fn test_cloned_service_shares_state() {
444        let svc1 = AggregatorService::new(config_size(2));
445        let mut svc2 = svc1.clone();
446        // send first exchange via svc1
447        svc1.clone()
448            .ready()
449            .await
450            .unwrap()
451            .call(make_exchange("orderId", "A", "from-svc1"))
452            .await
453            .unwrap();
454        // send second exchange via svc2 — should complete because same Arc<Mutex>
455        let result = svc2
456            .ready()
457            .await
458            .unwrap()
459            .call(make_exchange("orderId", "A", "from-svc2"))
460            .await
461            .unwrap();
462        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
463    }
464
465    #[tokio::test]
466    async fn test_camel_aggregated_key_property_set() {
467        let mut svc = AggregatorService::new(config_size(1));
468        let ex = make_exchange("orderId", "ORDER-42", "body");
469        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
470        assert_eq!(
471            result.property(CAMEL_AGGREGATED_KEY),
472            Some(&serde_json::json!("ORDER-42"))
473        );
474    }
475
476    #[tokio::test]
477    async fn test_aggregator_enforces_max_buckets() {
478        let config = AggregatorConfig::correlate_by("orderId")
479            .complete_when_size(2)
480            .max_buckets(3)
481            .build();
482
483        let mut svc = AggregatorService::new(config);
484
485        // Create 3 different correlation keys (fills limit)
486        for i in 0..3 {
487            let ex = make_exchange("orderId", &format!("key-{}", i), "body");
488            let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
489        }
490
491        // 4th key should be rejected
492        let ex = make_exchange("orderId", "key-4", "body");
493        let result = svc.ready().await.unwrap().call(ex).await;
494
495        assert!(result.is_err(), "Should reject when max buckets reached");
496        let err = result.unwrap_err().to_string();
497        assert!(
498            err.contains("maximum"),
499            "Error message should contain 'maximum': {}",
500            err
501        );
502    }
503
504    #[tokio::test]
505    async fn test_max_buckets_allows_existing_key() {
506        let config = AggregatorConfig::correlate_by("orderId")
507            .complete_when_size(5) // Large size so bucket doesn't complete
508            .max_buckets(2)
509            .build();
510
511        let mut svc = AggregatorService::new(config);
512
513        // Create 2 different correlation keys (fills limit)
514        let ex1 = make_exchange("orderId", "key-A", "body1");
515        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
516        let ex2 = make_exchange("orderId", "key-B", "body2");
517        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
518
519        // Should still allow adding to existing key
520        let ex3 = make_exchange("orderId", "key-A", "body3");
521        let result = svc.ready().await.unwrap().call(ex3).await;
522        assert!(
523            result.is_ok(),
524            "Should allow adding to existing bucket even at max limit"
525        );
526    }
527
528    #[tokio::test]
529    async fn test_bucket_ttl_eviction() {
530        let config = AggregatorConfig::correlate_by("orderId")
531            .complete_when_size(10) // Large size so bucket doesn't complete normally
532            .bucket_ttl(Duration::from_millis(50))
533            .build();
534
535        let mut svc = AggregatorService::new(config);
536
537        // Create a bucket
538        let ex1 = make_exchange("orderId", "key-A", "body1");
539        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
540
541        // Wait for TTL to expire
542        tokio::time::sleep(Duration::from_millis(100)).await;
543
544        // Create a new bucket - this should trigger eviction of the old one
545        let ex2 = make_exchange("orderId", "key-B", "body2");
546        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
547
548        // The expired bucket should have been evicted, so we should be able to
549        // add a new key-A bucket again
550        let ex3 = make_exchange("orderId", "key-A", "body3");
551        let result = svc.ready().await.unwrap().call(ex3).await;
552        assert!(result.is_ok(), "Should be able to recreate evicted bucket");
553    }
554
555    #[tokio::test]
556    async fn test_aggregate_stream_bodies_creates_valid_json() {
557        use bytes::Bytes;
558        use camel_api::{Body, StreamBody, StreamMetadata};
559        use futures::stream;
560        use tokio::sync::Mutex;
561
562        let chunks = vec![Ok(Bytes::from("test"))];
563        let stream_body = StreamBody {
564            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
565            metadata: StreamMetadata {
566                origin: Some("file:///test.txt".to_string()),
567                ..Default::default()
568            },
569        };
570
571        let ex1 = Exchange::new(Message {
572            headers: Default::default(),
573            body: Body::Stream(stream_body),
574        });
575
576        let exchanges = vec![ex1];
577        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
578
579        let exchange = result.expect("Expected Ok result");
580        assert!(
581            matches!(exchange.input.body, Body::Json(_)),
582            "Expected Json body"
583        );
584
585        if let Body::Json(value) = exchange.input.body {
586            let json_str = serde_json::to_string(&value).unwrap();
587            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
588
589            assert!(parsed.is_array(), "Result should be an array");
590            let arr = parsed.as_array().unwrap();
591            assert!(arr[0].is_object(), "First element should be an object");
592            assert!(
593                arr[0]["_stream"].is_object(),
594                "Should contain _stream object"
595            );
596            assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
597            assert_eq!(
598                arr[0]["_stream"]["placeholder"], true,
599                "placeholder flag should be true"
600            );
601        }
602    }
603
604    #[tokio::test]
605    async fn test_aggregate_stream_bodies_with_none_origin() {
606        use bytes::Bytes;
607        use camel_api::{Body, StreamBody, StreamMetadata};
608        use futures::stream;
609        use tokio::sync::Mutex;
610
611        let chunks = vec![Ok(Bytes::from("test"))];
612        let stream_body = StreamBody {
613            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
614            metadata: StreamMetadata {
615                origin: None,
616                ..Default::default()
617            },
618        };
619
620        let ex1 = Exchange::new(Message {
621            headers: Default::default(),
622            body: Body::Stream(stream_body),
623        });
624
625        let exchanges = vec![ex1];
626        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
627
628        let exchange = result.expect("Expected Ok result");
629        assert!(
630            matches!(exchange.input.body, Body::Json(_)),
631            "Expected Json body"
632        );
633
634        if let Body::Json(value) = exchange.input.body {
635            let json_str = serde_json::to_string(&value).unwrap();
636            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
637
638            assert!(parsed.is_array(), "Result should be an array");
639            let arr = parsed.as_array().unwrap();
640            assert!(arr[0].is_object(), "First element should be an object");
641            assert!(
642                arr[0]["_stream"].is_object(),
643                "Should contain _stream object"
644            );
645            assert_eq!(
646                arr[0]["_stream"]["origin"],
647                serde_json::Value::Null,
648                "origin should be null when None"
649            );
650            assert_eq!(
651                arr[0]["_stream"]["placeholder"], true,
652                "placeholder flag should be true"
653            );
654        }
655    }
656}