Skip to main content

camel_processor/
aggregator.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::pin::Pin;
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6use std::time::{Duration, Instant};
7
8use tokio::sync::mpsc;
9use tokio::task::JoinHandle;
10use tokio_util::sync::CancellationToken;
11use tower::Service;
12
13use async_trait::async_trait;
14use camel_api::{
15    CamelError, StepLifecycle, StepShutdownReason,
16    aggregator::{
17        AggregationStrategy, AggregatorConfig, CompletionCondition, CompletionMode,
18        CompletionReason, CorrelationStrategy,
19    },
20    body::Body,
21    exchange::Exchange,
22    message::Message,
23};
24use camel_language_api::Language;
25
26pub type SharedLanguageRegistry = Arc<std::sync::Mutex<HashMap<String, Arc<dyn Language>>>>;
27
28pub const CAMEL_AGGREGATOR_PENDING: &str = "CamelAggregatorPending";
29pub const CAMEL_AGGREGATED_SIZE: &str = "CamelAggregatedSize";
30pub const CAMEL_AGGREGATED_KEY: &str = "CamelAggregatedKey";
31pub const CAMEL_AGGREGATED_COMPLETION_REASON: &str = "CamelAggregatedCompletionReason";
32
33/// Internal bucket structure with timestamp tracking for TTL eviction.
34struct Bucket {
35    exchanges: Vec<Exchange>,
36    last_updated: Instant,
37}
38
39impl Bucket {
40    fn new() -> Self {
41        Self {
42            exchanges: Vec::new(),
43            last_updated: Instant::now(),
44        }
45    }
46
47    fn push(&mut self, exchange: Exchange) {
48        self.exchanges.push(exchange);
49        self.last_updated = Instant::now();
50    }
51
52    fn is_expired(&self, ttl: Duration) -> bool {
53        Instant::now().duration_since(self.last_updated) >= ttl
54    }
55}
56
57#[derive(Clone)]
58pub struct AggregatorService {
59    config: AggregatorConfig,
60    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
61    timeout_tasks: Arc<Mutex<HashMap<String, CancellationToken>>>,
62    timeout_handles: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
63    late_tx: mpsc::Sender<Exchange>,
64    language_registry: SharedLanguageRegistry,
65    route_cancel: CancellationToken,
66}
67
68impl std::fmt::Debug for AggregatorService {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("AggregatorService").finish_non_exhaustive()
71    }
72}
73
74impl AggregatorService {
75    pub fn new(
76        config: AggregatorConfig,
77        late_tx: mpsc::Sender<Exchange>,
78        language_registry: SharedLanguageRegistry,
79        route_cancel: CancellationToken,
80    ) -> Self {
81        Self {
82            config,
83            buckets: Arc::new(Mutex::new(HashMap::new())),
84            timeout_tasks: Arc::new(Mutex::new(HashMap::new())),
85            timeout_handles: Arc::new(Mutex::new(HashMap::new())),
86            late_tx,
87            language_registry,
88            route_cancel,
89        }
90    }
91
92    pub fn config(&self) -> &AggregatorConfig {
93        &self.config
94    }
95
96    pub fn has_timeout(&self) -> bool {
97        has_timeout_condition(&self.config.completion)
98    }
99
100    pub fn force_complete_all(&self) {
101        let mut buckets_guard = self.buckets.lock().unwrap_or_else(|e| e.into_inner());
102        let keys: Vec<String> = buckets_guard.keys().cloned().collect();
103
104        for key in keys {
105            if let Some(bucket) = buckets_guard.remove(&key) {
106                if self.config.force_completion_on_stop {
107                    cancel_timeout_task_with_handle(
108                        &key,
109                        &self.timeout_tasks,
110                        &self.timeout_handles,
111                    );
112                    match aggregate(bucket.exchanges, &self.config.strategy) {
113                        Ok(mut result) => {
114                            result.set_property(
115                                CAMEL_AGGREGATED_COMPLETION_REASON,
116                                serde_json::json!(CompletionReason::Stop.as_str()),
117                            );
118                            if self.late_tx.try_send(result).is_err() {
119                                tracing::warn!(
120                                    key = %key,
121                                    "aggregator force-complete emit dropped: late channel full"
122                                );
123                            }
124                        }
125                        Err(e) => {
126                            // log-policy: handler-owned
127                            tracing::warn!(
128                                key = %key,
129                                error = %e,
130                                "aggregation failed in force_complete_all"
131                            );
132                        }
133                    }
134                } else {
135                    cancel_timeout_task_with_handle(
136                        &key,
137                        &self.timeout_tasks,
138                        &self.timeout_handles,
139                    );
140                }
141            }
142        }
143    }
144
145    /// Graceful shutdown: cancel all outstanding timeout tasks and await their
146    /// JoinHandles (with a 5s deadline) so that no tasks are leaked.
147    pub(crate) async fn shutdown_inner(&self) {
148        // Cancel all timeout cancellation tokens.
149        {
150            let mut guard = self.timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
151            for token in guard.values() {
152                token.cancel();
153            }
154            guard.clear();
155        };
156
157        // Remove and collect all JoinHandles.
158        let handles: Vec<JoinHandle<()>> = {
159            let mut guard = self
160                .timeout_handles
161                .lock()
162                .unwrap_or_else(|e| e.into_inner());
163            guard.drain().map(|(_, handle)| handle).collect()
164        };
165
166        if handles.is_empty() {
167            return;
168        }
169
170        // Await all handles with a deadline.
171        let _ = tokio::time::timeout(Duration::from_secs(5), async {
172            for handle in handles {
173                let _ = handle.await;
174            }
175        })
176        .await;
177    }
178}
179
180#[async_trait]
181impl StepLifecycle for AggregatorService {
182    fn name(&self) -> &'static str {
183        "aggregator"
184    }
185
186    async fn shutdown(&self, reason: StepShutdownReason) -> Result<(), CamelError> {
187        tracing::debug!(reason = ?reason, "Aggregator shutdown via StepLifecycle");
188        self.shutdown_inner().await;
189        Ok(())
190    }
191}
192
193pub fn has_timeout_condition(mode: &CompletionMode) -> bool {
194    match mode {
195        CompletionMode::Single(CompletionCondition::Timeout(_)) => true,
196        CompletionMode::Any(conditions) => conditions
197            .iter()
198            .any(|c| matches!(c, CompletionCondition::Timeout(_))),
199        _ => false,
200    }
201}
202
203impl Service<Exchange> for AggregatorService {
204    type Response = Exchange;
205    type Error = CamelError;
206    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
207
208    fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), CamelError>> {
209        Poll::Ready(Ok(()))
210    }
211
212    fn call(&mut self, exchange: Exchange) -> Self::Future {
213        let config = self.config.clone();
214        let buckets = Arc::clone(&self.buckets);
215        let timeout_tasks = Arc::clone(&self.timeout_tasks);
216        let timeout_handles = Arc::clone(&self.timeout_handles);
217        let late_tx = self.late_tx.clone();
218        let language_registry = Arc::clone(&self.language_registry);
219        let route_cancel = self.route_cancel.clone();
220
221        Box::pin(async move {
222            let key_value =
223                extract_correlation_key(&exchange, &config.correlation, &language_registry).await?;
224
225            let key_str = serde_json::to_string(&key_value)
226                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
227
228            let completed_bucket = {
229                let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
230
231                if let Some(ttl) = config.bucket_ttl {
232                    guard.retain(|_, bucket| !bucket.is_expired(ttl));
233                }
234
235                if let Some(max) = config.max_buckets
236                    && !guard.contains_key(&key_str)
237                    && guard.len() >= max
238                {
239                    tracing::warn!(
240                        max_buckets = max,
241                        correlation_key = %key_str,
242                        "Aggregator reached max buckets limit, rejecting new correlation key"
243                    );
244                    return Err(CamelError::ProcessorError(format!(
245                        "Aggregator reached maximum {} buckets",
246                        max
247                    )));
248                }
249
250                let bucket = guard.entry(key_str.clone()).or_insert_with(Bucket::new);
251                bucket.push(exchange);
252
253                let (is_complete, reason) =
254                    check_sync_completion(&config.completion, &bucket.exchanges);
255
256                if is_complete {
257                    let exchanges = guard.remove(&key_str).map(|b| b.exchanges);
258                    (exchanges, reason)
259                } else {
260                    (None, CompletionReason::Size) // placeholder; reason unused when None
261                }
262            };
263
264            if completed_bucket.0.is_none() && has_timeout_condition(&config.completion) {
265                let cancel = {
266                    let mut tt_guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
267                    // Cancel and remove old handle for this key (if any).
268                    if let Some(existing) = tt_guard.get(&key_str) {
269                        existing.cancel();
270                    }
271                    let token = CancellationToken::new();
272                    tt_guard.insert(key_str.clone(), token.clone());
273                    token
274                };
275
276                let timeout_dur = extract_timeout_duration(&config.completion);
277                if let Some(timeout) = timeout_dur {
278                    // Remove old handle if present.
279                    {
280                        let mut hh = timeout_handles.lock().unwrap_or_else(|e| e.into_inner());
281                        if let Some(old) = hh.remove(&key_str) {
282                            old.abort();
283                        }
284                    }
285                    let handle = spawn_timeout_task(
286                        key_str.clone(),
287                        timeout,
288                        cancel,
289                        buckets.clone(),
290                        timeout_tasks.clone(),
291                        timeout_handles.clone(),
292                        late_tx,
293                        config.strategy.clone(),
294                        config.discard_on_timeout,
295                        route_cancel,
296                    );
297                    timeout_handles
298                        .lock()
299                        .unwrap_or_else(|e| e.into_inner())
300                        .insert(key_str.clone(), handle);
301                }
302            }
303
304            if let Some(exchanges) = completed_bucket.0 {
305                cancel_timeout_task_with_handle(&key_str, &timeout_tasks, &timeout_handles);
306                let reason = completed_bucket.1;
307                let size = exchanges.len();
308                let mut result = aggregate(exchanges, &config.strategy)?;
309                result.set_property(CAMEL_AGGREGATED_SIZE, serde_json::json!(size as u64));
310                result.set_property(CAMEL_AGGREGATED_KEY, key_value);
311                result.set_property(
312                    CAMEL_AGGREGATED_COMPLETION_REASON,
313                    serde_json::json!(reason.as_str()),
314                );
315                Ok(result)
316            } else {
317                let mut pending = Exchange::new(Message {
318                    headers: Default::default(),
319                    body: Body::Empty,
320                });
321                pending.set_property(CAMEL_AGGREGATOR_PENDING, serde_json::json!(true));
322                Ok(pending)
323            }
324        })
325    }
326}
327
328async fn extract_correlation_key(
329    exchange: &Exchange,
330    strategy: &CorrelationStrategy,
331    registry: &SharedLanguageRegistry,
332) -> Result<serde_json::Value, CamelError> {
333    match strategy {
334        CorrelationStrategy::HeaderName(h) => {
335            exchange.input.headers.get(h).cloned().ok_or_else(|| {
336                CamelError::ProcessorError(format!(
337                    "Aggregator: missing correlation key header '{}'",
338                    h
339                ))
340            })
341        }
342        CorrelationStrategy::Expression { expr, language } => {
343            let expression = {
344                let reg = registry.lock().unwrap_or_else(|e| e.into_inner());
345                let lang = reg.get(language).ok_or_else(|| {
346                    CamelError::ProcessorError(format!(
347                        "Aggregator: language '{}' not found in registry",
348                        language
349                    ))
350                })?;
351                lang.create_expression(expr)
352                    .map_err(|e| CamelError::ProcessorError(e.to_string()))?
353            };
354            let value = expression
355                .evaluate(exchange)
356                .await
357                .map_err(|e| CamelError::ProcessorError(e.to_string()))?;
358            if value.is_null() {
359                return Err(CamelError::ProcessorError(format!(
360                    "Aggregator: correlation expression '{}' evaluated to null",
361                    expr
362                )));
363            }
364            Ok(value)
365        }
366        CorrelationStrategy::Fn(f) => f(exchange).map(serde_json::Value::String).ok_or_else(|| {
367            CamelError::ProcessorError("Aggregator: correlation function returned None".to_string())
368        }),
369    }
370}
371
372fn check_sync_completion(
373    mode: &CompletionMode,
374    exchanges: &[Exchange],
375) -> (bool, CompletionReason) {
376    match mode {
377        CompletionMode::Single(cond) => check_single(cond, exchanges),
378        CompletionMode::Any(conditions) => {
379            for cond in conditions {
380                if let CompletionCondition::Timeout(_) = cond {
381                    continue;
382                }
383                let (done, reason) = check_single(cond, exchanges);
384                if done {
385                    return (true, reason);
386                }
387            }
388            (false, CompletionReason::Size)
389        }
390    }
391}
392
393fn check_single(cond: &CompletionCondition, exchanges: &[Exchange]) -> (bool, CompletionReason) {
394    match cond {
395        CompletionCondition::Size(n) => (exchanges.len() >= *n, CompletionReason::Size),
396        CompletionCondition::Predicate(pred) => (pred(exchanges), CompletionReason::Predicate),
397        CompletionCondition::Timeout(_) => (false, CompletionReason::Timeout),
398    }
399}
400
401fn extract_timeout_duration(mode: &CompletionMode) -> Option<Duration> {
402    match mode {
403        CompletionMode::Single(CompletionCondition::Timeout(d)) => Some(*d),
404        CompletionMode::Any(conditions) => conditions.iter().find_map(|c| {
405            if let CompletionCondition::Timeout(d) = c {
406                Some(*d)
407            } else {
408                None
409            }
410        }),
411        _ => None,
412    }
413}
414
415fn cancel_timeout_task(key: &str, timeout_tasks: &Arc<Mutex<HashMap<String, CancellationToken>>>) {
416    let mut guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
417    if let Some(token) = guard.remove(key) {
418        token.cancel();
419    }
420}
421
422/// Also removes the stored JoinHandle for a cancelled/completed timeout task.
423fn cancel_timeout_task_with_handle(
424    key: &str,
425    timeout_tasks: &Arc<Mutex<HashMap<String, CancellationToken>>>,
426    timeout_handles: &Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
427) {
428    cancel_timeout_task(key, timeout_tasks);
429    let mut guard = timeout_handles.lock().unwrap_or_else(|e| e.into_inner());
430    guard.remove(key);
431}
432
433#[allow(clippy::too_many_arguments)]
434fn spawn_timeout_task(
435    key: String,
436    timeout: Duration,
437    cancel: CancellationToken,
438    buckets: Arc<Mutex<HashMap<String, Bucket>>>,
439    timeout_tasks: Arc<Mutex<HashMap<String, CancellationToken>>>,
440    timeout_handles: Arc<Mutex<HashMap<String, JoinHandle<()>>>>,
441    late_tx: mpsc::Sender<Exchange>,
442    strategy: AggregationStrategy,
443    discard: bool,
444    _route_cancel: CancellationToken,
445) -> JoinHandle<()> {
446    let cancel_clone = cancel.clone();
447    tokio::spawn(async move {
448        tokio::select! {
449            _ = tokio::time::sleep(timeout) => {
450                let should_proceed = {
451                    let mut tt_guard = timeout_tasks.lock().unwrap_or_else(|e| e.into_inner());
452                    if cancel_clone.is_cancelled() {
453                        false
454                    } else {
455                        tt_guard.remove(&key);
456                        true
457                    }
458                };
459                if !should_proceed {
460                    return;
461                }
462                // Clean up our own JoinHandle from the map on natural completion.
463                // Without this, the handle entry leaks until route shutdown.
464                {
465                    let mut hh = timeout_handles.lock().unwrap_or_else(|e| e.into_inner());
466                    hh.remove(&key);
467                }
468                let bucket_exchanges = {
469                    let mut guard = buckets.lock().unwrap_or_else(|e| e.into_inner());
470                    guard.remove(&key).map(|b| b.exchanges)
471                };
472                if let Some(exchanges) = bucket_exchanges
473                    && !discard
474                {
475                    match aggregate(exchanges, &strategy) {
476                        Ok(mut result) => {
477                            result.set_property(
478                                CAMEL_AGGREGATED_COMPLETION_REASON,
479                                serde_json::json!(CompletionReason::Timeout.as_str()),
480                            );
481                            if late_tx.try_send(result).is_err() {
482                                tracing::warn!(
483                                    key = %key,
484                                    "aggregator timeout emit dropped: late channel full"
485                                );
486                            }
487                        }
488                        Err(e) => {
489                            // log-policy: handler-owned
490                            tracing::warn!(
491                                key = %key,
492                                error = %e,
493                                "aggregation failed in timeout task"
494                            );
495                        }
496                    }
497                }
498            }
499            _ = cancel_clone.cancelled() => {}
500        }
501    })
502}
503
504fn aggregate(
505    exchanges: Vec<Exchange>,
506    strategy: &AggregationStrategy,
507) -> Result<Exchange, CamelError> {
508    match strategy {
509        AggregationStrategy::CollectAll => {
510            let bodies: Vec<serde_json::Value> = exchanges
511                .into_iter()
512                .map(|e| match e.input.body {
513                    Body::Json(v) => v,
514                    Body::Text(s) => serde_json::Value::String(s),
515                    Body::Xml(s) => serde_json::Value::String(s),
516                    Body::Bytes(b) => {
517                        serde_json::Value::String(String::from_utf8_lossy(&b).into_owned())
518                    }
519                    Body::Empty => serde_json::Value::Null,
520                    Body::Stream(s) => serde_json::json!({
521                        "_stream": {
522                            "origin": s.metadata.origin,
523                            "placeholder": true,
524                            "hint": "Materialize exchange body with .into_bytes() before aggregation if content needed"
525                        }
526                    }),
527                })
528                .collect();
529            Ok(Exchange::new(Message {
530                headers: Default::default(),
531                body: Body::Json(serde_json::Value::Array(bodies)),
532            }))
533        }
534        AggregationStrategy::Custom(f) => {
535            let mut iter = exchanges.into_iter();
536            let first = iter.next().ok_or_else(|| {
537                CamelError::ProcessorError("Aggregator: empty bucket".to_string())
538            })?;
539            Ok(iter.fold(first, |acc, next| f(acc, next)))
540        }
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547    use std::collections::HashMap;
548
549    use camel_api::{
550        StepLifecycle, StepShutdownReason,
551        aggregator::{AggregationStrategy, AggregatorConfig},
552        body::Body,
553        exchange::Exchange,
554        message::Message,
555    };
556    use tokio::sync::mpsc;
557    use tokio_util::sync::CancellationToken;
558    use tower::ServiceExt;
559
560    fn make_exchange(header: &str, value: &str, body: &str) -> Exchange {
561        let mut msg = Message {
562            headers: Default::default(),
563            body: Body::Text(body.to_string()),
564        };
565        msg.headers
566            .insert(header.to_string(), serde_json::json!(value));
567        Exchange::new(msg)
568    }
569
570    fn config_size(n: usize) -> AggregatorConfig {
571        AggregatorConfig::correlate_by("orderId")
572            .complete_when_size(n)
573            .build()
574            .unwrap()
575    }
576
577    fn new_test_svc(config: AggregatorConfig) -> AggregatorService {
578        let (tx, _rx) = mpsc::channel(256);
579        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
580        let cancel = CancellationToken::new();
581        AggregatorService::new(config, tx, registry, cancel)
582    }
583
584    #[tokio::test]
585    async fn test_pending_exchange_not_yet_complete() {
586        let mut svc = new_test_svc(config_size(3));
587        let ex = make_exchange("orderId", "A", "first");
588        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
589        assert!(matches!(result.input.body, Body::Empty));
590        assert_eq!(
591            result.property(CAMEL_AGGREGATOR_PENDING),
592            Some(&serde_json::json!(true))
593        );
594    }
595
596    #[tokio::test]
597    async fn test_completes_on_size() {
598        let mut svc = new_test_svc(config_size(3));
599        for _ in 0..2 {
600            let ex = make_exchange("orderId", "A", "item");
601            let r = svc.ready().await.unwrap().call(ex).await.unwrap();
602            assert!(matches!(r.input.body, Body::Empty));
603        }
604        let ex = make_exchange("orderId", "A", "last");
605        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
606        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
607        assert_eq!(
608            result.property(CAMEL_AGGREGATED_SIZE),
609            Some(&serde_json::json!(3u64))
610        );
611    }
612
613    #[tokio::test]
614    async fn test_collect_all_produces_json_array() {
615        let mut svc = new_test_svc(config_size(2));
616        svc.ready()
617            .await
618            .unwrap()
619            .call(make_exchange("orderId", "A", "alpha"))
620            .await
621            .unwrap();
622        let result = svc
623            .ready()
624            .await
625            .unwrap()
626            .call(make_exchange("orderId", "A", "beta"))
627            .await
628            .unwrap();
629        let Body::Json(v) = &result.input.body else {
630            panic!("expected Body::Json")
631        };
632        let arr = v.as_array().unwrap();
633        assert_eq!(arr.len(), 2);
634        assert_eq!(arr[0], serde_json::json!("alpha"));
635        assert_eq!(arr[1], serde_json::json!("beta"));
636    }
637
638    #[tokio::test]
639    async fn test_two_keys_independent_buckets() {
640        // completionSize=3 so we can test that A and B accumulate independently.
641        let mut svc = new_test_svc(config_size(3));
642        svc.ready()
643            .await
644            .unwrap()
645            .call(make_exchange("orderId", "A", "a1"))
646            .await
647            .unwrap();
648        svc.ready()
649            .await
650            .unwrap()
651            .call(make_exchange("orderId", "B", "b1"))
652            .await
653            .unwrap();
654        svc.ready()
655            .await
656            .unwrap()
657            .call(make_exchange("orderId", "A", "a2"))
658            .await
659            .unwrap();
660        // A has 2 items, B has 1 item — neither complete yet
661        let ra = svc
662            .ready()
663            .await
664            .unwrap()
665            .call(make_exchange("orderId", "A", "a3"))
666            .await
667            .unwrap();
668        // A now has 3 → completes
669        assert!(matches!(ra.input.body, Body::Json(_)));
670        // B only has 1 → still pending
671        let rb = svc
672            .ready()
673            .await
674            .unwrap()
675            .call(make_exchange("orderId", "B", "b_check"))
676            .await
677            .unwrap();
678        assert!(matches!(rb.input.body, Body::Empty));
679    }
680
681    #[tokio::test]
682    async fn test_bucket_resets_after_completion() {
683        let mut svc = new_test_svc(config_size(2));
684        svc.ready()
685            .await
686            .unwrap()
687            .call(make_exchange("orderId", "A", "x"))
688            .await
689            .unwrap();
690        svc.ready()
691            .await
692            .unwrap()
693            .call(make_exchange("orderId", "A", "x"))
694            .await
695            .unwrap(); // completes
696        // New bucket starts
697        let r = svc
698            .ready()
699            .await
700            .unwrap()
701            .call(make_exchange("orderId", "A", "new"))
702            .await
703            .unwrap();
704        assert!(matches!(r.input.body, Body::Empty)); // pending again
705    }
706
707    #[tokio::test]
708    async fn test_completion_size_1_emits_immediately() {
709        let mut svc = new_test_svc(config_size(1));
710        let ex = make_exchange("orderId", "A", "solo");
711        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
712        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
713    }
714
715    #[tokio::test]
716    async fn test_custom_aggregation_strategy() {
717        use camel_api::aggregator::AggregationFn;
718        use std::sync::Arc;
719
720        let f: AggregationFn = Arc::new(|mut acc: Exchange, next: Exchange| {
721            let combined = format!(
722                "{}+{}",
723                acc.input.body.as_text().unwrap_or(""),
724                next.input.body.as_text().unwrap_or("")
725            );
726            acc.input.body = Body::Text(combined);
727            acc
728        });
729        let config = AggregatorConfig::correlate_by("key")
730            .complete_when_size(2)
731            .strategy(AggregationStrategy::Custom(f))
732            .build()
733            .unwrap();
734        let mut svc = new_test_svc(config);
735        svc.ready()
736            .await
737            .unwrap()
738            .call(make_exchange("key", "X", "hello"))
739            .await
740            .unwrap();
741        let result = svc
742            .ready()
743            .await
744            .unwrap()
745            .call(make_exchange("key", "X", "world"))
746            .await
747            .unwrap();
748        assert_eq!(result.input.body.as_text(), Some("hello+world"));
749    }
750
751    #[tokio::test]
752    async fn test_completion_predicate() {
753        let config = AggregatorConfig::correlate_by("key")
754            .complete_when(|bucket| {
755                bucket
756                    .iter()
757                    .any(|e| e.input.body.as_text() == Some("DONE"))
758            })
759            .build()
760            .unwrap();
761        let mut svc = new_test_svc(config);
762        svc.ready()
763            .await
764            .unwrap()
765            .call(make_exchange("key", "K", "first"))
766            .await
767            .unwrap();
768        svc.ready()
769            .await
770            .unwrap()
771            .call(make_exchange("key", "K", "second"))
772            .await
773            .unwrap();
774        let result = svc
775            .ready()
776            .await
777            .unwrap()
778            .call(make_exchange("key", "K", "DONE"))
779            .await
780            .unwrap();
781        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
782    }
783
784    #[tokio::test]
785    async fn test_missing_header_returns_error() {
786        let mut svc = new_test_svc(config_size(2));
787        let msg = Message {
788            headers: Default::default(),
789            body: Body::Text("no key".into()),
790        };
791        let ex = Exchange::new(msg);
792        let result = svc.ready().await.unwrap().call(ex).await;
793        assert!(result.is_err());
794        assert!(matches!(
795            result.unwrap_err(),
796            camel_api::CamelError::ProcessorError(_)
797        ));
798    }
799
800    #[tokio::test]
801    async fn test_cloned_service_shares_state() {
802        let svc1 = new_test_svc(config_size(2));
803        let mut svc2 = svc1.clone();
804        // send first exchange via svc1
805        svc1.clone()
806            .ready()
807            .await
808            .unwrap()
809            .call(make_exchange("orderId", "A", "from-svc1"))
810            .await
811            .unwrap();
812        // send second exchange via svc2 — should complete because same Arc<Mutex>
813        let result = svc2
814            .ready()
815            .await
816            .unwrap()
817            .call(make_exchange("orderId", "A", "from-svc2"))
818            .await
819            .unwrap();
820        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
821    }
822
823    #[tokio::test]
824    async fn test_camel_aggregated_key_property_set() {
825        let mut svc = new_test_svc(config_size(1));
826        let ex = make_exchange("orderId", "ORDER-42", "body");
827        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
828        assert_eq!(
829            result.property(CAMEL_AGGREGATED_KEY),
830            Some(&serde_json::json!("ORDER-42"))
831        );
832    }
833
834    #[tokio::test]
835    async fn test_aggregator_enforces_max_buckets() {
836        let config = AggregatorConfig::correlate_by("orderId")
837            .complete_when_size(2)
838            .max_buckets(3)
839            .build()
840            .unwrap();
841
842        let mut svc = new_test_svc(config);
843
844        // Create 3 different correlation keys (fills limit)
845        for i in 0..3 {
846            let ex = make_exchange("orderId", &format!("key-{}", i), "body");
847            let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
848        }
849
850        // 4th key should be rejected
851        let ex = make_exchange("orderId", "key-4", "body");
852        let result = svc.ready().await.unwrap().call(ex).await;
853
854        assert!(result.is_err(), "Should reject when max buckets reached");
855        let err = result.unwrap_err().to_string();
856        assert!(
857            err.contains("maximum"),
858            "Error message should contain 'maximum': {}",
859            err
860        );
861    }
862
863    #[tokio::test]
864    async fn test_max_buckets_allows_existing_key() {
865        let config = AggregatorConfig::correlate_by("orderId")
866            .complete_when_size(5) // Large size so bucket doesn't complete
867            .max_buckets(2)
868            .build()
869            .unwrap();
870
871        let mut svc = new_test_svc(config);
872
873        // Create 2 different correlation keys (fills limit)
874        let ex1 = make_exchange("orderId", "key-A", "body1");
875        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
876        let ex2 = make_exchange("orderId", "key-B", "body2");
877        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
878
879        // Should still allow adding to existing key
880        let ex3 = make_exchange("orderId", "key-A", "body3");
881        let result = svc.ready().await.unwrap().call(ex3).await;
882        assert!(
883            result.is_ok(),
884            "Should allow adding to existing bucket even at max limit"
885        );
886    }
887
888    #[tokio::test]
889    async fn test_bucket_ttl_eviction() {
890        let config = AggregatorConfig::correlate_by("orderId")
891            .complete_when_size(10) // Large size so bucket doesn't complete normally
892            .bucket_ttl(Duration::from_millis(50))
893            .build()
894            .unwrap();
895
896        let mut svc = new_test_svc(config);
897
898        // Create a bucket
899        let ex1 = make_exchange("orderId", "key-A", "body1");
900        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
901
902        // Wait for TTL to expire
903        tokio::time::sleep(Duration::from_millis(100)).await;
904
905        // Create a new bucket - this should trigger eviction of the old one
906        let ex2 = make_exchange("orderId", "key-B", "body2");
907        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
908
909        // The expired bucket should have been evicted, so we should be able to
910        // add a new key-A bucket again
911        let ex3 = make_exchange("orderId", "key-A", "body3");
912        let result = svc.ready().await.unwrap().call(ex3).await;
913        assert!(result.is_ok(), "Should be able to recreate evicted bucket");
914    }
915
916    #[tokio::test(start_paused = true)]
917    async fn test_timeout_completes_bucket() {
918        let config = AggregatorConfig::correlate_by("key")
919            .complete_on_timeout(Duration::from_millis(100))
920            .build()
921            .unwrap();
922        let mut svc = new_test_svc(config);
923        let ex = make_exchange("key", "A", "data");
924        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
925        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_some());
926
927        tokio::time::sleep(Duration::from_millis(200)).await;
928
929        assert_eq!(
930            svc.buckets.lock().unwrap().len(),
931            0,
932            "bucket should be removed after timeout"
933        );
934    }
935
936    #[tokio::test(start_paused = true)]
937    async fn test_timeout_resets_on_new_exchange() {
938        let config = AggregatorConfig::correlate_by("key")
939            .complete_on_timeout(Duration::from_millis(150))
940            .build()
941            .unwrap();
942        let mut svc = new_test_svc(config);
943
944        let ex1 = make_exchange("key", "A", "first");
945        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
946
947        tokio::time::sleep(Duration::from_millis(100)).await;
948
949        let ex2 = make_exchange("key", "A", "second");
950        let _ = svc.ready().await.unwrap().call(ex2).await.unwrap();
951
952        tokio::time::sleep(Duration::from_millis(100)).await;
953
954        assert_eq!(
955            svc.buckets.lock().unwrap().len(),
956            1,
957            "bucket should still exist — timeout was reset"
958        );
959
960        tokio::time::sleep(Duration::from_millis(100)).await;
961
962        assert_eq!(
963            svc.buckets.lock().unwrap().len(),
964            0,
965            "bucket should be gone after timeout fires"
966        );
967    }
968
969    #[tokio::test]
970    async fn test_composable_size_and_timeout() {
971        let config = AggregatorConfig::correlate_by("key")
972            .complete_on_size_or_timeout(2, Duration::from_millis(200))
973            .build()
974            .unwrap();
975        let mut svc = new_test_svc(config);
976
977        let ex1 = make_exchange("key", "A", "first");
978        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
979        assert!(svc.buckets.lock().unwrap().contains_key("\"A\""));
980
981        let ex2 = make_exchange("key", "A", "second");
982        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
983        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
984        assert_eq!(
985            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
986            Some(&serde_json::json!("size"))
987        );
988    }
989
990    #[tokio::test(start_paused = true)]
991    async fn test_discard_on_timeout() {
992        let config = AggregatorConfig::correlate_by("key")
993            .complete_on_timeout(Duration::from_millis(50))
994            .discard_on_timeout(true)
995            .build()
996            .unwrap();
997        let (tx, mut rx) = mpsc::channel(256);
998        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
999        let cancel = CancellationToken::new();
1000        let mut svc = AggregatorService::new(config, tx, registry, cancel);
1001
1002        let ex = make_exchange("key", "A", "data");
1003        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
1004
1005        tokio::time::sleep(Duration::from_millis(100)).await;
1006
1007        assert!(
1008            rx.try_recv().is_err(),
1009            "no emit expected with discard_on_timeout"
1010        );
1011        assert_eq!(svc.buckets.lock().unwrap().len(), 0);
1012        assert!(
1013            svc.timeout_tasks.lock().unwrap().is_empty(),
1014            "timeout task should be cleaned up"
1015        );
1016    }
1017
1018    #[tokio::test]
1019    async fn test_force_completion_on_stop() {
1020        let config = AggregatorConfig::correlate_by("key")
1021            .complete_when_size(10)
1022            .force_completion_on_stop(true)
1023            .build()
1024            .unwrap();
1025        let (tx, mut rx) = mpsc::channel(256);
1026        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1027        let cancel = CancellationToken::new();
1028        let svc = AggregatorService::new(config, tx, registry, cancel);
1029
1030        let mut call_svc = svc.clone();
1031        let ex = make_exchange("key", "A", "data");
1032        let _ = call_svc.ready().await.unwrap().call(ex).await.unwrap();
1033
1034        svc.force_complete_all();
1035
1036        let result = rx.try_recv().expect("should emit on force-complete");
1037        assert!(
1038            result.input.body.as_text().is_some() || matches!(result.input.body, Body::Json(_))
1039        );
1040        assert_eq!(
1041            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1042            Some(&serde_json::json!("stop"))
1043        );
1044    }
1045
1046    #[tokio::test]
1047    async fn test_completion_reason_property_size() {
1048        let config = AggregatorConfig::correlate_by("key")
1049            .complete_when_size(1)
1050            .build()
1051            .unwrap();
1052        let mut svc = new_test_svc(config);
1053        let ex = make_exchange("key", "X", "body");
1054        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
1055        assert_eq!(
1056            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1057            Some(&serde_json::json!("size"))
1058        );
1059    }
1060
1061    #[tokio::test]
1062    async fn test_completion_reason_property_predicate() {
1063        let config = AggregatorConfig::correlate_by("key")
1064            .complete_when(|_| true)
1065            .build()
1066            .unwrap();
1067        let mut svc = new_test_svc(config);
1068        let ex = make_exchange("key", "X", "body");
1069        let result = svc.ready().await.unwrap().call(ex).await.unwrap();
1070        assert_eq!(
1071            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1072            Some(&serde_json::json!("predicate"))
1073        );
1074    }
1075
1076    #[tokio::test(start_paused = true)]
1077    async fn test_size_completes_before_timeout() {
1078        let config = AggregatorConfig::correlate_by("key")
1079            .complete_on_size_or_timeout(2, Duration::from_millis(200))
1080            .build()
1081            .unwrap();
1082        let mut svc = new_test_svc(config);
1083
1084        let ex1 = make_exchange("key", "A", "first");
1085        let _ = svc.ready().await.unwrap().call(ex1).await.unwrap();
1086
1087        let ex2 = make_exchange("key", "A", "second");
1088        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
1089
1090        assert!(result.property(CAMEL_AGGREGATOR_PENDING).is_none());
1091        assert_eq!(
1092            result.property(CAMEL_AGGREGATED_COMPLETION_REASON),
1093            Some(&serde_json::json!("size"))
1094        );
1095        assert_eq!(svc.buckets.lock().unwrap().len(), 0);
1096
1097        tokio::time::sleep(Duration::from_millis(300)).await;
1098        assert_eq!(
1099            svc.buckets.lock().unwrap().len(),
1100            0,
1101            "no re-fire after timeout"
1102        );
1103    }
1104
1105    #[tokio::test(start_paused = true)]
1106    async fn test_concurrent_timeout_fire_and_new_exchange() {
1107        let config = AggregatorConfig::correlate_by("key")
1108            .complete_on_size_or_timeout(2, Duration::from_millis(100))
1109            .build()
1110            .unwrap();
1111        let (tx, mut rx) = mpsc::channel(256);
1112        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1113        let cancel = CancellationToken::new();
1114        let mut svc = AggregatorService::new(config, tx, registry, cancel);
1115
1116        let ex = make_exchange("key", "A", "data");
1117        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
1118
1119        // Advance time past timeout — timeout task fires and removes bucket
1120        tokio::time::sleep(Duration::from_millis(150)).await;
1121
1122        // New exchange arrives after timeout — starts a fresh bucket
1123        let ex2 = make_exchange("key", "A", "data2");
1124        let result = svc.ready().await.unwrap().call(ex2).await.unwrap();
1125        assert!(
1126            result.property(CAMEL_AGGREGATOR_PENDING).is_some(),
1127            "should be pending in new bucket"
1128        );
1129
1130        // Drain late emits from timeout
1131        let mut late_count = 0;
1132        while rx.try_recv().is_ok() {
1133            late_count += 1;
1134        }
1135        assert_eq!(
1136            late_count, 1,
1137            "exactly 1 late emit from the timed-out bucket"
1138        );
1139    }
1140
1141    #[tokio::test(start_paused = true)]
1142    async fn test_late_channel_full_drops_with_warning() {
1143        let config = AggregatorConfig::correlate_by("key")
1144            .complete_on_timeout(Duration::from_millis(50))
1145            .build()
1146            .unwrap();
1147        let (tx, mut rx) = mpsc::channel(1);
1148        rx.close();
1149        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1150        let cancel = CancellationToken::new();
1151        let mut svc = AggregatorService::new(config, tx, registry, cancel);
1152
1153        let ex = make_exchange("key", "A", "data");
1154        let _ = svc.ready().await.unwrap().call(ex).await.unwrap();
1155
1156        tokio::time::sleep(Duration::from_millis(100)).await;
1157        assert_eq!(
1158            svc.buckets.lock().unwrap().len(),
1159            0,
1160            "bucket removed despite channel closed"
1161        );
1162    }
1163
1164    #[tokio::test]
1165    async fn test_aggregate_stream_bodies_creates_valid_json() {
1166        use bytes::Bytes;
1167        use camel_api::{Body, StreamBody, StreamMetadata};
1168        use futures::stream;
1169        use tokio::sync::Mutex;
1170
1171        let chunks = vec![Ok(Bytes::from("test"))];
1172        let stream_body = StreamBody {
1173            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
1174            metadata: StreamMetadata {
1175                origin: Some("file:///test.txt".to_string()),
1176                ..Default::default()
1177            },
1178        };
1179
1180        let ex1 = Exchange::new(Message {
1181            headers: Default::default(),
1182            body: Body::Stream(stream_body),
1183        });
1184
1185        let exchanges = vec![ex1];
1186        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
1187
1188        let exchange = result.expect("Expected Ok result");
1189        assert!(
1190            matches!(exchange.input.body, Body::Json(_)),
1191            "Expected Json body"
1192        );
1193
1194        if let Body::Json(value) = exchange.input.body {
1195            let json_str = serde_json::to_string(&value).unwrap();
1196            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1197
1198            assert!(parsed.is_array(), "Result should be an array");
1199            let arr = parsed.as_array().unwrap();
1200            assert!(arr[0].is_object(), "First element should be an object");
1201            assert!(
1202                arr[0]["_stream"].is_object(),
1203                "Should contain _stream object"
1204            );
1205            assert_eq!(arr[0]["_stream"]["origin"], "file:///test.txt");
1206            assert_eq!(
1207                arr[0]["_stream"]["placeholder"], true,
1208                "placeholder flag should be true"
1209            );
1210        }
1211    }
1212
1213    #[tokio::test]
1214    async fn test_aggregate_stream_bodies_with_none_origin() {
1215        use bytes::Bytes;
1216        use camel_api::{Body, StreamBody, StreamMetadata};
1217        use futures::stream;
1218        use tokio::sync::Mutex;
1219
1220        let chunks = vec![Ok(Bytes::from("test"))];
1221        let stream_body = StreamBody {
1222            stream: Arc::new(Mutex::new(Some(Box::pin(stream::iter(chunks))))),
1223            metadata: StreamMetadata {
1224                origin: None,
1225                ..Default::default()
1226            },
1227        };
1228
1229        let ex1 = Exchange::new(Message {
1230            headers: Default::default(),
1231            body: Body::Stream(stream_body),
1232        });
1233
1234        let exchanges = vec![ex1];
1235        let result = aggregate(exchanges, &AggregationStrategy::CollectAll);
1236
1237        let exchange = result.expect("Expected Ok result");
1238        assert!(
1239            matches!(exchange.input.body, Body::Json(_)),
1240            "Expected Json body"
1241        );
1242
1243        if let Body::Json(value) = exchange.input.body {
1244            let json_str = serde_json::to_string(&value).unwrap();
1245            let parsed: serde_json::Value = serde_json::from_str(&json_str).unwrap();
1246
1247            assert!(parsed.is_array(), "Result should be an array");
1248            let arr = parsed.as_array().unwrap();
1249            assert!(arr[0].is_object(), "First element should be an object");
1250            assert!(
1251                arr[0]["_stream"].is_object(),
1252                "Should contain _stream object"
1253            );
1254            assert_eq!(
1255                arr[0]["_stream"]["origin"],
1256                serde_json::Value::Null,
1257                "origin should be null when None"
1258            );
1259            assert_eq!(
1260                arr[0]["_stream"]["placeholder"], true,
1261                "placeholder flag should be true"
1262            );
1263        }
1264    }
1265
1266    #[tokio::test]
1267    async fn timeout_completion_clears_handle_from_map() {
1268        // Regression: Oracle audit found that natural timeout completion removed
1269        // the bucket but left the JoinHandle in `timeout_handles`, leaking the
1270        // entry until route shutdown. After fix, the timeout task itself cleans
1271        // its handle from the map on natural completion.
1272        let config = AggregatorConfig::correlate_by("key")
1273            .complete_on_timeout(Duration::from_millis(50))
1274            .build()
1275            .unwrap();
1276        let (tx, _rx) = mpsc::channel(256);
1277        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1278        let cancel = CancellationToken::new();
1279        let svc = AggregatorService::new(config, tx, registry, cancel);
1280
1281        // Send an exchange to create a pending bucket with a timeout task.
1282        let mut call_svc = svc.clone();
1283        let ex = make_exchange("key", "A", "data");
1284        let _ = call_svc.ready().await.unwrap().call(ex).await.unwrap();
1285        assert!(
1286            !svc.timeout_handles.lock().unwrap().is_empty(),
1287            "handle should exist while timeout pending"
1288        );
1289
1290        // Wait real time for the 50ms timeout to fire + spawned task to complete.
1291        tokio::time::sleep(Duration::from_millis(200)).await;
1292
1293        assert!(
1294            svc.timeout_handles.lock().unwrap().is_empty(),
1295            "handle should be cleared from map after natural timeout completion (was leak)"
1296        );
1297    }
1298
1299    #[tokio::test]
1300    async fn aggregator_shutdown_via_trait_dispatch() {
1301        // RED: builds an AggregatorService, dispatches through Arc<dyn StepLifecycle>,
1302        // and asserts idempotent shutdown works.
1303        let config = AggregatorConfig::correlate_by("key")
1304            .complete_when_size(10)
1305            .build()
1306            .unwrap();
1307        let (tx, _rx) = mpsc::channel(256);
1308        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1309        let cancel = CancellationToken::new();
1310        let svc = AggregatorService::new(config, tx, registry, cancel);
1311
1312        let step: Arc<dyn StepLifecycle> = Arc::new(svc);
1313        step.shutdown(StepShutdownReason::RouteStop)
1314            .await
1315            .expect("first shutdown should succeed");
1316        step.shutdown(StepShutdownReason::RouteStop)
1317            .await
1318            .expect("second shutdown (idempotent) should succeed");
1319    }
1320
1321    #[tokio::test(start_paused = true)]
1322    async fn test_shutdown_awaits_timeout_handles() {
1323        let config = AggregatorConfig::correlate_by("key")
1324            .complete_on_timeout(Duration::from_millis(100))
1325            .build()
1326            .unwrap();
1327        let (tx, _rx) = mpsc::channel(256);
1328        let registry: SharedLanguageRegistry = Arc::new(std::sync::Mutex::new(HashMap::new()));
1329        let cancel = CancellationToken::new();
1330        let svc = AggregatorService::new(config, tx, registry, cancel);
1331
1332        // Send an exchange to create a pending bucket with a timeout task.
1333        let mut call_svc = svc.clone();
1334        let ex = make_exchange("key", "A", "data");
1335        let _ = call_svc.ready().await.unwrap().call(ex).await.unwrap();
1336
1337        // Verify timeout handle exists.
1338        assert!(
1339            !svc.timeout_handles.lock().unwrap().is_empty(),
1340            "should have a timeout handle"
1341        );
1342
1343        // Shutdown should complete within the 5s deadline (the timeout task
1344        // gets cancelled so it won't wait for the full 100ms sleep).
1345        svc.shutdown_inner().await;
1346
1347        assert!(
1348            svc.timeout_handles.lock().unwrap().is_empty(),
1349            "all handles should be cleaned up after shutdown"
1350        );
1351    }
1352}