Skip to main content

camel_processor/
aggregator.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tower::Service;
9
10use camel_api::{
11    CamelError,
12    aggregator::{AggregationStrategy, AggregatorConfig, CompletionCondition},
13    body::Body,
14    exchange::Exchange,
15    message::Message,
16};
17
18pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
19pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
20pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
21
22/// Internal bucket structure with timestamp tracking for TTL eviction.
23struct Bucket {
24    exchanges: Vec<Exchange>,
25    #[allow(dead_code)]
26    created_at: Instant,
27    last_updated: Instant,
28}
29
30impl Bucket {
31    fn new() -> Self {
32        let now = Instant::now();
33        Self {
34            exchanges: Vec::new(),
35            created_at: now,
36            last_updated: now,
37        }
38    }
39
40    fn push(&mut self, exchange: Exchange) {
41        self.exchanges.push(exchange);
42        self.last_updated = Instant::now();
43    }
44
45    fn len(&self) -> usize {
46        self.exchanges.len()
47    }
48
49    fn is_expired(&self, ttl: Duration) -> bool {
50        Instant::now().duration_since(self.last_updated) >= ttl
51    }
52}
53
54#[derive(Clone)]
55pub struct AggregatorService {
56    config: AggregatorConfig,
57    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
58}
59
60impl AggregatorService {
61    pub fn new(config: AggregatorConfig) -> Self {
62        Self {
63            config,
64            buckets: Arc::new(Mutex::new(HashMap::new())),
65        }
66    }
67}
68
69impl Service<Exchange> for AggregatorService {
70    type Response = Exchange;
71    type Error = CamelError;
72    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
73
74    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
75        Poll::Ready(Ok(()))
76    }
77
78    fn call(&mut self, exchange: Exchange) -> Self::Future {
79        let config = self.config.clone();
80        let buckets = Arc::clone(&self.buckets);
81
82        Box::pin(async move {
83            // 1. Extract correlation key value from header
84            let key_value = exchange
85                .input
86                .headers
87                .get(&config.header_name)
88                .cloned()
89                .ok_or_else(|| {
90                    CamelError::ProcessorError(format!(
91                        "Aggregator: missing correlation key header '{}'",
92                        config.header_name
93                    ))
94                })?;
95
96            // Serialize to String for use as HashMap key
97            let key_str = serde_json::to_string(&key_value)
98                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
99
100            // 2. Insert into bucket and check completion (lock scope)
101            let completed_bucket = {
102                let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
103
104                // Evict expired buckets if TTL is configured
105                if let Some(ttl) = config.bucket_ttl {
106                    guard.retain(|_, bucket| !bucket.is_expired(ttl));
107                }
108
109                // Enforce max buckets limit - reject new correlation keys if at limit
110                if let Some(max) = config.max_buckets
111                    && !guard.contains_key(&key_str)
112                    && guard.len() >= max
113                {
114                    tracing::warn!(
115                        max_buckets = max,
116                        correlation_key = %key_str,
117                        "Aggregator reached max buckets limit, rejecting new correlation key"
118                    );
119                    return Err(CamelError::ProcessorError(format!(
120                        "Aggregator reached maximum {} buckets",
121                        max
122                    )));
123                }
124
125                let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
126                bucket.push(exchange);
127
128                let is_complete = match &config.completion {
129                    CompletionCondition::Size(n) => bucket.len() >= *n,
130                    CompletionCondition::Predicate(pred) => pred(&bucket.exchanges),
131                };
132
133                if is_complete {
134                    guard.remove(&key_str).map(|b| b.exchanges)
135                } else {
136                    None
137                }
138            }; // Mutex released here
139
140            // 3. Emit aggregated exchange or return pending placeholder
141            match completed_bucket {
142                Some(exchanges) => {
143                    let size = exchanges.len();
144                    let mut result = aggregate(exchanges, &config.strategy)?;
145                    result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
146                    result.set_property(CAMEL_AGGREGATED_KEY, key_value);
147                    Ok(result)
148                }
149                None => {
150                    let mut pending = Exchange::new(Message {
151                        headers: Default::default(),
152                        body: Body::Empty,
153                    });
154                    pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
155                    Ok(pending)
156                }
157            }
158        })
159    }
160}
161
162fn aggregate(
163    exchanges: Vec<Exchange>,
164    strategy: &AggregationStrategy,
165) -> Result<Exchange, CamelError> {
166    match strategy {
167        AggregationStrategy::CollectAll => {
168            let bodies: Vec<serde_json::Value> = exchanges
169                .into_iter()
170                .map(|e| match e.input.body {
171                    Body::Json(v) => v,
172                    Body::Text(s) => serde_json::Value::String(s),
173                    Body::Bytes(b) => {
174                        serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
175                    }
176                    Body::Empty => serde_json::Value::Null,
177                })
178                .collect();
179            Ok(Exchange::new(Message {
180                headers: Default::default(),
181                body: Body::Json(serde_json::Value::Array(bodies)),
182            }))
183        }
184        AggregationStrategy::Custom(f) => {
185            let mut iter = exchanges.into_iter();
186            let first = iter.next().ok_or_else(|| {
187                CamelError::ProcessorError("Aggregator: empty bucket".to_string())
188            })?;
189            Ok(iter.fold(first, |acc, next| f(acc, next)))
190        }
191    }
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use camel_api::{
198        aggregator::{AggregationStrategy, AggregatorConfig},
199        body::Body,
200        exchange::Exchange,
201        message::Message,
202    };
203    use tower::ServiceExt;
204
205    fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
206        let mut msg = Message {
207            headers: Default::default(),
208            body: Body::Text(body.to_string()),
209        };
210        msg.headers
211            .insert(header.to_string(), serde_json::json!(value));
212        Exchange::new(msg)
213    }
214
215    fn config_size(n: usize) -> AggregatorConfig {
216        AggregatorConfig::correlate_by("orderId")
217            .complete_when_size(n)
218            .build()
219    }
220
221    #[tokio::test]
222    async fn test_pending_exchange_not_yet_complete() {
223        let mut svc = AggregatorService::new(config_size(3));
224        let ex = make_exchange("orderId", "A", "first");
225        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
226        assert!(matches!(result.input.body, Body::Empty));
227        assert_eq!(
228            result.property(CAMEL_AGGREGATOR_PENDING),
229            Some(&serde_json::json!(true))
230        );
231    }
232
233    #[tokio::test]
234    async fn test_completes_on_size() {
235        let mut svc = AggregatorService::new(config_size(3));
236        for _ in 0..2 {
237            let ex = make_exchange("orderId", "A", "item");
238            let r = svc.ready().await.unwrap().call(ex).await.unwrap();
239            assert!(matches!(r.input.body, Body::Empty));
240        }
241        let ex = make_exchange("orderId", "A", "last");
242        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
243        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
244        assert_eq!(
245            result.property(CAMEL_AGGREGATED_SIZE),
246            Some(&serde_json::json!(3u64))
247        );
248    }
249
250    #[tokio::test]
251    async fn test_collect_all_produces_json_array() {
252        let mut svc = AggregatorService::new(config_size(2));
253        svc.ready()
254            .await
255            .unwrap()
256            .call(make_exchange("orderId", "A", "alpha"))
257            .await
258            .unwrap();
259        let result = svc
260            .ready()
261            .await
262            .unwrap()
263            .call(make_exchange("orderId", "A", "beta"))
264            .await
265            .unwrap();
266        let Body::Json(v) = &result.input.body else {
267            panic!("expected Body::Json")
268        };
269        let arr = v.as_array().unwrap();
270        assert_eq!(arr.len(), 2);
271        assert_eq!(arr[0], serde_json::json!("alpha"));
272        assert_eq!(arr[1], serde_json::json!("beta"));
273    }
274
275    #[tokio::test]
276    async fn test_two_keys_independent_buckets() {
277        // completionSize=3 so we can test that A and B accumulate independently.
278        let mut svc = AggregatorService::new(config_size(3));
279        svc.ready()
280            .await
281            .unwrap()
282            .call(make_exchange("orderId", "A", "a1"))
283            .await
284            .unwrap();
285        svc.ready()
286            .await
287            .unwrap()
288            .call(make_exchange("orderId", "B", "b1"))
289            .await
290            .unwrap();
291        svc.ready()
292            .await
293            .unwrap()
294            .call(make_exchange("orderId", "A", "a2"))
295            .await
296            .unwrap();
297        // A has 2 items, B has 1 item — neither complete yet
298        let ra = svc
299            .ready()
300            .await
301            .unwrap()
302            .call(make_exchange("orderId", "A", "a3"))
303            .await
304            .unwrap();
305        // A now has 3 → completes
306        assert!(matches!(ra.input.body, Body::Json(_)));
307        // B only has 1 → still pending
308        let rb = svc
309            .ready()
310            .await
311            .unwrap()
312            .call(make_exchange("orderId", "B", "b_check"))
313            .await
314            .unwrap();
315        assert!(matches!(rb.input.body, Body::Empty));
316    }
317
318    #[tokio::test]
319    async fn test_bucket_resets_after_completion() {
320        let mut svc = AggregatorService::new(config_size(2));
321        svc.ready()
322            .await
323            .unwrap()
324            .call(make_exchange("orderId", "A", "x"))
325            .await
326            .unwrap();
327        svc.ready()
328            .await
329            .unwrap()
330            .call(make_exchange("orderId", "A", "x"))
331            .await
332            .unwrap(); // completes
333        // New bucket starts
334        let r = svc
335            .ready()
336            .await
337            .unwrap()
338            .call(make_exchange("orderId", "A", "new"))
339            .await
340            .unwrap();
341        assert!(matches!(r.input.body, Body::Empty)); // pending again
342    }
343
344    #[tokio::test]
345    async fn test_completion_size_1_emits_immediately() {
346        let mut svc = AggregatorService::new(config_size(1));
347        let ex = make_exchange("orderId", "A", "solo");
348        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
349        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
350    }
351
352    #[tokio::test]
353    async fn test_custom_aggregation_strategy() {
354        use camel_api::aggregator::AggregationFn;
355        use std::sync::Arc;
356
357        let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
358            let combined = format!(
359                "{}+{}",
360                acc.input.body.as_text().unwrap_or(""),
361                next.input.body.as_text().unwrap_or("")
362            );
363            acc.input.body = Body::Text(combined);
364            acc
365        });
366        let config = AggregatorConfig::correlate_by("key")
367            .complete_when_size(2)
368            .strategy(AggregationStrategy::Custom(f))
369            .build();
370        let mut svc = AggregatorService::new(config);
371        svc.ready()
372            .await
373            .unwrap()
374            .call(make_exchange("key", "X", "hello"))
375            .await
376            .unwrap();
377        let result = svc
378            .ready()
379            .await
380            .unwrap()
381            .call(make_exchange("key", "X", "world"))
382            .await
383            .unwrap();
384        assert_eq!(result.input.body.as_text(), Some("hello+world"));
385    }
386
387    #[tokio::test]
388    async fn test_completion_predicate() {
389        let config = AggregatorConfig::correlate_by("key")
390            .complete_when(|bucket| {
391                bucket
392                    .iter()
393                    .any(|e| e.input.body.as_text() == Some("DONE"))
394            })
395            .build();
396        let mut svc = AggregatorService::new(config);
397        svc.ready()
398            .await
399            .unwrap()
400            .call(make_exchange("key", "K", "first"))
401            .await
402            .unwrap();
403        svc.ready()
404            .await
405            .unwrap()
406            .call(make_exchange("key", "K", "second"))
407            .await
408            .unwrap();
409        let result = svc
410            .ready()
411            .await
412            .unwrap()
413            .call(make_exchange("key", "K", "DONE"))
414            .await
415            .unwrap();
416        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
417    }
418
419    #[tokio::test]
420    async fn test_missing_header_returns_error() {
421        let mut svc = AggregatorService::new(config_size(2));
422        let msg = Message {
423            headers: Default::default(),
424            body: Body::Text("no key".into()),
425        };
426        let ex = Exchange::new(msg);
427        let result = svc.ready().await.unwrap().call(ex).await;
428        assert!(result.is_err());
429        assert!(matches!(
430            result.unwrap_err(),
431            camel_api::CamelError::ProcessorError(_)
432        ));
433    }
434
435    #[tokio::test]
436    async fn test_cloned_service_shares_state() {
437        let svc1 = AggregatorService::new(config_size(2));
438        let mut svc2 = svc1.clone();
439        // send first exchange via svc1
440        svc1.clone()
441            .ready()
442            .await
443            .unwrap()
444            .call(make_exchange("orderId", "A", "from-svc1"))
445            .await
446            .unwrap();
447        // send second exchange via svc2 — should complete because same Arc<Mutex>
448        let result = svc2
449            .ready()
450            .await
451            .unwrap()
452            .call(make_exchange("orderId", "A", "from-svc2"))
453            .await
454            .unwrap();
455        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
456    }
457
458    #[tokio::test]
459    async fn test_camel_aggregated_key_property_set() {
460        let mut svc = AggregatorService::new(config_size(1));
461        let ex = make_exchange("orderId", "ORDER-42", "body");
462        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
463        assert_eq!(
464            result.property(CAMEL_AGGREGATED_KEY),
465            Some(&serde_json::json!("ORDER-42"))
466        );
467    }
468
469    #[tokio::test]
470    async fn test_aggregator_enforces_max_buckets() {
471        let config = AggregatorConfig::correlate_by("orderId")
472            .complete_when_size(2)
473            .max_buckets(3)
474            .build();
475
476        let mut svc = AggregatorService::new(config);
477
478        // Create 3 different correlation keys (fills limit)
479        for i in 0..3 {
480            let ex = make_exchange("orderId", &format!("key-{}", i), "body");
481            let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
482        }
483
484        // 4th key should be rejected
485        let ex = make_exchange("orderId", "key-4", "body");
486        let result = svc.ready().await.unwrap().call(ex).await;
487
488        assert!(result.is_err(), "Should reject when max buckets reached");
489        let err = result.unwrap_err().to_string();
490        assert!(
491            err.contains("maximum"),
492            "Error message should contain 'maximum': {}",
493            err
494        );
495    }
496
497    #[tokio::test]
498    async fn test_max_buckets_allows_existing_key() {
499        let config = AggregatorConfig::correlate_by("orderId")
500            .complete_when_size(5) // Large size so bucket doesn't complete
501            .max_buckets(2)
502            .build();
503
504        let mut svc = AggregatorService::new(config);
505
506        // Create 2 different correlation keys (fills limit)
507        let ex1 = make_exchange("orderId", "key-A", "body1");
508        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
509        let ex2 = make_exchange("orderId", "key-B", "body2");
510        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
511
512        // Should still allow adding to existing key
513        let ex3 = make_exchange("orderId", "key-A", "body3");
514        let result = svc.ready().await.unwrap().call(ex3).await;
515        assert!(
516            result.is_ok(),
517            "Should allow adding to existing bucket even at max limit"
518        );
519    }
520
521    #[tokio::test]
522    async fn test_bucket_ttl_eviction() {
523        let config = AggregatorConfig::correlate_by("orderId")
524            .complete_when_size(10) // Large size so bucket doesn't complete normally
525            .bucket_ttl(Duration::from_millis(50))
526            .build();
527
528        let mut svc = AggregatorService::new(config);
529
530        // Create a bucket
531        let ex1 = make_exchange("orderId", "key-A", "body1");
532        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
533
534        // Wait for TTL to expire
535        tokio::time::sleep(Duration::from_millis(100)).await;
536
537        // Create a new bucket - this should trigger eviction of the old one
538        let ex2 = make_exchange("orderId", "key-B", "body2");
539        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
540
541        // The expired bucket should have been evicted, so we should be able to
542        // add a new key-A bucket again
543        let ex3 = make_exchange("orderId", "key-A", "body3");
544        let result = svc.ready().await.unwrap().call(ex3).await;
545        assert!(result.is_ok(), "Should be able to recreate evicted bucket");
546    }
547}