Skip to main content

camel_processor/
load_balancer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::task::{Context, Poll};
6
7use tower::Service;
8use tower::ServiceExt;
9
10use camel_api::{BoxProcessor, CamelError, Exchange, LoadBalanceStrategy, LoadBalancerConfig};
11
12#[derive(Clone)]
13pub struct LoadBalancerService {
14    endpoints: Vec<BoxProcessor>,
15    config: LoadBalancerConfig,
16    round_robin_index: Arc<AtomicUsize>,
17    failover_index: Arc<AtomicUsize>,
18}
19
20impl LoadBalancerService {
21    pub fn new(endpoints: Vec<BoxProcessor>, config: LoadBalancerConfig) -> Self {
22        Self {
23            endpoints,
24            config,
25            round_robin_index: Arc::new(AtomicUsize::new(0)),
26            failover_index: Arc::new(AtomicUsize::new(0)),
27        }
28    }
29}
30
31impl Service<Exchange> for LoadBalancerService {
32    type Response = Exchange;
33    type Error = CamelError;
34    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
35
36    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
37        for endpoint in &mut self.endpoints {
38            match endpoint.poll_ready(cx) {
39                Poll::Pending => return Poll::Pending,
40                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
41                Poll::Ready(Ok(())) => {}
42            }
43        }
44        Poll::Ready(Ok(()))
45    }
46
47    fn call(&mut self, exchange: Exchange) -> Self::Future {
48        let endpoints = self.endpoints.clone();
49        let config = self.config.clone();
50        let round_robin_index = self.round_robin_index.clone();
51        let failover_index = self.failover_index.clone();
52
53        Box::pin(async move {
54            if endpoints.is_empty() {
55                return Ok(exchange);
56            }
57
58            match &config.strategy {
59                LoadBalanceStrategy::RoundRobin => {
60                    process_round_robin(exchange, endpoints, round_robin_index).await
61                }
62                LoadBalanceStrategy::Random => process_random(exchange, endpoints).await,
63                LoadBalanceStrategy::Weighted(weights) => {
64                    process_weighted(exchange, endpoints, weights).await
65                }
66                LoadBalanceStrategy::Failover => {
67                    process_failover(exchange, endpoints, failover_index).await
68                }
69            }
70        })
71    }
72}
73
74async fn process_round_robin(
75    exchange: Exchange,
76    endpoints: Vec<BoxProcessor>,
77    index: Arc<AtomicUsize>,
78) -> Result<Exchange, CamelError> {
79    let len = endpoints.len();
80    let idx = index.fetch_add(1, Ordering::SeqCst) % len;
81    let mut endpoint = endpoints[idx].clone();
82    endpoint.ready().await?.call(exchange).await
83}
84
85async fn process_random(
86    exchange: Exchange,
87    endpoints: Vec<BoxProcessor>,
88) -> Result<Exchange, CamelError> {
89    let len = endpoints.len();
90    let idx = rand::random_range(0..len);
91    let mut endpoint = endpoints[idx].clone();
92    endpoint.ready().await?.call(exchange).await
93}
94
95async fn process_weighted(
96    exchange: Exchange,
97    endpoints: Vec<BoxProcessor>,
98    weights: &[(String, u32)],
99) -> Result<Exchange, CamelError> {
100    if endpoints.is_empty() || weights.is_empty() {
101        return Ok(exchange);
102    }
103
104    let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
105    let total: u32 = numeric_weights.iter().sum();
106
107    if total == 0 {
108        return Err(CamelError::ProcessorError(
109            "Weighted load balancer has zero total weight".to_string(),
110        ));
111    }
112
113    let mut r = rand::random::<u32>() % total;
114    let mut selected_idx = 0;
115    for (i, w) in numeric_weights.iter().enumerate() {
116        if r < *w {
117            selected_idx = i.min(endpoints.len() - 1);
118            break;
119        }
120        r -= w;
121    }
122
123    let mut endpoint = endpoints[selected_idx].clone();
124    endpoint.ready().await?.call(exchange).await
125}
126
127async fn process_failover(
128    exchange: Exchange,
129    endpoints: Vec<BoxProcessor>,
130    start_index: Arc<AtomicUsize>,
131) -> Result<Exchange, CamelError> {
132    let len = endpoints.len();
133    let start = start_index.load(Ordering::SeqCst);
134    let mut last_error = None;
135
136    for i in 0..len {
137        let idx = (start + i) % len;
138        let mut endpoint = endpoints[idx].clone();
139        match endpoint.ready().await?.call(exchange.clone()).await {
140            Ok(ex) => {
141                start_index.store((idx + 1) % len, Ordering::SeqCst);
142                return Ok(ex);
143            }
144            Err(e) => {
145                last_error = Some(e);
146            }
147        }
148    }
149
150    Err(last_error.unwrap_or_else(|| {
151        CamelError::ProcessorError("All endpoints failed in failover".to_string())
152    }))
153}
154
155// ── LoadBalanceSegment (ADR-0025 OutcomePipeline) ────────────────────────
156
157/// Outcome-aware LoadBalance segment. Holds N destinations + a strategy.
158/// On each call: strategy picks ONE destination (round-robin / failover /
159/// random / weighted), runs it. If chosen destination returns Completed,
160/// return Completed. If Stopped: return Stopped immediately (no failover —
161/// Stop is successful control flow). If Failed: strategy decides (failover
162/// retries next dest, others return Failed).
163///
164/// This differs from Multicast (which runs all branches) — LoadBalance picks
165/// exactly one. The parallel cancellation logic from T13/T15 does NOT apply.
166#[derive(Clone)]
167pub struct LoadBalanceSegment {
168    pub destinations: Vec<camel_api::OutcomeSegment>,
169    pub strategy: camel_api::LoadBalanceStrategy,
170    /// Shared round-robin index for interior mutability across cloned segments.
171    pub round_robin_index: Arc<AtomicUsize>,
172}
173
174impl camel_api::OutcomePipeline for LoadBalanceSegment {
175    fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
176        Box::new(self.clone())
177    }
178
179    fn run<'a>(
180        &'a mut self,
181        exchange: camel_api::Exchange,
182    ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
183        Box::pin(async move {
184            let len = self.destinations.len();
185            if len == 0 {
186                return camel_api::PipelineOutcome::Completed(exchange);
187            }
188
189            let start_idx = match &self.strategy {
190                camel_api::LoadBalanceStrategy::RoundRobin => {
191                    self.round_robin_index.fetch_add(1, Ordering::SeqCst) % len
192                }
193                camel_api::LoadBalanceStrategy::Random => rand::random_range(0..len),
194                camel_api::LoadBalanceStrategy::Weighted(weights) => pick_weighted(weights, len),
195                camel_api::LoadBalanceStrategy::Failover => 0,
196            };
197
198            let mut idx = start_idx;
199            let mut last_err: Option<camel_api::CamelError> = None;
200            loop {
201                if idx >= len {
202                    return camel_api::PipelineOutcome::Failed(last_err.unwrap_or_else(|| {
203                        camel_api::CamelError::ProcessorError(
204                            "load_balance: all destinations exhausted".to_string(),
205                        )
206                    }));
207                }
208                match self.destinations[idx].run(exchange.clone()).await {
209                    camel_api::PipelineOutcome::Completed(ex) => {
210                        return camel_api::PipelineOutcome::Completed(ex);
211                    }
212                    camel_api::PipelineOutcome::Stopped(ex) => {
213                        return camel_api::PipelineOutcome::Stopped(ex);
214                    }
215                    camel_api::PipelineOutcome::Failed(err) => match self.strategy {
216                        camel_api::LoadBalanceStrategy::Failover => {
217                            last_err = Some(err);
218                            idx += 1;
219                            continue;
220                        }
221                        _ => return camel_api::PipelineOutcome::Failed(err),
222                    },
223                }
224            }
225        })
226    }
227}
228
229/// Pick a destination index using weighted random selection.
230fn pick_weighted(weights: &[(String, u32)], len: usize) -> usize {
231    if weights.is_empty() || len == 0 {
232        return 0;
233    }
234    let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
235    let total: u32 = numeric_weights.iter().sum();
236    if total == 0 {
237        return 0;
238    }
239    let mut r = rand::random::<u32>() % total;
240    for (i, w) in numeric_weights.iter().enumerate() {
241        if r < *w {
242            return i.min(len - 1);
243        }
244        r -= w;
245    }
246    len - 1
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252    use camel_api::{BoxProcessorExt, Message};
253    use std::sync::Mutex;
254    use tower::ServiceExt;
255
256    fn counting_processor() -> (BoxProcessor, Arc<AtomicUsize>) {
257        let count = Arc::new(AtomicUsize::new(0));
258        let count_clone = count.clone();
259        let processor = BoxProcessor::from_fn(move |ex| {
260            count_clone.fetch_add(1, Ordering::SeqCst);
261            Box::pin(async move { Ok(ex) })
262        });
263        (processor, count)
264    }
265
266    #[tokio::test]
267    async fn test_round_robin_distribution() {
268        let (p1, c1) = counting_processor();
269        let (p2, c2) = counting_processor();
270        let (p3, c3) = counting_processor();
271
272        let config = LoadBalancerConfig::round_robin();
273        let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
274
275        for _ in 0..6 {
276            let ex = Exchange::new(Message::new("test"));
277            svc.ready().await.unwrap().call(ex).await.unwrap();
278        }
279
280        assert_eq!(c1.load(Ordering::SeqCst), 2);
281        assert_eq!(c2.load(Ordering::SeqCst), 2);
282        assert_eq!(c3.load(Ordering::SeqCst), 2);
283    }
284
285    #[tokio::test]
286    async fn test_random_distribution() {
287        let (p1, c1) = counting_processor();
288        let (p2, c2) = counting_processor();
289
290        let config = LoadBalancerConfig::random();
291        let mut svc = LoadBalancerService::new(vec![p1, p2], config);
292
293        for _ in 0..100 {
294            let ex = Exchange::new(Message::new("test"));
295            svc.ready().await.unwrap().call(ex).await.unwrap();
296        }
297
298        let total = c1.load(Ordering::SeqCst) + c2.load(Ordering::SeqCst);
299        assert_eq!(total, 100);
300        assert!(c1.load(Ordering::SeqCst) > 20);
301        assert!(c2.load(Ordering::SeqCst) > 20);
302    }
303
304    #[tokio::test]
305    async fn test_failover_on_error() {
306        let failing = BoxProcessor::from_fn(|_ex| {
307            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
308        });
309        let (success, count) = counting_processor();
310
311        let config = LoadBalancerConfig::failover();
312        let mut svc = LoadBalancerService::new(vec![failing, success], config);
313
314        let ex = Exchange::new(Message::new("test"));
315        let _result = svc.ready().await.unwrap().call(ex).await.unwrap();
316
317        assert_eq!(count.load(Ordering::SeqCst), 1);
318    }
319
320    #[tokio::test]
321    async fn test_failover_preserves_original_exchange() {
322        // Capture body seen by retry endpoint to verify it's the original
323        let seen_body: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
324        let seen_body_clone = seen_body.clone();
325
326        let failing = BoxProcessor::from_fn(|_ex| {
327            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
328        });
329
330        let retry = BoxProcessor::from_fn(move |ex: Exchange| {
331            let seen = seen_body_clone.clone();
332            Box::pin(async move {
333                if let Some(text) = ex.input.body.as_text() {
334                    *seen.lock().unwrap() = Some(text.to_string());
335                }
336                Ok(ex)
337            })
338        });
339
340        let config = LoadBalancerConfig::failover();
341        let mut svc = LoadBalancerService::new(vec![failing, retry], config);
342
343        let ex = Exchange::new(Message::new("original body"));
344        svc.ready().await.unwrap().call(ex).await.unwrap();
345
346        assert_eq!(
347            seen_body.lock().unwrap().as_deref(),
348            Some("original body"),
349            "retry endpoint must receive the original exchange body, not a blank one"
350        );
351    }
352
353    #[tokio::test]
354    async fn test_failover_all_fail() {
355        let failing = BoxProcessor::from_fn(|_ex| {
356            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
357        });
358
359        let config = LoadBalancerConfig::failover();
360        let mut svc = LoadBalancerService::new(vec![failing.clone(), failing], config);
361
362        let ex = Exchange::new(Message::new("test"));
363        let result = svc.ready().await.unwrap().call(ex).await;
364
365        assert!(result.is_err());
366    }
367
368    #[tokio::test]
369    async fn test_empty_endpoints() {
370        let config = LoadBalancerConfig::round_robin();
371        let mut svc = LoadBalancerService::new(vec![], config);
372
373        let ex = Exchange::new(Message::new("test"));
374        let result = svc.ready().await.unwrap().call(ex).await;
375
376        assert!(result.is_ok());
377    }
378
379    // ── LoadBalanceSegment tests (ADR-0025 OutcomePipeline parity) ───
380
381    /// OutcomePipeline body that mutates exchange body to "lb-stopped" then returns Stopped.
382    struct StoppingBody;
383    impl camel_api::OutcomePipeline for StoppingBody {
384        fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
385            Box::new(StoppingBody)
386        }
387        fn run<'a>(
388            &'a mut self,
389            mut ex: Exchange,
390        ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
391            Box::pin(async move {
392                ex.input.body = camel_api::Body::Text("lb-stopped".to_string());
393                camel_api::PipelineOutcome::Stopped(ex)
394            })
395        }
396    }
397
398    /// OutcomePipeline body that records invocation count via shared counter.
399    struct RecordingBody(Arc<AtomicUsize>);
400    impl camel_api::OutcomePipeline for RecordingBody {
401        fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
402            Box::new(RecordingBody(Arc::clone(&self.0)))
403        }
404        fn run<'a>(
405            &'a mut self,
406            ex: Exchange,
407        ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
408            let count = Arc::clone(&self.0);
409            Box::pin(async move {
410                count.fetch_add(1, Ordering::SeqCst);
411                camel_api::PipelineOutcome::Completed(ex)
412            })
413        }
414    }
415
416    /// OutcomePipeline body that always fails with ProcessorError.
417    struct FailingBody;
418    impl camel_api::OutcomePipeline for FailingBody {
419        fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
420            Box::new(FailingBody)
421        }
422        fn run<'a>(
423            &'a mut self,
424            _ex: Exchange,
425        ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
426            Box::pin(async {
427                camel_api::PipelineOutcome::Failed(CamelError::ProcessorError(
428                    "intentional fail".to_string(),
429                ))
430            })
431        }
432    }
433
434    /// OutcomePipeline body that mutates body to "recovered" then completes.
435    struct RecoveringBody;
436    impl camel_api::OutcomePipeline for RecoveringBody {
437        fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
438            Box::new(RecoveringBody)
439        }
440        fn run<'a>(
441            &'a mut self,
442            mut ex: Exchange,
443        ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
444            Box::pin(async move {
445                ex.input.body = camel_api::Body::Text("recovered".to_string());
446                camel_api::PipelineOutcome::Completed(ex)
447            })
448        }
449    }
450
451    /// Test 1: Stop inside a destination propagates immediately (no failover).
452    /// First destination mutates + Stops; second destination is NOT tried.
453    #[tokio::test]
454    async fn load_balance_child_stop_propagates() {
455        let count = Arc::new(AtomicUsize::new(0));
456        let mut seg = LoadBalanceSegment {
457            destinations: vec![
458                camel_api::OutcomeSegment::new(Box::new(StoppingBody)),
459                camel_api::OutcomeSegment::new(Box::new(RecordingBody(count.clone()))),
460            ],
461            strategy: camel_api::LoadBalanceStrategy::RoundRobin,
462            round_robin_index: Arc::new(AtomicUsize::new(0)),
463        };
464
465        let ex = Exchange::new(Message::new("trigger"));
466        let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
467
468        match result {
469            camel_api::PipelineOutcome::Stopped(ex) => {
470                assert_eq!(
471                    ex.input.body.as_text(),
472                    Some("lb-stopped"),
473                    "Stopped exchange must preserve mutation"
474                );
475            }
476            other => panic!("expected PipelineOutcome::Stopped, got {other:?}"),
477        }
478        assert_eq!(
479            count.load(Ordering::SeqCst),
480            0,
481            "second destination must NOT be tried when first is Stopped"
482        );
483    }
484
485    /// Test 2: Failover strategy retries on failure. First destination fails,
486    /// second destination succeeds.
487    #[tokio::test]
488    async fn load_balance_child_failure_retries_whole_step() {
489        let mut seg = LoadBalanceSegment {
490            destinations: vec![
491                camel_api::OutcomeSegment::new(Box::new(FailingBody)),
492                camel_api::OutcomeSegment::new(Box::new(RecoveringBody)),
493            ],
494            strategy: camel_api::LoadBalanceStrategy::Failover,
495            round_robin_index: Arc::new(AtomicUsize::new(0)),
496        };
497
498        let ex = Exchange::new(Message::new("trigger"));
499        let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
500
501        match result {
502            camel_api::PipelineOutcome::Completed(ex) => {
503                assert_eq!(
504                    ex.input.body.as_text(),
505                    Some("recovered"),
506                    "failover must produce the second destination's output"
507                );
508            }
509            other => panic!("expected PipelineOutcome::Completed, got {other:?}"),
510        }
511    }
512
513    /// Test 3: Round-robin strategy distributes across destinations.
514    /// 3 sequential calls hit each destination once.
515    #[tokio::test]
516    async fn load_balance_strategy_selection_preserved() {
517        let c1 = Arc::new(AtomicUsize::new(0));
518        let c2 = Arc::new(AtomicUsize::new(0));
519        let c3 = Arc::new(AtomicUsize::new(0));
520
521        let mut seg = LoadBalanceSegment {
522            destinations: vec![
523                camel_api::OutcomeSegment::new(Box::new(RecordingBody(c1.clone()))),
524                camel_api::OutcomeSegment::new(Box::new(RecordingBody(c2.clone()))),
525                camel_api::OutcomeSegment::new(Box::new(RecordingBody(c3.clone()))),
526            ],
527            strategy: camel_api::LoadBalanceStrategy::RoundRobin,
528            round_robin_index: Arc::new(AtomicUsize::new(0)),
529        };
530
531        for _ in 0..3 {
532            let ex = Exchange::new(Message::new("test"));
533            let _result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
534        }
535
536        assert_eq!(
537            c1.load(Ordering::SeqCst),
538            1,
539            "round-robin: dest 0 call count"
540        );
541        assert_eq!(
542            c2.load(Ordering::SeqCst),
543            1,
544            "round-robin: dest 1 call count"
545        );
546        assert_eq!(
547            c3.load(Ordering::SeqCst),
548            1,
549            "round-robin: dest 2 call count"
550        );
551    }
552
553    /// Test: failover exhaustion preserves the LAST destination's error,
554    /// NOT a generic "all destinations exhausted" message.
555    #[tokio::test]
556    async fn load_balance_segment_failover_exhaustion_preserves_last_error() {
557        let err1 = CamelError::ProcessorError("first-dest-failed".to_string());
558        let err2 = CamelError::ProcessorError("second-dest-failed".to_string());
559
560        struct FailWith(CamelError);
561        impl camel_api::OutcomePipeline for FailWith {
562            fn clone_box(&self) -> Box<dyn camel_api::OutcomePipeline> {
563                Box::new(FailWith(self.0.clone()))
564            }
565            fn run<'a>(
566                &'a mut self,
567                _ex: Exchange,
568            ) -> Pin<Box<dyn Future<Output = camel_api::PipelineOutcome> + Send + 'a>> {
569                let e = self.0.clone();
570                Box::pin(async move { camel_api::PipelineOutcome::Failed(e) })
571            }
572        }
573
574        let mut seg = LoadBalanceSegment {
575            destinations: vec![
576                camel_api::OutcomeSegment::new(Box::new(FailWith(err1))),
577                camel_api::OutcomeSegment::new(Box::new(FailWith(err2.clone()))),
578            ],
579            strategy: camel_api::LoadBalanceStrategy::Failover,
580            round_robin_index: Arc::new(AtomicUsize::new(0)),
581        };
582
583        let ex = Exchange::new(Message::new("test"));
584        let result = camel_api::OutcomePipeline::run(&mut seg, ex).await;
585
586        match result {
587            camel_api::PipelineOutcome::Failed(err) => {
588                assert_eq!(
589                    err.to_string(),
590                    err2.to_string(),
591                    "failover exhaustion must return the LAST destination error, not a generic message"
592                );
593            }
594            other => panic!(
595                "expected PipelineOutcome::Failed(last error), got {:?}",
596                other
597            ),
598        }
599    }
600}