1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::task::{Context, Poll};
6
7use tower::Service;
8use tower::ServiceExt;
9
10use camel_api::{BoxProcessor, CamelError, Exchange, LoadBalanceStrategy, LoadBalancerConfig};
11
12#[derive(Clone)]
13pub struct LoadBalancerService {
14 endpoints: Vec<BoxProcessor>,
15 config: LoadBalancerConfig,
16 round_robin_index: Arc<AtomicUsize>,
17 failover_index: Arc<AtomicUsize>,
18}
19
20impl LoadBalancerService {
21 pub fn new(endpoints: Vec<BoxProcessor>, config: LoadBalancerConfig) -> Self {
22 Self {
23 endpoints,
24 config,
25 round_robin_index: Arc::new(AtomicUsize::new(0)),
26 failover_index: Arc::new(AtomicUsize::new(0)),
27 }
28 }
29}
30
31impl Service<Exchange> for LoadBalancerService {
32 type Response = Exchange;
33 type Error = CamelError;
34 type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
35
36 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
37 for endpoint in &mut self.endpoints {
38 match endpoint.poll_ready(cx) {
39 Poll::Pending => return Poll::Pending,
40 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
41 Poll::Ready(Ok(())) => {}
42 }
43 }
44 Poll::Ready(Ok(()))
45 }
46
47 fn call(&mut self, exchange: Exchange) -> Self::Future {
48 let endpoints = self.endpoints.clone();
49 let config = self.config.clone();
50 let round_robin_index = self.round_robin_index.clone();
51 let failover_index = self.failover_index.clone();
52
53 Box::pin(async move {
54 if endpoints.is_empty() {
55 return Ok(exchange);
56 }
57
58 match &config.strategy {
59 LoadBalanceStrategy::RoundRobin => {
60 process_round_robin(exchange, endpoints, round_robin_index).await
61 }
62 LoadBalanceStrategy::Random => process_random(exchange, endpoints).await,
63 LoadBalanceStrategy::Weighted(weights) => {
64 process_weighted(exchange, endpoints, weights).await
65 }
66 LoadBalanceStrategy::Failover => {
67 process_failover(exchange, endpoints, failover_index).await
68 }
69 }
70 })
71 }
72}
73
74async fn process_round_robin(
75 exchange: Exchange,
76 endpoints: Vec<BoxProcessor>,
77 index: Arc<AtomicUsize>,
78) -> Result<Exchange, CamelError> {
79 let len = endpoints.len();
80 let idx = index.fetch_add(1, Ordering::SeqCst) % len;
81 let mut endpoint = endpoints[idx].clone();
82 endpoint.ready().await?.call(exchange).await
83}
84
85async fn process_random(
86 exchange: Exchange,
87 endpoints: Vec<BoxProcessor>,
88) -> Result<Exchange, CamelError> {
89 let len = endpoints.len();
90 let idx = rand::random_range(0..len);
91 let mut endpoint = endpoints[idx].clone();
92 endpoint.ready().await?.call(exchange).await
93}
94
95async fn process_weighted(
96 exchange: Exchange,
97 endpoints: Vec<BoxProcessor>,
98 weights: &[(String, u32)],
99) -> Result<Exchange, CamelError> {
100 if endpoints.is_empty() || weights.is_empty() {
101 return Ok(exchange);
102 }
103
104 let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
105 let total: u32 = numeric_weights.iter().sum();
106
107 if total == 0 {
108 return Err(CamelError::ProcessorError(
109 "Weighted load balancer has zero total weight".to_string(),
110 ));
111 }
112
113 let mut r = rand::random::<u32>() % total;
114 let mut selected_idx = 0;
115 for (i, w) in numeric_weights.iter().enumerate() {
116 if r < *w {
117 selected_idx = i.min(endpoints.len() - 1);
118 break;
119 }
120 r -= w;
121 }
122
123 let mut endpoint = endpoints[selected_idx].clone();
124 endpoint.ready().await?.call(exchange).await
125}
126
127async fn process_failover(
128 exchange: Exchange,
129 endpoints: Vec<BoxProcessor>,
130 start_index: Arc<AtomicUsize>,
131) -> Result<Exchange, CamelError> {
132 let len = endpoints.len();
133 let start = start_index.load(Ordering::SeqCst);
134 let mut last_error = None;
135
136 for i in 0..len {
137 let idx = (start + i) % len;
138 let mut endpoint = endpoints[idx].clone();
139 match endpoint.ready().await?.call(exchange.clone()).await {
140 Ok(ex) => {
141 start_index.store((idx + 1) % len, Ordering::SeqCst);
142 return Ok(ex);
143 }
144 Err(e) => {
145 last_error = Some(e);
146 }
147 }
148 }
149
150 Err(last_error.unwrap_or_else(|| {
151 CamelError::ProcessorError("All endpoints failed in failover".to_string())
152 }))
153}
154
155#[derive(Clone)]
167pub struct LoadBalanceSegment {
168 pub destinations: Vec<camel_api::OutcomeSegment>,
169 pub strategy: camel_api::LoadBalanceStrategy,
170 pub round_robin_index: Arc<AtomicUsize>,
172}
173
174impl camel_api::OutcomePipeline for LoadBalanceSegment {
175 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
176 Box::new(self.clone())
177 }
178
179 fn run<'a>(
180 &'a mut self,
181 exchange: camel_api::Exchange,
182 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
183 Box::pin(async move {
184 let len = self.destinations.len();
185 if len == 0 {
186 return camel_api::PipelineOutcome::Completed(exchange);
187 }
188
189 let start_idx = match &self.strategy {
190 camel_api::LoadBalanceStrategy::RoundRobin => {
191 self.round_robin_index.fetch_add(1, Ordering::SeqCst) % len
192 }
193 camel_api::LoadBalanceStrategy::Random => rand::random_range(0..len),
194 camel_api::LoadBalanceStrategy::Weighted(weights) => pick_weighted(weights, len),
195 camel_api::LoadBalanceStrategy::Failover => 0,
196 };
197
198 let mut idx = start_idx;
199 let mut last_err: Option<camel_api::CamelError> = None;
200 loop {
201 if idx >= len {
202 return camel_api::PipelineOutcome::Failed(last_err.unwrap_or_else(|| {
203 camel_api::CamelError::ProcessorError(
204 "load_balance: all destinations exhausted".to_string(),
205 )
206 }));
207 }
208 match self.destinations[idx].run(exchange.clone()).await {
209 camel_api::PipelineOutcome::Completed(ex) => {
210 return camel_api::PipelineOutcome::Completed(ex);
211 }
212 camel_api::PipelineOutcome::Stopped(ex) => {
213 return camel_api::PipelineOutcome::Stopped(ex);
214 }
215 camel_api::PipelineOutcome::Failed(err) => match self.strategy {
216 camel_api::LoadBalanceStrategy::Failover => {
217 last_err = Some(err);
218 idx += 1;
219 continue;
220 }
221 _ => return camel_api::PipelineOutcome::Failed(err),
222 },
223 }
224 }
225 })
226 }
227}
228
229fn pick_weighted(weights: &[(String, u32)], len: usize) -> usize {
231 if weights.is_empty() || len == 0 {
232 return 0;
233 }
234 let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
235 let total: u32 = numeric_weights.iter().sum();
236 if total == 0 {
237 return 0;
238 }
239 let mut r = rand::random::<u32>() % total;
240 for (i, w) in numeric_weights.iter().enumerate() {
241 if r < *w {
242 return i.min(len - 1);
243 }
244 r -= w;
245 }
246 len - 1
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252 use camel_api::{BoxProcessorExt, Message};
253 use std::sync::Mutex;
254 use tower::ServiceExt;
255
256 fn counting_processor() -> (BoxProcessor, Arc<AtomicUsize>) {
257 let count = Arc::new(AtomicUsize::new(0));
258 let count_clone = count.clone();
259 let processor = BoxProcessor::from_fn(move |ex| {
260 count_clone.fetch_add(1, Ordering::SeqCst);
261 Box::pin(async move { Ok(ex) })
262 });
263 (processor, count)
264 }
265
266 #[tokio::test]
267 async fn test_round_robin_distribution() {
268 let (p1, c1) = counting_processor();
269 let (p2, c2) = counting_processor();
270 let (p3, c3) = counting_processor();
271
272 let config = LoadBalancerConfig::round_robin();
273 let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
274
275 for _ in 0..6 {
276 let ex = Exchange::new(Message::new("test"));
277 svc.ready().await.unwrap().call(ex).await.unwrap();
278 }
279
280 assert_eq!(c1.load(Ordering::SeqCst), 2);
281 assert_eq!(c2.load(Ordering::SeqCst), 2);
282 assert_eq!(c3.load(Ordering::SeqCst), 2);
283 }
284
285 #[tokio::test]
286 async fn test_random_distribution() {
287 let (p1, c1) = counting_processor();
288 let (p2, c2) = counting_processor();
289
290 let config = LoadBalancerConfig::random();
291 let mut svc = LoadBalancerService::new(vec![p1, p2], config);
292
293 for _ in 0..100 {
294 let ex = Exchange::new(Message::new("test"));
295 svc.ready().await.unwrap().call(ex).await.unwrap();
296 }
297
298 let total = c1.load(Ordering::SeqCst) + c2.load(Ordering::SeqCst);
299 assert_eq!(total, 100);
300 assert!(c1.load(Ordering::SeqCst) > 20);
301 assert!(c2.load(Ordering::SeqCst) > 20);
302 }
303
304 #[tokio::test]
305 async fn test_failover_on_error() {
306 let failing = BoxProcessor::from_fn(|_ex| {
307 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
308 });
309 let (success, count) = counting_processor();
310
311 let config = LoadBalancerConfig::failover();
312 let mut svc = LoadBalancerService::new(vec![failing, success], config);
313
314 let ex = Exchange::new(Message::new("test"));
315 let _result = svc.ready().await.unwrap().call(ex).await.unwrap();
316
317 assert_eq!(count.load(Ordering::SeqCst), 1);
318 }
319
320 #[tokio::test]
321 async fn test_failover_preserves_original_exchange() {
322 let seen_body: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
324 let seen_body_clone = seen_body.clone();
325
326 let failing = BoxProcessor::from_fn(|_ex| {
327 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
328 });
329
330 let retry = BoxProcessor::from_fn(move |ex: Exchange| {
331 let seen = seen_body_clone.clone();
332 Box::pin(async move {
333 if let Some(text) = ex.input.body.as_text() {
334 *seen.lock().unwrap() = Some(text.to_string());
335 }
336 Ok(ex)
337 })
338 });
339
340 let config = LoadBalancerConfig::failover();
341 let mut svc = LoadBalancerService::new(vec![failing, retry], config);
342
343 let ex = Exchange::new(Message::new("original body"));
344 svc.ready().await.unwrap().call(ex).await.unwrap();
345
346 assert_eq!(
347 seen_body.lock().unwrap().as_deref(),
348 Some("original body"),
349 "retry endpoint must receive the original exchange body, not a blank one"
350 );
351 }
352
353 #[tokio::test]
354 async fn test_failover_all_fail() {
355 let failing = BoxProcessor::from_fn(|_ex| {
356 Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
357 });
358
359 let config = LoadBalancerConfig::failover();
360 let mut svc = LoadBalancerService::new(vec![failing.clone(), failing], config);
361
362 let ex = Exchange::new(Message::new("test"));
363 let result = svc.ready().await.unwrap().call(ex).await;
364
365 assert!(result.is_err());
366 }
367
368 #[tokio::test]
369 async fn test_empty_endpoints() {
370 let config = LoadBalancerConfig::round_robin();
371 let mut svc = LoadBalancerService::new(vec![], config);
372
373 let ex = Exchange::new(Message::new("test"));
374 let result = svc.ready().await.unwrap().call(ex).await;
375
376 assert!(result.is_ok());
377 }
378
379 struct StoppingBody;
383 impl camel_api::OutcomePipeline for StoppingBody {
384 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
385 Box::new(StoppingBody)
386 }
387 fn run<'a>(
388 &'a mut self,
389 mut ex: Exchange,
390 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
391 Box::pin(async move {
392 ex.input.body = camel_api::Body::Text("lb-stopped".to_string());
393 camel_api::PipelineOutcome::Stopped(ex)
394 })
395 }
396 }
397
398 struct RecordingBody(Arc<AtomicUsize>);
400 impl camel_api::OutcomePipeline for RecordingBody {
401 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
402 Box::new(RecordingBody(Arc::clone(&self.0)))
403 }
404 fn run<'a>(
405 &'a mut self,
406 ex: Exchange,
407 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
408 let count = Arc::clone(&self.0);
409 Box::pin(async move {
410 count.fetch_add(1, Ordering::SeqCst);
411 camel_api::PipelineOutcome::Completed(ex)
412 })
413 }
414 }
415
416 struct FailingBody;
418 impl camel_api::OutcomePipeline for FailingBody {
419 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
420 Box::new(FailingBody)
421 }
422 fn run<'a>(
423 &'a mut self,
424 _ex: Exchange,
425 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
426 Box::pin(async {
427 camel_api::PipelineOutcome::Failed(CamelError::ProcessorError(
428 "intentional fail".to_string(),
429 ))
430 })
431 }
432 }
433
434 struct RecoveringBody;
436 impl camel_api::OutcomePipeline for RecoveringBody {
437 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
438 Box::new(RecoveringBody)
439 }
440 fn run<'a>(
441 &'a mut self,
442 mut ex: Exchange,
443 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
444 Box::pin(async move {
445 ex.input.body = camel_api::Body::Text("recovered".to_string());
446 camel_api::PipelineOutcome::Completed(ex)
447 })
448 }
449 }
450
451 #[tokio::test]
454 async fn load_balance_child_stop_propagates() {
455 let count = Arc::new(AtomicUsize::new(0));
456 let mut seg = LoadBalanceSegment {
457 destinations: vec![
458 camel_api::OutcomeSegment::new(Box::new(StoppingBody)),
459 camel_api::OutcomeSegment::new(Box::new(RecordingBody(count.clone()))),
460 ],
461 strategy: camel_api::LoadBalanceStrategy::RoundRobin,
462 round_robin_index: Arc::new(AtomicUsize::new(0)),
463 };
464
465 let ex = Exchange::new(Message::new("trigger"));
466 let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
467
468 match result {
469 camel_api::PipelineOutcome::Stopped(ex) => {
470 assert_eq!(
471 ex.input.body.as_text(),
472 Some("lb-stopped"),
473 "Stopped exchange must preserve mutation"
474 );
475 }
476 other => panic!("expected PipelineOutcome::Stopped, got {other:?}"),
477 }
478 assert_eq!(
479 count.load(Ordering::SeqCst),
480 0,
481 "second destination must NOT be tried when first is Stopped"
482 );
483 }
484
485 #[tokio::test]
488 async fn load_balance_child_failure_retries_whole_step() {
489 let mut seg = LoadBalanceSegment {
490 destinations: vec![
491 camel_api::OutcomeSegment::new(Box::new(FailingBody)),
492 camel_api::OutcomeSegment::new(Box::new(RecoveringBody)),
493 ],
494 strategy: camel_api::LoadBalanceStrategy::Failover,
495 round_robin_index: Arc::new(AtomicUsize::new(0)),
496 };
497
498 let ex = Exchange::new(Message::new("trigger"));
499 let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
500
501 match result {
502 camel_api::PipelineOutcome::Completed(ex) => {
503 assert_eq!(
504 ex.input.body.as_text(),
505 Some("recovered"),
506 "failover must produce the second destination's output"
507 );
508 }
509 other => panic!("expected PipelineOutcome::Completed, got {other:?}"),
510 }
511 }
512
513 #[tokio::test]
516 async fn load_balance_strategy_selection_preserved() {
517 let c1 = Arc::new(AtomicUsize::new(0));
518 let c2 = Arc::new(AtomicUsize::new(0));
519 let c3 = Arc::new(AtomicUsize::new(0));
520
521 let mut seg = LoadBalanceSegment {
522 destinations: vec![
523 camel_api::OutcomeSegment::new(Box::new(RecordingBody(c1.clone()))),
524 camel_api::OutcomeSegment::new(Box::new(RecordingBody(c2.clone()))),
525 camel_api::OutcomeSegment::new(Box::new(RecordingBody(c3.clone()))),
526 ],
527 strategy: camel_api::LoadBalanceStrategy::RoundRobin,
528 round_robin_index: Arc::new(AtomicUsize::new(0)),
529 };
530
531 for _ in 0..3 {
532 let ex = Exchange::new(Message::new("test"));
533 let _result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
534 }
535
536 assert_eq!(
537 c1.load(Ordering::SeqCst),
538 1,
539 "round-robin: dest 0 call count"
540 );
541 assert_eq!(
542 c2.load(Ordering::SeqCst),
543 1,
544 "round-robin: dest 1 call count"
545 );
546 assert_eq!(
547 c3.load(Ordering::SeqCst),
548 1,
549 "round-robin: dest 2 call count"
550 );
551 }
552
553 #[tokio::test]
556 async fn load_balance_segment_failover_exhaustion_preserves_last_error() {
557 let err1 = CamelError::ProcessorError("first-dest-failed".to_string());
558 let err2 = CamelError::ProcessorError("second-dest-failed".to_string());
559
560 struct FailWith(CamelError);
561 impl camel_api::OutcomePipeline for FailWith {
562 fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
563 Box::new(FailWith(self.0.clone()))
564 }
565 fn run<'a>(
566 &'a mut self,
567 _ex: Exchange,
568 ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
569 let e = self.0.clone();
570 Box::pin(async move { camel_api::PipelineOutcome::Failed(e) })
571 }
572 }
573
574 let mut seg = LoadBalanceSegment {
575 destinations: vec![
576 camel_api::OutcomeSegment::new(Box::new(FailWith(err1))),
577 camel_api::OutcomeSegment::new(Box::new(FailWith(err2.clone()))),
578 ],
579 strategy: camel_api::LoadBalanceStrategy::Failover,
580 round_robin_index: Arc::new(AtomicUsize::new(0)),
581 };
582
583 let ex = Exchange::new(Message::new("test"));
584 let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
585
586 match result {
587 camel_api::PipelineOutcome::Failed(err) => {
588 assert_eq!(
589 err.to_string(),
590 err2.to_string(),
591 "failover exhaustion must return the LAST destination error, not a generic message"
592 );
593 }
594 other => panic!(
595 "expected PipelineOutcome::Failed(last error), got {:?}",
596 other
597 ),
598 }
599 }
600}