1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::net::SocketAddr;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7use http::HeaderMap;
8
9use super::upstream::UpstreamPool;
10
11pub struct LbContext {
18 pub client_addr: SocketAddr,
19 pub uri: String,
20 pub headers: HeaderMap,
21}
22
23pub trait LoadBalancer: Send + Sync {
31 fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize>;
33}
34
35pub struct RoundRobinLb {
41 counter: AtomicUsize,
42}
43
44impl Default for RoundRobinLb {
45 fn default() -> Self {
46 Self::new()
47 }
48}
49
50impl RoundRobinLb {
51 pub fn new() -> Self {
52 Self {
53 counter: AtomicUsize::new(0),
54 }
55 }
56}
57
58impl LoadBalancer for RoundRobinLb {
59 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
60 let n = pool.len();
61 if n == 0 {
62 return None;
63 }
64 for _ in 0..n {
66 let idx = self.counter.fetch_add(1, Ordering::Relaxed) % n;
67 if pool.is_healthy(idx) {
68 return Some(idx);
69 }
70 }
71 None
72 }
73}
74
75pub struct RandomLb;
81
82impl Default for RandomLb {
83 fn default() -> Self {
84 Self::new()
85 }
86}
87
88impl RandomLb {
89 pub fn new() -> Self {
90 Self
91 }
92}
93
94impl LoadBalancer for RandomLb {
95 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
96 use rand::prelude::IndexedRandom;
97
98 let healthy_indices: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
99 if healthy_indices.is_empty() {
100 return None;
101 }
102 let mut rng = rand::rng();
103 healthy_indices.choose(&mut rng).copied()
104 }
105}
106
107pub struct WeightedRoundRobinLb {
119 state: Mutex<Vec<WrrEntry>>,
120}
121
122struct WrrEntry {
123 effective_weight: i64,
124 current_weight: i64,
125}
126
127impl WeightedRoundRobinLb {
128 pub fn new(weights: &[u32]) -> Self {
129 let state = weights
130 .iter()
131 .map(|&w| WrrEntry {
132 effective_weight: w as i64,
133 current_weight: 0,
134 })
135 .collect();
136 Self {
137 state: Mutex::new(state),
138 }
139 }
140}
141
142impl LoadBalancer for WeightedRoundRobinLb {
143 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
144 let mut entries = self.state.lock().ok()?;
145 if entries.is_empty() {
146 return None;
147 }
148
149 let mut total: i64 = 0;
150 let mut best_idx: Option<usize> = None;
151 let mut best_weight: i64 = i64::MIN;
152
153 for (i, entry) in entries.iter_mut().enumerate() {
154 if !pool.is_healthy(i) {
155 continue;
156 }
157 entry.current_weight += entry.effective_weight;
158 total += entry.effective_weight;
159
160 if entry.current_weight > best_weight {
161 best_weight = entry.current_weight;
162 best_idx = Some(i);
163 }
164 }
165
166 if let Some(idx) = best_idx {
167 entries[idx].current_weight -= total;
168 }
169
170 best_idx
171 }
172}
173
174pub struct IpHashLb;
180
181impl Default for IpHashLb {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl IpHashLb {
188 pub fn new() -> Self {
189 Self
190 }
191}
192
193impl LoadBalancer for IpHashLb {
194 fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
195 let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
196 if healthy.is_empty() {
197 return None;
198 }
199 let hash = hash_value(&ctx.client_addr.ip().to_string());
200 Some(healthy[hash as usize % healthy.len()])
201 }
202}
203
204pub struct LeastConnLb;
210
211impl Default for LeastConnLb {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217impl LeastConnLb {
218 pub fn new() -> Self {
219 Self
220 }
221}
222
223impl LoadBalancer for LeastConnLb {
224 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
225 let mut best_idx: Option<usize> = None;
226 let mut best_count = usize::MAX;
227
228 for i in 0..pool.len() {
229 if !pool.is_healthy(i) {
230 continue;
231 }
232 let count = pool.conn_count(i);
233 if count < best_count {
234 best_count = count;
235 best_idx = Some(i);
236 }
237 }
238
239 best_idx
240 }
241}
242
243pub struct UriHashLb;
250
251impl Default for UriHashLb {
252 fn default() -> Self {
253 Self::new()
254 }
255}
256
257impl UriHashLb {
258 pub fn new() -> Self {
259 Self
260 }
261}
262
263impl LoadBalancer for UriHashLb {
264 fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
265 let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
266 if healthy.is_empty() {
267 return None;
268 }
269 let hash = hash_value(&ctx.uri);
270 Some(healthy[hash as usize % healthy.len()])
271 }
272}
273
274pub struct HeaderHashLb {
280 header_name: String,
281}
282
283impl HeaderHashLb {
284 pub fn new(header_name: String) -> Self {
285 Self { header_name }
286 }
287}
288
289impl LoadBalancer for HeaderHashLb {
290 fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
291 let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
292 if healthy.is_empty() {
293 return None;
294 }
295
296 let value = ctx
297 .headers
298 .get(&self.header_name)
299 .and_then(|v| v.to_str().ok())
300 .unwrap_or("");
301
302 let hash = hash_value(value);
303 Some(healthy[hash as usize % healthy.len()])
304 }
305}
306
307pub struct CookieHashLb {
313 cookie_name: String,
314}
315
316impl CookieHashLb {
317 pub fn new(cookie_name: String) -> Self {
318 Self { cookie_name }
319 }
320}
321
322impl LoadBalancer for CookieHashLb {
323 fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
324 let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
325 if healthy.is_empty() {
326 return None;
327 }
328
329 let cookie_value = extract_cookie(&ctx.headers, &self.cookie_name).unwrap_or_default();
330 let hash = hash_value(&cookie_value);
331 Some(healthy[hash as usize % healthy.len()])
332 }
333}
334
335pub struct FirstLb;
341
342impl Default for FirstLb {
343 fn default() -> Self {
344 Self::new()
345 }
346}
347
348impl FirstLb {
349 pub fn new() -> Self {
350 Self
351 }
352}
353
354impl LoadBalancer for FirstLb {
355 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
356 (0..pool.len()).find(|&i| pool.is_healthy(i))
357 }
358}
359
360pub struct TwoRandomChoicesLb;
374
375impl Default for TwoRandomChoicesLb {
376 fn default() -> Self {
377 Self::new()
378 }
379}
380
381impl TwoRandomChoicesLb {
382 pub fn new() -> Self {
383 Self
384 }
385}
386
387impl LoadBalancer for TwoRandomChoicesLb {
388 fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
389 use rand::prelude::IndexedRandom;
390
391 let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
392 match healthy.len() {
393 0 => None,
394 1 => Some(healthy[0]),
395 _ => {
396 let mut rng = rand::rng();
397 let candidates: Vec<usize> = healthy.sample(&mut rng, 2).copied().collect();
399 let a = candidates[0];
400 let b = candidates[1];
401 if pool.conn_count(a) <= pool.conn_count(b) {
403 Some(a)
404 } else {
405 Some(b)
406 }
407 }
408 }
409 }
410}
411
412fn hash_value(value: &str) -> u64 {
418 let mut hasher = DefaultHasher::new();
419 value.hash(&mut hasher);
420 hasher.finish()
421}
422
423fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
425 for value in headers.get_all(http::header::COOKIE) {
426 let Ok(cookie_str) = value.to_str() else {
427 continue;
428 };
429 for pair in cookie_str.split(';') {
430 let pair = pair.trim();
431 if let Some((k, v)) = pair.split_once('=')
432 && k.trim() == name
433 {
434 return Some(v.trim().to_string());
435 }
436 }
437 }
438 None
439}