Skip to main content

chainrpc_core/
selection.rs

1//! Provider selection strategies for the pool.
2//!
3//! Strategies:
4//! - RoundRobin — distribute evenly across healthy providers
5//! - Priority — try in priority order (lowest number = highest priority)
6//! - WeightedRoundRobin — distribute proportionally by weight
7//! - LatencyBased — route to the fastest responding provider
8//! - Sticky — same provider for same key (e.g. address for nonce management)
9
10use std::collections::hash_map::DefaultHasher;
11use std::hash::{Hash, Hasher};
12use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
13use std::sync::Mutex;
14use std::time::Duration;
15
16/// A selection strategy decides which provider index to use next.
17#[derive(Debug, Clone, Default)]
18pub enum SelectionStrategy {
19    /// Round-robin across all allowed providers.
20    #[default]
21    RoundRobin,
22    /// Try providers in priority order (based on their registration order).
23    Priority,
24    /// Weighted round-robin — higher weight gets more traffic.
25    WeightedRoundRobin { weights: Vec<u32> },
26    /// Route to the provider with the lowest observed latency.
27    LatencyBased,
28    /// Stick to the same provider for a given key (e.g. sender address).
29    Sticky { key: String },
30}
31
32/// State for stateful selection strategies.
33pub struct SelectionState {
34    /// Round-robin cursor.
35    rr_cursor: AtomicUsize,
36    /// Per-provider latency (microseconds) for LatencyBased.
37    latencies: Mutex<Vec<u64>>,
38    /// Weighted round-robin state.
39    #[allow(dead_code)]
40    wrr_cursor: AtomicUsize,
41    wrr_counter: AtomicU64,
42}
43
44impl SelectionState {
45    /// Create state for a given number of providers.
46    pub fn new(provider_count: usize) -> Self {
47        Self {
48            rr_cursor: AtomicUsize::new(0),
49            latencies: Mutex::new(vec![0; provider_count]),
50            wrr_cursor: AtomicUsize::new(0),
51            wrr_counter: AtomicU64::new(0),
52        }
53    }
54
55    /// Select the next provider index.
56    ///
57    /// `allowed` is a slice of booleans indicating which providers are available
58    /// (circuit breaker allows requests).
59    pub fn select(&self, strategy: &SelectionStrategy, allowed: &[bool]) -> Option<usize> {
60        let count = allowed.len();
61        if count == 0 {
62            return None;
63        }
64
65        match strategy {
66            SelectionStrategy::RoundRobin => self.select_round_robin(allowed),
67            SelectionStrategy::Priority => self.select_priority(allowed),
68            SelectionStrategy::WeightedRoundRobin { weights } => {
69                self.select_weighted(allowed, weights)
70            }
71            SelectionStrategy::LatencyBased => self.select_latency(allowed),
72            SelectionStrategy::Sticky { key } => self.select_sticky(allowed, key),
73        }
74    }
75
76    /// Record observed latency for a provider (used by LatencyBased strategy).
77    pub fn record_latency(&self, index: usize, latency: Duration) {
78        let mut latencies = self.latencies.lock().unwrap();
79        if index < latencies.len() {
80            // Exponential moving average: new = 0.3 * sample + 0.7 * old
81            let sample = latency.as_micros() as u64;
82            let old = latencies[index];
83            if old == 0 {
84                latencies[index] = sample;
85            } else {
86                latencies[index] = (sample * 3 + old * 7) / 10;
87            }
88        }
89    }
90
91    // -- strategy implementations -------------------------------------------
92
93    fn select_round_robin(&self, allowed: &[bool]) -> Option<usize> {
94        let count = allowed.len();
95        let start = self.rr_cursor.fetch_add(1, Ordering::Relaxed) % count;
96        for i in 0..count {
97            let idx = (start + i) % count;
98            if allowed[idx] {
99                return Some(idx);
100            }
101        }
102        None
103    }
104
105    fn select_priority(&self, allowed: &[bool]) -> Option<usize> {
106        // Try in order — first allowed provider wins.
107        for (idx, &ok) in allowed.iter().enumerate() {
108            if ok {
109                return Some(idx);
110            }
111        }
112        None
113    }
114
115    fn select_weighted(&self, allowed: &[bool], weights: &[u32]) -> Option<usize> {
116        // Build effective weights (zero for disallowed)
117        let effective: Vec<u32> = allowed
118            .iter()
119            .enumerate()
120            .map(|(i, &ok)| {
121                if ok {
122                    weights.get(i).copied().unwrap_or(1)
123                } else {
124                    0
125                }
126            })
127            .collect();
128
129        let total: u64 = effective.iter().map(|&w| w as u64).sum();
130        if total == 0 {
131            return None;
132        }
133
134        let counter = self.wrr_counter.fetch_add(1, Ordering::Relaxed);
135        let target = counter % total;
136
137        let mut cumulative = 0u64;
138        for (idx, &w) in effective.iter().enumerate() {
139            cumulative += w as u64;
140            if target < cumulative {
141                return Some(idx);
142            }
143        }
144        // Fallback (shouldn't reach)
145        allowed.iter().position(|&ok| ok)
146    }
147
148    fn select_latency(&self, allowed: &[bool]) -> Option<usize> {
149        let latencies = self.latencies.lock().unwrap();
150        let mut best_idx = None;
151        let mut best_latency = u64::MAX;
152
153        for (idx, &ok) in allowed.iter().enumerate() {
154            if !ok {
155                continue;
156            }
157            let lat = latencies.get(idx).copied().unwrap_or(0);
158            // Treat 0 (no data) as very fast — give new providers a chance
159            let effective = if lat == 0 { 1 } else { lat };
160            if effective < best_latency {
161                best_latency = effective;
162                best_idx = Some(idx);
163            }
164        }
165        best_idx
166    }
167
168    fn select_sticky(&self, allowed: &[bool], key: &str) -> Option<usize> {
169        let count = allowed.len();
170        if count == 0 {
171            return None;
172        }
173
174        // Hash the key to pick a consistent provider
175        let mut hasher = DefaultHasher::new();
176        key.hash(&mut hasher);
177        let hash = hasher.finish() as usize;
178        let preferred = hash % count;
179
180        // If preferred is allowed, use it
181        if allowed[preferred] {
182            return Some(preferred);
183        }
184
185        // Otherwise, fall back to the next allowed provider
186        for i in 1..count {
187            let idx = (preferred + i) % count;
188            if allowed[idx] {
189                return Some(idx);
190            }
191        }
192        None
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    #[test]
201    fn round_robin_basic() {
202        let state = SelectionState::new(3);
203        let allowed = [true, true, true];
204
205        let a = state
206            .select(&SelectionStrategy::RoundRobin, &allowed)
207            .unwrap();
208        let b = state
209            .select(&SelectionStrategy::RoundRobin, &allowed)
210            .unwrap();
211        let c = state
212            .select(&SelectionStrategy::RoundRobin, &allowed)
213            .unwrap();
214
215        // Should cycle through all 3
216        assert_ne!(a, b);
217        assert_ne!(b, c);
218    }
219
220    #[test]
221    fn round_robin_skips_disallowed() {
222        let state = SelectionState::new(3);
223        let allowed = [true, false, true];
224
225        let mut selected = std::collections::HashSet::new();
226        for _ in 0..10 {
227            let idx = state
228                .select(&SelectionStrategy::RoundRobin, &allowed)
229                .unwrap();
230            selected.insert(idx);
231            assert_ne!(idx, 1, "should never select disallowed provider");
232        }
233    }
234
235    #[test]
236    fn priority_selects_first_allowed() {
237        let state = SelectionState::new(3);
238
239        let allowed1 = [true, true, true];
240        assert_eq!(
241            state.select(&SelectionStrategy::Priority, &allowed1),
242            Some(0)
243        );
244
245        let allowed2 = [false, true, true];
246        assert_eq!(
247            state.select(&SelectionStrategy::Priority, &allowed2),
248            Some(1)
249        );
250
251        let allowed3 = [false, false, true];
252        assert_eq!(
253            state.select(&SelectionStrategy::Priority, &allowed3),
254            Some(2)
255        );
256    }
257
258    #[test]
259    fn priority_none_when_all_down() {
260        let state = SelectionState::new(3);
261        let allowed = [false, false, false];
262        assert_eq!(state.select(&SelectionStrategy::Priority, &allowed), None);
263    }
264
265    #[test]
266    fn weighted_round_robin() {
267        let state = SelectionState::new(3);
268        let strategy = SelectionStrategy::WeightedRoundRobin {
269            weights: vec![3, 1, 1],
270        };
271        let allowed = [true, true, true];
272
273        let mut counts = [0u32; 3];
274        for _ in 0..500 {
275            let idx = state.select(&strategy, &allowed).unwrap();
276            counts[idx] += 1;
277        }
278
279        // Provider 0 should get ~60% of traffic (3/5)
280        assert!(
281            counts[0] > counts[1],
282            "weighted provider should get more traffic"
283        );
284        assert!(
285            counts[0] > counts[2],
286            "weighted provider should get more traffic"
287        );
288    }
289
290    #[test]
291    fn latency_based_selects_fastest() {
292        let state = SelectionState::new(3);
293        let allowed = [true, true, true];
294
295        // Record latencies
296        state.record_latency(0, Duration::from_millis(100));
297        state.record_latency(1, Duration::from_millis(10)); // fastest
298        state.record_latency(2, Duration::from_millis(50));
299
300        let idx = state
301            .select(&SelectionStrategy::LatencyBased, &allowed)
302            .unwrap();
303        assert_eq!(idx, 1, "should select fastest provider");
304    }
305
306    #[test]
307    fn latency_based_skips_disallowed() {
308        let state = SelectionState::new(3);
309        let allowed = [true, false, true]; // provider 1 disallowed
310
311        state.record_latency(0, Duration::from_millis(100));
312        state.record_latency(1, Duration::from_millis(1)); // fastest but disallowed
313        state.record_latency(2, Duration::from_millis(50));
314
315        let idx = state
316            .select(&SelectionStrategy::LatencyBased, &allowed)
317            .unwrap();
318        assert_eq!(idx, 2, "should select fastest ALLOWED provider");
319    }
320
321    #[test]
322    fn sticky_consistent_hashing() {
323        let state = SelectionState::new(3);
324        let allowed = [true, true, true];
325        let strategy = SelectionStrategy::Sticky {
326            key: "0xAlice".to_string(),
327        };
328
329        let idx1 = state.select(&strategy, &allowed).unwrap();
330        let idx2 = state.select(&strategy, &allowed).unwrap();
331        let idx3 = state.select(&strategy, &allowed).unwrap();
332
333        // Same key should always select same provider
334        assert_eq!(idx1, idx2);
335        assert_eq!(idx2, idx3);
336    }
337
338    #[test]
339    fn sticky_different_keys() {
340        let state = SelectionState::new(100);
341        let allowed = vec![true; 100];
342
343        let s1 = SelectionStrategy::Sticky {
344            key: "0xAlice".to_string(),
345        };
346        let s2 = SelectionStrategy::Sticky {
347            key: "0xBob".to_string(),
348        };
349
350        let idx1 = state.select(&s1, &allowed).unwrap();
351        let idx2 = state.select(&s2, &allowed).unwrap();
352
353        // Different keys should (usually) select different providers
354        // With 100 providers, collision probability is low
355        // But not guaranteed, so just verify both return valid indices
356        assert!(idx1 < 100);
357        assert!(idx2 < 100);
358    }
359
360    #[test]
361    fn sticky_fallback_when_preferred_down() {
362        let state = SelectionState::new(3);
363        let allowed_all = [true, true, true];
364        let strategy = SelectionStrategy::Sticky {
365            key: "test".to_string(),
366        };
367
368        let preferred = state.select(&strategy, &allowed_all).unwrap();
369
370        // Mark preferred as down
371        let mut allowed_partial = [true, true, true];
372        allowed_partial[preferred] = false;
373
374        let fallback = state.select(&strategy, &allowed_partial).unwrap();
375        assert_ne!(fallback, preferred);
376    }
377
378    #[test]
379    fn latency_ema_smoothing() {
380        let state = SelectionState::new(1);
381
382        // Record a few samples
383        state.record_latency(0, Duration::from_millis(100));
384        state.record_latency(0, Duration::from_millis(200)); // EMA: 0.3*200 + 0.7*100 = 130ms
385
386        let latencies = state.latencies.lock().unwrap();
387        let lat_us = latencies[0];
388        // Should be smoothed, not just the latest sample
389        assert!(
390            lat_us > 100_000 && lat_us < 200_000,
391            "EMA should smooth: {lat_us}"
392        );
393    }
394}