Skip to main content

grapsus_proxy/upstream/
maglev.rs

1//! Maglev consistent hashing load balancer
2//!
3//! Implements Google's Maglev algorithm for consistent hashing with minimal
4//! disruption when backends are added or removed. Uses a permutation-based
5//! lookup table for O(1) selection.
6//!
7//! Reference: <https://research.google/pubs/pub44824/>
8
9use async_trait::async_trait;
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12use std::sync::Arc;
13use tokio::sync::RwLock;
14use tracing::{debug, trace, warn};
15use xxhash_rust::xxh3::xxh3_64;
16
17use grapsus_common::errors::{GrapsusError, GrapsusResult};
18
19use super::{LoadBalancer, RequestContext, TargetSelection, UpstreamTarget};
20
21/// Configuration for Maglev consistent hashing
22#[derive(Debug, Clone)]
23pub struct MaglevConfig {
24    /// Size of the lookup table (must be prime, default: 65537)
25    pub table_size: usize,
26    /// Key extraction method for hashing
27    pub key_source: MaglevKeySource,
28}
29
30impl Default for MaglevConfig {
31    fn default() -> Self {
32        Self {
33            // 65537 is a prime number commonly used in Maglev
34            table_size: 65537,
35            key_source: MaglevKeySource::ClientIp,
36        }
37    }
38}
39
40/// Source for extracting the hash key from requests
41#[derive(Debug, Clone)]
42pub enum MaglevKeySource {
43    /// Use client IP address (default)
44    ClientIp,
45    /// Use a specific header value
46    Header(String),
47    /// Use a specific cookie value
48    Cookie(String),
49    /// Use the request path
50    Path,
51}
52
53/// Maglev consistent hashing load balancer
54pub struct MaglevBalancer {
55    /// Original target list
56    targets: Vec<UpstreamTarget>,
57    /// Lookup table mapping hash -> target index
58    lookup_table: Arc<RwLock<Vec<Option<usize>>>>,
59    /// Health status per target
60    health_status: Arc<RwLock<HashMap<String, bool>>>,
61    /// Configuration
62    config: MaglevConfig,
63    /// Table generation counter (for cache invalidation)
64    generation: Arc<RwLock<u64>>,
65}
66
67impl MaglevBalancer {
68    /// Create a new Maglev balancer
69    pub fn new(targets: Vec<UpstreamTarget>, config: MaglevConfig) -> Self {
70        let mut health_status = HashMap::new();
71        for target in &targets {
72            health_status.insert(target.full_address(), true);
73        }
74
75        let table_size = config.table_size;
76        let balancer = Self {
77            targets,
78            lookup_table: Arc::new(RwLock::new(vec![None; table_size])),
79            health_status: Arc::new(RwLock::new(health_status)),
80            config,
81            generation: Arc::new(RwLock::new(0)),
82        };
83
84        // Build initial lookup table synchronously in a blocking manner
85        // This is fine since we're in construction
86        let targets_clone = balancer.targets.clone();
87        let table_size = balancer.config.table_size;
88        let table = Self::build_lookup_table(&targets_clone, table_size);
89
90        // We need to set the table - use try_write since we just created this
91        if let Ok(mut lookup) = balancer.lookup_table.try_write() {
92            *lookup = table;
93        }
94
95        balancer
96    }
97
98    /// Build the Maglev lookup table using permutation sequences
99    fn build_lookup_table(targets: &[UpstreamTarget], table_size: usize) -> Vec<Option<usize>> {
100        if targets.is_empty() {
101            return vec![None; table_size];
102        }
103
104        let n = targets.len();
105        let m = table_size;
106
107        // Generate permutation for each backend
108        let permutations: Vec<Vec<usize>> = targets
109            .iter()
110            .map(|target| Self::generate_permutation(&target.full_address(), m))
111            .collect();
112
113        // Build lookup table using round-robin across permutations
114        let mut table = vec![None; m];
115        let mut next = vec![0usize; n]; // Next index in each backend's permutation
116        let mut filled = 0;
117
118        while filled < m {
119            for i in 0..n {
120                // Find next empty slot for backend i
121                loop {
122                    let c = permutations[i][next[i]];
123                    next[i] += 1;
124
125                    if table[c].is_none() {
126                        table[c] = Some(i);
127                        filled += 1;
128                        break;
129                    }
130
131                    // Safety check to prevent infinite loop
132                    if next[i] >= m {
133                        next[i] = 0;
134                        break;
135                    }
136                }
137
138                if filled >= m {
139                    break;
140                }
141            }
142        }
143
144        table
145    }
146
147    /// Generate permutation sequence for a backend
148    fn generate_permutation(name: &str, table_size: usize) -> Vec<usize> {
149        let m = table_size;
150
151        // Use two independent hash functions
152        let h1 = xxh3_64(name.as_bytes()) as usize;
153        let h2 = {
154            let mut hasher = std::collections::hash_map::DefaultHasher::new();
155            name.hash(&mut hasher);
156            hasher.finish() as usize
157        };
158
159        // offset and skip for this backend
160        let offset = h1 % m;
161        let skip = (h2 % (m - 1)) + 1; // skip must be non-zero and < m
162
163        // Generate permutation
164        (0..m).map(|i| (offset + i * skip) % m).collect()
165    }
166
167    /// Rebuild lookup table with only healthy targets
168    async fn rebuild_table_for_healthy(&self) {
169        let health = self.health_status.read().await;
170        let healthy_targets: Vec<_> = self
171            .targets
172            .iter()
173            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
174            .cloned()
175            .collect();
176        drop(health);
177
178        if healthy_targets.is_empty() {
179            // Keep existing table to allow fallback
180            return;
181        }
182
183        let table = Self::build_lookup_table(&healthy_targets, self.config.table_size);
184
185        let mut lookup = self.lookup_table.write().await;
186        *lookup = table;
187
188        let mut gen = self.generation.write().await;
189        *gen += 1;
190
191        debug!(
192            healthy_count = healthy_targets.len(),
193            total_count = self.targets.len(),
194            generation = *gen,
195            "Maglev lookup table rebuilt"
196        );
197    }
198
199    /// Extract hash key from request context
200    fn extract_key(&self, context: Option<&RequestContext>) -> String {
201        match &self.config.key_source {
202            MaglevKeySource::ClientIp => context
203                .and_then(|c| c.client_ip.map(|ip| ip.ip().to_string()))
204                .unwrap_or_else(|| "default".to_string()),
205            MaglevKeySource::Header(name) => context
206                .and_then(|c| c.headers.get(name).cloned())
207                .unwrap_or_else(|| "default".to_string()),
208            MaglevKeySource::Cookie(name) => context
209                .and_then(|c| {
210                    c.headers.get("cookie").and_then(|cookies| {
211                        cookies.split(';').find_map(|cookie| {
212                            let (key, value) = cookie.trim().split_once('=')?;
213                            if key == name {
214                                Some(value.to_string())
215                            } else {
216                                None
217                            }
218                        })
219                    })
220                })
221                .unwrap_or_else(|| "default".to_string()),
222            MaglevKeySource::Path => context
223                .map(|c| c.path.clone())
224                .unwrap_or_else(|| "/".to_string()),
225        }
226    }
227
228    /// Get healthy targets for fallback selection
229    async fn get_healthy_targets(&self) -> Vec<&UpstreamTarget> {
230        let health = self.health_status.read().await;
231        self.targets
232            .iter()
233            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
234            .collect()
235    }
236}
237
238#[async_trait]
239impl LoadBalancer for MaglevBalancer {
240    async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
241        trace!(
242            total_targets = self.targets.len(),
243            algorithm = "maglev",
244            "Selecting upstream target"
245        );
246
247        // Get healthy targets
248        let health = self.health_status.read().await;
249        let healthy_targets: Vec<_> = self
250            .targets
251            .iter()
252            .enumerate()
253            .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
254            .collect();
255        drop(health);
256
257        if healthy_targets.is_empty() {
258            warn!(
259                total_targets = self.targets.len(),
260                algorithm = "maglev",
261                "No healthy upstream targets available"
262            );
263            return Err(GrapsusError::NoHealthyUpstream);
264        }
265
266        // Extract key and compute hash
267        let key = self.extract_key(context);
268        let hash = xxh3_64(key.as_bytes()) as usize;
269        let table_index = hash % self.config.table_size;
270
271        // Look up in table
272        let lookup = self.lookup_table.read().await;
273        let target_index = lookup[table_index];
274        drop(lookup);
275
276        // Get the target
277        let target = if let Some(idx) = target_index {
278            // Verify the target is still healthy
279            if idx < self.targets.len() {
280                let t = &self.targets[idx];
281                let health = self.health_status.read().await;
282                if *health.get(&t.full_address()).unwrap_or(&true) {
283                    t
284                } else {
285                    // Target unhealthy, fall back to first healthy target
286                    healthy_targets
287                        .first()
288                        .map(|(_, t)| *t)
289                        .ok_or(GrapsusError::NoHealthyUpstream)?
290                }
291            } else {
292                // Index out of bounds, fall back
293                healthy_targets
294                    .first()
295                    .map(|(_, t)| *t)
296                    .ok_or(GrapsusError::NoHealthyUpstream)?
297            }
298        } else {
299            // No entry in table, fall back to first healthy
300            healthy_targets
301                .first()
302                .map(|(_, t)| *t)
303                .ok_or(GrapsusError::NoHealthyUpstream)?
304        };
305
306        trace!(
307            selected_target = %target.full_address(),
308            hash_key = %key,
309            table_index = table_index,
310            healthy_count = healthy_targets.len(),
311            algorithm = "maglev",
312            "Selected target via Maglev consistent hashing"
313        );
314
315        Ok(TargetSelection {
316            address: target.full_address(),
317            weight: target.weight,
318            metadata: HashMap::new(),
319        })
320    }
321
322    async fn report_health(&self, address: &str, healthy: bool) {
323        let prev_health = {
324            let health = self.health_status.read().await;
325            *health.get(address).unwrap_or(&true)
326        };
327
328        if prev_health != healthy {
329            trace!(
330                target = %address,
331                healthy = healthy,
332                algorithm = "maglev",
333                "Target health changed, rebuilding lookup table"
334            );
335
336            self.health_status
337                .write()
338                .await
339                .insert(address.to_string(), healthy);
340
341            // Rebuild table when health changes
342            self.rebuild_table_for_healthy().await;
343        } else {
344            self.health_status
345                .write()
346                .await
347                .insert(address.to_string(), healthy);
348        }
349    }
350
351    async fn healthy_targets(&self) -> Vec<String> {
352        self.health_status
353            .read()
354            .await
355            .iter()
356            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
357            .collect()
358    }
359}
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    fn make_targets(count: usize) -> Vec<UpstreamTarget> {
366        (0..count)
367            .map(|i| UpstreamTarget::new(format!("backend-{}", i), 8080, 100))
368            .collect()
369    }
370
371    #[test]
372    fn test_build_lookup_table() {
373        let targets = make_targets(3);
374        let table = MaglevBalancer::build_lookup_table(&targets, 65537);
375
376        // All slots should be filled
377        assert!(table.iter().all(|entry| entry.is_some()));
378
379        // Distribution should be roughly even
380        let mut counts = vec![0usize; 3];
381        for idx in table.iter().flatten() {
382            counts[*idx] += 1;
383        }
384
385        // Each backend should get roughly 1/3 of the slots
386        let expected = 65537 / 3;
387        for count in counts {
388            assert!(
389                (count as i64 - expected as i64).abs() < (expected as i64 / 10),
390                "Uneven distribution: {} vs expected ~{}",
391                count,
392                expected
393            );
394        }
395    }
396
397    #[test]
398    fn test_permutation_generation() {
399        let perm1 = MaglevBalancer::generate_permutation("backend-1", 65537);
400        let perm2 = MaglevBalancer::generate_permutation("backend-2", 65537);
401
402        // Permutations should be different
403        assert_ne!(perm1[0..100], perm2[0..100]);
404
405        // Each permutation should cover all indices
406        let mut seen = vec![false; 65537];
407        for &idx in &perm1 {
408            seen[idx] = true;
409        }
410        assert!(seen.iter().all(|&s| s));
411    }
412
413    #[tokio::test]
414    async fn test_consistent_selection() {
415        let targets = make_targets(5);
416        let balancer = MaglevBalancer::new(targets, MaglevConfig::default());
417
418        let context = RequestContext {
419            client_ip: Some("192.168.1.100:12345".parse().unwrap()),
420            headers: HashMap::new(),
421            path: "/api/test".to_string(),
422            method: "GET".to_string(),
423        };
424
425        // Same context should always select same target
426        let selection1 = balancer.select(Some(&context)).await.unwrap();
427        let selection2 = balancer.select(Some(&context)).await.unwrap();
428        let selection3 = balancer.select(Some(&context)).await.unwrap();
429
430        assert_eq!(selection1.address, selection2.address);
431        assert_eq!(selection2.address, selection3.address);
432    }
433
434    #[tokio::test]
435    async fn test_minimal_disruption() {
436        // Test that removing a backend only affects keys mapped to that backend
437        let targets = make_targets(5);
438        let balancer = MaglevBalancer::new(targets.clone(), MaglevConfig::default());
439
440        // Record selections for many keys
441        let mut original_selections = HashMap::new();
442        for i in 0..1000 {
443            let context = RequestContext {
444                client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
445                headers: HashMap::new(),
446                path: format!("/api/test/{}", i),
447                method: "GET".to_string(),
448            };
449            let selection = balancer.select(Some(&context)).await.unwrap();
450            original_selections.insert(i, selection.address);
451        }
452
453        // Mark one backend as unhealthy
454        balancer.report_health("backend-2:8080", false).await;
455
456        // Check how many selections changed
457        let mut changed = 0;
458        for i in 0..1000 {
459            let context = RequestContext {
460                client_ip: Some(format!("192.168.1.{}:12345", i % 256).parse().unwrap()),
461                headers: HashMap::new(),
462                path: format!("/api/test/{}", i),
463                method: "GET".to_string(),
464            };
465            let selection = balancer.select(Some(&context)).await.unwrap();
466            if selection.address != original_selections[&i] {
467                changed += 1;
468            }
469        }
470
471        // When a backend is removed, ideally only ~20% should change.
472        // Our current implementation rebuilds the table which causes more
473        // disruption, but it should still be less than replacing all keys.
474        // With 5 backends -> 4 backends, worst case is 100% change.
475        // We expect significantly less than that.
476        assert!(
477            changed < 800,
478            "Too many selections changed: {} (expected less than 800 for 1/5 backend removal)",
479            changed
480        );
481
482        // And verify that at least some selections are stable
483        assert!(
484            changed < 1000 - 100,
485            "Too few stable selections: only {} unchanged",
486            1000 - changed
487        );
488    }
489}