camel_processor/
multicast.rs1use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9use tower::Service;
10
11use camel_api::{
12 Body, BoxProcessor, CamelError, Exchange, MulticastConfig, MulticastStrategy, Value,
13};
14
15pub const CAMEL_MULTICAST_INDEX: &str = "CamelMulticastIndex";
19pub const CAMEL_MULTICAST_COMPLETE: &str = "CamelMulticastComplete";
21
22#[derive(Clone)]
34pub struct MulticastService {
35 endpoints: Vec<BoxProcessor>,
36 config: MulticastConfig,
37}
38
39impl MulticastService {
40 pub fn new(endpoints: Vec<BoxProcessor>, config: MulticastConfig) -> Result<Self, CamelError> {
42 config.validate()?;
43 Ok(Self { endpoints, config })
44 }
45}
46
47impl Service<Exchange> for MulticastService {
48 type Response = Exchange;
49 type Error = CamelError;
50 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
51
52 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
53 Poll::Ready(Ok(()))
58 }
59
60 fn call(&mut self, exchange: Exchange) -> Self::Future {
61 let original = exchange.clone();
62 let endpoints = self.endpoints.clone();
63 let config = self.config.clone();
64
65 Box::pin(async move {
66 if endpoints.is_empty() {
68 return Ok(original);
69 }
70
71 let total = endpoints.len();
72
73 let results = if config.parallel {
74 process_parallel(exchange, endpoints, config.parallel_limit, total).await
76 } else {
77 process_sequential(exchange, endpoints, config.stop_on_exception, total).await
79 };
80
81 aggregate(results, original, config.aggregation)
83 })
84 }
85}
86
87async fn process_sequential(
90 exchange: Exchange,
91 endpoints: Vec<BoxProcessor>,
92 stop_on_exception: bool,
93 total: usize,
94) -> Vec<Result<Exchange, CamelError>> {
95 let mut results = Vec::with_capacity(endpoints.len());
96
97 for (i, endpoint) in endpoints.into_iter().enumerate() {
98 let mut cloned_exchange = exchange.clone();
100
101 cloned_exchange.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
103 cloned_exchange.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
104
105 let mut endpoint = endpoint;
106 match tower::ServiceExt::ready(&mut endpoint).await {
107 Err(e) => {
108 results.push(Err(e));
109 if stop_on_exception {
110 break;
111 }
112 }
113 Ok(svc) => {
114 let result = svc.call(cloned_exchange).await;
115 let is_err = result.is_err();
116 results.push(result);
117 if stop_on_exception && is_err {
118 break;
119 }
120 }
121 }
122 }
123
124 results
125}
126
127async fn process_parallel(
130 exchange: Exchange,
131 endpoints: Vec<BoxProcessor>,
132 parallel_limit: Option<usize>,
133 total: usize,
134) -> Vec<Result<Exchange, CamelError>> {
135 use std::sync::Arc;
136 use tokio::sync::Semaphore;
137
138 let semaphore = parallel_limit.map(|limit| Arc::new(Semaphore::new(limit)));
139
140 let futures: Vec<_> = endpoints
142 .into_iter()
143 .enumerate()
144 .map(|(i, mut endpoint)| {
145 let mut ex = exchange.clone();
146 ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
147 ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
148 let sem = semaphore.clone();
149 async move {
150 let _permit = match &sem {
152 Some(s) => match s.acquire().await {
153 Ok(p) => Some(p),
154 Err(_) => {
155 return Err(CamelError::ProcessorError("semaphore closed".to_string()));
156 }
157 },
158 None => None,
159 };
160
161 tower::ServiceExt::ready(&mut endpoint).await?;
164 endpoint.call(ex).await
165 }
166 })
167 .collect();
168
169 futures::future::join_all(futures).await
171}
172
173fn aggregate(
176 results: Vec<Result<Exchange, CamelError>>,
177 original: Exchange,
178 strategy: MulticastStrategy,
179) -> Result<Exchange, CamelError> {
180 match strategy {
181 MulticastStrategy::LastWins => {
182 results.into_iter().last().unwrap_or_else(|| Ok(original))
185 }
186 MulticastStrategy::CollectAll => {
187 let mut bodies = Vec::new();
189 for result in results {
190 let ex = result?;
191 let value = match &ex.input.body {
192 Body::Text(s) => Value::String(s.clone()),
193 Body::Json(v) => v.clone(),
194 Body::Xml(s) => Value::String(s.clone()),
195 Body::Bytes(b) => Value::String(String::from_utf8_lossy(b).into_owned()),
196 Body::Empty => Value::Null,
197 Body::Stream(s) => serde_json::json!({
198 "_stream": {
199 "origin": s.metadata.origin,
200 "placeholder": true,
201 "hint": "Materialize exchange body with .into_bytes() before multicast aggregation"
202 }
203 }),
204 };
205 bodies.push(value);
206 }
207 let mut out = original;
208 out.input.body = Body::Json(Value::Array(bodies));
209 Ok(out)
210 }
211 MulticastStrategy::Original => Ok(original),
212 MulticastStrategy::Custom(fold_fn) => {
213 let mut iter = results.into_iter();
215 let first = iter.next().unwrap_or_else(|| Ok(original.clone()))?;
216 iter.try_fold(first, |acc, next_result| {
217 let next = next_result?;
218 Ok(fold_fn(acc, next))
219 })
220 }
221 }
222}
223
224#[cfg(test)]
225#[path = "multicast_tests.rs"]
226mod tests;