Skip to main content

camel_processor/
load_balancer.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::atomic::{AtomicUsize, Ordering};
5use std::task::{Context, Poll};
6
7use tower::Service;
8use tower::ServiceExt;
9
10use camel_api::{
11    BoxProcessor, CamelError, Exchange, LoadBalanceStrategy, LoadBalancerConfig, Value,
12};
13
14use crate::multicast::{CAMEL_MULTICAST_COMPLETE, CAMEL_MULTICAST_INDEX};
15
16#[derive(Clone)]
17pub struct LoadBalancerService {
18    endpoints: Vec<BoxProcessor>,
19    config: LoadBalancerConfig,
20    round_robin_index: Arc<AtomicUsize>,
21    failover_index: Arc<AtomicUsize>,
22}
23
24impl LoadBalancerService {
25    pub fn new(endpoints: Vec<BoxProcessor>, config: LoadBalancerConfig) -> Self {
26        Self {
27            endpoints,
28            config,
29            round_robin_index: Arc::new(AtomicUsize::new(0)),
30            failover_index: Arc::new(AtomicUsize::new(0)),
31        }
32    }
33}
34
35impl Service<Exchange> for LoadBalancerService {
36    type Response = Exchange;
37    type Error = CamelError;
38    type Future = Pin<Box<dyn Future<Output = Result<Exchange, CamelError>> + Send>>;
39
40    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
41        for endpoint in &mut self.endpoints {
42            match endpoint.poll_ready(cx) {
43                Poll::Pending => return Poll::Pending,
44                Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
45                Poll::Ready(Ok(())) => {}
46            }
47        }
48        Poll::Ready(Ok(()))
49    }
50
51    fn call(&mut self, exchange: Exchange) -> Self::Future {
52        let endpoints = self.endpoints.clone();
53        let config = self.config.clone();
54        let round_robin_index = self.round_robin_index.clone();
55        let failover_index = self.failover_index.clone();
56
57        Box::pin(async move {
58            if endpoints.is_empty() {
59                return Ok(exchange);
60            }
61
62            if config.parallel {
63                process_parallel(exchange, endpoints).await
64            } else {
65                match &config.strategy {
66                    LoadBalanceStrategy::RoundRobin => {
67                        process_round_robin(exchange, endpoints, round_robin_index).await
68                    }
69                    LoadBalanceStrategy::Random => process_random(exchange, endpoints).await,
70                    LoadBalanceStrategy::Weighted(weights) => {
71                        process_weighted(exchange, endpoints, weights).await
72                    }
73                    LoadBalanceStrategy::Failover => {
74                        process_failover(exchange, endpoints, failover_index).await
75                    }
76                }
77            }
78        })
79    }
80}
81
82async fn process_round_robin(
83    exchange: Exchange,
84    endpoints: Vec<BoxProcessor>,
85    index: Arc<AtomicUsize>,
86) -> Result<Exchange, CamelError> {
87    let len = endpoints.len();
88    let idx = index.fetch_add(1, Ordering::SeqCst) % len;
89    let mut endpoint = endpoints[idx].clone();
90    endpoint.ready().await?.call(exchange).await
91}
92
93async fn process_random(
94    exchange: Exchange,
95    endpoints: Vec<BoxProcessor>,
96) -> Result<Exchange, CamelError> {
97    let len = endpoints.len();
98    let idx = rand::random::<usize>() % len;
99    let mut endpoint = endpoints[idx].clone();
100    endpoint.ready().await?.call(exchange).await
101}
102
103async fn process_weighted(
104    exchange: Exchange,
105    endpoints: Vec<BoxProcessor>,
106    weights: &[(String, u32)],
107) -> Result<Exchange, CamelError> {
108    if endpoints.is_empty() || weights.is_empty() {
109        return Ok(exchange);
110    }
111
112    let numeric_weights: Vec<u32> = weights.iter().map(|(_, w)| *w).collect();
113    let total: u32 = numeric_weights.iter().sum();
114
115    if total == 0 {
116        return Err(CamelError::ProcessorError(
117            "Weighted load balancer has zero total weight".to_string(),
118        ));
119    }
120
121    let mut r = rand::random::<u32>() % total;
122    let mut selected_idx = 0;
123    for (i, w) in numeric_weights.iter().enumerate() {
124        if r < *w {
125            selected_idx = i.min(endpoints.len() - 1);
126            break;
127        }
128        r -= w;
129    }
130
131    let mut endpoint = endpoints[selected_idx].clone();
132    endpoint.ready().await?.call(exchange).await
133}
134
135async fn process_failover(
136    exchange: Exchange,
137    endpoints: Vec<BoxProcessor>,
138    start_index: Arc<AtomicUsize>,
139) -> Result<Exchange, CamelError> {
140    let len = endpoints.len();
141    let start = start_index.load(Ordering::SeqCst);
142    let mut last_error = None;
143
144    for i in 0..len {
145        let idx = (start + i) % len;
146        let mut endpoint = endpoints[idx].clone();
147        match endpoint.ready().await?.call(exchange.clone()).await {
148            Ok(ex) => {
149                start_index.store((idx + 1) % len, Ordering::SeqCst);
150                return Ok(ex);
151            }
152            Err(e) => {
153                last_error = Some(e);
154            }
155        }
156    }
157
158    Err(last_error.unwrap_or_else(|| {
159        CamelError::ProcessorError("All endpoints failed in failover".to_string())
160    }))
161}
162
163async fn process_parallel(
164    exchange: Exchange,
165    endpoints: Vec<BoxProcessor>,
166) -> Result<Exchange, CamelError> {
167    use futures::future::join_all;
168
169    let total = endpoints.len();
170    let futures: Vec<_> = endpoints
171        .into_iter()
172        .enumerate()
173        .map(|(i, mut endpoint)| {
174            let mut ex = exchange.clone();
175            ex.set_property(CAMEL_MULTICAST_INDEX, Value::from(i as i64));
176            ex.set_property(CAMEL_MULTICAST_COMPLETE, Value::Bool(i == total - 1));
177            async move {
178                tower::ServiceExt::ready(&mut endpoint).await?;
179                endpoint.call(ex).await
180            }
181        })
182        .collect();
183
184    let results: Vec<Result<Exchange, CamelError>> = join_all(futures).await;
185
186    for result in &results {
187        if let Err(e) = result {
188            return Err(e.clone());
189        }
190    }
191
192    results.into_iter().last().unwrap_or(Ok(exchange))
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198    use camel_api::{BoxProcessorExt, Message};
199    use std::sync::Mutex;
200    use tower::ServiceExt;
201
202    fn counting_processor() -> (BoxProcessor, Arc<AtomicUsize>) {
203        let count = Arc::new(AtomicUsize::new(0));
204        let count_clone = count.clone();
205        let processor = BoxProcessor::from_fn(move |ex| {
206            count_clone.fetch_add(1, Ordering::SeqCst);
207            Box::pin(async move { Ok(ex) })
208        });
209        (processor, count)
210    }
211
212    #[tokio::test]
213    async fn test_round_robin_distribution() {
214        let (p1, c1) = counting_processor();
215        let (p2, c2) = counting_processor();
216        let (p3, c3) = counting_processor();
217
218        let config = LoadBalancerConfig::round_robin();
219        let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
220
221        for _ in 0..6 {
222            let ex = Exchange::new(Message::new("test"));
223            svc.ready().await.unwrap().call(ex).await.unwrap();
224        }
225
226        assert_eq!(c1.load(Ordering::SeqCst), 2);
227        assert_eq!(c2.load(Ordering::SeqCst), 2);
228        assert_eq!(c3.load(Ordering::SeqCst), 2);
229    }
230
231    #[tokio::test]
232    async fn test_random_distribution() {
233        let (p1, c1) = counting_processor();
234        let (p2, c2) = counting_processor();
235
236        let config = LoadBalancerConfig::random();
237        let mut svc = LoadBalancerService::new(vec![p1, p2], config);
238
239        for _ in 0..100 {
240            let ex = Exchange::new(Message::new("test"));
241            svc.ready().await.unwrap().call(ex).await.unwrap();
242        }
243
244        let total = c1.load(Ordering::SeqCst) + c2.load(Ordering::SeqCst);
245        assert_eq!(total, 100);
246        assert!(c1.load(Ordering::SeqCst) > 20);
247        assert!(c2.load(Ordering::SeqCst) > 20);
248    }
249
250    #[tokio::test]
251    async fn test_failover_on_error() {
252        let failing = BoxProcessor::from_fn(|_ex| {
253            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
254        });
255        let (success, count) = counting_processor();
256
257        let config = LoadBalancerConfig::failover();
258        let mut svc = LoadBalancerService::new(vec![failing, success], config);
259
260        let ex = Exchange::new(Message::new("test"));
261        let _result = svc.ready().await.unwrap().call(ex).await.unwrap();
262
263        assert_eq!(count.load(Ordering::SeqCst), 1);
264    }
265
266    #[tokio::test]
267    async fn test_failover_preserves_original_exchange() {
268        // Capture body seen by retry endpoint to verify it's the original
269        let seen_body: Arc<Mutex<Option<String>>> = Arc::new(Mutex::new(None));
270        let seen_body_clone = seen_body.clone();
271
272        let failing = BoxProcessor::from_fn(|_ex| {
273            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
274        });
275
276        let retry = BoxProcessor::from_fn(move |ex: Exchange| {
277            let seen = seen_body_clone.clone();
278            Box::pin(async move {
279                if let Some(text) = ex.input.body.as_text() {
280                    *seen.lock().unwrap() = Some(text.to_string());
281                }
282                Ok(ex)
283            })
284        });
285
286        let config = LoadBalancerConfig::failover();
287        let mut svc = LoadBalancerService::new(vec![failing, retry], config);
288
289        let ex = Exchange::new(Message::new("original body"));
290        svc.ready().await.unwrap().call(ex).await.unwrap();
291
292        assert_eq!(
293            seen_body.lock().unwrap().as_deref(),
294            Some("original body"),
295            "retry endpoint must receive the original exchange body, not a blank one"
296        );
297    }
298
299    #[tokio::test]
300    async fn test_failover_all_fail() {
301        let failing = BoxProcessor::from_fn(|_ex| {
302            Box::pin(async { Err(CamelError::ProcessorError("fail".into())) })
303        });
304
305        let config = LoadBalancerConfig::failover();
306        let mut svc = LoadBalancerService::new(vec![failing.clone(), failing], config);
307
308        let ex = Exchange::new(Message::new("test"));
309        let result = svc.ready().await.unwrap().call(ex).await;
310
311        assert!(result.is_err());
312    }
313
314    #[tokio::test]
315    async fn test_parallel_sends_to_all() {
316        let (p1, c1) = counting_processor();
317        let (p2, c2) = counting_processor();
318        let (p3, c3) = counting_processor();
319
320        let config = LoadBalancerConfig::round_robin().parallel(true);
321        let mut svc = LoadBalancerService::new(vec![p1, p2, p3], config);
322
323        let ex = Exchange::new(Message::new("test"));
324        svc.ready().await.unwrap().call(ex).await.unwrap();
325
326        assert_eq!(c1.load(Ordering::SeqCst), 1);
327        assert_eq!(c2.load(Ordering::SeqCst), 1);
328        assert_eq!(c3.load(Ordering::SeqCst), 1);
329    }
330
331    #[tokio::test]
332    async fn test_empty_endpoints() {
333        let config = LoadBalancerConfig::round_robin();
334        let mut svc = LoadBalancerService::new(vec![], config);
335
336        let ex = Exchange::new(Message::new("test"));
337        let result = svc.ready().await.unwrap().call(ex).await;
338
339        assert!(result.is_ok());
340    }
341}