loadwise_core/strategy/
weighted_round_robin.rs1use std::hash::{DefaultHasher, Hash, Hasher};
2use std::sync::Mutex;
3
4use super::{SelectionContext, Strategy};
5use crate::{Node, Weighted};
6
7pub struct WeightedRoundRobin {
20 state: Mutex<WrrState>,
21}
22
23impl std::fmt::Debug for WeightedRoundRobin {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("WeightedRoundRobin").finish_non_exhaustive()
26 }
27}
28
29struct WrrState {
30 fingerprint: u64,
31 weights: Vec<i64>,
32}
33
34impl WeightedRoundRobin {
35 pub fn new() -> Self {
36 Self {
37 state: Mutex::new(WrrState {
38 fingerprint: 0,
39 weights: Vec::new(),
40 }),
41 }
42 }
43}
44
45impl Default for WeightedRoundRobin {
46 fn default() -> Self {
47 Self::new()
48 }
49}
50
51fn wrr_fingerprint<N: Weighted + Node>(candidates: &[N]) -> u64 {
54 let mut hasher = DefaultHasher::new();
55 candidates.len().hash(&mut hasher);
56 for node in candidates {
57 node.id().hash(&mut hasher);
58 node.weight().hash(&mut hasher);
59 }
60 hasher.finish()
61}
62
63impl<N: Weighted + Node> Strategy<N> for WeightedRoundRobin {
64 fn select(&self, candidates: &[N], ctx: &SelectionContext) -> Option<usize> {
65 if candidates.is_empty() {
66 return None;
67 }
68
69 let fingerprint = wrr_fingerprint(candidates);
70 let mut state = self.state.lock().unwrap();
71
72 if state.fingerprint != fingerprint {
74 state.fingerprint = fingerprint;
75 state.weights = vec![0; candidates.len()];
76 }
77
78 let total_weight: i64 = candidates
79 .iter()
80 .enumerate()
81 .filter(|(i, _)| !ctx.is_excluded(*i))
82 .map(|(_, n)| n.weight() as i64)
83 .sum();
84 if total_weight == 0 {
85 return None;
86 }
87
88 let mut best_idx = None;
89 let mut best_weight = i64::MIN;
90
91 for (i, node) in candidates.iter().enumerate() {
92 if ctx.is_excluded(i) {
93 continue;
94 }
95 state.weights[i] += node.weight() as i64;
96 if state.weights[i] > best_weight {
97 best_weight = state.weights[i];
98 best_idx = Some(i);
99 }
100 }
101
102 if let Some(idx) = best_idx {
103 state.weights[idx] -= total_weight;
104 }
105 best_idx
106 }
107}
108
109#[cfg(test)]
110mod tests {
111 use super::*;
112
113 struct W {
114 id: &'static str,
115 weight: u32,
116 }
117
118 impl W {
119 fn new(id: &'static str, weight: u32) -> Self {
120 Self { id, weight }
121 }
122 }
123
124 impl Node for W {
125 type Id = &'static str;
126 fn id(&self) -> &&'static str {
127 &self.id
128 }
129 }
130
131 impl Weighted for W {
132 fn weight(&self) -> u32 {
133 self.weight
134 }
135 }
136
137 #[test]
138 fn respects_weights() {
139 let wrr = WeightedRoundRobin::new();
140 let nodes = [W::new("a", 5), W::new("b", 1), W::new("c", 1)];
141 let ctx = SelectionContext::default();
142
143 let mut counts = [0u32; 3];
144 for _ in 0..70 {
145 let idx = wrr.select(&nodes, &ctx).unwrap();
146 counts[idx] += 1;
147 }
148
149 assert_eq!(counts[0], 50);
151 assert_eq!(counts[1], 10);
152 assert_eq!(counts[2], 10);
153 }
154
155 #[test]
156 fn smooth_distribution() {
157 let wrr = WeightedRoundRobin::new();
158 let nodes = [W::new("x", 2), W::new("y", 1)];
159 let ctx = SelectionContext::default();
160
161 let sequence: Vec<usize> = (0..6)
162 .map(|_| wrr.select(&nodes, &ctx).unwrap())
163 .collect();
164 assert_eq!(sequence, vec![0, 1, 0, 0, 1, 0]);
165 }
166
167 #[test]
168 fn skips_excluded() {
169 let wrr = WeightedRoundRobin::new();
170 let nodes = [W::new("a", 3), W::new("b", 1)];
171 let ctx = SelectionContext::builder().exclude(vec![0]).build();
172 assert_eq!(wrr.select(&nodes, &ctx), Some(1));
174 }
175
176 #[test]
177 fn all_excluded_returns_none() {
178 let wrr = WeightedRoundRobin::new();
179 let nodes = [W::new("a", 1), W::new("b", 1)];
180 let ctx = SelectionContext::builder().exclude(vec![0, 1]).build();
181 assert_eq!(wrr.select(&nodes, &ctx), None);
182 }
183
184 #[test]
185 fn resets_on_candidate_change() {
186 let wrr = WeightedRoundRobin::new();
187 let ctx = SelectionContext::default();
188
189 let nodes_v1 = [W::new("a", 2), W::new("b", 1)];
190 let _ = wrr.select(&nodes_v1, &ctx);
191 let _ = wrr.select(&nodes_v1, &ctx);
192
193 let nodes_v2 = [W::new("b", 1), W::new("c", 3)];
195 let idx = wrr.select(&nodes_v2, &ctx).unwrap();
196 assert_eq!(idx, 1);
198 }
199}