1use eyre::Result;
20use rand::Rng;
21use serde::{Deserialize, Serialize};
22use std::collections::HashSet;
23use std::sync::atomic::{AtomicUsize, Ordering};
24use std::sync::Arc;
25use std::time::{Duration, Instant};
26use tokio::sync::RwLock;
27use tracing::{debug, info, warn};
28
29pub const DEFAULT_MAINNET_RPCS: &[&str] = &[
32 "https://rpc.eth.gateway.fm",
33 "https://mainnet.gateway.tenderly.co",
35 "https://gateway.tenderly.co/public/mainnet",
38 "https://eth-mainnet.public.blastapi.io",
39 "https://ethereum-mainnet.gateway.tatum.io",
40 "https://eth.api.onfinality.io/public",
41 "https://eth.llamarpc.com",
42 "https://api.zan.top/eth-mainnet",
43 "https://eth.drpc.org",
44 "https://ethereum.rpc.subquery.network/public",
45];
46
47#[derive(Debug, Clone)]
49pub struct ProviderInfo {
50 pub url: String,
52 pub is_healthy: bool,
54 pub last_health_check: Option<Instant>,
56 pub response_time_ms: Option<u64>,
58 pub consecutive_failures: u32,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct ProviderInfoResponse {
65 pub url: String,
67 pub is_healthy: bool,
69 pub last_health_check_seconds_ago: Option<u64>,
71 pub response_time_ms: Option<u64>,
73 pub consecutive_failures: u32,
75}
76
77impl From<&ProviderInfo> for ProviderInfoResponse {
78 fn from(info: &ProviderInfo) -> Self {
79 Self {
80 url: info.url.clone(),
81 is_healthy: info.is_healthy,
82 last_health_check_seconds_ago: info.last_health_check.map(|t| t.elapsed().as_secs()),
83 response_time_ms: info.response_time_ms,
84 consecutive_failures: info.consecutive_failures,
85 }
86 }
87}
88
89pub struct ProviderManager {
91 providers: Arc<RwLock<Vec<ProviderInfo>>>,
93 round_robin_counter: AtomicUsize,
95 client: reqwest::Client,
97 max_failures: u32,
99}
100
101fn get_performance_tier(response_time_ms: u64) -> u8 {
104 match response_time_ms / 100 {
105 0..=1 => 1, 2..=3 => 2, 4..=5 => 3, _ => 4, }
110}
111
112fn get_tier_weight(tier: u8) -> u32 {
115 match tier {
116 1 => 100, 2 => 60, 3 => 30, 4 => 10, _ => 1, }
122}
123
124impl ProviderManager {
125 pub async fn new(rpc_urls: Vec<String>, max_failures: u32) -> Result<Self> {
127 let client = reqwest::Client::builder().timeout(Duration::from_secs(5)).build()?;
128
129 let mut providers = Vec::new();
130
131 for url in rpc_urls {
133 let mut provider = ProviderInfo {
134 url: url.clone(),
135 is_healthy: false,
136 last_health_check: None,
137 response_time_ms: None,
138 consecutive_failures: 0,
139 };
140
141 if let Ok(response_time) = Self::check_provider_health(&client, &url).await {
143 provider.is_healthy = true;
144 provider.response_time_ms = Some(response_time);
145 provider.last_health_check = Some(Instant::now());
146 info!("Provider {} is healthy ({}ms)", url, response_time);
147 } else {
148 warn!("Provider {} is not responding during initialization", url);
149 provider.consecutive_failures = 1;
150 }
151
152 providers.push(provider);
153 }
154
155 let healthy_count = providers.iter().filter(|p| p.is_healthy).count();
157 if healthy_count == 0 {
158 return Err(eyre::eyre!("No healthy RPC providers available"));
159 }
160
161 info!("Initialized with {} healthy providers out of {}", healthy_count, providers.len());
162
163 Ok(Self {
164 providers: Arc::new(RwLock::new(providers)),
165 round_robin_counter: AtomicUsize::new(0),
166 client,
167 max_failures,
168 })
169 }
170
171 async fn check_provider_health(client: &reqwest::Client, url: &str) -> Result<u64> {
173 let start = Instant::now();
174
175 let request = serde_json::json!({
177 "jsonrpc": "2.0",
178 "method": "eth_blockNumber",
179 "params": [],
180 "id": 1
181 });
182
183 let response = client
184 .post(url)
185 .header("Content-Type", "application/json")
186 .json(&request)
187 .send()
188 .await?;
189
190 let response_time = start.elapsed().as_millis() as u64;
191
192 let json: serde_json::Value = response.json().await?;
194 if json.get("result").is_some() {
195 Ok(response_time)
196 } else {
197 Err(eyre::eyre!("Invalid response from provider"))
198 }
199 }
200
201 pub async fn get_weighted_provider_excluding(
204 &self,
205 tried_providers: &HashSet<String>,
206 ) -> Option<String> {
207 let providers = self.providers.read().await;
208 let available_providers: Vec<_> = providers
209 .iter()
210 .filter(|p| p.is_healthy && !tried_providers.contains(&p.url))
211 .collect();
212
213 if available_providers.is_empty() {
214 return None;
215 }
216
217 if available_providers.len() == 1 {
219 return Some(available_providers[0].url.clone());
220 }
221
222 let mut weighted_providers = Vec::new();
224 let mut total_weight = 0u32;
225
226 for provider in &available_providers {
227 let response_time = provider.response_time_ms.unwrap_or(300); let tier = get_performance_tier(response_time);
230 let weight = get_tier_weight(tier);
231
232 total_weight += weight;
233 weighted_providers.push((provider, weight));
234 }
235
236 let mut rng = rand::thread_rng();
238 let random_weight = rng.gen_range(0..total_weight);
239
240 let mut current_weight = 0u32;
242 for (provider, weight) in weighted_providers {
243 current_weight += weight;
244 if random_weight < current_weight {
245 return Some(provider.url.clone());
246 }
247 }
248
249 Some(available_providers[0].url.clone())
251 }
252
253 #[allow(dead_code)]
257 async fn get_weighted_provider(&self) -> Option<String> {
258 let providers = self.providers.read().await;
259 let healthy_providers: Vec<_> = providers.iter().filter(|p| p.is_healthy).collect();
260
261 if healthy_providers.is_empty() {
262 return None;
263 }
264
265 if healthy_providers.len() == 1 {
267 return Some(healthy_providers[0].url.clone());
268 }
269
270 let mut weighted_providers = Vec::new();
272 let mut total_weight = 0u32;
273
274 for provider in &healthy_providers {
275 let response_time = provider.response_time_ms.unwrap_or(300); let tier = get_performance_tier(response_time);
278 let weight = get_tier_weight(tier);
279
280 total_weight += weight;
281 weighted_providers.push((provider, weight));
282 }
283
284 let mut rng = rand::thread_rng();
286 let random_weight = rng.gen_range(0..total_weight);
287
288 let mut current_weight = 0u32;
290 for (provider, weight) in weighted_providers {
291 current_weight += weight;
292 if random_weight < current_weight {
293 return Some(provider.url.clone());
294 }
295 }
296
297 Some(healthy_providers[0].url.clone())
299 }
300
301 #[allow(dead_code)]
304 pub async fn get_next_provider(&self) -> Option<String> {
305 let providers = self.providers.read().await;
306 let healthy_providers: Vec<_> = providers.iter().filter(|p| p.is_healthy).collect();
307
308 if healthy_providers.is_empty() {
309 return None;
310 }
311
312 let index =
314 self.round_robin_counter.fetch_add(1, Ordering::Relaxed) % healthy_providers.len();
315 Some(healthy_providers[index].url.clone())
316 }
317
318 pub async fn mark_provider_failed(&self, url: &str) {
320 let mut providers = self.providers.write().await;
321
322 if let Some(provider) = providers.iter_mut().find(|p| p.url == url) {
323 provider.consecutive_failures += 1;
324
325 if provider.consecutive_failures >= self.max_failures {
326 provider.is_healthy = false;
327 debug!("Provider {} marked as unhealthy after {} failures", url, self.max_failures);
328 }
329 }
330 }
331
332 pub async fn mark_provider_success(&self, url: &str, response_time_ms: u64) {
334 let mut providers = self.providers.write().await;
335
336 if let Some(provider) = providers.iter_mut().find(|p| p.url == url) {
337 provider.consecutive_failures = 0;
338 provider.is_healthy = true;
339 provider.response_time_ms = Some(response_time_ms);
340 provider.last_health_check = Some(Instant::now());
341
342 debug!("Provider {} successful ({}ms)", url, response_time_ms);
343 }
344 }
345
346 pub async fn health_check_all(&self) {
348 let providers_snapshot = {
349 let providers = self.providers.read().await;
350 providers.clone()
351 };
352
353 for provider in providers_snapshot {
354 let needs_check = !provider.is_healthy
356 || provider.last_health_check.is_none_or(|t| t.elapsed() > Duration::from_secs(60));
357
358 if needs_check {
359 match Self::check_provider_health(&self.client, &provider.url).await {
360 Ok(response_time) => {
361 self.mark_provider_success(&provider.url, response_time).await;
362 if !provider.is_healthy {
363 debug!("Provider {} is now healthy", provider.url);
364 }
365 }
366 Err(e) => {
367 debug!("Health check failed for {}: {}", provider.url, e);
368 self.mark_provider_failed(&provider.url).await;
369 }
370 }
371 }
372 }
373 }
374
375 pub async fn get_providers_info(&self) -> Vec<ProviderInfoResponse> {
377 let providers = self.providers.read().await;
378 providers.iter().map(|p| p.into()).collect()
379 }
380
381 pub async fn healthy_provider_count(&self) -> usize {
383 let providers = self.providers.read().await;
384 providers.iter().filter(|p| p.is_healthy).count()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use tracing::{debug, info};
392 use wiremock::matchers::{method, path};
393 use wiremock::{Mock, MockServer, ResponseTemplate};
394
395 #[tokio::test]
396 async fn test_provider_initialization() {
397 edb_common::logging::ensure_test_logging(None);
398 info!("Testing provider initialization with health checks");
399
400 let mock1 = MockServer::start().await;
402 let mock2 = MockServer::start().await;
403
404 Mock::given(method("POST"))
406 .and(path("/"))
407 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
408 "jsonrpc": "2.0",
409 "id": 1,
410 "result": "0x1234567"
411 })))
412 .mount(&mock1)
413 .await;
414
415 Mock::given(method("POST"))
417 .and(path("/"))
418 .respond_with(ResponseTemplate::new(500))
419 .mount(&mock2)
420 .await;
421
422 let urls = vec![mock1.uri(), mock2.uri()];
423 let manager = ProviderManager::new(urls, 3).await.unwrap();
424
425 assert_eq!(manager.healthy_provider_count().await, 1);
427
428 let providers = manager.get_providers_info().await;
429 assert_eq!(providers.len(), 2);
430 assert!(providers[0].is_healthy);
431 assert!(!providers[1].is_healthy);
432 }
433
434 #[tokio::test]
435 async fn test_round_robin_selection() {
436 edb_common::logging::ensure_test_logging(None);
437 info!("Testing round-robin provider selection");
438
439 let mocks =
441 vec![MockServer::start().await, MockServer::start().await, MockServer::start().await];
442
443 for mock in &mocks {
444 Mock::given(method("POST"))
445 .and(path("/"))
446 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
447 "jsonrpc": "2.0",
448 "id": 1,
449 "result": "0x1234567"
450 })))
451 .mount(mock)
452 .await;
453 }
454
455 let urls: Vec<String> = mocks.iter().map(|m| m.uri()).collect();
456 let manager = ProviderManager::new(urls.clone(), 3).await.unwrap();
457
458 let mut selections = Vec::new();
460 for _ in 0..9 {
461 selections.push(manager.get_next_provider().await.unwrap());
462 }
463
464 for url in &urls {
466 assert_eq!(selections.iter().filter(|s| *s == url).count(), 3);
467 }
468 }
469
470 #[tokio::test]
471 async fn test_provider_failure_handling() {
472 edb_common::logging::ensure_test_logging(None);
473 debug!("Testing provider failure detection and handling");
474
475 let mock = MockServer::start().await;
476
477 Mock::given(method("POST"))
479 .and(path("/"))
480 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
481 "jsonrpc": "2.0",
482 "id": 1,
483 "result": "0x1234567"
484 })))
485 .expect(1)
486 .mount(&mock)
487 .await;
488
489 let manager = ProviderManager::new(vec![mock.uri()], 2).await.unwrap();
490 assert_eq!(manager.healthy_provider_count().await, 1);
491
492 manager.mark_provider_failed(&mock.uri()).await;
494 assert_eq!(manager.healthy_provider_count().await, 1); manager.mark_provider_failed(&mock.uri()).await;
497 assert_eq!(manager.healthy_provider_count().await, 0); }
499
500 #[tokio::test]
501 async fn test_weighted_provider_selection() {
502 edb_common::logging::ensure_test_logging(None);
503 debug!("Testing weighted provider selection based on response time");
504
505 let mocks =
507 vec![MockServer::start().await, MockServer::start().await, MockServer::start().await];
508
509 for mock in &mocks {
510 Mock::given(method("POST"))
511 .and(path("/"))
512 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
513 "jsonrpc": "2.0",
514 "id": 1,
515 "result": "0x1234567"
516 })))
517 .mount(mock)
518 .await;
519 }
520
521 let urls: Vec<String> = mocks.iter().map(|m| m.uri()).collect();
522 let manager = ProviderManager::new(urls.clone(), 3).await.unwrap();
523
524 manager.mark_provider_success(&urls[0], 50).await; manager.mark_provider_success(&urls[1], 250).await; manager.mark_provider_success(&urls[2], 500).await; let mut selections = std::collections::HashMap::new();
531 for _ in 0..100 {
532 if let Some(provider) = manager.get_weighted_provider().await {
533 *selections.entry(provider).or_insert(0) += 1;
534 }
535 }
536
537 assert_eq!(selections.len(), 3);
539
540 let fast_count = selections.get(&urls[0]).unwrap_or(&0);
542 let medium_count = selections.get(&urls[1]).unwrap_or(&0);
543 let slow_count = selections.get(&urls[2]).unwrap_or(&0);
544
545 debug!(
546 "Selection counts - Fast: {}, Medium: {}, Slow: {}",
547 fast_count, medium_count, slow_count
548 );
549
550 assert!(
552 fast_count > slow_count,
553 "Fast provider should be selected more often than slow provider"
554 );
555 }
556}