1use std::future::Future;
7use std::pin::Pin;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
10use std::time::Duration;
11use tokio::task::JoinSet;
12
13use camel_api::{Exchange, Value};
14
15use crate::multicast::{CAMEL_MULTICAST_COMPLETE, CAMEL_MULTICAST_INDEX};
16
17#[derive(Clone)]
31pub struct MulticastSegment {
32 pub branches: Vec<camel_api::OutcomeSegment>,
33 pub parallel: bool,
34 pub parallel_limit: Option<usize>,
36 pub stop_on_exception: bool,
46 pub timeout: Option<Duration>,
48 pub aggregator: Arc<dyn Fn(Vec<Exchange>) -> Exchange + Send + Sync>,
49}
50
51impl camel_api::OutcomePipeline for MulticastSegment {
52 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
53 Box::new(self.clone())
54 }
55
56 fn run<'a>(
57 &'a mut self,
58 exchange: Exchange,
59 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
60 Box::pin(async move {
61 if self.parallel {
62 parallel_multicast(self, exchange).await
63 } else {
64 sequential_multicast(self, exchange).await
65 }
66 })
67 }
68}
69
70async fn sequential_multicast(
73 seg: &mut MulticastSegment,
74 exchange: Exchange,
75) -> camel_api::PipelineOutcome {
76 let mut outputs = Vec::new();
77 let mut last_error: Option<camel_api::CamelError> = None;
78 let total = seg.branches.len();
79 for (i, branch) in seg.branches.iter_mut().enumerate() {
80 let mut ex = exchange.clone();
82 ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
83 ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
84 match branch.run(ex).await {
85 camel_api::PipelineOutcome::Completed(ex) => outputs.push(ex),
86 camel_api::PipelineOutcome::Stopped(ex) => {
87 return camel_api::PipelineOutcome::Stopped(ex);
88 }
89 camel_api::PipelineOutcome::Failed(err) => {
90 if seg.stop_on_exception {
91 return camel_api::PipelineOutcome::Failed(err);
92 }
93 last_error = Some(err);
95 }
96 }
97 }
98 if let Some(err) = last_error {
99 return camel_api::PipelineOutcome::Failed(err);
100 }
101 camel_api::PipelineOutcome::Completed((seg.aggregator)(outputs))
102}
103
104async fn parallel_multicast(
111 seg: &mut MulticastSegment,
112 exchange: Exchange,
113) -> camel_api::PipelineOutcome {
114 use std::sync::Arc;
115 use tokio::sync::Semaphore;
116
117 let stopped_seen = Arc::new(AtomicBool::new(false));
118 let stopped_idx = Arc::new(AtomicUsize::new(usize::MAX));
119 let semaphore = seg
120 .parallel_limit
121 .filter(|&limit| limit > 0)
122 .map(|limit| Arc::new(Semaphore::new(limit)));
123 let timeout = seg.timeout;
124 let stop_on_exception = seg.stop_on_exception;
125 let total = seg.branches.len();
126
127 let mut set: JoinSet<(usize, Option<camel_api::PipelineOutcome>)> = JoinSet::new();
128
129 for (idx, mut branch) in seg.branches.clone().into_iter().enumerate() {
130 let stopped_seen = Arc::clone(&stopped_seen);
131 let stopped_idx = Arc::clone(&stopped_idx);
132 let sem = semaphore.clone();
133 let mut ex = exchange.clone();
135 ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(idx as i64));
136 ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(idx == total - 1));
137 set.spawn(async move {
138 if stopped_seen.load(Ordering::SeqCst) {
140 return (idx, None);
141 }
142 let _permit: Option<tokio::sync::OwnedSemaphorePermit> = match &sem {
144 Some(s) => match Arc::clone(s).acquire_owned().await {
145 Ok(p) => Some(p),
146 Err(_) => {
147 return (
148 idx,
149 Some(camel_api::PipelineOutcome::Failed(
150 camel_api::CamelError::ProcessorError("semaphore closed".into()),
151 )),
152 );
153 }
154 },
155 None => None,
156 };
157 if stopped_seen.load(Ordering::SeqCst) {
159 return (idx, None);
160 }
161
162 let outcome = async {
164 let outcome = branch.run(ex).await;
165 if let camel_api::PipelineOutcome::Stopped(_) = &outcome {
166 loop {
168 let cur = stopped_idx.load(Ordering::SeqCst);
169 if idx >= cur {
170 break;
171 }
172 match stopped_idx.compare_exchange_weak(
173 cur,
174 idx,
175 Ordering::SeqCst,
176 Ordering::SeqCst,
177 ) {
178 Ok(_) => break,
179 Err(actual) => {
180 if actual <= idx {
181 break;
182 }
183 }
184 }
185 }
186 stopped_seen.store(true, Ordering::SeqCst);
187 }
188 outcome
189 };
190
191 let outcome = if let Some(dur) = timeout {
192 match tokio::time::timeout(dur, outcome).await {
193 Ok(o) => o,
194 Err(_elapsed) => {
195 camel_api::PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
196 format!("multicast branch {idx} timed out after {dur:?}"),
197 ))
198 }
199 }
200 } else {
201 outcome.await
202 };
203
204 (idx, Some(outcome))
205 });
206 }
207
208 let mut results: Vec<(usize, camel_api::PipelineOutcome)> = Vec::new();
210 while let Some(res) = set.join_next().await {
211 if let Ok((idx, Some(o))) = res {
212 results.push((idx, o));
213 }
214 }
215
216 if stopped_seen.load(Ordering::SeqCst) {
218 let winning_idx = stopped_idx.load(Ordering::SeqCst);
219 if winning_idx == usize::MAX {
220 tracing::warn!(
221 target: "camel.phase4.multicast",
222 "stopped_seen=true but stopped_idx=usize::MAX — race; falling back to pre-multicast exchange"
223 );
224 return camel_api::PipelineOutcome::Stopped(exchange);
225 }
226 let stopped_ex = results
227 .iter()
228 .find(|(idx, _)| *idx == winning_idx)
229 .and_then(|(_, o)| match o {
230 camel_api::PipelineOutcome::Stopped(ex) => Some(ex.clone()),
231 _ => None,
232 });
233 if let Some(ex) = stopped_ex {
234 return camel_api::PipelineOutcome::Stopped(ex);
235 }
236 tracing::warn!(
237 target: "camel.phase4.multicast",
238 winning_idx = winning_idx,
239 "winning_idx not found — falling back to pre-multicast exchange"
240 );
241 return camel_api::PipelineOutcome::Stopped(exchange);
242 }
243
244 results.sort_by_key(|(idx, _)| *idx);
248 if stop_on_exception {
249 let mut first_failed: Option<(usize, camel_api::CamelError)> = None;
250 for (idx, o) in &results {
251 if let camel_api::PipelineOutcome::Failed(err) = o
252 && first_failed
253 .as_ref()
254 .map(|(i, _)| *i > *idx)
255 .unwrap_or(true)
256 {
257 first_failed = Some((*idx, err.clone()));
258 }
259 }
260 if let Some((_, err)) = first_failed {
261 return camel_api::PipelineOutcome::Failed(err);
262 }
263 } else {
264 let mut last_error: Option<camel_api::CamelError> = None;
266 for (_, o) in &results {
267 if let camel_api::PipelineOutcome::Failed(err) = o {
268 last_error = Some(err.clone());
269 }
270 }
271 if let Some(err) = last_error {
272 return camel_api::PipelineOutcome::Failed(err);
273 }
274 }
275
276 let completed: Vec<Exchange> = results
278 .into_iter()
279 .filter_map(|(_, o)| match o {
280 camel_api::PipelineOutcome::Completed(ex) => Some(ex),
281 _ => None,
282 })
283 .collect();
284 camel_api::PipelineOutcome::Completed((seg.aggregator)(completed))
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use camel_api::{Message, OutcomePipeline, OutcomeSegment, PipelineOutcome};
291 use std::sync::Arc;
292 use std::sync::atomic::{AtomicUsize, Ordering};
293
294 fn counting_passing_body(counter: Arc<AtomicUsize>) -> OutcomeSegment {
296 counting_body(counter, usize::MAX) }
298
299 fn counting_body(counter: Arc<AtomicUsize>, fail_at: usize) -> OutcomeSegment {
301 #[derive(Clone)]
302 struct CountBody {
303 counter: Arc<AtomicUsize>,
304 fail_at: usize,
305 }
306 impl camel_api::OutcomePipeline for CountBody {
307 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
308 Box::new(self.clone())
309 }
310 fn run<'a>(
311 &'a mut self,
312 exchange: Exchange,
313 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = PipelineOutcome> + Send + 'a>>
314 {
315 let count = self.counter.fetch_add(1, Ordering::SeqCst);
316 let fail_at = self.fail_at;
317 Box::pin(async move {
318 if count == fail_at {
319 PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(format!(
320 "fail at {count}"
321 )))
322 } else {
323 PipelineOutcome::Completed(exchange)
324 }
325 })
326 }
327 }
328 OutcomeSegment::new(Box::new(CountBody { counter, fail_at }))
329 }
330
331 #[tokio::test]
334 async fn multicast_sequential_stop_on_exception_true() {
335 let invocations = Arc::new(AtomicUsize::new(0));
336 let mut seg = MulticastSegment {
337 branches: vec![
338 counting_passing_body(Arc::clone(&invocations)),
339 counting_body(Arc::clone(&invocations), 1), counting_passing_body(Arc::clone(&invocations)),
341 ],
342 parallel: false,
343 parallel_limit: None,
344 stop_on_exception: true,
345 timeout: None,
346 aggregator: Arc::new(|exchanges: Vec<Exchange>| {
347 exchanges.into_iter().last().unwrap_or_default()
348 }),
349 };
350
351 let ex = Exchange::new(Message::new("test"));
352 let result = OutcomePipeline::run(&mut seg, ex).await;
353
354 assert!(
355 matches!(result, PipelineOutcome::Failed(_)),
356 "stop_on_exception=true should propagate failure"
357 );
358 assert_eq!(invocations.load(Ordering::SeqCst), 2);
360 }
361
362 #[tokio::test]
365 async fn multicast_sequential_stop_on_exception_false() {
366 let invocations = Arc::new(AtomicUsize::new(0));
367 let mut seg = MulticastSegment {
368 branches: vec![
369 counting_passing_body(Arc::clone(&invocations)),
370 counting_body(Arc::clone(&invocations), 1), counting_passing_body(Arc::clone(&invocations)),
372 ],
373 parallel: false,
374 parallel_limit: None,
375 stop_on_exception: false,
376 timeout: None,
377 aggregator: Arc::new(|exchanges: Vec<Exchange>| {
378 exchanges.into_iter().last().unwrap_or_default()
379 }),
380 };
381
382 let ex = Exchange::new(Message::new("test"));
383 let result = OutcomePipeline::run(&mut seg, ex).await;
384
385 assert!(
387 matches!(result, PipelineOutcome::Failed(_)),
388 "should propagate error at end"
389 );
390 assert_eq!(invocations.load(Ordering::SeqCst), 3);
391 }
392
393 #[tokio::test(flavor = "multi_thread")]
396 async fn multicast_parallel_limit_enforcement() {
397 let concurrent = Arc::new(AtomicUsize::new(0));
398 let max_concurrent = Arc::new(AtomicUsize::new(0));
399
400 #[derive(Clone)]
401 struct LimitedBody {
402 concurrent: Arc<AtomicUsize>,
403 max_concurrent: Arc<AtomicUsize>,
404 }
405 impl camel_api::OutcomePipeline for LimitedBody {
406 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
407 Box::new(self.clone())
408 }
409 fn run<'a>(
410 &'a mut self,
411 exchange: Exchange,
412 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = PipelineOutcome> + Send + 'a>>
413 {
414 let c = Arc::clone(&self.concurrent);
415 let mc = Arc::clone(&self.max_concurrent);
416 Box::pin(async move {
417 let current = c.fetch_add(1, Ordering::SeqCst) + 1;
418 mc.fetch_max(current, Ordering::SeqCst);
419 tokio::task::yield_now().await;
420 c.fetch_sub(1, Ordering::SeqCst);
421 PipelineOutcome::Completed(exchange)
422 })
423 }
424 }
425
426 let target: Arc<dyn Fn(Vec<Exchange>) -> Exchange + Send + Sync> =
427 Arc::new(|exchanges: Vec<Exchange>| exchanges.into_iter().last().unwrap_or_default());
428
429 let mut seg = MulticastSegment {
430 branches: (0..6)
431 .map(|_| {
432 OutcomeSegment::new(Box::new(LimitedBody {
433 concurrent: Arc::clone(&concurrent),
434 max_concurrent: Arc::clone(&max_concurrent),
435 }))
436 })
437 .collect(),
438 parallel: true,
439 parallel_limit: Some(2),
440 stop_on_exception: true,
441 timeout: None,
442 aggregator: target,
443 };
444
445 let ex = Exchange::new(Message::new("test"));
446 let result = OutcomePipeline::run(&mut seg, ex).await;
447 assert!(
448 matches!(result, PipelineOutcome::Completed(_)),
449 "Expected Completed, got {result:?}"
450 );
451
452 assert!(
453 max_concurrent.load(Ordering::SeqCst) <= 2,
454 "parallel_limit=2 but observed max concurrency {}",
455 max_concurrent.load(Ordering::SeqCst)
456 );
457 }
458
459 #[tokio::test(flavor = "multi_thread")]
462 async fn multicast_timeout_exceeded() {
463 #[derive(Clone)]
465 struct SlowBody;
466 impl camel_api::OutcomePipeline for SlowBody {
467 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
468 Box::new(self.clone())
469 }
470 fn run<'a>(
471 &'a mut self,
472 exchange: Exchange,
473 ) -> std::pin::Pin<Box<dyn std::future::Future<Output = PipelineOutcome> + Send + 'a>>
474 {
475 Box::pin(async move {
476 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
477 PipelineOutcome::Completed(exchange)
478 })
479 }
480 }
481
482 let target: Arc<dyn Fn(Vec<Exchange>) -> Exchange + Send + Sync> =
483 Arc::new(|exchanges: Vec<Exchange>| exchanges.into_iter().last().unwrap_or_default());
484
485 let mut seg = MulticastSegment {
486 branches: vec![
487 OutcomeSegment::new(Box::new(SlowBody)),
488 counting_passing_body(Arc::new(AtomicUsize::new(0))),
489 ],
490 parallel: true,
491 parallel_limit: None,
492 stop_on_exception: true,
493 timeout: Some(std::time::Duration::from_millis(50)),
494 aggregator: target,
495 };
496
497 let ex = Exchange::new(Message::new("test"));
498 let result = OutcomePipeline::run(&mut seg, ex).await;
499
500 assert!(
502 matches!(result, PipelineOutcome::Failed(_)),
503 "Expected Failed due to timeout, got {result:?}"
504 );
505 }
506
507 #[tokio::test(flavor = "multi_thread")]
510 async fn multicast_parallel_stop_on_exception_false_propagates_last_error() {
511 fn always_pass_body() -> OutcomeSegment {
513 #[derive(Clone)]
514 struct PassBody;
515 impl camel_api::OutcomePipeline for PassBody {
516 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
517 Box::new(PassBody)
518 }
519 fn run<'a>(
520 &'a mut self,
521 exchange: Exchange,
522 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
523 Box::pin(async move { PipelineOutcome::Completed(exchange) })
524 }
525 }
526 OutcomeSegment::new(Box::new(PassBody))
527 }
528 fn always_fail_body(msg: &'static str) -> OutcomeSegment {
529 #[derive(Clone)]
530 struct FailBody {
531 msg: &'static str,
532 }
533 impl camel_api::OutcomePipeline for FailBody {
534 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
535 Box::new(self.clone())
536 }
537 fn run<'a>(
538 &'a mut self,
539 _exchange: Exchange,
540 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
541 let msg = self.msg;
542 Box::pin(async move {
543 PipelineOutcome::Failed(camel_api::CamelError::ProcessorError(
544 msg.to_string(),
545 ))
546 })
547 }
548 }
549 OutcomeSegment::new(Box::new(FailBody { msg }))
550 }
551
552 let target: Arc<dyn Fn(Vec<Exchange>) -> Exchange + Send + Sync> =
553 Arc::new(|exchanges: Vec<Exchange>| exchanges.into_iter().last().unwrap_or_default());
554
555 let mut seg = MulticastSegment {
556 branches: vec![
557 always_fail_body("err1"), always_pass_body(), always_fail_body("err2"), ],
561 parallel: true,
562 parallel_limit: None,
563 stop_on_exception: false,
564 timeout: None,
565 aggregator: target,
566 };
567
568 let ex = Exchange::new(Message::new("test"));
569 let result = OutcomePipeline::run(&mut seg, ex).await;
570
571 match result {
573 PipelineOutcome::Failed(err) => {
574 let msg = format!("{err}");
575 assert!(
576 msg.contains("err2"),
577 "Expected last error 'err2' (from highest-index branch), got: {msg}"
578 );
579 }
580 other => panic!("Expected Failed(err2) with last-wins semantics, got {other:?}"),
581 }
582 }
583
584 #[tokio::test(flavor = "multi_thread")]
587 async fn multicast_parallel_timeout_stop_on_exception_false_propagates_timeout_error() {
588 #[derive(Clone)]
589 struct SlowBody;
590 impl camel_api::OutcomePipeline for SlowBody {
591 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
592 Box::new(SlowBody)
593 }
594 fn run<'a>(
595 &'a mut self,
596 exchange: Exchange,
597 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
598 Box::pin(async move {
599 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
600 PipelineOutcome::Completed(exchange)
601 })
602 }
603 }
604 #[derive(Clone)]
605 struct FastPassBody;
606 impl camel_api::OutcomePipeline for FastPassBody {
607 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
608 Box::new(FastPassBody)
609 }
610 fn run<'a>(
611 &'a mut self,
612 exchange: Exchange,
613 ) -> Pin<Box<dyn Future<Output = PipelineOutcome> + Send + 'a>> {
614 Box::pin(async move { PipelineOutcome::Completed(exchange) })
615 }
616 }
617
618 let target: Arc<dyn Fn(Vec<Exchange>) -> Exchange + Send + Sync> =
619 Arc::new(|exchanges: Vec<Exchange>| exchanges.into_iter().last().unwrap_or_default());
620
621 let mut seg = MulticastSegment {
622 branches: vec![
623 OutcomeSegment::new(Box::new(SlowBody)), OutcomeSegment::new(Box::new(FastPassBody)), ],
626 parallel: true,
627 parallel_limit: None,
628 stop_on_exception: false,
629 timeout: Some(std::time::Duration::from_millis(50)),
630 aggregator: target,
631 };
632
633 let ex = Exchange::new(Message::new("test"));
634 let result = OutcomePipeline::run(&mut seg, ex).await;
635
636 assert!(
639 matches!(result, PipelineOutcome::Failed(_)),
640 "Expected Failed due to timeout with stop_on_exception=false, got {result:?}"
641 );
642 }
643}