Skip to main content

gatel_core/proxy/
lb.rs

1use std::collections::hash_map::DefaultHasher;
2use std::hash::{Hash, Hasher};
3use std::net::SocketAddr;
4use std::sync::Mutex;
5use std::sync::atomic::{AtomicUsize, Ordering};
6
7use http::HeaderMap;
8
9use super::upstream::UpstreamPool;
10
11// ---------------------------------------------------------------------------
12// Context passed to load balancers that need request-level information
13// ---------------------------------------------------------------------------
14
15/// Contextual information about the current request, made available to
16/// load-balancing strategies that need more than the pool itself.
17pub struct LbContext {
18    pub client_addr: SocketAddr,
19    pub uri: String,
20    pub headers: HeaderMap,
21}
22
23// ---------------------------------------------------------------------------
24// Trait
25// ---------------------------------------------------------------------------
26
27/// Selects an upstream backend index from the pool.
28///
29/// Implementations must be `Send + Sync` so they can be shared across tasks.
30pub trait LoadBalancer: Send + Sync {
31    /// Choose a backend index.  Returns `None` when no backend is available.
32    fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize>;
33}
34
35// ---------------------------------------------------------------------------
36// Round-Robin
37// ---------------------------------------------------------------------------
38
39/// Simple round-robin load balancer.
40pub struct RoundRobinLb {
41    counter: AtomicUsize,
42}
43
44impl Default for RoundRobinLb {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl RoundRobinLb {
51    pub fn new() -> Self {
52        Self {
53            counter: AtomicUsize::new(0),
54        }
55    }
56}
57
58impl LoadBalancer for RoundRobinLb {
59    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
60        let n = pool.len();
61        if n == 0 {
62            return None;
63        }
64        // Try up to `n` times to find a healthy backend.
65        for _ in 0..n {
66            let idx = self.counter.fetch_add(1, Ordering::Relaxed) % n;
67            if pool.is_healthy(idx) {
68                return Some(idx);
69            }
70        }
71        None
72    }
73}
74
75// ---------------------------------------------------------------------------
76// Random
77// ---------------------------------------------------------------------------
78
79/// Random selection across healthy backends.
80pub struct RandomLb;
81
82impl Default for RandomLb {
83    fn default() -> Self {
84        Self::new()
85    }
86}
87
88impl RandomLb {
89    pub fn new() -> Self {
90        Self
91    }
92}
93
94impl LoadBalancer for RandomLb {
95    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
96        use rand::prelude::IndexedRandom;
97
98        let healthy_indices: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
99        if healthy_indices.is_empty() {
100            return None;
101        }
102        let mut rng = rand::rng();
103        healthy_indices.choose(&mut rng).copied()
104    }
105}
106
107// ---------------------------------------------------------------------------
108// Weighted Round-Robin (smooth, nginx-style)
109// ---------------------------------------------------------------------------
110
111/// Smooth weighted round-robin as described in the nginx implementation.
112///
113/// Each backend has:
114///   - `effective_weight` (starts equal to the configured weight; can be reduced on transient
115///     errors and recovered later).
116///   - `current_weight` (accumulates each round; the backend with the highest current_weight is
117///     selected, then has total_weight subtracted).
118pub struct WeightedRoundRobinLb {
119    state: Mutex<Vec<WrrEntry>>,
120}
121
122struct WrrEntry {
123    effective_weight: i64,
124    current_weight: i64,
125}
126
127impl WeightedRoundRobinLb {
128    pub fn new(weights: &[u32]) -> Self {
129        let state = weights
130            .iter()
131            .map(|&w| WrrEntry {
132                effective_weight: w as i64,
133                current_weight: 0,
134            })
135            .collect();
136        Self {
137            state: Mutex::new(state),
138        }
139    }
140}
141
142impl LoadBalancer for WeightedRoundRobinLb {
143    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
144        let mut entries = self.state.lock().ok()?;
145        if entries.is_empty() {
146            return None;
147        }
148
149        let mut total: i64 = 0;
150        let mut best_idx: Option<usize> = None;
151        let mut best_weight: i64 = i64::MIN;
152
153        for (i, entry) in entries.iter_mut().enumerate() {
154            if !pool.is_healthy(i) {
155                continue;
156            }
157            entry.current_weight += entry.effective_weight;
158            total += entry.effective_weight;
159
160            if entry.current_weight > best_weight {
161                best_weight = entry.current_weight;
162                best_idx = Some(i);
163            }
164        }
165
166        if let Some(idx) = best_idx {
167            entries[idx].current_weight -= total;
168        }
169
170        best_idx
171    }
172}
173
174// ---------------------------------------------------------------------------
175// IP-Hash
176// ---------------------------------------------------------------------------
177
178/// Select a backend by hashing the client IP address for session affinity.
179pub struct IpHashLb;
180
181impl Default for IpHashLb {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl IpHashLb {
188    pub fn new() -> Self {
189        Self
190    }
191}
192
193impl LoadBalancer for IpHashLb {
194    fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
195        let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
196        if healthy.is_empty() {
197            return None;
198        }
199        let hash = hash_value(&ctx.client_addr.ip().to_string());
200        Some(healthy[hash as usize % healthy.len()])
201    }
202}
203
204// ---------------------------------------------------------------------------
205// Least Connections
206// ---------------------------------------------------------------------------
207
208/// Select the healthy backend with the fewest active connections.
209pub struct LeastConnLb;
210
211impl Default for LeastConnLb {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217impl LeastConnLb {
218    pub fn new() -> Self {
219        Self
220    }
221}
222
223impl LoadBalancer for LeastConnLb {
224    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
225        let mut best_idx: Option<usize> = None;
226        let mut best_count = usize::MAX;
227
228        for i in 0..pool.len() {
229            if !pool.is_healthy(i) {
230                continue;
231            }
232            let count = pool.conn_count(i);
233            if count < best_count {
234                best_count = count;
235                best_idx = Some(i);
236            }
237        }
238
239        best_idx
240    }
241}
242
243// ---------------------------------------------------------------------------
244// URI-Hash
245// ---------------------------------------------------------------------------
246
247/// Select a backend by hashing the request URI path.  Requests to the same
248/// path will consistently hit the same backend (useful for caching).
249pub struct UriHashLb;
250
251impl Default for UriHashLb {
252    fn default() -> Self {
253        Self::new()
254    }
255}
256
257impl UriHashLb {
258    pub fn new() -> Self {
259        Self
260    }
261}
262
263impl LoadBalancer for UriHashLb {
264    fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
265        let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
266        if healthy.is_empty() {
267            return None;
268        }
269        let hash = hash_value(&ctx.uri);
270        Some(healthy[hash as usize % healthy.len()])
271    }
272}
273
274// ---------------------------------------------------------------------------
275// Header-Hash
276// ---------------------------------------------------------------------------
277
278/// Select a backend by hashing the value of a specific request header.
279pub struct HeaderHashLb {
280    header_name: String,
281}
282
283impl HeaderHashLb {
284    pub fn new(header_name: String) -> Self {
285        Self { header_name }
286    }
287}
288
289impl LoadBalancer for HeaderHashLb {
290    fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
291        let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
292        if healthy.is_empty() {
293            return None;
294        }
295
296        let value = ctx
297            .headers
298            .get(&self.header_name)
299            .and_then(|v| v.to_str().ok())
300            .unwrap_or("");
301
302        let hash = hash_value(value);
303        Some(healthy[hash as usize % healthy.len()])
304    }
305}
306
307// ---------------------------------------------------------------------------
308// Cookie-Hash
309// ---------------------------------------------------------------------------
310
311/// Select a backend by hashing the value of a specific cookie.
312pub struct CookieHashLb {
313    cookie_name: String,
314}
315
316impl CookieHashLb {
317    pub fn new(cookie_name: String) -> Self {
318        Self { cookie_name }
319    }
320}
321
322impl LoadBalancer for CookieHashLb {
323    fn select(&self, pool: &UpstreamPool, ctx: &LbContext) -> Option<usize> {
324        let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
325        if healthy.is_empty() {
326            return None;
327        }
328
329        let cookie_value = extract_cookie(&ctx.headers, &self.cookie_name).unwrap_or_default();
330        let hash = hash_value(&cookie_value);
331        Some(healthy[hash as usize % healthy.len()])
332    }
333}
334
335// ---------------------------------------------------------------------------
336// First (always pick the first healthy backend)
337// ---------------------------------------------------------------------------
338
339/// Always select the first healthy backend.  Simple active/standby failover.
340pub struct FirstLb;
341
342impl Default for FirstLb {
343    fn default() -> Self {
344        Self::new()
345    }
346}
347
348impl FirstLb {
349    pub fn new() -> Self {
350        Self
351    }
352}
353
354impl LoadBalancer for FirstLb {
355    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
356        (0..pool.len()).find(|&i| pool.is_healthy(i))
357    }
358}
359
360// ---------------------------------------------------------------------------
361// Two Random Choices
362// ---------------------------------------------------------------------------
363
364/// Two Random Choices (Power of Two Choices) load balancer.
365///
366/// Picks two healthy backends at random and selects the one with fewer active
367/// connections. This achieves near-optimal load distribution with O(1)
368/// selection cost, avoiding the global scan of `LeastConnLb`.
369///
370/// - 0 healthy backends → `None`
371/// - 1 healthy backend  → use it directly
372/// - 2+ healthy backends → pick 2 at random, choose the one with fewer connections
373pub struct TwoRandomChoicesLb;
374
375impl Default for TwoRandomChoicesLb {
376    fn default() -> Self {
377        Self::new()
378    }
379}
380
381impl TwoRandomChoicesLb {
382    pub fn new() -> Self {
383        Self
384    }
385}
386
387impl LoadBalancer for TwoRandomChoicesLb {
388    fn select(&self, pool: &UpstreamPool, _ctx: &LbContext) -> Option<usize> {
389        use rand::prelude::IndexedRandom;
390
391        let healthy: Vec<usize> = (0..pool.len()).filter(|&i| pool.is_healthy(i)).collect();
392        match healthy.len() {
393            0 => None,
394            1 => Some(healthy[0]),
395            _ => {
396                let mut rng = rand::rng();
397                // Sample two distinct candidates.
398                let candidates: Vec<usize> = healthy.sample(&mut rng, 2).copied().collect();
399                let a = candidates[0];
400                let b = candidates[1];
401                // Pick the backend with fewer active connections.
402                if pool.conn_count(a) <= pool.conn_count(b) {
403                    Some(a)
404                } else {
405                    Some(b)
406                }
407            }
408        }
409    }
410}
411
412// ---------------------------------------------------------------------------
413// Helpers
414// ---------------------------------------------------------------------------
415
416/// Compute a deterministic hash for a string value using `DefaultHasher`.
417fn hash_value(value: &str) -> u64 {
418    let mut hasher = DefaultHasher::new();
419    value.hash(&mut hasher);
420    hasher.finish()
421}
422
423/// Extract a cookie value from the `Cookie` header(s).
424fn extract_cookie(headers: &HeaderMap, name: &str) -> Option<String> {
425    for value in headers.get_all(http::header::COOKIE) {
426        let Ok(cookie_str) = value.to_str() else {
427            continue;
428        };
429        for pair in cookie_str.split(';') {
430            let pair = pair.trim();
431            if let Some((k, v)) = pair.split_once('=')
432                && k.trim() == name
433            {
434                return Some(v.trim().to_string());
435            }
436        }
437    }
438    None
439}