Skip to main content

loadwise_core/strategy/
weighted_round_robin.rs

1use std::hash::{DefaultHasher, Hash, Hasher};
2use std::sync::Mutex;
3
4use super::{SelectionContext, Strategy};
5use crate::{Node, Weighted};
6
7/// Smooth weighted round-robin (Nginx-style).
8///
9/// Each node maintains a running `current_weight`. On each selection:
10/// 1. Add each node's effective weight to its current weight
11/// 2. Select the node with the highest current weight
12/// 3. Subtract total weight from the selected node's current weight
13///
14/// This produces an interleaved sequence that respects relative weights.
15///
16/// State is tracked by a fingerprint of the candidate set (IDs **and** weights).
17/// If the candidate list changes — nodes added/removed/reordered, or a node's
18/// weight changes — the internal state resets and re-converges within a few rounds.
19pub struct WeightedRoundRobin {
20    state: Mutex<WrrState>,
21}
22
23impl std::fmt::Debug for WeightedRoundRobin {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        f.debug_struct("WeightedRoundRobin").finish_non_exhaustive()
26    }
27}
28
29struct WrrState {
30    fingerprint: u64,
31    weights: Vec<i64>,
32}
33
34impl WeightedRoundRobin {
35    pub fn new() -> Self {
36        Self {
37            state: Mutex::new(WrrState {
38                fingerprint: 0,
39                weights: Vec::new(),
40            }),
41        }
42    }
43}
44
45impl Default for WeightedRoundRobin {
46    fn default() -> Self {
47        Self::new()
48    }
49}
50
51/// WRR-specific fingerprint that includes node weights, so a weight change
52/// also triggers a state reset.
53fn wrr_fingerprint<N: Weighted + Node>(candidates: &[N]) -> u64 {
54    let mut hasher = DefaultHasher::new();
55    candidates.len().hash(&mut hasher);
56    for node in candidates {
57        node.id().hash(&mut hasher);
58        node.weight().hash(&mut hasher);
59    }
60    hasher.finish()
61}
62
63impl<N: Weighted + Node> Strategy<N> for WeightedRoundRobin {
64    fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
65        if candidates.is_empty() {
66            return None;
67        }
68
69        let fingerprint = wrr_fingerprint(candidates);
70        let mut state = self.state.lock().unwrap();
71
72        // Reset state when the candidate set changes
73        if state.fingerprint != fingerprint {
74            state.fingerprint = fingerprint;
75            state.weights = vec![0; candidates.len()];
76        }
77
78        let total_weight: i64 = candidates
79            .iter()
80            .enumerate()
81            .filter(|(i, _)| !ctx.is_excluded(*i))
82            .map(|(_, n)| n.weight() as i64)
83            .sum();
84        if total_weight == 0 {
85            return None;
86        }
87
88        let mut best_idx = None;
89        let mut best_weight = i64::MIN;
90
91        for (i, node) in candidates.iter().enumerate() {
92            if ctx.is_excluded(i) {
93                continue;
94            }
95            state.weights[i] += node.weight() as i64;
96            if state.weights[i] > best_weight {
97                best_weight = state.weights[i];
98                best_idx = Some(i);
99            }
100        }
101
102        if let Some(idx) = best_idx {
103            state.weights[idx] -= total_weight;
104        }
105        best_idx
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    struct W {
114        id: &'static str,
115        weight: u32,
116    }
117
118    impl W {
119        fn new(id: &'static str, weight: u32) -> Self {
120            Self { id, weight }
121        }
122    }
123
124    impl Node for W {
125        type Id = &'static str;
126        fn id(&self) -> &&'static str {
127            &self.id
128        }
129    }
130
131    impl Weighted for W {
132        fn weight(&self) -> u32 {
133            self.weight
134        }
135    }
136
137    #[test]
138    fn respects_weights() {
139        let wrr = WeightedRoundRobin::new();
140        let nodes = [W::new("a", 5), W::new("b", 1), W::new("c", 1)];
141        let ctx = SelectionContext::default();
142
143        let mut counts = [0u32; 3];
144        for _ in 0..70 {
145            let idx = wrr.select(&nodes, &ctx).unwrap();
146            counts[idx] += 1;
147        }
148
149        // 5:1:1 ratio over 70 rounds = 50:10:10
150        assert_eq!(counts[0], 50);
151        assert_eq!(counts[1], 10);
152        assert_eq!(counts[2], 10);
153    }
154
155    #[test]
156    fn smooth_distribution() {
157        let wrr = WeightedRoundRobin::new();
158        let nodes = [W::new("x", 2), W::new("y", 1)];
159        let ctx = SelectionContext::default();
160
161        let sequence: Vec<usize> = (0..6)
162            .map(|_| wrr.select(&nodes, &ctx).unwrap())
163            .collect();
164        assert_eq!(sequence, vec![0, 1, 0, 0, 1, 0]);
165    }
166
167    #[test]
168    fn skips_excluded() {
169        let wrr = WeightedRoundRobin::new();
170        let nodes = [W::new("a", 3), W::new("b", 1)];
171        let ctx = SelectionContext::builder().exclude(vec![0]).build();
172        // Only node b (index 1) is eligible
173        assert_eq!(wrr.select(&nodes, &ctx), Some(1));
174    }
175
176    #[test]
177    fn all_excluded_returns_none() {
178        let wrr = WeightedRoundRobin::new();
179        let nodes = [W::new("a", 1), W::new("b", 1)];
180        let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
181        assert_eq!(wrr.select(&nodes, &ctx), None);
182    }
183
184    #[test]
185    fn resets_on_candidate_change() {
186        let wrr = WeightedRoundRobin::new();
187        let ctx = SelectionContext::default();
188
189        let nodes_v1 = [W::new("a", 2), W::new("b", 1)];
190        let _ = wrr.select(&nodes_v1, &ctx);
191        let _ = wrr.select(&nodes_v1, &ctx);
192
193        // Change candidate set — state should reset, not corrupt
194        let nodes_v2 = [W::new("b", 1), W::new("c", 3)];
195        let idx = wrr.select(&nodes_v2, &ctx).unwrap();
196        // After reset, node with weight 3 should win first round
197        assert_eq!(idx, 1);
198    }
199}