Skip to main content

camel_processor/
aggregator.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12
13use camel_api::{
14    CamelError,
15    aggregator::{
16        AggregationStrategy, AggregatorConfig, CompletionCondition, CompletionMode,
17        CompletionReason, CorrelationStrategy,
18    },
19    body::Body,
20    exchange::Exchange,
21    message::Message,
22};
23use camel_language_api::Language;
24
25pub type SharedLanguageRegistry = Arc<std::sync::Mutex<HashMap<String, Arc<dyn Language>>>>;
26
27pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
28pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
29pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
30pub const CAMEL_AGGREGATED_COMPLETION_REASON: &str = "CamelAggregatedCompletionReason";
31
32/// Internal bucket structure with timestamp tracking for TTL eviction.
33struct Bucket {
34    exchanges: Vec<Exchange>,
35    last_updated: Instant,
36}
37
38impl Bucket {
39    fn new() -> Self {
40        Self {
41            exchanges: Vec::new(),
42            last_updated: Instant::now(),
43        }
44    }
45
46    fn push(&mut self, exchange: Exchange) {
47        self.exchanges.push(exchange);
48        self.last_updated = Instant::now();
49    }
50
51    fn is_expired(&self, ttl: Duration) -> bool {
52        Instant::now().duration_since(self.last_updated) >= ttl
53    }
54}
55
56#[derive(Clone)]
57pub struct AggregatorService {
58    config: AggregatorConfig,
59    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
60    timeout_tasks: Arc<Mutex<HashMap<String, CancellationToken>>>,
61    timeout_handles: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
62    late_tx: mpsc::Sender<Exchange>,
63    language_registry: SharedLanguageRegistry,
64    route_cancel: CancellationToken,
65}
66
67impl AggregatorService {
68    pub fn new(
69        config: AggregatorConfig,
70        late_tx: mpsc::Sender<Exchange>,
71        language_registry: SharedLanguageRegistry,
72        route_cancel: CancellationToken,
73    ) -> Self {
74        Self {
75            config,
76            buckets: Arc::new(Mutex::new(HashMap::new())),
77            timeout_tasks: Arc::new(Mutex::new(HashMap::new())),
78            timeout_handles: Arc::new(Mutex::new(HashMap::new())),
79            late_tx,
80            language_registry,
81            route_cancel,
82        }
83    }
84
85    pub fn config(&self) -> &AggregatorConfig {
86        &self.config
87    }
88
89    pub fn has_timeout(&self) -> bool {
90        has_timeout_condition(&self.config.completion)
91    }
92
93    pub fn force_complete_all(&self) {
94        let mut buckets_guard = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
95        let keys: Vec<String> = buckets_guard.keys().cloned().collect();
96
97        for key in keys {
98            if let Some(bucket) = buckets_guard.remove(&key) {
99                if self.config.force_completion_on_stop {
100                    cancel_timeout_task_with_handle(
101                        &key,
102                        &self.timeout_tasks,
103                        &self.timeout_handles,
104                    );
105                    match aggregate(bucket.exchanges, &self.config.strategy) {
106                        Ok(mut result) => {
107                            result.set_property(
108                                CAMEL_AGGREGATED_COMPLETION_REASON,
109                                serde_json::json!(CompletionReason::Stop.as_str()),
110                            );
111                            if self.late_tx.try_send(result).is_err() {
112                                tracing::warn!(
113                                    key = %key,
114                                    "aggregator force-complete emit dropped: late channel full"
115                                );
116                            }
117                        }
118                        Err(e) => {
119                            tracing::error!(
120                                key = %key,
121                                error = %e,
122                                "aggregation failed in force_complete_all"
123                            );
124                        }
125                    }
126                } else {
127                    cancel_timeout_task_with_handle(
128                        &key,
129                        &self.timeout_tasks,
130                        &self.timeout_handles,
131                    );
132                }
133            }
134        }
135    }
136
137    /// Graceful shutdown: cancel all outstanding timeout tasks and await their
138    /// JoinHandles (with a 5s deadline) so that no tasks are leaked.
139    pub async fn shutdown(&self) {
140        // Cancel all timeout cancellation tokens.
141        {
142            let mut guard = self.timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
143            for token in guard.values() {
144                token.cancel();
145            }
146            guard.clear();
147        };
148
149        // Remove and collect all JoinHandles.
150        let handles: Vec<JoinHandle<()>> = {
151            let mut guard = self
152                .timeout_handles
153                .lock()
154                .unwrap_or_else(|e| e.into_inner());
155            guard.drain().map(|(_, handle)| handle).collect()
156        };
157
158        if handles.is_empty() {
159            return;
160        }
161
162        // Await all handles with a deadline.
163        let _ = tokio::time::timeout(Duration::from_secs(5), async {
164            for handle in handles {
165                let _ = handle.await;
166            }
167        })
168        .await;
169    }
170}
171
172pub fn has_timeout_condition(mode: &CompletionMode) -> bool {
173    match mode {
174        CompletionMode::Single(CompletionCondition::Timeout(_)) => true,
175        CompletionMode::Any(conditions) => conditions
176            .iter()
177            .any(|c| matches!(c, CompletionCondition::Timeout(_))),
178        _ => false,
179    }
180}
181
182impl Service<Exchange> for AggregatorService {
183    type Response = Exchange;
184    type Error = CamelError;
185    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
186
187    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
188        Poll::Ready(Ok(()))
189    }
190
191    fn call(&mut self, exchange: Exchange) -> Self::Future {
192        let config = self.config.clone();
193        let buckets = Arc::clone(&self.buckets);
194        let timeout_tasks = Arc::clone(&self.timeout_tasks);
195        let timeout_handles = Arc::clone(&self.timeout_handles);
196        let late_tx = self.late_tx.clone();
197        let language_registry = Arc::clone(&self.language_registry);
198        let route_cancel = self.route_cancel.clone();
199
200        Box::pin(async move {
201            let key_value =
202                extract_correlation_key(&exchange, &config.correlation, &language_registry).await?;
203
204            let key_str = serde_json::to_string(&key_value)
205                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
206
207            let completed_bucket = {
208                let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
209
210                if let Some(ttl) = config.bucket_ttl {
211                    guard.retain(|_, bucket| !bucket.is_expired(ttl));
212                }
213
214                if let Some(max) = config.max_buckets
215                    && !guard.contains_key(&key_str)
216                    && guard.len() >= max
217                {
218                    tracing::warn!(
219                        max_buckets = max,
220                        correlation_key = %key_str,
221                        "Aggregator reached max buckets limit, rejecting new correlation key"
222                    );
223                    return Err(CamelError::ProcessorError(format!(
224                        "Aggregator reached maximum {} buckets",
225                        max
226                    )));
227                }
228
229                let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
230                bucket.push(exchange);
231
232                let (is_complete, reason) =
233                    check_sync_completion(&config.completion, &bucket.exchanges);
234
235                if is_complete {
236                    let exchanges = guard.remove(&key_str).map(|b| b.exchanges);
237                    (exchanges, reason)
238                } else {
239                    (None, CompletionReason::Size) // placeholder; reason unused when None
240                }
241            };
242
243            if completed_bucket.0.is_none() && has_timeout_condition(&config.completion) {
244                let cancel = {
245                    let mut tt_guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
246                    // Cancel and remove old handle for this key (if any).
247                    if let Some(existing) = tt_guard.get(&key_str) {
248                        existing.cancel();
249                    }
250                    let token = CancellationToken::new();
251                    tt_guard.insert(key_str.clone(), token.clone());
252                    token
253                };
254
255                let timeout_dur = extract_timeout_duration(&config.completion);
256                if let Some(timeout) = timeout_dur {
257                    // Remove old handle if present.
258                    {
259                        let mut hh = timeout_handles.lock().unwrap_or_else(|e| e.into_inner());
260                        if let Some(old) = hh.remove(&key_str) {
261                            old.abort();
262                        }
263                    }
264                    let handle = spawn_timeout_task(
265                        key_str.clone(),
266                        timeout,
267                        cancel,
268                        buckets.clone(),
269                        timeout_tasks.clone(),
270                        timeout_handles.clone(),
271                        late_tx,
272                        config.strategy.clone(),
273                        config.discard_on_timeout,
274                        route_cancel,
275                    );
276                    timeout_handles
277                        .lock()
278                        .unwrap_or_else(|e| e.into_inner())
279                        .insert(key_str.clone(), handle);
280                }
281            }
282
283            if let Some(exchanges) = completed_bucket.0 {
284                cancel_timeout_task_with_handle(&key_str, &timeout_tasks, &timeout_handles);
285                let reason = completed_bucket.1;
286                let size = exchanges.len();
287                let mut result = aggregate(exchanges, &config.strategy)?;
288                result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
289                result.set_property(CAMEL_AGGREGATED_KEY, key_value);
290                result.set_property(
291                    CAMEL_AGGREGATED_COMPLETION_REASON,
292                    serde_json::json!(reason.as_str()),
293                );
294                Ok(result)
295            } else {
296                let mut pending = Exchange::new(Message {
297                    headers: Default::default(),
298                    body: Body::Empty,
299                });
300                pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
301                Ok(pending)
302            }
303        })
304    }
305}
306
307async fn extract_correlation_key(
308    exchange: &Exchange,
309    strategy: &CorrelationStrategy,
310    registry: &SharedLanguageRegistry,
311) -> Result<serde_json::Value, CamelError> {
312    match strategy {
313        CorrelationStrategy::HeaderName(h) => {
314            exchange.input.headers.get(h).cloned().ok_or_else(|| {
315                CamelError::ProcessorError(format!(
316                    "Aggregator: missing correlation key header '{}'",
317                    h
318                ))
319            })
320        }
321        CorrelationStrategy::Expression { expr, language } => {
322            let expression = {
323                let reg = registry.lock().unwrap_or_else(|e| e.into_inner());
324                let lang = reg.get(language).ok_or_else(|| {
325                    CamelError::ProcessorError(format!(
326                        "Aggregator: language '{}' not found in registry",
327                        language
328                    ))
329                })?;
330                lang.create_expression(expr)
331                    .map_err(|e| CamelError::ProcessorError(e.to_string()))?
332            };
333            let value = expression
334                .evaluate(exchange)
335                .await
336                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
337            if value.is_null() {
338                return Err(CamelError::ProcessorError(format!(
339                    "Aggregator: correlation expression '{}' evaluated to null",
340                    expr
341                )));
342            }
343            Ok(value)
344        }
345        CorrelationStrategy::Fn(f) => f(exchange).map(serde_json::Value::String).ok_or_else(|| {
346            CamelError::ProcessorError("Aggregator: correlation function returned None".to_string())
347        }),
348    }
349}
350
351fn check_sync_completion(
352    mode: &CompletionMode,
353    exchanges: &[Exchange],
354) -> (bool, CompletionReason) {
355    match mode {
356        CompletionMode::Single(cond) => check_single(cond, exchanges),
357        CompletionMode::Any(conditions) => {
358            for cond in conditions {
359                if let CompletionCondition::Timeout(_) = cond {
360                    continue;
361                }
362                let (done, reason) = check_single(cond, exchanges);
363                if done {
364                    return (true, reason);
365                }
366            }
367            (false, CompletionReason::Size)
368        }
369    }
370}
371
372fn check_single(cond: &CompletionCondition, exchanges: &[Exchange]) -> (bool, CompletionReason) {
373    match cond {
374        CompletionCondition::Size(n) => (exchanges.len() >= *n, CompletionReason::Size),
375        CompletionCondition::Predicate(pred) => (pred(exchanges), CompletionReason::Predicate),
376        CompletionCondition::Timeout(_) => (false, CompletionReason::Timeout),
377    }
378}
379
380fn extract_timeout_duration(mode: &CompletionMode) -> Option<Duration> {
381    match mode {
382        CompletionMode::Single(CompletionCondition::Timeout(d)) => Some(*d),
383        CompletionMode::Any(conditions) => conditions.iter().find_map(|c| {
384            if let CompletionCondition::Timeout(d) = c {
385                Some(*d)
386            } else {
387                None
388            }
389        }),
390        _ => None,
391    }
392}
393
394fn cancel_timeout_task(key: &str, timeout_tasks: &Arc<Mutex<HashMap<String, CancellationToken>>>) {
395    let mut guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
396    if let Some(token) = guard.remove(key) {
397        token.cancel();
398    }
399}
400
401/// Also removes the stored JoinHandle for a cancelled/completed timeout task.
402fn cancel_timeout_task_with_handle(
403    key: &str,
404    timeout_tasks: &Arc<Mutex<HashMap<String, CancellationToken>>>,
405    timeout_handles: &Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
406) {
407    cancel_timeout_task(key, timeout_tasks);
408    let mut guard = timeout_handles.lock().unwrap_or_else(|e| e.into_inner());
409    guard.remove(key);
410}
411
412#[allow(clippy::too_many_arguments)]
413fn spawn_timeout_task(
414    key: String,
415    timeout: Duration,
416    cancel: CancellationToken,
417    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
418    timeout_tasks: Arc<Mutex<HashMap<String, CancellationToken>>>,
419    _timeout_handles: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
420    late_tx: mpsc::Sender<Exchange>,
421    strategy: AggregationStrategy,
422    discard: bool,
423    _route_cancel: CancellationToken,
424) -> JoinHandle<()> {
425    let cancel_clone = cancel.clone();
426    tokio::spawn(async move {
427        tokio::select! {
428            _ = tokio::time::sleep(timeout) => {
429                let should_proceed = {
430                    let mut tt_guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
431                    if cancel_clone.is_cancelled() {
432                        false
433                    } else {
434                        tt_guard.remove(&key);
435                        true
436                    }
437                };
438                if !should_proceed {
439                    return;
440                }
441                let bucket_exchanges = {
442                    let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
443                    guard.remove(&key).map(|b| b.exchanges)
444                };
445                if let Some(exchanges) = bucket_exchanges
446                    && !discard
447                {
448                    match aggregate(exchanges, &strategy) {
449                        Ok(mut result) => {
450                            result.set_property(
451                                CAMEL_AGGREGATED_COMPLETION_REASON,
452                                serde_json::json!(CompletionReason::Timeout.as_str()),
453                            );
454                            if late_tx.try_send(result).is_err() {
455                                tracing::warn!(
456                                    key = %key,
457                                    "aggregator timeout emit dropped: late channel full"
458                                );
459                            }
460                        }
461                        Err(e) => {
462                            tracing::error!(
463                                key = %key,
464                                error = %e,
465                                "aggregation failed in timeout task"
466                            );
467                        }
468                    }
469                }
470            }
471            _ = cancel_clone.cancelled() => {}
472        }
473    })
474}
475
476fn aggregate(
477    exchanges: Vec<Exchange>,
478    strategy: &AggregationStrategy,
479) -> Result<Exchange, CamelError> {
480    match strategy {
481        AggregationStrategy::CollectAll => {
482            let bodies: Vec<serde_json::Value> = exchanges
483                .into_iter()
484                .map(|e| match e.input.body {
485                    Body::Json(v) => v,
486                    Body::Text(s) => serde_json::Value::String(s),
487                    Body::Xml(s) => serde_json::Value::String(s),
488                    Body::Bytes(b) => {
489                        serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
490                    }
491                    Body::Empty => serde_json::Value::Null,
492                    Body::Stream(s) => serde_json::json!({
493                        "_stream": {
494                            "origin": s.metadata.origin,
495                            "placeholder": true,
496                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
497                        }
498                    }),
499                })
500                .collect();
501            Ok(Exchange::new(Message {
502                headers: Default::default(),
503                body: Body::Json(serde_json::Value::Array(bodies)),
504            }))
505        }
506        AggregationStrategy::Custom(f) => {
507            let mut iter = exchanges.into_iter();
508            let first = iter.next().ok_or_else(|| {
509                CamelError::ProcessorError("Aggregator: empty bucket".to_string())
510            })?;
511            Ok(iter.fold(first, |acc, next| f(acc, next)))
512        }
513    }
514}
515
516#[cfg(test)]
517mod tests {
518    use super::*;
519    use std::collections::HashMap;
520
521    use camel_api::{
522        aggregator::{AggregationStrategy, AggregatorConfig},
523        body::Body,
524        exchange::Exchange,
525        message::Message,
526    };
527    use tokio::sync::mpsc;
528    use tokio_util::sync::CancellationToken;
529    use tower::ServiceExt;
530
531    fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
532        let mut msg = Message {
533            headers: Default::default(),
534            body: Body::Text(body.to_string()),
535        };
536        msg.headers
537            .insert(header.to_string(), serde_json::json!(value));
538        Exchange::new(msg)
539    }
540
541    fn config_size(n: usize) -> AggregatorConfig {
542        AggregatorConfig::correlate_by("orderId")
543            .complete_when_size(n)
544            .build()
545            .unwrap()
546    }
547
548    fn new_test_svc(config: AggregatorConfig) -> AggregatorService {
549        let (tx, _rx) = mpsc::channel(256);
550        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
551        let cancel = CancellationToken::new();
552        AggregatorService::new(config, tx, registry, cancel)
553    }
554
555    #[tokio::test]
556    async fn test_pending_exchange_not_yet_complete() {
557        let mut svc = new_test_svc(config_size(3));
558        let ex = make_exchange("orderId", "A", "first");
559        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
560        assert!(matches!(result.input.body, Body::Empty));
561        assert_eq!(
562            result.property(CAMEL_AGGREGATOR_PENDING),
563            Some(&serde_json::json!(true))
564        );
565    }
566
567    #[tokio::test]
568    async fn test_completes_on_size() {
569        let mut svc = new_test_svc(config_size(3));
570        for _ in 0..2 {
571            let ex = make_exchange("orderId", "A", "item");
572            let r = svc.ready().await.unwrap().call(ex).await.unwrap();
573            assert!(matches!(r.input.body, Body::Empty));
574        }
575        let ex = make_exchange("orderId", "A", "last");
576        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
577        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
578        assert_eq!(
579            result.property(CAMEL_AGGREGATED_SIZE),
580            Some(&serde_json::json!(3u64))
581        );
582    }
583
584    #[tokio::test]
585    async fn test_collect_all_produces_json_array() {
586        let mut svc = new_test_svc(config_size(2));
587        svc.ready()
588            .await
589            .unwrap()
590            .call(make_exchange("orderId", "A", "alpha"))
591            .await
592            .unwrap();
593        let result = svc
594            .ready()
595            .await
596            .unwrap()
597            .call(make_exchange("orderId", "A", "beta"))
598            .await
599            .unwrap();
600        let Body::Json(v) = &result.input.body else {
601            panic!("expected Body::Json")
602        };
603        let arr = v.as_array().unwrap();
604        assert_eq!(arr.len(), 2);
605        assert_eq!(arr[0], serde_json::json!("alpha"));
606        assert_eq!(arr[1], serde_json::json!("beta"));
607    }
608
609    #[tokio::test]
610    async fn test_two_keys_independent_buckets() {
611        // completionSize=3 so we can test that A and B accumulate independently.
612        let mut svc = new_test_svc(config_size(3));
613        svc.ready()
614            .await
615            .unwrap()
616            .call(make_exchange("orderId", "A", "a1"))
617            .await
618            .unwrap();
619        svc.ready()
620            .await
621            .unwrap()
622            .call(make_exchange("orderId", "B", "b1"))
623            .await
624            .unwrap();
625        svc.ready()
626            .await
627            .unwrap()
628            .call(make_exchange("orderId", "A", "a2"))
629            .await
630            .unwrap();
631        // A has 2 items, B has 1 item — neither complete yet
632        let ra = svc
633            .ready()
634            .await
635            .unwrap()
636            .call(make_exchange("orderId", "A", "a3"))
637            .await
638            .unwrap();
639        // A now has 3 → completes
640        assert!(matches!(ra.input.body, Body::Json(_)));
641        // B only has 1 → still pending
642        let rb = svc
643            .ready()
644            .await
645            .unwrap()
646            .call(make_exchange("orderId", "B", "b_check"))
647            .await
648            .unwrap();
649        assert!(matches!(rb.input.body, Body::Empty));
650    }
651
652    #[tokio::test]
653    async fn test_bucket_resets_after_completion() {
654        let mut svc = new_test_svc(config_size(2));
655        svc.ready()
656            .await
657            .unwrap()
658            .call(make_exchange("orderId", "A", "x"))
659            .await
660            .unwrap();
661        svc.ready()
662            .await
663            .unwrap()
664            .call(make_exchange("orderId", "A", "x"))
665            .await
666            .unwrap(); // completes
667        // New bucket starts
668        let r = svc
669            .ready()
670            .await
671            .unwrap()
672            .call(make_exchange("orderId", "A", "new"))
673            .await
674            .unwrap();
675        assert!(matches!(r.input.body, Body::Empty)); // pending again
676    }
677
678    #[tokio::test]
679    async fn test_completion_size_1_emits_immediately() {
680        let mut svc = new_test_svc(config_size(1));
681        let ex = make_exchange("orderId", "A", "solo");
682        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
683        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
684    }
685
686    #[tokio::test]
687    async fn test_custom_aggregation_strategy() {
688        use camel_api::aggregator::AggregationFn;
689        use std::sync::Arc;
690
691        let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
692            let combined = format!(
693                "{}+{}",
694                acc.input.body.as_text().unwrap_or(""),
695                next.input.body.as_text().unwrap_or("")
696            );
697            acc.input.body = Body::Text(combined);
698            acc
699        });
700        let config = AggregatorConfig::correlate_by("key")
701            .complete_when_size(2)
702            .strategy(AggregationStrategy::Custom(f))
703            .build()
704            .unwrap();
705        let mut svc = new_test_svc(config);
706        svc.ready()
707            .await
708            .unwrap()
709            .call(make_exchange("key", "X", "hello"))
710            .await
711            .unwrap();
712        let result = svc
713            .ready()
714            .await
715            .unwrap()
716            .call(make_exchange("key", "X", "world"))
717            .await
718            .unwrap();
719        assert_eq!(result.input.body.as_text(), Some("hello+world"));
720    }
721
722    #[tokio::test]
723    async fn test_completion_predicate() {
724        let config = AggregatorConfig::correlate_by("key")
725            .complete_when(|bucket| {
726                bucket
727                    .iter()
728                    .any(|e| e.input.body.as_text() == Some("DONE"))
729            })
730            .build()
731            .unwrap();
732        let mut svc = new_test_svc(config);
733        svc.ready()
734            .await
735            .unwrap()
736            .call(make_exchange("key", "K", "first"))
737            .await
738            .unwrap();
739        svc.ready()
740            .await
741            .unwrap()
742            .call(make_exchange("key", "K", "second"))
743            .await
744            .unwrap();
745        let result = svc
746            .ready()
747            .await
748            .unwrap()
749            .call(make_exchange("key", "K", "DONE"))
750            .await
751            .unwrap();
752        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
753    }
754
755    #[tokio::test]
756    async fn test_missing_header_returns_error() {
757        let mut svc = new_test_svc(config_size(2));
758        let msg = Message {
759            headers: Default::default(),
760            body: Body::Text("no key".into()),
761        };
762        let ex = Exchange::new(msg);
763        let result = svc.ready().await.unwrap().call(ex).await;
764        assert!(result.is_err());
765        assert!(matches!(
766            result.unwrap_err(),
767            camel_api::CamelError::ProcessorError(_)
768        ));
769    }
770
771    #[tokio::test]
772    async fn test_cloned_service_shares_state() {
773        let svc1 = new_test_svc(config_size(2));
774        let mut svc2 = svc1.clone();
775        // send first exchange via svc1
776        svc1.clone()
777            .ready()
778            .await
779            .unwrap()
780            .call(make_exchange("orderId", "A", "from-svc1"))
781            .await
782            .unwrap();
783        // send second exchange via svc2 — should complete because same Arc<Mutex>
784        let result = svc2
785            .ready()
786            .await
787            .unwrap()
788            .call(make_exchange("orderId", "A", "from-svc2"))
789            .await
790            .unwrap();
791        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
792    }
793
794    #[tokio::test]
795    async fn test_camel_aggregated_key_property_set() {
796        let mut svc = new_test_svc(config_size(1));
797        let ex = make_exchange("orderId", "ORDER-42", "body");
798        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
799        assert_eq!(
800            result.property(CAMEL_AGGREGATED_KEY),
801            Some(&serde_json::json!("ORDER-42"))
802        );
803    }
804
805    #[tokio::test]
806    async fn test_aggregator_enforces_max_buckets() {
807        let config = AggregatorConfig::correlate_by("orderId")
808            .complete_when_size(2)
809            .max_buckets(3)
810            .build()
811            .unwrap();
812
813        let mut svc = new_test_svc(config);
814
815        // Create 3 different correlation keys (fills limit)
816        for i in 0..3 {
817            let ex = make_exchange("orderId", &format!("key-{}", i), "body");
818            let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
819        }
820
821        // 4th key should be rejected
822        let ex = make_exchange("orderId", "key-4", "body");
823        let result = svc.ready().await.unwrap().call(ex).await;
824
825        assert!(result.is_err(), "Should reject when max buckets reached");
826        let err = result.unwrap_err().to_string();
827        assert!(
828            err.contains("maximum"),
829            "Error message should contain 'maximum': {}",
830            err
831        );
832    }
833
834    #[tokio::test]
835    async fn test_max_buckets_allows_existing_key() {
836        let config = AggregatorConfig::correlate_by("orderId")
837            .complete_when_size(5) // Large size so bucket doesn't complete
838            .max_buckets(2)
839            .build()
840            .unwrap();
841
842        let mut svc = new_test_svc(config);
843
844        // Create 2 different correlation keys (fills limit)
845        let ex1 = make_exchange("orderId", "key-A", "body1");
846        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
847        let ex2 = make_exchange("orderId", "key-B", "body2");
848        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
849
850        // Should still allow adding to existing key
851        let ex3 = make_exchange("orderId", "key-A", "body3");
852        let result = svc.ready().await.unwrap().call(ex3).await;
853        assert!(
854            result.is_ok(),
855            "Should allow adding to existing bucket even at max limit"
856        );
857    }
858
859    #[tokio::test]
860    async fn test_bucket_ttl_eviction() {
861        let config = AggregatorConfig::correlate_by("orderId")
862            .complete_when_size(10) // Large size so bucket doesn't complete normally
863            .bucket_ttl(Duration::from_millis(50))
864            .build()
865            .unwrap();
866
867        let mut svc = new_test_svc(config);
868
869        // Create a bucket
870        let ex1 = make_exchange("orderId", "key-A", "body1");
871        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
872
873        // Wait for TTL to expire
874        tokio::time::sleep(Duration::from_millis(100)).await;
875
876        // Create a new bucket - this should trigger eviction of the old one
877        let ex2 = make_exchange("orderId", "key-B", "body2");
878        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
879
880        // The expired bucket should have been evicted, so we should be able to
881        // add a new key-A bucket again
882        let ex3 = make_exchange("orderId", "key-A", "body3");
883        let result = svc.ready().await.unwrap().call(ex3).await;
884        assert!(result.is_ok(), "Should be able to recreate evicted bucket");
885    }
886
887    #[tokio::test(start_paused = true)]
888    async fn test_timeout_completes_bucket() {
889        let config = AggregatorConfig::correlate_by("key")
890            .complete_on_timeout(Duration::from_millis(100))
891            .build()
892            .unwrap();
893        let mut svc = new_test_svc(config);
894        let ex = make_exchange("key", "A", "data");
895        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
896        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_some());
897
898        tokio::time::sleep(Duration::from_millis(200)).await;
899
900        assert_eq!(
901            svc.buckets.lock().unwrap().len(),
902            0,
903            "bucket should be removed after timeout"
904        );
905    }
906
907    #[tokio::test(start_paused = true)]
908    async fn test_timeout_resets_on_new_exchange() {
909        let config = AggregatorConfig::correlate_by("key")
910            .complete_on_timeout(Duration::from_millis(150))
911            .build()
912            .unwrap();
913        let mut svc = new_test_svc(config);
914
915        let ex1 = make_exchange("key", "A", "first");
916        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
917
918        tokio::time::sleep(Duration::from_millis(100)).await;
919
920        let ex2 = make_exchange("key", "A", "second");
921        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
922
923        tokio::time::sleep(Duration::from_millis(100)).await;
924
925        assert_eq!(
926            svc.buckets.lock().unwrap().len(),
927            1,
928            "bucket should still exist — timeout was reset"
929        );
930
931        tokio::time::sleep(Duration::from_millis(100)).await;
932
933        assert_eq!(
934            svc.buckets.lock().unwrap().len(),
935            0,
936            "bucket should be gone after timeout fires"
937        );
938    }
939
940    #[tokio::test]
941    async fn test_composable_size_and_timeout() {
942        let config = AggregatorConfig::correlate_by("key")
943            .complete_on_size_or_timeout(2, Duration::from_millis(200))
944            .build()
945            .unwrap();
946        let mut svc = new_test_svc(config);
947
948        let ex1 = make_exchange("key", "A", "first");
949        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
950        assert!(svc.buckets.lock().unwrap().contains_key("\"A\""));
951
952        let ex2 = make_exchange("key", "A", "second");
953        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
954        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
955        assert_eq!(
956            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
957            Some(&serde_json::json!("size"))
958        );
959    }
960
961    #[tokio::test(start_paused = true)]
962    async fn test_discard_on_timeout() {
963        let config = AggregatorConfig::correlate_by("key")
964            .complete_on_timeout(Duration::from_millis(50))
965            .discard_on_timeout(true)
966            .build()
967            .unwrap();
968        let (tx, mut rx) = mpsc::channel(256);
969        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
970        let cancel = CancellationToken::new();
971        let mut svc = AggregatorService::new(config, tx, registry, cancel);
972
973        let ex = make_exchange("key", "A", "data");
974        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
975
976        tokio::time::sleep(Duration::from_millis(100)).await;
977
978        assert!(
979            rx.try_recv().is_err(),
980            "no emit expected with discard_on_timeout"
981        );
982        assert_eq!(svc.buckets.lock().unwrap().len(), 0);
983        assert!(
984            svc.timeout_tasks.lock().unwrap().is_empty(),
985            "timeout task should be cleaned up"
986        );
987    }
988
989    #[tokio::test]
990    async fn test_force_completion_on_stop() {
991        let config = AggregatorConfig::correlate_by("key")
992            .complete_when_size(10)
993            .force_completion_on_stop(true)
994            .build()
995            .unwrap();
996        let (tx, mut rx) = mpsc::channel(256);
997        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
998        let cancel = CancellationToken::new();
999        let svc = AggregatorService::new(config, tx, registry, cancel);
1000
1001        let mut call_svc = svc.clone();
1002        let ex = make_exchange("key", "A", "data");
1003        let _ = call_svc.ready().await.unwrap().call(ex).await.unwrap();
1004
1005        svc.force_complete_all();
1006
1007        let result = rx.try_recv().expect("should emit on force-complete");
1008        assert!(
1009            result.input.body.as_text().is_some() || matches!(result.input.body, Body::Json(_))
1010        );
1011        assert_eq!(
1012            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1013            Some(&serde_json::json!("stop"))
1014        );
1015    }
1016
1017    #[tokio::test]
1018    async fn test_completion_reason_property_size() {
1019        let config = AggregatorConfig::correlate_by("key")
1020            .complete_when_size(1)
1021            .build()
1022            .unwrap();
1023        let mut svc = new_test_svc(config);
1024        let ex = make_exchange("key", "X", "body");
1025        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
1026        assert_eq!(
1027            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1028            Some(&serde_json::json!("size"))
1029        );
1030    }
1031
1032    #[tokio::test]
1033    async fn test_completion_reason_property_predicate() {
1034        let config = AggregatorConfig::correlate_by("key")
1035            .complete_when(|_| true)
1036            .build()
1037            .unwrap();
1038        let mut svc = new_test_svc(config);
1039        let ex = make_exchange("key", "X", "body");
1040        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
1041        assert_eq!(
1042            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1043            Some(&serde_json::json!("predicate"))
1044        );
1045    }
1046
1047    #[tokio::test(start_paused = true)]
1048    async fn test_size_completes_before_timeout() {
1049        let config = AggregatorConfig::correlate_by("key")
1050            .complete_on_size_or_timeout(2, Duration::from_millis(200))
1051            .build()
1052            .unwrap();
1053        let mut svc = new_test_svc(config);
1054
1055        let ex1 = make_exchange("key", "A", "first");
1056        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
1057
1058        let ex2 = make_exchange("key", "A", "second");
1059        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
1060
1061        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
1062        assert_eq!(
1063            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1064            Some(&serde_json::json!("size"))
1065        );
1066        assert_eq!(svc.buckets.lock().unwrap().len(), 0);
1067
1068        tokio::time::sleep(Duration::from_millis(300)).await;
1069        assert_eq!(
1070            svc.buckets.lock().unwrap().len(),
1071            0,
1072            "no re-fire after timeout"
1073        );
1074    }
1075
1076    #[tokio::test(start_paused = true)]
1077    async fn test_concurrent_timeout_fire_and_new_exchange() {
1078        let config = AggregatorConfig::correlate_by("key")
1079            .complete_on_size_or_timeout(2, Duration::from_millis(100))
1080            .build()
1081            .unwrap();
1082        let (tx, mut rx) = mpsc::channel(256);
1083        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1084        let cancel = CancellationToken::new();
1085        let mut svc = AggregatorService::new(config, tx, registry, cancel);
1086
1087        let ex = make_exchange("key", "A", "data");
1088        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
1089
1090        // Advance time past timeout — timeout task fires and removes bucket
1091        tokio::time::sleep(Duration::from_millis(150)).await;
1092
1093        // New exchange arrives after timeout — starts a fresh bucket
1094        let ex2 = make_exchange("key", "A", "data2");
1095        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
1096        assert!(
1097            result.property(CAMEL_AGGREGATOR_PENDING).is_some(),
1098            "should be pending in new bucket"
1099        );
1100
1101        // Drain late emits from timeout
1102        let mut late_count = 0;
1103        while rx.try_recv().is_ok() {
1104            late_count += 1;
1105        }
1106        assert_eq!(
1107            late_count, 1,
1108            "exactly 1 late emit from the timed-out bucket"
1109        );
1110    }
1111
1112    #[tokio::test(start_paused = true)]
1113    async fn test_late_channel_full_drops_with_warning() {
1114        let config = AggregatorConfig::correlate_by("key")
1115            .complete_on_timeout(Duration::from_millis(50))
1116            .build()
1117            .unwrap();
1118        let (tx, mut rx) = mpsc::channel(1);
1119        rx.close();
1120        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1121        let cancel = CancellationToken::new();
1122        let mut svc = AggregatorService::new(config, tx, registry, cancel);
1123
1124        let ex = make_exchange("key", "A", "data");
1125        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
1126
1127        tokio::time::sleep(Duration::from_millis(100)).await;
1128        assert_eq!(
1129            svc.buckets.lock().unwrap().len(),
1130            0,
1131            "bucket removed despite channel closed"
1132        );
1133    }
1134
1135    #[tokio::test]
1136    async fn test_aggregate_stream_bodies_creates_valid_json() {
1137        use bytes::Bytes;
1138        use camel_api::{Body, StreamBody, StreamMetadata};
1139        use futures::stream;
1140        use tokio::sync::Mutex;
1141
1142        let chunks = vec![Ok(Bytes::from("test"))];
1143        let stream_body = StreamBody {
1144            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
1145            metadata: StreamMetadata {
1146                origin: Some("file:///test.txt".to_string()),
1147                ..Default::default()
1148            },
1149        };
1150
1151        let ex1 = Exchange::new(Message {
1152            headers: Default::default(),
1153            body: Body::Stream(stream_body),
1154        });
1155
1156        let exchanges = vec![ex1];
1157        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
1158
1159        let exchange = result.expect("Expected Ok result");
1160        assert!(
1161            matches!(exchange.input.body, Body::Json(_)),
1162            "Expected Json body"
1163        );
1164
1165        if let Body::Json(value) = exchange.input.body {
1166            let json_str = serde_json::to_string(&value).unwrap();
1167            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1168
1169            assert!(parsed.is_array(), "Result should be an array");
1170            let arr = parsed.as_array().unwrap();
1171            assert!(arr[0].is_object(), "First element should be an object");
1172            assert!(
1173                arr[0]["_stream"].is_object(),
1174                "Should contain _stream object"
1175            );
1176            assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
1177            assert_eq!(
1178                arr[0]["_stream"]["placeholder"], true,
1179                "placeholder flag should be true"
1180            );
1181        }
1182    }
1183
1184    #[tokio::test]
1185    async fn test_aggregate_stream_bodies_with_none_origin() {
1186        use bytes::Bytes;
1187        use camel_api::{Body, StreamBody, StreamMetadata};
1188        use futures::stream;
1189        use tokio::sync::Mutex;
1190
1191        let chunks = vec![Ok(Bytes::from("test"))];
1192        let stream_body = StreamBody {
1193            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
1194            metadata: StreamMetadata {
1195                origin: None,
1196                ..Default::default()
1197            },
1198        };
1199
1200        let ex1 = Exchange::new(Message {
1201            headers: Default::default(),
1202            body: Body::Stream(stream_body),
1203        });
1204
1205        let exchanges = vec![ex1];
1206        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
1207
1208        let exchange = result.expect("Expected Ok result");
1209        assert!(
1210            matches!(exchange.input.body, Body::Json(_)),
1211            "Expected Json body"
1212        );
1213
1214        if let Body::Json(value) = exchange.input.body {
1215            let json_str = serde_json::to_string(&value).unwrap();
1216            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1217
1218            assert!(parsed.is_array(), "Result should be an array");
1219            let arr = parsed.as_array().unwrap();
1220            assert!(arr[0].is_object(), "First element should be an object");
1221            assert!(
1222                arr[0]["_stream"].is_object(),
1223                "Should contain _stream object"
1224            );
1225            assert_eq!(
1226                arr[0]["_stream"]["origin"],
1227                serde_json::Value::Null,
1228                "origin should be null when None"
1229            );
1230            assert_eq!(
1231                arr[0]["_stream"]["placeholder"], true,
1232                "placeholder flag should be true"
1233            );
1234        }
1235    }
1236
1237    #[tokio::test(start_paused = true)]
1238    async fn test_shutdown_awaits_timeout_handles() {
1239        let config = AggregatorConfig::correlate_by("key")
1240            .complete_on_timeout(Duration::from_millis(100))
1241            .build()
1242            .unwrap();
1243        let (tx, _rx) = mpsc::channel(256);
1244        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1245        let cancel = CancellationToken::new();
1246        let svc = AggregatorService::new(config, tx, registry, cancel);
1247
1248        // Send an exchange to create a pending bucket with a timeout task.
1249        let mut call_svc = svc.clone();
1250        let ex = make_exchange("key", "A", "data");
1251        let _ = call_svc.ready().await.unwrap().call(ex).await.unwrap();
1252
1253        // Verify timeout handle exists.
1254        assert!(
1255            !svc.timeout_handles.lock().unwrap().is_empty(),
1256            "should have a timeout handle"
1257        );
1258
1259        // Shutdown should complete within the 5s deadline (the timeout task
1260        // gets cancelled so it won't wait for the full 100ms sleep).
1261        svc.shutdown().await;
1262
1263        assert!(
1264            svc.timeout_handles.lock().unwrap().is_empty(),
1265            "all handles should be cleaned up after shutdown"
1266        );
1267    }
1268}