1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::task::{Context, Poll};
6
7use tower::Service;
8use tower::ServiceExt;
9
10use camel_api::{
11 BoxProcessor, CamelError, Exchange, LoadBalanceStrategy, LoadBalancerConfig, Value,
12};
13
14use crate::multicast::{CAMEL_MULTICAST_COMPLETE, CAMEL_MULTICAST_INDEX};
15
16#[derive(Clone)]
17pub struct LoadBalancerService {
18 endpoints: Vec<BoxProcessor>,
19 config: LoadBalancerConfig,
20 round_robin_index: Arc<AtomicUsize>,
21 failover_index: Arc<AtomicUsize>,
22}
23
24impl LoadBalancerService {
25 pub fn new(endpoints: Vec<BoxProcessor>, config: LoadBalancerConfig) -> Self {
26 Self {
27 endpoints,
28 config,
29 round_robin_index: Arc::new(AtomicUsize::new(0)),
30 failover_index: Arc::new(AtomicUsize::new(0)),
31 }
32 }
33}
34
35impl Service<Exchange> for LoadBalancerService {
36 type Response = Exchange;
37 type Error = CamelError;
38 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
39
40 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
41 for endpoint in &mut self.endpoints {
42 match endpoint.poll_ready(cx) {
43 Poll::Pending => return Poll::Pending,
44 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
45 Poll::Ready(Ok(())) => {}
46 }
47 }
48 Poll::Ready(Ok(()))
49 }
50
51 fn call(&mut self, exchange: Exchange) -> Self::Future {
52 let endpoints = self.endpoints.clone();
53 let config = self.config.clone();
54 let round_robin_index = self.round_robin_index.clone();
55 let failover_index = self.failover_index.clone();
56
57 Box::pin(async move {
58 if endpoints.is_empty() {
59 return Ok(exchange);
60 }
61
62 if config.parallel {
63 process_parallel(exchange, endpoints).await
64 } else {
65 match &config.strategy {
66 LoadBalanceStrategy::RoundRobin => {
67 process_round_robin(exchange, endpoints, round_robin_index).await
68 }
69 LoadBalanceStrategy::Random => process_random(exchange, endpoints).await,
70 LoadBalanceStrategy::Weighted(weights) => {
71 process_weighted(exchange, endpoints, weights).await
72 }
73 LoadBalanceStrategy::Failover => {
74 process_failover(exchange, endpoints, failover_index).await
75 }
76 }
77 }
78 })
79 }
80}
81
82async fn process_round_robin(
83 exchange: Exchange,
84 endpoints: Vec<BoxProcessor>,
85 index: Arc<AtomicUsize>,
86) -> Result<Exchange, CamelError> {
87 let len = endpoints.len();
88 let idx = index.fetch_add(1, Ordering::SeqCst) % len;
89 let mut endpoint = endpoints[idx].clone();
90 endpoint.ready().await?.call(exchange).await
91}
92
93async fn process_random(
94 exchange: Exchange,
95 endpoints: Vec<BoxProcessor>,
96) -> Result<Exchange, CamelError> {
97 let len = endpoints.len();
98 let idx = rand::random::<usize>() % len;
99 let mut endpoint = endpoints[idx].clone();
100 endpoint.ready().await?.call(exchange).await
101}
102
103async fn process_weighted(
104 exchange: Exchange,
105 endpoints: Vec<BoxProcessor>,
106 weights: &[(String, u32)],
107) -> Result<Exchange, CamelError> {
108 if endpoints.is_empty() || weights.is_empty() {
109 return Ok(exchange);
110 }
111
112 let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
113 let total: u32 = numeric_weights.iter().sum();
114
115 if total == 0 {
116 return Err(CamelError::ProcessorError(
117 "Weighted load balancer has zero total weight".to_string(),
118 ));
119 }
120
121 let mut r = rand::random::<u32>() % total;
122 let mut selected_idx = 0;
123 for (i, w) in numeric_weights.iter().enumerate() {
124 if r < *w {
125 selected_idx = i.min(endpoints.len() - 1);
126 break;
127 }
128 r -= w;
129 }
130
131 let mut endpoint = endpoints[selected_idx].clone();
132 endpoint.ready().await?.call(exchange).await
133}
134
135async fn process_failover(
136 exchange: Exchange,
137 endpoints: Vec<BoxProcessor>,
138 start_index: Arc<AtomicUsize>,
139) -> Result<Exchange, CamelError> {
140 let len = endpoints.len();
141 let start = start_index.load(Ordering::SeqCst);
142 let mut last_error = None;
143
144 for i in 0..len {
145 let idx = (start + i) % len;
146 let mut endpoint = endpoints[idx].clone();
147 match endpoint.ready().await?.call(exchange.clone()).await {
148 Ok(ex) => {
149 start_index.store((idx + 1) % len, Ordering::SeqCst);
150 return Ok(ex);
151 }
152 Err(e) => {
153 last_error = Some(e);
154 }
155 }
156 }
157
158 Err(last_error.unwrap_or_else(|| {
159 CamelError::ProcessorError("All endpoints failed in failover".to_string())
160 }))
161}
162
163async fn process_parallel(
164 exchange: Exchange,
165 endpoints: Vec<BoxProcessor>,
166) -> Result<Exchange, CamelError> {
167 use futures::future::join_all;
168
169 let total = endpoints.len();
170 let futures: Vec<_> = endpoints
171 .into_iter()
172 .enumerate()
173 .map(|(i, mut endpoint)| {
174 let mut ex = exchange.clone();
175 ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
176 ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
177 async move {
178 tower::ServiceExt::ready(&mut endpoint).await?;
179 endpoint.call(ex).await
180 }
181 })
182 .collect();
183
184 let results: Vec<Result<Exchange, CamelError>> = join_all(futures).await;
185
186 for result in &results {
187 if let Err(e) = result {
188 return Err(e.clone());
189 }
190 }
191
192 results.into_iter().last().unwrap_or(Ok(exchange))
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use camel_api::{BoxProcessorExt, Message};
199 use std::sync::Mutex;
200 use tower::ServiceExt;
201
202 fn counting_processor() -> (BoxProcessor, Arc<AtomicUsize>) {
203 let count = Arc::new(AtomicUsize::new(0));
204 let count_clone = count.clone();
205 let processor = BoxProcessor::from_fn(move |ex| {
206 count_clone.fetch_add(1, Ordering::SeqCst);
207 Box::pin(async move { Ok(ex) })
208 });
209 (processor, count)
210 }
211
212 #[tokio::test]
213 async fn test_round_robin_distribution() {
214 let (p1, c1) = counting_processor();
215 let (p2, c2) = counting_processor();
216 let (p3, c3) = counting_processor();
217
218 let config = LoadBalancerConfig::round_robin();
219 let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
220
221 for _ in 0..6 {
222 let ex = Exchange::new(Message::new("test"));
223 svc.ready().await.unwrap().call(ex).await.unwrap();
224 }
225
226 assert_eq!(c1.load(Ordering::SeqCst), 2);
227 assert_eq!(c2.load(Ordering::SeqCst), 2);
228 assert_eq!(c3.load(Ordering::SeqCst), 2);
229 }
230
231 #[tokio::test]
232 async fn test_random_distribution() {
233 let (p1, c1) = counting_processor();
234 let (p2, c2) = counting_processor();
235
236 let config = LoadBalancerConfig::random();
237 let mut svc = LoadBalancerService::new(vec![p1, p2], config);
238
239 for _ in 0..100 {
240 let ex = Exchange::new(Message::new("test"));
241 svc.ready().await.unwrap().call(ex).await.unwrap();
242 }
243
244 let total = c1.load(Ordering::SeqCst) + c2.load(Ordering::SeqCst);
245 assert_eq!(total, 100);
246 assert!(c1.load(Ordering::SeqCst) > 20);
247 assert!(c2.load(Ordering::SeqCst) > 20);
248 }
249
250 #[tokio::test]
251 async fn test_failover_on_error() {
252 let failing = BoxProcessor::from_fn(|_ex| {
253 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
254 });
255 let (success, count) = counting_processor();
256
257 let config = LoadBalancerConfig::failover();
258 let mut svc = LoadBalancerService::new(vec![failing, success], config);
259
260 let ex = Exchange::new(Message::new("test"));
261 let _result = svc.ready().await.unwrap().call(ex).await.unwrap();
262
263 assert_eq!(count.load(Ordering::SeqCst), 1);
264 }
265
266 #[tokio::test]
267 async fn test_failover_preserves_original_exchange() {
268 let seen_body: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
270 let seen_body_clone = seen_body.clone();
271
272 let failing = BoxProcessor::from_fn(|_ex| {
273 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
274 });
275
276 let retry = BoxProcessor::from_fn(move |ex: Exchange| {
277 let seen = seen_body_clone.clone();
278 Box::pin(async move {
279 if let Some(text) = ex.input.body.as_text() {
280 *seen.lock().unwrap() = Some(text.to_string());
281 }
282 Ok(ex)
283 })
284 });
285
286 let config = LoadBalancerConfig::failover();
287 let mut svc = LoadBalancerService::new(vec![failing, retry], config);
288
289 let ex = Exchange::new(Message::new("original body"));
290 svc.ready().await.unwrap().call(ex).await.unwrap();
291
292 assert_eq!(
293 seen_body.lock().unwrap().as_deref(),
294 Some("original body"),
295 "retry endpoint must receive the original exchange body, not a blank one"
296 );
297 }
298
299 #[tokio::test]
300 async fn test_failover_all_fail() {
301 let failing = BoxProcessor::from_fn(|_ex| {
302 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
303 });
304
305 let config = LoadBalancerConfig::failover();
306 let mut svc = LoadBalancerService::new(vec![failing.clone(), failing], config);
307
308 let ex = Exchange::new(Message::new("test"));
309 let result = svc.ready().await.unwrap().call(ex).await;
310
311 assert!(result.is_err());
312 }
313
314 #[tokio::test]
315 async fn test_parallel_sends_to_all() {
316 let (p1, c1) = counting_processor();
317 let (p2, c2) = counting_processor();
318 let (p3, c3) = counting_processor();
319
320 let config = LoadBalancerConfig::round_robin().parallel(true);
321 let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
322
323 let ex = Exchange::new(Message::new("test"));
324 svc.ready().await.unwrap().call(ex).await.unwrap();
325
326 assert_eq!(c1.load(Ordering::SeqCst), 1);
327 assert_eq!(c2.load(Ordering::SeqCst), 1);
328 assert_eq!(c3.load(Ordering::SeqCst), 1);
329 }
330
331 #[tokio::test]
332 async fn test_empty_endpoints() {
333 let config = LoadBalancerConfig::round_robin();
334 let mut svc = LoadBalancerService::new(vec![], config);
335
336 let ex = Exchange::new(Message::new("test"));
337 let result = svc.ready().await.unwrap().call(ex).await;
338
339 assert!(result.is_ok());
340 }
341}