1use crate::{LlmError, LlmProvider, LlmRequest, LlmResponse};
27use async_trait::async_trait;
28use std::sync::atomic::{AtomicUsize, Ordering};
29use std::sync::Arc;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum LoadBalancingStrategy {
34 RoundRobin,
36 Random,
38 Weighted,
40}
41
42pub struct LoadBalancer {
44 providers: Vec<ProviderWithWeight>,
45 strategy: LoadBalancingStrategy,
46 counter: Arc<AtomicUsize>,
47}
48
49struct ProviderWithWeight {
50 provider: Box<dyn LlmProvider>,
51 weight: u32,
52}
53
54impl LoadBalancer {
55 pub fn new(providers: Vec<Box<dyn LlmProvider>>) -> Self {
57 if providers.is_empty() {
58 panic!("LoadBalancer requires at least one provider");
59 }
60
61 let providers_with_weight = providers
62 .into_iter()
63 .map(|p| ProviderWithWeight {
64 provider: p,
65 weight: 1,
66 })
67 .collect();
68
69 Self {
70 providers: providers_with_weight,
71 strategy: LoadBalancingStrategy::RoundRobin,
72 counter: Arc::new(AtomicUsize::new(0)),
73 }
74 }
75
76 pub fn with_weights(providers: Vec<(Box<dyn LlmProvider>, u32)>) -> Self {
78 if providers.is_empty() {
79 panic!("LoadBalancer requires at least one provider");
80 }
81
82 let providers_with_weight = providers
83 .into_iter()
84 .map(|(p, w)| ProviderWithWeight {
85 provider: p,
86 weight: w,
87 })
88 .collect();
89
90 Self {
91 providers: providers_with_weight,
92 strategy: LoadBalancingStrategy::Weighted,
93 counter: Arc::new(AtomicUsize::new(0)),
94 }
95 }
96
97 pub fn with_strategy(mut self, strategy: LoadBalancingStrategy) -> Self {
99 self.strategy = strategy;
100 self
101 }
102
103 pub fn provider_count(&self) -> usize {
105 self.providers.len()
106 }
107
108 pub fn get_stats(&self) -> LoadBalancerStats {
110 let total_weight: u32 = self.providers.iter().map(|p| p.weight).sum();
111 LoadBalancerStats {
112 provider_count: self.providers.len(),
113 strategy: self.strategy,
114 total_weight,
115 request_count: self.counter.load(Ordering::SeqCst),
116 }
117 }
118
119 fn select_provider(&self) -> &dyn LlmProvider {
121 match self.strategy {
122 LoadBalancingStrategy::RoundRobin => {
123 let index = self.counter.fetch_add(1, Ordering::SeqCst);
124 &*self.providers[index % self.providers.len()].provider
125 }
126 LoadBalancingStrategy::Random => {
127 let index = self.counter.fetch_add(1, Ordering::SeqCst);
129 let pseudo_random = index.wrapping_mul(2654435761);
131 &*self.providers[pseudo_random % self.providers.len()].provider
132 }
133 LoadBalancingStrategy::Weighted => {
134 let total_weight: u32 = self.providers.iter().map(|p| p.weight).sum();
135 let counter = self.counter.fetch_add(1, Ordering::SeqCst);
136 let mut target_weight = (counter as u32) % total_weight;
138
139 for provider in &self.providers {
140 if target_weight < provider.weight {
141 return &*provider.provider;
142 }
143 target_weight -= provider.weight;
144 }
145
146 &*self.providers[0].provider
148 }
149 }
150 }
151}
152
153#[derive(Debug, Clone)]
155pub struct LoadBalancerStats {
156 pub provider_count: usize,
158 pub strategy: LoadBalancingStrategy,
160 pub total_weight: u32,
162 pub request_count: usize,
164}
165
166#[async_trait]
167impl LlmProvider for LoadBalancer {
168 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse, LlmError> {
169 let provider = self.select_provider();
170 provider.complete(request).await
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::Usage;
178 use std::sync::atomic::{AtomicU32, Ordering};
179 use std::sync::Arc;
180
181 struct MockProvider {
182 id: u32,
183 call_count: Arc<AtomicU32>,
184 }
185
186 #[async_trait]
187 impl LlmProvider for MockProvider {
188 async fn complete(&self, _request: LlmRequest) -> Result<LlmResponse, LlmError> {
189 self.call_count.fetch_add(1, Ordering::SeqCst);
190 Ok(LlmResponse {
191 content: format!("Response from provider {}", self.id),
192 model: "mock".to_string(),
193 usage: Some(Usage {
194 prompt_tokens: 10,
195 completion_tokens: 20,
196 total_tokens: 30,
197 }),
198 tool_calls: Vec::new(),
199 })
200 }
201 }
202
203 #[tokio::test]
204 async fn test_load_balancer_round_robin() {
205 let count1 = Arc::new(AtomicU32::new(0));
206 let count2 = Arc::new(AtomicU32::new(0));
207 let count3 = Arc::new(AtomicU32::new(0));
208
209 let provider1 = MockProvider {
210 id: 1,
211 call_count: Arc::clone(&count1),
212 };
213 let provider2 = MockProvider {
214 id: 2,
215 call_count: Arc::clone(&count2),
216 };
217 let provider3 = MockProvider {
218 id: 3,
219 call_count: Arc::clone(&count3),
220 };
221
222 let lb = LoadBalancer::new(vec![
223 Box::new(provider1),
224 Box::new(provider2),
225 Box::new(provider3),
226 ])
227 .with_strategy(LoadBalancingStrategy::RoundRobin);
228
229 for _ in 0..9 {
231 let request = LlmRequest {
232 prompt: "test".to_string(),
233 system_prompt: None,
234 temperature: None,
235 max_tokens: None,
236 tools: Vec::new(),
237 images: Vec::new(),
238 };
239 let _ = lb.complete(request).await;
240 }
241
242 assert_eq!(count1.load(Ordering::SeqCst), 3);
244 assert_eq!(count2.load(Ordering::SeqCst), 3);
245 assert_eq!(count3.load(Ordering::SeqCst), 3);
246 }
247
248 #[tokio::test]
249 async fn test_load_balancer_random() {
250 let count1 = Arc::new(AtomicU32::new(0));
251 let count2 = Arc::new(AtomicU32::new(0));
252
253 let provider1 = MockProvider {
254 id: 1,
255 call_count: Arc::clone(&count1),
256 };
257 let provider2 = MockProvider {
258 id: 2,
259 call_count: Arc::clone(&count2),
260 };
261
262 let lb = LoadBalancer::new(vec![Box::new(provider1), Box::new(provider2)])
263 .with_strategy(LoadBalancingStrategy::Random);
264
265 for _ in 0..100 {
267 let request = LlmRequest {
268 prompt: "test".to_string(),
269 system_prompt: None,
270 temperature: None,
271 max_tokens: None,
272 tools: Vec::new(),
273 images: Vec::new(),
274 };
275 let _ = lb.complete(request).await;
276 }
277
278 let total = count1.load(Ordering::SeqCst) + count2.load(Ordering::SeqCst);
279 assert_eq!(total, 100);
280
281 assert!(count1.load(Ordering::SeqCst) > 0);
283 assert!(count2.load(Ordering::SeqCst) > 0);
284 }
285
286 #[tokio::test]
287 async fn test_load_balancer_weighted() {
288 let count1 = Arc::new(AtomicU32::new(0));
289 let count2 = Arc::new(AtomicU32::new(0));
290
291 let provider1 = MockProvider {
292 id: 1,
293 call_count: Arc::clone(&count1),
294 };
295 let provider2 = MockProvider {
296 id: 2,
297 call_count: Arc::clone(&count2),
298 };
299
300 let lb =
303 LoadBalancer::with_weights(vec![(Box::new(provider1), 3), (Box::new(provider2), 1)]);
304
305 for _ in 0..1000 {
307 let request = LlmRequest {
308 prompt: "test".to_string(),
309 system_prompt: None,
310 temperature: None,
311 max_tokens: None,
312 tools: Vec::new(),
313 images: Vec::new(),
314 };
315 let _ = lb.complete(request).await;
316 }
317
318 let total = count1.load(Ordering::SeqCst) + count2.load(Ordering::SeqCst);
319 assert_eq!(total, 1000);
320
321 let count1_val = count1.load(Ordering::SeqCst);
322 let count2_val = count2.load(Ordering::SeqCst);
323
324 assert!(
326 count1_val > 650 && count1_val < 850,
327 "count1: {}",
328 count1_val
329 );
330 assert!(
331 count2_val > 150 && count2_val < 350,
332 "count2: {}",
333 count2_val
334 );
335 }
336
337 #[tokio::test]
338 async fn test_load_balancer_stats() {
339 let provider1 = MockProvider {
340 id: 1,
341 call_count: Arc::new(AtomicU32::new(0)),
342 };
343 let provider2 = MockProvider {
344 id: 2,
345 call_count: Arc::new(AtomicU32::new(0)),
346 };
347
348 let lb = LoadBalancer::new(vec![Box::new(provider1), Box::new(provider2)])
349 .with_strategy(LoadBalancingStrategy::RoundRobin);
350
351 let stats = lb.get_stats();
352 assert_eq!(stats.provider_count, 2);
353 assert_eq!(stats.strategy, LoadBalancingStrategy::RoundRobin);
354 assert_eq!(stats.total_weight, 2); for _ in 0..10 {
358 let request = LlmRequest {
359 prompt: "test".to_string(),
360 system_prompt: None,
361 temperature: None,
362 max_tokens: None,
363 tools: Vec::new(),
364 images: Vec::new(),
365 };
366 let _ = lb.complete(request).await;
367 }
368
369 let stats = lb.get_stats();
370 assert_eq!(stats.request_count, 10);
371 }
372
373 #[test]
374 fn test_load_balancer_provider_count() {
375 let provider1 = MockProvider {
376 id: 1,
377 call_count: Arc::new(AtomicU32::new(0)),
378 };
379 let provider2 = MockProvider {
380 id: 2,
381 call_count: Arc::new(AtomicU32::new(0)),
382 };
383 let provider3 = MockProvider {
384 id: 3,
385 call_count: Arc::new(AtomicU32::new(0)),
386 };
387
388 let lb = LoadBalancer::new(vec![
389 Box::new(provider1),
390 Box::new(provider2),
391 Box::new(provider3),
392 ]);
393
394 assert_eq!(lb.provider_count(), 3);
395 }
396}