1use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::sync::Mutex;
14use std::time::Duration;
15
16#[derive(Debug, Clone, Default)]
18pub enum SelectionStrategy {
19 #[default]
21 RoundRobin,
22 Priority,
24 WeightedRoundRobin { weights: Vec<u32> },
26 LatencyBased,
28 Sticky { key: String },
30}
31
32pub struct SelectionState {
34 rr_cursor: AtomicUsize,
36 latencies: Mutex<Vec<u64>>,
38 #[allow(dead_code)]
40 wrr_cursor: AtomicUsize,
41 wrr_counter: AtomicU64,
42}
43
44impl SelectionState {
45 pub fn new(provider_count: usize) -> Self {
47 Self {
48 rr_cursor: AtomicUsize::new(0),
49 latencies: Mutex::new(vec![0; provider_count]),
50 wrr_cursor: AtomicUsize::new(0),
51 wrr_counter: AtomicU64::new(0),
52 }
53 }
54
55 pub fn select(&self, strategy: &SelectionStrategy, allowed: &[bool]) -> Option<usize> {
60 let count = allowed.len();
61 if count == 0 {
62 return None;
63 }
64
65 match strategy {
66 SelectionStrategy::RoundRobin => self.select_round_robin(allowed),
67 SelectionStrategy::Priority => self.select_priority(allowed),
68 SelectionStrategy::WeightedRoundRobin { weights } => {
69 self.select_weighted(allowed, weights)
70 }
71 SelectionStrategy::LatencyBased => self.select_latency(allowed),
72 SelectionStrategy::Sticky { key } => self.select_sticky(allowed, key),
73 }
74 }
75
76 pub fn record_latency(&self, index: usize, latency: Duration) {
78 let mut latencies = self.latencies.lock().unwrap();
79 if index < latencies.len() {
80 let sample = latency.as_micros() as u64;
82 let old = latencies[index];
83 if old == 0 {
84 latencies[index] = sample;
85 } else {
86 latencies[index] = (sample * 3 + old * 7) / 10;
87 }
88 }
89 }
90
91 fn select_round_robin(&self, allowed: &[bool]) -> Option<usize> {
94 let count = allowed.len();
95 let start = self.rr_cursor.fetch_add(1, Ordering::Relaxed) % count;
96 for i in 0..count {
97 let idx = (start + i) % count;
98 if allowed[idx] {
99 return Some(idx);
100 }
101 }
102 None
103 }
104
105 fn select_priority(&self, allowed: &[bool]) -> Option<usize> {
106 for (idx, &ok) in allowed.iter().enumerate() {
108 if ok {
109 return Some(idx);
110 }
111 }
112 None
113 }
114
115 fn select_weighted(&self, allowed: &[bool], weights: &[u32]) -> Option<usize> {
116 let effective: Vec<u32> = allowed
118 .iter()
119 .enumerate()
120 .map(|(i, &ok)| {
121 if ok {
122 weights.get(i).copied().unwrap_or(1)
123 } else {
124 0
125 }
126 })
127 .collect();
128
129 let total: u64 = effective.iter().map(|&w| w as u64).sum();
130 if total == 0 {
131 return None;
132 }
133
134 let counter = self.wrr_counter.fetch_add(1, Ordering::Relaxed);
135 let target = counter % total;
136
137 let mut cumulative = 0u64;
138 for (idx, &w) in effective.iter().enumerate() {
139 cumulative += w as u64;
140 if target < cumulative {
141 return Some(idx);
142 }
143 }
144 allowed.iter().position(|&ok| ok)
146 }
147
148 fn select_latency(&self, allowed: &[bool]) -> Option<usize> {
149 let latencies = self.latencies.lock().unwrap();
150 let mut best_idx = None;
151 let mut best_latency = u64::MAX;
152
153 for (idx, &ok) in allowed.iter().enumerate() {
154 if !ok {
155 continue;
156 }
157 let lat = latencies.get(idx).copied().unwrap_or(0);
158 let effective = if lat == 0 { 1 } else { lat };
160 if effective < best_latency {
161 best_latency = effective;
162 best_idx = Some(idx);
163 }
164 }
165 best_idx
166 }
167
168 fn select_sticky(&self, allowed: &[bool], key: &str) -> Option<usize> {
169 let count = allowed.len();
170 if count == 0 {
171 return None;
172 }
173
174 let mut hasher = DefaultHasher::new();
176 key.hash(&mut hasher);
177 let hash = hasher.finish() as usize;
178 let preferred = hash % count;
179
180 if allowed[preferred] {
182 return Some(preferred);
183 }
184
185 for i in 1..count {
187 let idx = (preferred + i) % count;
188 if allowed[idx] {
189 return Some(idx);
190 }
191 }
192 None
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 #[test]
201 fn round_robin_basic() {
202 let state = SelectionState::new(3);
203 let allowed = [true, true, true];
204
205 let a = state
206 .select(&SelectionStrategy::RoundRobin, &allowed)
207 .unwrap();
208 let b = state
209 .select(&SelectionStrategy::RoundRobin, &allowed)
210 .unwrap();
211 let c = state
212 .select(&SelectionStrategy::RoundRobin, &allowed)
213 .unwrap();
214
215 assert_ne!(a, b);
217 assert_ne!(b, c);
218 }
219
220 #[test]
221 fn round_robin_skips_disallowed() {
222 let state = SelectionState::new(3);
223 let allowed = [true, false, true];
224
225 let mut selected = std::collections::HashSet::new();
226 for _ in 0..10 {
227 let idx = state
228 .select(&SelectionStrategy::RoundRobin, &allowed)
229 .unwrap();
230 selected.insert(idx);
231 assert_ne!(idx, 1, "should never select disallowed provider");
232 }
233 }
234
235 #[test]
236 fn priority_selects_first_allowed() {
237 let state = SelectionState::new(3);
238
239 let allowed1 = [true, true, true];
240 assert_eq!(
241 state.select(&SelectionStrategy::Priority, &allowed1),
242 Some(0)
243 );
244
245 let allowed2 = [false, true, true];
246 assert_eq!(
247 state.select(&SelectionStrategy::Priority, &allowed2),
248 Some(1)
249 );
250
251 let allowed3 = [false, false, true];
252 assert_eq!(
253 state.select(&SelectionStrategy::Priority, &allowed3),
254 Some(2)
255 );
256 }
257
258 #[test]
259 fn priority_none_when_all_down() {
260 let state = SelectionState::new(3);
261 let allowed = [false, false, false];
262 assert_eq!(state.select(&SelectionStrategy::Priority, &allowed), None);
263 }
264
265 #[test]
266 fn weighted_round_robin() {
267 let state = SelectionState::new(3);
268 let strategy = SelectionStrategy::WeightedRoundRobin {
269 weights: vec![3, 1, 1],
270 };
271 let allowed = [true, true, true];
272
273 let mut counts = [0u32; 3];
274 for _ in 0..500 {
275 let idx = state.select(&strategy, &allowed).unwrap();
276 counts[idx] += 1;
277 }
278
279 assert!(
281 counts[0] > counts[1],
282 "weighted provider should get more traffic"
283 );
284 assert!(
285 counts[0] > counts[2],
286 "weighted provider should get more traffic"
287 );
288 }
289
290 #[test]
291 fn latency_based_selects_fastest() {
292 let state = SelectionState::new(3);
293 let allowed = [true, true, true];
294
295 state.record_latency(0, Duration::from_millis(100));
297 state.record_latency(1, Duration::from_millis(10)); state.record_latency(2, Duration::from_millis(50));
299
300 let idx = state
301 .select(&SelectionStrategy::LatencyBased, &allowed)
302 .unwrap();
303 assert_eq!(idx, 1, "should select fastest provider");
304 }
305
306 #[test]
307 fn latency_based_skips_disallowed() {
308 let state = SelectionState::new(3);
309 let allowed = [true, false, true]; state.record_latency(0, Duration::from_millis(100));
312 state.record_latency(1, Duration::from_millis(1)); state.record_latency(2, Duration::from_millis(50));
314
315 let idx = state
316 .select(&SelectionStrategy::LatencyBased, &allowed)
317 .unwrap();
318 assert_eq!(idx, 2, "should select fastest ALLOWED provider");
319 }
320
321 #[test]
322 fn sticky_consistent_hashing() {
323 let state = SelectionState::new(3);
324 let allowed = [true, true, true];
325 let strategy = SelectionStrategy::Sticky {
326 key: "0xAlice".to_string(),
327 };
328
329 let idx1 = state.select(&strategy, &allowed).unwrap();
330 let idx2 = state.select(&strategy, &allowed).unwrap();
331 let idx3 = state.select(&strategy, &allowed).unwrap();
332
333 assert_eq!(idx1, idx2);
335 assert_eq!(idx2, idx3);
336 }
337
338 #[test]
339 fn sticky_different_keys() {
340 let state = SelectionState::new(100);
341 let allowed = vec![true; 100];
342
343 let s1 = SelectionStrategy::Sticky {
344 key: "0xAlice".to_string(),
345 };
346 let s2 = SelectionStrategy::Sticky {
347 key: "0xBob".to_string(),
348 };
349
350 let idx1 = state.select(&s1, &allowed).unwrap();
351 let idx2 = state.select(&s2, &allowed).unwrap();
352
353 assert!(idx1 < 100);
357 assert!(idx2 < 100);
358 }
359
360 #[test]
361 fn sticky_fallback_when_preferred_down() {
362 let state = SelectionState::new(3);
363 let allowed_all = [true, true, true];
364 let strategy = SelectionStrategy::Sticky {
365 key: "test".to_string(),
366 };
367
368 let preferred = state.select(&strategy, &allowed_all).unwrap();
369
370 let mut allowed_partial = [true, true, true];
372 allowed_partial[preferred] = false;
373
374 let fallback = state.select(&strategy, &allowed_partial).unwrap();
375 assert_ne!(fallback, preferred);
376 }
377
378 #[test]
379 fn latency_ema_smoothing() {
380 let state = SelectionState::new(1);
381
382 state.record_latency(0, Duration::from_millis(100));
384 state.record_latency(0, Duration::from_millis(200)); let latencies = state.latencies.lock().unwrap();
387 let lat_us = latencies[0];
388 assert!(
390 lat_us > 100_000 && lat_us < 200_000,
391 "EMA should smooth: {lat_us}"
392 );
393 }
394}