1use std::future::Future;
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use tower::Service;
6
7use camel_api::{
8 Body, BoxProcessor, CamelError, Exchange, MulticastConfig, MulticastStrategy, Value,
9};
10
11pub const CAMEL_MULTICAST_INDEX: &str = "CamelMulticastIndex";
15pub const CAMEL_MULTICAST_COMPLETE: &str = "CamelMulticastComplete";
17
18#[derive(Clone)]
30pub struct MulticastService {
31 endpoints: Vec<BoxProcessor>,
32 config: MulticastConfig,
33}
34
35impl MulticastService {
36 pub fn new(endpoints: Vec<BoxProcessor>, config: MulticastConfig) -> Self {
38 Self { endpoints, config }
39 }
40}
41
42impl Service<Exchange> for MulticastService {
43 type Response = Exchange;
44 type Error = CamelError;
45 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
46
47 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
48 for endpoint in &mut self.endpoints {
50 match endpoint.poll_ready(cx) {
51 Poll::Pending => return Poll::Pending,
52 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
53 Poll::Ready(Ok(())) => {}
54 }
55 }
56 Poll::Ready(Ok(()))
57 }
58
59 fn call(&mut self, exchange: Exchange) -> Self::Future {
60 let original = exchange.clone();
61 let endpoints = self.endpoints.clone();
62 let config = self.config.clone();
63
64 Box::pin(async move {
65 if endpoints.is_empty() {
67 return Ok(original);
68 }
69
70 let total = endpoints.len();
71
72 let results = if config.parallel {
73 process_parallel(exchange, endpoints, config.parallel_limit, total).await
75 } else {
76 process_sequential(exchange, endpoints, config.stop_on_exception, total).await
78 };
79
80 aggregate(results, original, config.aggregation)
82 })
83 }
84}
85
86async fn process_sequential(
89 exchange: Exchange,
90 endpoints: Vec<BoxProcessor>,
91 stop_on_exception: bool,
92 total: usize,
93) -> Vec<Result<Exchange, CamelError>> {
94 let mut results = Vec::with_capacity(endpoints.len());
95
96 for (i, endpoint) in endpoints.into_iter().enumerate() {
97 let mut cloned_exchange = exchange.clone();
99
100 cloned_exchange.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
102 cloned_exchange.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
103
104 let mut endpoint = endpoint;
105 match tower::ServiceExt::ready(&mut endpoint).await {
106 Err(e) => {
107 results.push(Err(e));
108 if stop_on_exception {
109 break;
110 }
111 }
112 Ok(svc) => {
113 let result = svc.call(cloned_exchange).await;
114 let is_err = result.is_err();
115 results.push(result);
116 if stop_on_exception && is_err {
117 break;
118 }
119 }
120 }
121 }
122
123 results
124}
125
126async fn process_parallel(
129 exchange: Exchange,
130 endpoints: Vec<BoxProcessor>,
131 parallel_limit: Option<usize>,
132 total: usize,
133) -> Vec<Result<Exchange, CamelError>> {
134 use std::sync::Arc;
135 use tokio::sync::Semaphore;
136
137 let semaphore = parallel_limit.map(|limit| Arc::new(Semaphore::new(limit)));
138
139 let futures: Vec<_> = endpoints
141 .into_iter()
142 .enumerate()
143 .map(|(i, mut endpoint)| {
144 let mut ex = exchange.clone();
145 ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
146 ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
147 let sem = semaphore.clone();
148 async move {
149 let _permit = match &sem {
151 Some(s) => match s.acquire().await {
152 Ok(p) => Some(p),
153 Err(_) => {
154 return Err(CamelError::ProcessorError("semaphore closed".to_string()));
155 }
156 },
157 None => None,
158 };
159
160 tower::ServiceExt::ready(&mut endpoint).await?;
162 endpoint.call(ex).await
163 }
164 })
165 .collect();
166
167 futures::future::join_all(futures).await
169}
170
171fn aggregate(
174 results: Vec<Result<Exchange, CamelError>>,
175 original: Exchange,
176 strategy: MulticastStrategy,
177) -> Result<Exchange, CamelError> {
178 match strategy {
179 MulticastStrategy::LastWins => {
180 results.into_iter().last().unwrap_or_else(|| Ok(original))
183 }
184 MulticastStrategy::CollectAll => {
185 let mut bodies = Vec::new();
187 for result in results {
188 let ex = result?;
189 let value = match &ex.input.body {
190 Body::Text(s) => Value::String(s.clone()),
191 Body::Json(v) => v.clone(),
192 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
193 Body::Empty => Value::Null,
194 };
195 bodies.push(value);
196 }
197 let mut out = original;
198 out.input.body = Body::Json(Value::Array(bodies));
199 Ok(out)
200 }
201 MulticastStrategy::Original => Ok(original),
202 MulticastStrategy::Custom(fold_fn) => {
203 let mut iter = results.into_iter();
205 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
206 iter.try_fold(first, |acc, next_result| {
207 let next = next_result?;
208 Ok(fold_fn(acc, next))
209 })
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use camel_api::{BoxProcessorExt, Message};
218 use std::sync::Arc;
219 use std::sync::atomic::Ordering;
220 use tower::ServiceExt;
221
222 fn make_exchange(body: &str) -> Exchange {
225 Exchange::new(Message::new(body))
226 }
227
228 fn uppercase_processor() -> BoxProcessor {
229 BoxProcessor::from_fn(|mut ex: Exchange| {
230 Box::pin(async move {
231 if let Body::Text(s) = &ex.input.body {
232 ex.input.body = Body::Text(s.to_uppercase());
233 }
234 Ok(ex)
235 })
236 })
237 }
238
239 fn failing_processor() -> BoxProcessor {
240 BoxProcessor::from_fn(|_ex| {
241 Box::pin(async { Err(CamelError::ProcessorError("boom".into())) })
242 })
243 }
244
245 #[tokio::test]
248 async fn test_multicast_sequential_last_wins() {
249 let endpoints = vec![
250 uppercase_processor(),
251 uppercase_processor(),
252 uppercase_processor(),
253 ];
254
255 let config = MulticastConfig::new(); let mut svc = MulticastService::new(endpoints, config);
257
258 let result = svc
259 .ready()
260 .await
261 .unwrap()
262 .call(make_exchange("hello"))
263 .await
264 .unwrap();
265
266 assert_eq!(result.input.body.as_text(), Some("HELLO"));
267 }
268
269 #[tokio::test]
272 async fn test_multicast_sequential_collect_all() {
273 let endpoints = vec![
274 uppercase_processor(),
275 uppercase_processor(),
276 uppercase_processor(),
277 ];
278
279 let config = MulticastConfig::new().aggregation(MulticastStrategy::CollectAll);
280 let mut svc = MulticastService::new(endpoints, config);
281
282 let result = svc
283 .ready()
284 .await
285 .unwrap()
286 .call(make_exchange("hello"))
287 .await
288 .unwrap();
289
290 let expected = serde_json::json!(["HELLO", "HELLO", "HELLO"]);
291 match &result.input.body {
292 Body::Json(v) => assert_eq!(*v, expected),
293 other => panic!("expected JSON body, got {other:?}"),
294 }
295 }
296
297 #[tokio::test]
300 async fn test_multicast_sequential_original() {
301 let endpoints = vec![
302 uppercase_processor(),
303 uppercase_processor(),
304 uppercase_processor(),
305 ];
306
307 let config = MulticastConfig::new().aggregation(MulticastStrategy::Original);
308 let mut svc = MulticastService::new(endpoints, config);
309
310 let result = svc
311 .ready()
312 .await
313 .unwrap()
314 .call(make_exchange("hello"))
315 .await
316 .unwrap();
317
318 assert_eq!(result.input.body.as_text(), Some("hello"));
320 }
321
322 #[tokio::test]
325 async fn test_multicast_sequential_custom_aggregation() {
326 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
327 Arc::new(|mut acc: Exchange, next: Exchange| {
328 let acc_text = acc.input.body.as_text().unwrap_or("").to_string();
329 let next_text = next.input.body.as_text().unwrap_or("").to_string();
330 acc.input.body = Body::Text(format!("{acc_text}+{next_text}"));
331 acc
332 });
333
334 let endpoints = vec![
335 uppercase_processor(),
336 uppercase_processor(),
337 uppercase_processor(),
338 ];
339
340 let config = MulticastConfig::new().aggregation(MulticastStrategy::Custom(joiner));
341 let mut svc = MulticastService::new(endpoints, config);
342
343 let result = svc
344 .ready()
345 .await
346 .unwrap()
347 .call(make_exchange("a"))
348 .await
349 .unwrap();
350
351 assert_eq!(result.input.body.as_text(), Some("A+A+A"));
352 }
353
354 #[tokio::test]
357 async fn test_multicast_stop_on_exception() {
358 let endpoints = vec![
359 uppercase_processor(),
360 failing_processor(),
361 uppercase_processor(),
362 ];
363
364 let config = MulticastConfig::new().stop_on_exception(true);
365 let mut svc = MulticastService::new(endpoints, config);
366
367 let result = svc
368 .ready()
369 .await
370 .unwrap()
371 .call(make_exchange("hello"))
372 .await;
373
374 assert!(result.is_err(), "expected error due to stop_on_exception");
375 }
376
377 #[tokio::test]
380 async fn test_multicast_continue_on_exception() {
381 let endpoints = vec![
382 uppercase_processor(),
383 failing_processor(),
384 uppercase_processor(),
385 ];
386
387 let config = MulticastConfig::new()
388 .stop_on_exception(false)
389 .aggregation(MulticastStrategy::LastWins);
390 let mut svc = MulticastService::new(endpoints, config);
391
392 let result = svc
393 .ready()
394 .await
395 .unwrap()
396 .call(make_exchange("hello"))
397 .await;
398
399 assert!(result.is_ok(), "last endpoint should succeed");
401 assert_eq!(result.unwrap().input.body.as_text(), Some("HELLO"));
402 }
403
404 #[tokio::test]
407 async fn test_multicast_stop_on_exception_halts_early() {
408 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
409
410 let executed = Arc::new(AtomicUsize::new(0));
412
413 let exec_clone1 = Arc::clone(&executed);
414 let endpoint0 = BoxProcessor::from_fn(move |ex: Exchange| {
415 let e = Arc::clone(&exec_clone1);
416 Box::pin(async move {
417 e.fetch_add(1, AtomicOrdering::SeqCst);
418 Ok(ex)
419 })
420 });
421
422 let exec_clone2 = Arc::clone(&executed);
423 let endpoint1 = BoxProcessor::from_fn(move |_ex: Exchange| {
424 let e = Arc::clone(&exec_clone2);
425 Box::pin(async move {
426 e.fetch_add(1, AtomicOrdering::SeqCst);
427 Err(CamelError::ProcessorError("fail on 1".into()))
428 })
429 });
430
431 let exec_clone3 = Arc::clone(&executed);
432 let endpoint2 = BoxProcessor::from_fn(move |ex: Exchange| {
433 let e = Arc::clone(&exec_clone3);
434 Box::pin(async move {
435 e.fetch_add(1, AtomicOrdering::SeqCst);
436 Ok(ex)
437 })
438 });
439
440 let endpoints = vec![endpoint0, endpoint1, endpoint2];
441 let config = MulticastConfig::new().stop_on_exception(true);
442 let mut svc = MulticastService::new(endpoints, config);
443
444 let result = svc.ready().await.unwrap().call(make_exchange("x")).await;
445 assert!(result.is_err(), "should fail at endpoint 1");
446
447 let count = executed.load(AtomicOrdering::SeqCst);
449 assert_eq!(
450 count, 2,
451 "endpoint 2 should not have executed due to stop_on_exception"
452 );
453 }
454
455 #[tokio::test]
458 async fn test_multicast_continue_on_exception_executes_all() {
459 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
460
461 let executed = Arc::new(AtomicUsize::new(0));
463
464 let exec_clone1 = Arc::clone(&executed);
465 let endpoint0 = BoxProcessor::from_fn(move |ex: Exchange| {
466 let e = Arc::clone(&exec_clone1);
467 Box::pin(async move {
468 e.fetch_add(1, AtomicOrdering::SeqCst);
469 Ok(ex)
470 })
471 });
472
473 let exec_clone2 = Arc::clone(&executed);
474 let endpoint1 = BoxProcessor::from_fn(move |_ex: Exchange| {
475 let e = Arc::clone(&exec_clone2);
476 Box::pin(async move {
477 e.fetch_add(1, AtomicOrdering::SeqCst);
478 Err(CamelError::ProcessorError("fail on 1".into()))
479 })
480 });
481
482 let exec_clone3 = Arc::clone(&executed);
483 let endpoint2 = BoxProcessor::from_fn(move |ex: Exchange| {
484 let e = Arc::clone(&exec_clone3);
485 Box::pin(async move {
486 e.fetch_add(1, AtomicOrdering::SeqCst);
487 Ok(ex)
488 })
489 });
490
491 let endpoints = vec![endpoint0, endpoint1, endpoint2];
492 let config = MulticastConfig::new()
493 .stop_on_exception(false)
494 .aggregation(MulticastStrategy::LastWins);
495 let mut svc = MulticastService::new(endpoints, config);
496
497 let result = svc.ready().await.unwrap().call(make_exchange("x")).await;
498 assert!(result.is_ok(), "last endpoint should succeed");
499
500 let count = executed.load(AtomicOrdering::SeqCst);
502 assert_eq!(
503 count, 3,
504 "all endpoints should have executed despite error in endpoint 1"
505 );
506 }
507
508 #[tokio::test]
511 async fn test_multicast_empty_endpoints() {
512 let endpoints: Vec<BoxProcessor> = vec![];
513
514 let config = MulticastConfig::new();
515 let mut svc = MulticastService::new(endpoints, config);
516
517 let mut ex = make_exchange("hello");
518 ex.set_property("marker", Value::Bool(true));
519
520 let result = svc.ready().await.unwrap().call(ex).await.unwrap();
521 assert_eq!(result.input.body.as_text(), Some("hello"));
522 assert_eq!(result.property("marker"), Some(&Value::Bool(true)));
523 }
524
525 #[tokio::test]
528 async fn test_multicast_metadata_properties() {
529 let recorder = BoxProcessor::from_fn(|ex: Exchange| {
531 Box::pin(async move {
532 let idx = ex.property(CAMEL_MULTICAST_INDEX).cloned();
533 let complete = ex.property(CAMEL_MULTICAST_COMPLETE).cloned();
534 let body = serde_json::json!({
535 "index": idx,
536 "complete": complete,
537 });
538 let mut out = ex;
539 out.input.body = Body::Json(body);
540 Ok(out)
541 })
542 });
543
544 let endpoints = vec![recorder.clone(), recorder.clone(), recorder];
545
546 let config = MulticastConfig::new().aggregation(MulticastStrategy::CollectAll);
547 let mut svc = MulticastService::new(endpoints, config);
548
549 let result = svc
550 .ready()
551 .await
552 .unwrap()
553 .call(make_exchange("x"))
554 .await
555 .unwrap();
556
557 let expected = serde_json::json!([
558 {"index": 0, "complete": false},
559 {"index": 1, "complete": false},
560 {"index": 2, "complete": true},
561 ]);
562 match &result.input.body {
563 Body::Json(v) => assert_eq!(*v, expected),
564 other => panic!("expected JSON body, got {other:?}"),
565 }
566 }
567
568 #[tokio::test]
571 async fn test_poll_ready_delegates_to_endpoints() {
572 use std::sync::atomic::AtomicBool;
573
574 #[derive(Clone)]
576 struct DelayedReady {
577 ready: Arc<AtomicBool>,
578 }
579
580 impl Service<Exchange> for DelayedReady {
581 type Response = Exchange;
582 type Error = CamelError;
583 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
584
585 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
586 if self.ready.load(Ordering::SeqCst) {
587 Poll::Ready(Ok(()))
588 } else {
589 cx.waker().wake_by_ref();
590 Poll::Pending
591 }
592 }
593
594 fn call(&mut self, exchange: Exchange) -> Self::Future {
595 Box::pin(async move { Ok(exchange) })
596 }
597 }
598
599 let ready_flag = Arc::new(AtomicBool::new(false));
600 let inner = DelayedReady {
601 ready: Arc::clone(&ready_flag),
602 };
603 let boxed: BoxProcessor = BoxProcessor::new(inner);
604
605 let config = MulticastConfig::new();
606 let mut svc = MulticastService::new(vec![boxed], config);
607
608 let waker = futures::task::noop_waker();
610 let mut cx = Context::from_waker(&waker);
611 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
612 assert!(
613 poll.is_pending(),
614 "expected Pending when endpoint not ready"
615 );
616
617 ready_flag.store(true, Ordering::SeqCst);
619
620 let poll = Pin::new(&mut svc).poll_ready(&mut cx);
621 assert!(
622 matches!(poll, Poll::Ready(Ok(()))),
623 "expected Ready after endpoint becomes ready"
624 );
625 }
626
627 #[tokio::test]
630 async fn test_multicast_collect_all_error_propagates() {
631 let endpoints = vec![
632 uppercase_processor(),
633 failing_processor(),
634 uppercase_processor(),
635 ];
636
637 let config = MulticastConfig::new()
638 .stop_on_exception(false)
639 .aggregation(MulticastStrategy::CollectAll);
640 let mut svc = MulticastService::new(endpoints, config);
641
642 let result = svc
643 .ready()
644 .await
645 .unwrap()
646 .call(make_exchange("hello"))
647 .await;
648
649 assert!(result.is_err(), "CollectAll should propagate first error");
650 }
651
652 #[tokio::test]
655 async fn test_multicast_last_wins_error_last() {
656 let endpoints = vec![
657 uppercase_processor(),
658 uppercase_processor(),
659 failing_processor(),
660 ];
661
662 let config = MulticastConfig::new()
663 .stop_on_exception(false)
664 .aggregation(MulticastStrategy::LastWins);
665 let mut svc = MulticastService::new(endpoints, config);
666
667 let result = svc
668 .ready()
669 .await
670 .unwrap()
671 .call(make_exchange("hello"))
672 .await;
673
674 assert!(result.is_err(), "LastWins should return last error");
675 }
676
677 #[tokio::test]
680 async fn test_multicast_custom_error_propagates() {
681 let joiner: Arc<dyn Fn(Exchange, Exchange) -> Exchange + Send + Sync> =
682 Arc::new(|acc: Exchange, _next: Exchange| acc);
683
684 let endpoints = vec![
685 uppercase_processor(),
686 failing_processor(),
687 uppercase_processor(),
688 ];
689
690 let config = MulticastConfig::new()
691 .stop_on_exception(false)
692 .aggregation(MulticastStrategy::Custom(joiner));
693 let mut svc = MulticastService::new(endpoints, config);
694
695 let result = svc
696 .ready()
697 .await
698 .unwrap()
699 .call(make_exchange("hello"))
700 .await;
701
702 assert!(
703 result.is_err(),
704 "Custom aggregation should propagate errors"
705 );
706 }
707
708 #[tokio::test]
711 async fn test_multicast_parallel_basic() {
712 let endpoints = vec![uppercase_processor(), uppercase_processor()];
713
714 let config = MulticastConfig::new()
715 .parallel(true)
716 .aggregation(MulticastStrategy::CollectAll);
717 let mut svc = MulticastService::new(endpoints, config);
718
719 let result = svc
720 .ready()
721 .await
722 .unwrap()
723 .call(make_exchange("test"))
724 .await
725 .unwrap();
726
727 match &result.input.body {
730 Body::Json(v) => {
731 let arr = v.as_array().expect("expected array");
732 assert_eq!(arr.len(), 2);
733 assert!(arr.iter().all(|v| v.as_str() == Some("TEST")));
734 }
735 other => panic!("expected JSON body, got {:?}", other),
736 }
737 }
738
739 #[tokio::test]
742 async fn test_multicast_parallel_with_limit() {
743 use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
744
745 let concurrent = Arc::new(AtomicUsize::new(0));
746 let max_concurrent = Arc::new(AtomicUsize::new(0));
747
748 let endpoints: Vec<BoxProcessor> = (0..4)
749 .map(|_| {
750 let c = Arc::clone(&concurrent);
751 let mc = Arc::clone(&max_concurrent);
752 BoxProcessor::from_fn(move |ex: Exchange| {
753 let c = Arc::clone(&c);
754 let mc = Arc::clone(&mc);
755 Box::pin(async move {
756 let current = c.fetch_add(1, AtomicOrdering::SeqCst) + 1;
757 mc.fetch_max(current, AtomicOrdering::SeqCst);
758 tokio::task::yield_now().await;
759 c.fetch_sub(1, AtomicOrdering::SeqCst);
760 Ok(ex)
761 })
762 })
763 })
764 .collect();
765
766 let config = MulticastConfig::new().parallel(true).parallel_limit(2);
767 let mut svc = MulticastService::new(endpoints, config);
768
769 let _ = svc.ready().await.unwrap().call(make_exchange("x")).await;
770
771 let observed_max = max_concurrent.load(std::sync::atomic::Ordering::SeqCst);
772 assert!(
773 observed_max <= 2,
774 "max concurrency was {}, expected <= 2",
775 observed_max
776 );
777 }
778}