edb_rpc_proxy/
providers.rs

1// EDB - Ethereum Debugger
2// Copyright (C) 2024 Zhuo Zhang and Wuqi Zhang
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU Affero General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8//
9// This program is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU Affero General Public License for more details.
13//
14// You should have received a copy of the GNU Affero General Public License
15// along with this program. If not, see <https://www.gnu.org/licenses/>.
16
17//! Multi-provider RPC management with health checking and load balancing
18
19use eyre::Result;
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use std::collections::HashSet;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use tokio::sync::RwLock;
27use tracing::{debug, info, warn};
28
29/// Default Ethereum mainnet RPC endpoints
30/// These are free public endpoints from chainlist.org, sorted by latency
31pub const DEFAULT_MAINNET_RPCS: &[&str] = &[
32    "https://rpc.eth.gateway.fm",
33    // "https://ethereum-rpc.publicnode.com", // disable due to publicnode's temporary issues
34    "https://mainnet.gateway.tenderly.co",
35    // "https://rpc.flashbots.net/fast", // disable due to flashbots' temporary issues
36    // "https://rpc.flashbots.net", // disable due to flashbots' temporary issues
37    "https://gateway.tenderly.co/public/mainnet",
38    "https://eth-mainnet.public.blastapi.io",
39    "https://ethereum-mainnet.gateway.tatum.io",
40    "https://eth.api.onfinality.io/public",
41    "https://eth.llamarpc.com",
42    "https://api.zan.top/eth-mainnet",
43    "https://eth.drpc.org",
44    "https://ethereum.rpc.subquery.network/public",
45];
46
47/// Information about an RPC provider
48#[derive(Debug, Clone)]
49pub struct ProviderInfo {
50    /// The RPC endpoint URL
51    pub url: String,
52    /// Whether the provider is currently healthy
53    pub is_healthy: bool,
54    /// When the provider was last health checked
55    pub last_health_check: Option<Instant>,
56    /// Response time in milliseconds for the last successful request
57    pub response_time_ms: Option<u64>,
58    /// Number of consecutive failures (reset on success)
59    pub consecutive_failures: u32,
60}
61
62/// Serializable version of ProviderInfo for API responses
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ProviderInfoResponse {
65    /// The RPC endpoint URL
66    pub url: String,
67    /// Whether the provider is currently healthy
68    pub is_healthy: bool,
69    /// Seconds since the last health check
70    pub last_health_check_seconds_ago: Option<u64>,
71    /// Response time in milliseconds for the last successful request
72    pub response_time_ms: Option<u64>,
73    /// Number of consecutive failures (reset on success)
74    pub consecutive_failures: u32,
75}
76
77impl From<&ProviderInfo> for ProviderInfoResponse {
78    fn from(info: &ProviderInfo) -> Self {
79        Self {
80            url: info.url.clone(),
81            is_healthy: info.is_healthy,
82            last_health_check_seconds_ago: info.last_health_check.map(|t| t.elapsed().as_secs()),
83            response_time_ms: info.response_time_ms,
84            consecutive_failures: info.consecutive_failures,
85        }
86    }
87}
88
89/// Multi-provider manager with health checking and round-robin load balancing
90pub struct ProviderManager {
91    /// List of all providers (healthy and unhealthy)
92    providers: Arc<RwLock<Vec<ProviderInfo>>>,
93    /// Round-robin counter for load balancing
94    round_robin_counter: AtomicUsize,
95    /// HTTP client for health checks
96    client: reqwest::Client,
97    /// Maximum consecutive failures before marking unhealthy
98    max_failures: u32,
99}
100
101/// Calculate performance tier based on response time (100ms buckets)
102/// Lower tier numbers indicate better performance
103fn get_performance_tier(response_time_ms: u64) -> u8 {
104    match response_time_ms / 100 {
105        0..=1 => 1, // 0-199ms: Tier 1 (fastest)
106        2..=3 => 2, // 200-399ms: Tier 2
107        4..=5 => 3, // 400-599ms: Tier 3
108        _ => 4,     // 600ms+: Tier 4 (slowest)
109    }
110}
111
112/// Calculate weight for each performance tier
113/// Higher weight means more likely to be selected
114fn get_tier_weight(tier: u8) -> u32 {
115    match tier {
116        1 => 100, // Fast providers get 100% weight
117        2 => 60,  // Medium providers get 60% weight
118        3 => 30,  // Slow providers get 30% weight
119        4 => 10,  // Very slow providers get 10% weight
120        _ => 1,   // Fallback for unknown tiers
121    }
122}
123
124impl ProviderManager {
125    /// Create a new provider manager with the given RPC URLs
126    pub async fn new(rpc_urls: Vec<String>, max_failures: u32) -> Result<Self> {
127        let client = reqwest::Client::builder().timeout(Duration::from_secs(5)).build()?;
128
129        let mut providers = Vec::new();
130
131        // Initialize providers and perform initial health check
132        for url in rpc_urls {
133            let mut provider = ProviderInfo {
134                url: url.clone(),
135                is_healthy: false,
136                last_health_check: None,
137                response_time_ms: None,
138                consecutive_failures: 0,
139            };
140
141            // Perform initial health check
142            if let Ok(response_time) = Self::check_provider_health(&client, &url).await {
143                provider.is_healthy = true;
144                provider.response_time_ms = Some(response_time);
145                provider.last_health_check = Some(Instant::now());
146                info!("Provider {} is healthy ({}ms)", url, response_time);
147            } else {
148                warn!("Provider {} is not responding during initialization", url);
149                provider.consecutive_failures = 1;
150            }
151
152            providers.push(provider);
153        }
154
155        // Ensure at least one provider is healthy
156        let healthy_count = providers.iter().filter(|p| p.is_healthy).count();
157        if healthy_count == 0 {
158            return Err(eyre::eyre!("No healthy RPC providers available"));
159        }
160
161        info!("Initialized with {} healthy providers out of {}", healthy_count, providers.len());
162
163        Ok(Self {
164            providers: Arc::new(RwLock::new(providers)),
165            round_robin_counter: AtomicUsize::new(0),
166            client,
167            max_failures,
168        })
169    }
170
171    /// Check the health of a specific provider
172    async fn check_provider_health(client: &reqwest::Client, url: &str) -> Result<u64> {
173        let start = Instant::now();
174
175        // Simple eth_blockNumber request to check if provider is responsive
176        let request = serde_json::json!({
177            "jsonrpc": "2.0",
178            "method": "eth_blockNumber",
179            "params": [],
180            "id": 1
181        });
182
183        let response = client
184            .post(url)
185            .header("Content-Type", "application/json")
186            .json(&request)
187            .send()
188            .await?;
189
190        let response_time = start.elapsed().as_millis() as u64;
191
192        // Check if we got a valid response
193        let json: serde_json::Value = response.json().await?;
194        if json.get("result").is_some() {
195            Ok(response_time)
196        } else {
197            Err(eyre::eyre!("Invalid response from provider"))
198        }
199    }
200
201    /// Get a weighted random provider that hasn't been tried yet
202    /// Only considers healthy providers not in the exclusion set
203    pub async fn get_weighted_provider_excluding(
204        &self,
205        tried_providers: &HashSet<String>,
206    ) -> Option<String> {
207        let providers = self.providers.read().await;
208        let available_providers: Vec<_> = providers
209            .iter()
210            .filter(|p| p.is_healthy && !tried_providers.contains(&p.url))
211            .collect();
212
213        if available_providers.is_empty() {
214            return None;
215        }
216
217        // If only one available provider, return it
218        if available_providers.len() == 1 {
219            return Some(available_providers[0].url.clone());
220        }
221
222        // Calculate weights for each available provider
223        let mut weighted_providers = Vec::new();
224        let mut total_weight = 0u32;
225
226        for provider in &available_providers {
227            // Use response time if available, otherwise assume medium performance
228            let response_time = provider.response_time_ms.unwrap_or(300); // Default to 300ms
229            let tier = get_performance_tier(response_time);
230            let weight = get_tier_weight(tier);
231
232            total_weight += weight;
233            weighted_providers.push((provider, weight));
234        }
235
236        // Generate random number for weighted selection
237        let mut rng = rand::thread_rng();
238        let random_weight = rng.gen_range(0..total_weight);
239
240        // Find the provider corresponding to the random weight
241        let mut current_weight = 0u32;
242        for (provider, weight) in weighted_providers {
243            current_weight += weight;
244            if random_weight < current_weight {
245                return Some(provider.url.clone());
246            }
247        }
248
249        // Fallback to first available provider (should not reach here)
250        Some(available_providers[0].url.clone())
251    }
252
253    /// Get a weighted random provider based on response time
254    /// Only considers healthy providers
255    /// DEPRECATED: Use get_weighted_provider_excluding instead
256    #[allow(dead_code)]
257    async fn get_weighted_provider(&self) -> Option<String> {
258        let providers = self.providers.read().await;
259        let healthy_providers: Vec<_> = providers.iter().filter(|p| p.is_healthy).collect();
260
261        if healthy_providers.is_empty() {
262            return None;
263        }
264
265        // If only one healthy provider, return it
266        if healthy_providers.len() == 1 {
267            return Some(healthy_providers[0].url.clone());
268        }
269
270        // Calculate weights for each healthy provider
271        let mut weighted_providers = Vec::new();
272        let mut total_weight = 0u32;
273
274        for provider in &healthy_providers {
275            // Use response time if available, otherwise assume medium performance
276            let response_time = provider.response_time_ms.unwrap_or(300); // Default to 300ms
277            let tier = get_performance_tier(response_time);
278            let weight = get_tier_weight(tier);
279
280            total_weight += weight;
281            weighted_providers.push((provider, weight));
282        }
283
284        // Generate random number for weighted selection
285        let mut rng = rand::thread_rng();
286        let random_weight = rng.gen_range(0..total_weight);
287
288        // Find the provider corresponding to the random weight
289        let mut current_weight = 0u32;
290        for (provider, weight) in weighted_providers {
291            current_weight += weight;
292            if random_weight < current_weight {
293                return Some(provider.url.clone());
294            }
295        }
296
297        // Fallback to first healthy provider (should not reach here)
298        Some(healthy_providers[0].url.clone())
299    }
300
301    /// Get the next healthy provider using round-robin
302    /// DEPRECATED: Use get_weighted_provider_excluding instead
303    #[allow(dead_code)]
304    pub async fn get_next_provider(&self) -> Option<String> {
305        let providers = self.providers.read().await;
306        let healthy_providers: Vec<_> = providers.iter().filter(|p| p.is_healthy).collect();
307
308        if healthy_providers.is_empty() {
309            return None;
310        }
311
312        // Round-robin selection
313        let index =
314            self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % healthy_providers.len();
315        Some(healthy_providers[index].url.clone())
316    }
317
318    /// Mark a provider as failed and update its health status
319    pub async fn mark_provider_failed(&self, url: &str) {
320        let mut providers = self.providers.write().await;
321
322        if let Some(provider) = providers.iter_mut().find(|p| p.url == url) {
323            provider.consecutive_failures += 1;
324
325            if provider.consecutive_failures >= self.max_failures {
326                provider.is_healthy = false;
327                debug!("Provider {} marked as unhealthy after {} failures", url, self.max_failures);
328            }
329        }
330    }
331
332    /// Mark a provider as successful and reset failure count
333    pub async fn mark_provider_success(&self, url: &str, response_time_ms: u64) {
334        let mut providers = self.providers.write().await;
335
336        if let Some(provider) = providers.iter_mut().find(|p| p.url == url) {
337            provider.consecutive_failures = 0;
338            provider.is_healthy = true;
339            provider.response_time_ms = Some(response_time_ms);
340            provider.last_health_check = Some(Instant::now());
341
342            debug!("Provider {} successful ({}ms)", url, response_time_ms);
343        }
344    }
345
346    /// Perform health checks on all providers
347    pub async fn health_check_all(&self) {
348        let providers_snapshot = {
349            let providers = self.providers.read().await;
350            providers.clone()
351        };
352
353        for provider in providers_snapshot {
354            // Check if provider needs health check (unhealthy or stale)
355            let needs_check = !provider.is_healthy
356                || provider.last_health_check.is_none_or(|t| t.elapsed() > Duration::from_secs(60));
357
358            if needs_check {
359                match Self::check_provider_health(&self.client, &provider.url).await {
360                    Ok(response_time) => {
361                        self.mark_provider_success(&provider.url, response_time).await;
362                        if !provider.is_healthy {
363                            debug!("Provider {} is now healthy", provider.url);
364                        }
365                    }
366                    Err(e) => {
367                        debug!("Health check failed for {}: {}", provider.url, e);
368                        self.mark_provider_failed(&provider.url).await;
369                    }
370                }
371            }
372        }
373    }
374
375    /// Get information about all providers (serializable version)
376    pub async fn get_providers_info(&self) -> Vec<ProviderInfoResponse> {
377        let providers = self.providers.read().await;
378        providers.iter().map(|p| p.into()).collect()
379    }
380
381    /// Get count of healthy providers
382    pub async fn healthy_provider_count(&self) -> usize {
383        let providers = self.providers.read().await;
384        providers.iter().filter(|p| p.is_healthy).count()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use tracing::{debug, info};
392    use wiremock::matchers::{method, path};
393    use wiremock::{Mock, MockServer, ResponseTemplate};
394
395    #[tokio::test]
396    async fn test_provider_initialization() {
397        edb_common::logging::ensure_test_logging(None);
398        info!("Testing provider initialization with health checks");
399
400        // Start mock servers
401        let mock1 = MockServer::start().await;
402        let mock2 = MockServer::start().await;
403
404        // Setup successful response for mock1
405        Mock::given(method("POST"))
406            .and(path("/"))
407            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
408                "jsonrpc": "2.0",
409                "id": 1,
410                "result": "0x1234567"
411            })))
412            .mount(&mock1)
413            .await;
414
415        // Setup failed response for mock2
416        Mock::given(method("POST"))
417            .and(path("/"))
418            .respond_with(ResponseTemplate::new(500))
419            .mount(&mock2)
420            .await;
421
422        let urls = vec![mock1.uri(), mock2.uri()];
423        let manager = ProviderManager::new(urls, 3).await.unwrap();
424
425        // Check that only mock1 is healthy
426        assert_eq!(manager.healthy_provider_count().await, 1);
427
428        let providers = manager.get_providers_info().await;
429        assert_eq!(providers.len(), 2);
430        assert!(providers[0].is_healthy);
431        assert!(!providers[1].is_healthy);
432    }
433
434    #[tokio::test]
435    async fn test_round_robin_selection() {
436        edb_common::logging::ensure_test_logging(None);
437        info!("Testing round-robin provider selection");
438
439        // Start 3 healthy mock servers
440        let mocks =
441            vec![MockServer::start().await, MockServer::start().await, MockServer::start().await];
442
443        for mock in &mocks {
444            Mock::given(method("POST"))
445                .and(path("/"))
446                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
447                    "jsonrpc": "2.0",
448                    "id": 1,
449                    "result": "0x1234567"
450                })))
451                .mount(mock)
452                .await;
453        }
454
455        let urls: Vec<String> = mocks.iter().map(|m| m.uri()).collect();
456        let manager = ProviderManager::new(urls.clone(), 3).await.unwrap();
457
458        // Get providers multiple times and verify round-robin
459        let mut selections = Vec::new();
460        for _ in 0..9 {
461            selections.push(manager.get_next_provider().await.unwrap());
462        }
463
464        // Each provider should be selected 3 times
465        for url in &urls {
466            assert_eq!(selections.iter().filter(|s| *s == url).count(), 3);
467        }
468    }
469
470    #[tokio::test]
471    async fn test_provider_failure_handling() {
472        edb_common::logging::ensure_test_logging(None);
473        debug!("Testing provider failure detection and handling");
474
475        let mock = MockServer::start().await;
476
477        // Initially healthy
478        Mock::given(method("POST"))
479            .and(path("/"))
480            .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
481                "jsonrpc": "2.0",
482                "id": 1,
483                "result": "0x1234567"
484            })))
485            .expect(1)
486            .mount(&mock)
487            .await;
488
489        let manager = ProviderManager::new(vec![mock.uri()], 2).await.unwrap();
490        assert_eq!(manager.healthy_provider_count().await, 1);
491
492        // Mark as failed twice (max_failures = 2)
493        manager.mark_provider_failed(&mock.uri()).await;
494        assert_eq!(manager.healthy_provider_count().await, 1); // Still healthy after 1 failure
495
496        manager.mark_provider_failed(&mock.uri()).await;
497        assert_eq!(manager.healthy_provider_count().await, 0); // Unhealthy after 2 failures
498    }
499
500    #[tokio::test]
501    async fn test_weighted_provider_selection() {
502        edb_common::logging::ensure_test_logging(None);
503        debug!("Testing weighted provider selection based on response time");
504
505        // Start 3 healthy mock servers with different response times
506        let mocks =
507            vec![MockServer::start().await, MockServer::start().await, MockServer::start().await];
508
509        for mock in &mocks {
510            Mock::given(method("POST"))
511                .and(path("/"))
512                .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
513                    "jsonrpc": "2.0",
514                    "id": 1,
515                    "result": "0x1234567"
516                })))
517                .mount(mock)
518                .await;
519        }
520
521        let urls: Vec<String> = mocks.iter().map(|m| m.uri()).collect();
522        let manager = ProviderManager::new(urls.clone(), 3).await.unwrap();
523
524        // Simulate different response times for providers
525        manager.mark_provider_success(&urls[0], 50).await; // Fast: Tier 1 (100 weight)
526        manager.mark_provider_success(&urls[1], 250).await; // Medium: Tier 2 (60 weight)
527        manager.mark_provider_success(&urls[2], 500).await; // Slow: Tier 3 (30 weight)
528
529        // Test weighted selection multiple times
530        let mut selections = std::collections::HashMap::new();
531        for _ in 0..100 {
532            if let Some(provider) = manager.get_weighted_provider().await {
533                *selections.entry(provider).or_insert(0) += 1;
534            }
535        }
536
537        // Verify all providers were selected
538        assert_eq!(selections.len(), 3);
539
540        // Fast provider should be selected most often due to higher weight
541        let fast_count = selections.get(&urls[0]).unwrap_or(&0);
542        let medium_count = selections.get(&urls[1]).unwrap_or(&0);
543        let slow_count = selections.get(&urls[2]).unwrap_or(&0);
544
545        debug!(
546            "Selection counts - Fast: {}, Medium: {}, Slow: {}",
547            fast_count, medium_count, slow_count
548        );
549
550        // Fast provider should have more selections than slow provider
551        assert!(
552            fast_count > slow_count,
553            "Fast provider should be selected more often than slow provider"
554        );
555    }
556}