oxify_connect_llm/
load_balancer.rs

1//! Load balancer for distributing requests across multiple LLM providers.
2//!
3//! Supports multiple load balancing strategies:
4//! - **Round Robin**: Distributes requests evenly across providers
5//! - **Random**: Randomly selects a provider for each request
6//! - **Weighted**: Distributes based on configured weights
7//!
8//! # Example
9//!
10//! ```rust,no_run
11//! use oxify_connect_llm::{LoadBalancer, LoadBalancingStrategy, LlmProvider};
12//!
13//! # async fn example() -> Result<(), Box<dyn std::error::Error>> {
14//! # let provider1: Box<dyn LlmProvider> = todo!();
15//! # let provider2: Box<dyn LlmProvider> = todo!();
16//! # let provider3: Box<dyn LlmProvider> = todo!();
17//! // Create load balancer with multiple providers
18//! let lb = LoadBalancer::new(vec![provider1, provider2, provider3])
19//!     .with_strategy(LoadBalancingStrategy::RoundRobin);
20//!
21//! // Requests will be distributed across all providers
22//! # Ok(())
23//! # }
24//! ```
25
26use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
27use async_trait::async_trait;
28use std::sync::atomic::{AtomicUsize, Ordering};
29use std::sync::Arc;
30
31/// Load balancing strategy
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum LoadBalancingStrategy {
34    /// Round-robin distribution
35    RoundRobin,
36    /// Random selection
37    Random,
38    /// Weighted distribution (requires weights to be set)
39    Weighted,
40}
41
42/// Load balancer that distributes requests across multiple providers
43pub struct LoadBalancer {
44    providers: Vec<ProviderWithWeight>,
45    strategy: LoadBalancingStrategy,
46    counter: Arc<AtomicUsize>,
47}
48
49struct ProviderWithWeight {
50    provider: Box<dyn LlmProvider>,
51    weight: u32,
52}
53
54impl LoadBalancer {
55    /// Create a new load balancer with default round-robin strategy
56    pub fn new(providers: Vec<Box<dyn LlmProvider>>) -> Self {
57        if providers.is_empty() {
58            panic!("LoadBalancer requires at least one provider");
59        }
60
61        let providers_with_weight = providers
62            .into_iter()
63            .map(|p| ProviderWithWeight {
64                provider: p,
65                weight: 1,
66            })
67            .collect();
68
69        Self {
70            providers: providers_with_weight,
71            strategy: LoadBalancingStrategy::RoundRobin,
72            counter: Arc::new(AtomicUsize::new(0)),
73        }
74    }
75
76    /// Create a load balancer with weighted providers
77    pub fn with_weights(providers: Vec<(Box<dyn LlmProvider>, u32)>) -> Self {
78        if providers.is_empty() {
79            panic!("LoadBalancer requires at least one provider");
80        }
81
82        let providers_with_weight = providers
83            .into_iter()
84            .map(|(p, w)| ProviderWithWeight {
85                provider: p,
86                weight: w,
87            })
88            .collect();
89
90        Self {
91            providers: providers_with_weight,
92            strategy: LoadBalancingStrategy::Weighted,
93            counter: Arc::new(AtomicUsize::new(0)),
94        }
95    }
96
97    /// Set the load balancing strategy
98    pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
99        self.strategy = strategy;
100        self
101    }
102
103    /// Get the number of providers
104    pub fn provider_count(&self) -> usize {
105        self.providers.len()
106    }
107
108    /// Get load balancer statistics
109    pub fn get_stats(&self) -> LoadBalancerStats {
110        let total_weight: u32 = self.providers.iter().map(|p| p.weight).sum();
111        LoadBalancerStats {
112            provider_count: self.providers.len(),
113            strategy: self.strategy,
114            total_weight,
115            request_count: self.counter.load(Ordering::SeqCst),
116        }
117    }
118
119    /// Select a provider based on the load balancing strategy
120    fn select_provider(&self) -> &dyn LlmProvider {
121        match self.strategy {
122            LoadBalancingStrategy::RoundRobin => {
123                let index = self.counter.fetch_add(1, Ordering::SeqCst);
124                &*self.providers[index % self.providers.len()].provider
125            }
126            LoadBalancingStrategy::Random => {
127                // Use counter with a multiplier for pseudo-random distribution
128                let index = self.counter.fetch_add(1, Ordering::SeqCst);
129                // Use a large prime multiplier to get better distribution
130                let pseudo_random = index.wrapping_mul(2654435761);
131                &*self.providers[pseudo_random % self.providers.len()].provider
132            }
133            LoadBalancingStrategy::Weighted => {
134                let total_weight: u32 = self.providers.iter().map(|p| p.weight).sum();
135                let counter = self.counter.fetch_add(1, Ordering::SeqCst);
136                // Use modulo for weighted selection
137                let mut target_weight = (counter as u32) % total_weight;
138
139                for provider in &self.providers {
140                    if target_weight < provider.weight {
141                        return &*provider.provider;
142                    }
143                    target_weight -= provider.weight;
144                }
145
146                // Fallback (should never happen)
147                &*self.providers[0].provider
148            }
149        }
150    }
151}
152
153/// Load balancer statistics
154#[derive(Debug, Clone)]
155pub struct LoadBalancerStats {
156    /// Number of providers in the pool
157    pub provider_count: usize,
158    /// Current load balancing strategy
159    pub strategy: LoadBalancingStrategy,
160    /// Total weight (for weighted strategy)
161    pub total_weight: u32,
162    /// Total number of requests processed
163    pub request_count: usize,
164}
165
166#[async_trait]
167impl LlmProvider for LoadBalancer {
168    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
169        let provider = self.select_provider();
170        provider.complete(request).await
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    use crate::Usage;
178    use std::sync::atomic::{AtomicU32, Ordering};
179    use std::sync::Arc;
180
181    struct MockProvider {
182        id: u32,
183        call_count: Arc<AtomicU32>,
184    }
185
186    #[async_trait]
187    impl LlmProvider for MockProvider {
188        async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
189            self.call_count.fetch_add(1, Ordering::SeqCst);
190            Ok(LlmResponse {
191                content: format!("Response from provider {}", self.id),
192                model: "mock".to_string(),
193                usage: Some(Usage {
194                    prompt_tokens: 10,
195                    completion_tokens: 20,
196                    total_tokens: 30,
197                }),
198                tool_calls: Vec::new(),
199            })
200        }
201    }
202
203    #[tokio::test]
204    async fn test_load_balancer_round_robin() {
205        let count1 = Arc::new(AtomicU32::new(0));
206        let count2 = Arc::new(AtomicU32::new(0));
207        let count3 = Arc::new(AtomicU32::new(0));
208
209        let provider1 = MockProvider {
210            id: 1,
211            call_count: Arc::clone(&count1),
212        };
213        let provider2 = MockProvider {
214            id: 2,
215            call_count: Arc::clone(&count2),
216        };
217        let provider3 = MockProvider {
218            id: 3,
219            call_count: Arc::clone(&count3),
220        };
221
222        let lb = LoadBalancer::new(vec![
223            Box::new(provider1),
224            Box::new(provider2),
225            Box::new(provider3),
226        ])
227        .with_strategy(LoadBalancingStrategy::RoundRobin);
228
229        // Make 9 requests - should be evenly distributed
230        for _ in 0..9 {
231            let request = LlmRequest {
232                prompt: "test".to_string(),
233                system_prompt: None,
234                temperature: None,
235                max_tokens: None,
236                tools: Vec::new(),
237                images: Vec::new(),
238            };
239            let _ = lb.complete(request).await;
240        }
241
242        // Each provider should have received 3 requests
243        assert_eq!(count1.load(Ordering::SeqCst), 3);
244        assert_eq!(count2.load(Ordering::SeqCst), 3);
245        assert_eq!(count3.load(Ordering::SeqCst), 3);
246    }
247
248    #[tokio::test]
249    async fn test_load_balancer_random() {
250        let count1 = Arc::new(AtomicU32::new(0));
251        let count2 = Arc::new(AtomicU32::new(0));
252
253        let provider1 = MockProvider {
254            id: 1,
255            call_count: Arc::clone(&count1),
256        };
257        let provider2 = MockProvider {
258            id: 2,
259            call_count: Arc::clone(&count2),
260        };
261
262        let lb = LoadBalancer::new(vec![Box::new(provider1), Box::new(provider2)])
263            .with_strategy(LoadBalancingStrategy::Random);
264
265        // Make many requests - both providers should receive some
266        for _ in 0..100 {
267            let request = LlmRequest {
268                prompt: "test".to_string(),
269                system_prompt: None,
270                temperature: None,
271                max_tokens: None,
272                tools: Vec::new(),
273                images: Vec::new(),
274            };
275            let _ = lb.complete(request).await;
276        }
277
278        let total = count1.load(Ordering::SeqCst) + count2.load(Ordering::SeqCst);
279        assert_eq!(total, 100);
280
281        // Both should have received at least some requests (statistically)
282        assert!(count1.load(Ordering::SeqCst) > 0);
283        assert!(count2.load(Ordering::SeqCst) > 0);
284    }
285
286    #[tokio::test]
287    async fn test_load_balancer_weighted() {
288        let count1 = Arc::new(AtomicU32::new(0));
289        let count2 = Arc::new(AtomicU32::new(0));
290
291        let provider1 = MockProvider {
292            id: 1,
293            call_count: Arc::clone(&count1),
294        };
295        let provider2 = MockProvider {
296            id: 2,
297            call_count: Arc::clone(&count2),
298        };
299
300        // Provider 1 has weight 3, provider 2 has weight 1
301        // So provider 1 should get ~75% of requests
302        let lb =
303            LoadBalancer::with_weights(vec![(Box::new(provider1), 3), (Box::new(provider2), 1)]);
304
305        // Make many requests
306        for _ in 0..1000 {
307            let request = LlmRequest {
308                prompt: "test".to_string(),
309                system_prompt: None,
310                temperature: None,
311                max_tokens: None,
312                tools: Vec::new(),
313                images: Vec::new(),
314            };
315            let _ = lb.complete(request).await;
316        }
317
318        let total = count1.load(Ordering::SeqCst) + count2.load(Ordering::SeqCst);
319        assert_eq!(total, 1000);
320
321        let count1_val = count1.load(Ordering::SeqCst);
322        let count2_val = count2.load(Ordering::SeqCst);
323
324        // Provider 1 should get roughly 75% of requests (allow some variance)
325        assert!(
326            count1_val > 650 && count1_val < 850,
327            "count1: {}",
328            count1_val
329        );
330        assert!(
331            count2_val > 150 && count2_val < 350,
332            "count2: {}",
333            count2_val
334        );
335    }
336
337    #[tokio::test]
338    async fn test_load_balancer_stats() {
339        let provider1 = MockProvider {
340            id: 1,
341            call_count: Arc::new(AtomicU32::new(0)),
342        };
343        let provider2 = MockProvider {
344            id: 2,
345            call_count: Arc::new(AtomicU32::new(0)),
346        };
347
348        let lb = LoadBalancer::new(vec![Box::new(provider1), Box::new(provider2)])
349            .with_strategy(LoadBalancingStrategy::RoundRobin);
350
351        let stats = lb.get_stats();
352        assert_eq!(stats.provider_count, 2);
353        assert_eq!(stats.strategy, LoadBalancingStrategy::RoundRobin);
354        assert_eq!(stats.total_weight, 2); // Both providers have weight 1
355
356        // Make some requests
357        for _ in 0..10 {
358            let request = LlmRequest {
359                prompt: "test".to_string(),
360                system_prompt: None,
361                temperature: None,
362                max_tokens: None,
363                tools: Vec::new(),
364                images: Vec::new(),
365            };
366            let _ = lb.complete(request).await;
367        }
368
369        let stats = lb.get_stats();
370        assert_eq!(stats.request_count, 10);
371    }
372
373    #[test]
374    fn test_load_balancer_provider_count() {
375        let provider1 = MockProvider {
376            id: 1,
377            call_count: Arc::new(AtomicU32::new(0)),
378        };
379        let provider2 = MockProvider {
380            id: 2,
381            call_count: Arc::new(AtomicU32::new(0)),
382        };
383        let provider3 = MockProvider {
384            id: 3,
385            call_count: Arc::new(AtomicU32::new(0)),
386        };
387
388        let lb = LoadBalancer::new(vec![
389            Box::new(provider1),
390            Box::new(provider2),
391            Box::new(provider3),
392        ]);
393
394        assert_eq!(lb.provider_count(), 3);
395    }
396}