1use std::{fmt::Debug, future::Future, marker::PhantomData, sync::Arc, time::Duration};
2
3use alloy::{
4 eips::{BlockId, BlockNumberOrTag},
5 network::{Ethereum, Network},
6 providers::{
7 DynProvider, Provider, RootProvider,
8 fillers::{FillProvider, TxFiller},
9 layers::{CacheProvider, CallBatchProvider},
10 },
11 pubsub::Subscription,
12 rpc::types::{Filter, Log},
13 transports::{RpcError, TransportErrorKind, http::reqwest::Url},
14};
15use backon::{ExponentialBuilder, Retryable};
16use thiserror::Error;
17use tokio::time::{error as TokioError, timeout};
18use tracing::{error, info};
19
20#[derive(Error, Debug, Clone)]
21pub enum Error {
22 #[error("Operation timed out")]
23 Timeout,
24 #[error("RPC call failed after exhausting all retry attempts: {0}")]
25 RpcError(Arc<RpcError<TransportErrorKind>>),
26 #[error("Block not found, Block Id: {0}")]
27 BlockNotFound(BlockId),
28}
29
30impl From<RpcError<TransportErrorKind>> for Error {
31 fn from(err: RpcError<TransportErrorKind>) -> Self {
32 Error::RpcError(Arc::new(err))
33 }
34}
35
36impl From<TokioError::Elapsed> for Error {
37 fn from(_: TokioError::Elapsed) -> Self {
38 Error::Timeout
39 }
40}
41
42pub trait IntoProvider<N: Network = Ethereum> {
43 fn into_provider(
44 self,
45 ) -> impl std::future::Future<Output = Result<impl Provider<N>, Error>> + Send;
46}
47
48impl<N: Network> IntoProvider<N> for RobustProvider<N> {
49 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
50 Ok(self.primary().to_owned())
51 }
52}
53
54impl<N: Network> IntoProvider<N> for RootProvider<N> {
55 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
56 Ok(self)
57 }
58}
59
60impl<N: Network> IntoProvider<N> for &str {
61 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
62 Ok(RootProvider::connect(self).await?)
63 }
64}
65
66impl<N: Network> IntoProvider<N> for Url {
67 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
68 Ok(RootProvider::connect(self.as_str()).await?)
69 }
70}
71
72impl<F, P, N> IntoProvider<N> for FillProvider<F, P, N>
73where
74 F: TxFiller<N>,
75 P: Provider<N>,
76 N: Network,
77{
78 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
79 Ok(self)
80 }
81}
82
83impl<P, N> IntoProvider<N> for CacheProvider<P, N>
84where
85 P: Provider<N>,
86 N: Network,
87{
88 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
89 Ok(self)
90 }
91}
92
93impl<N> IntoProvider<N> for DynProvider<N>
94where
95 N: Network,
96{
97 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
98 Ok(self)
99 }
100}
101
102impl<P, N> IntoProvider<N> for CallBatchProvider<P, N>
103where
104 P: Provider<N> + 'static,
105 N: Network,
106{
107 async fn into_provider(self) -> Result<impl Provider<N>, Error> {
108 Ok(self)
109 }
110}
111
112pub trait IntoRobustProvider<N: Network = Ethereum> {
113 fn into_robust_provider(
114 self,
115 ) -> impl std::future::Future<Output = Result<RobustProvider<N>, Error>> + Send;
116}
117
118impl<N: Network, P: IntoProvider<N> + Send> IntoRobustProvider<N> for P {
119 async fn into_robust_provider(self) -> Result<RobustProvider<N>, Error> {
120 RobustProviderBuilder::new(self).build().await
121 }
122}
123
124pub const DEFAULT_MAX_TIMEOUT: Duration = Duration::from_secs(60);
127pub const DEFAULT_MAX_RETRIES: usize = 3;
129pub const DEFAULT_MIN_DELAY: Duration = Duration::from_secs(1);
131
132#[derive(Clone)]
133pub struct RobustProviderBuilder<N: Network, P: IntoProvider<N>> {
134 providers: Vec<P>,
135 max_timeout: Duration,
136 max_retries: usize,
137 min_delay: Duration,
138 _network: PhantomData<N>,
139}
140
141impl<N: Network, P: IntoProvider<N>> RobustProviderBuilder<N, P> {
142 #[must_use]
146 pub fn new(provider: P) -> Self {
147 Self {
148 providers: vec![provider],
149 max_timeout: DEFAULT_MAX_TIMEOUT,
150 max_retries: DEFAULT_MAX_RETRIES,
151 min_delay: DEFAULT_MIN_DELAY,
152 _network: PhantomData,
153 }
154 }
155
156 #[must_use]
160 pub fn fragile(provider: P) -> Self {
161 Self::new(provider).max_retries(0).min_delay(Duration::ZERO)
162 }
163
164 #[must_use]
168 pub fn fallback(mut self, provider: P) -> Self {
169 self.providers.push(provider);
170 self
171 }
172
173 #[must_use]
175 pub fn max_timeout(mut self, timeout: Duration) -> Self {
176 self.max_timeout = timeout;
177 self
178 }
179
180 #[must_use]
182 pub fn max_retries(mut self, max_retries: usize) -> Self {
183 self.max_retries = max_retries;
184 self
185 }
186
187 #[must_use]
189 pub fn min_delay(mut self, min_delay: Duration) -> Self {
190 self.min_delay = min_delay;
191 self
192 }
193
194 pub async fn build(self) -> Result<RobustProvider<N>, Error> {
202 let mut providers = vec![];
203 for p in self.providers {
204 providers.push(p.into_provider().await?.root().to_owned());
205 }
206 Ok(RobustProvider {
207 providers,
208 max_timeout: self.max_timeout,
209 max_retries: self.max_retries,
210 min_delay: self.min_delay,
211 })
212 }
213}
214
215#[derive(Clone)]
221pub struct RobustProvider<N: Network = Ethereum> {
222 providers: Vec<RootProvider<N>>,
223 max_timeout: Duration,
224 max_retries: usize,
225 min_delay: Duration,
226}
227
228impl<N: Network> RobustProvider<N> {
229 #[must_use]
235 pub fn primary(&self) -> &RootProvider<N> {
236 self.providers.first().expect("providers vector should never be empty")
238 }
239
240 pub async fn get_block_by_number(
248 &self,
249 number: BlockNumberOrTag,
250 ) -> Result<N::BlockResponse, Error> {
251 info!("eth_getBlockByNumber called");
252 let result = self
253 .retry_with_total_timeout(
254 move |provider| async move { provider.get_block_by_number(number).await },
255 false,
256 )
257 .await;
258 if let Err(e) = &result {
259 error!(error = %e, "eth_getByBlockNumber failed");
260 }
261
262 result?.ok_or_else(|| Error::BlockNotFound(number.into()))
263 }
264
265 pub async fn get_block_number(&self) -> Result<u64, Error> {
273 info!("eth_getBlockNumber called");
274 let result = self
275 .retry_with_total_timeout(
276 move |provider| async move { provider.get_block_number().await },
277 false,
278 )
279 .await;
280 if let Err(e) = &result {
281 error!(error = %e, "eth_getBlockNumber failed");
282 }
283 result
284 }
285
286 pub async fn get_block_by_hash(
294 &self,
295 hash: alloy::primitives::BlockHash,
296 ) -> Result<N::BlockResponse, Error> {
297 info!("eth_getBlockByHash called");
298 let result = self
299 .retry_with_total_timeout(
300 move |provider| async move { provider.get_block_by_hash(hash).await },
301 false,
302 )
303 .await;
304 if let Err(e) = &result {
305 error!(error = %e, "eth_getBlockByHash failed");
306 }
307
308 result?.ok_or_else(|| Error::BlockNotFound(hash.into()))
309 }
310
311 pub async fn get_logs(&self, filter: &Filter) -> Result<Vec<Log>, Error> {
319 info!("eth_getLogs called");
320 let result = self
321 .retry_with_total_timeout(
322 move |provider| async move { provider.get_logs(filter).await },
323 false,
324 )
325 .await;
326 if let Err(e) = &result {
327 error!(error = %e, "eth_getLogs failed");
328 }
329 result
330 }
331
332 pub async fn subscribe_blocks(&self) -> Result<Subscription<N::HeaderResponse>, Error> {
341 info!("eth_subscribe called");
342 let result = self
343 .retry_with_total_timeout(
344 move |provider| async move { provider.subscribe_blocks().await },
345 true,
346 )
347 .await;
348 if let Err(e) = &result {
349 error!(error = %e, "eth_subscribe failed");
350 }
351 result
352 }
353
354 async fn retry_with_total_timeout<T: Debug, F, Fut>(
374 &self,
375 operation: F,
376 require_pubsub: bool,
377 ) -> Result<T, Error>
378 where
379 F: Fn(RootProvider<N>) -> Fut,
380 Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
381 {
382 let mut providers = self.providers.iter();
383 let primary = providers.next().expect("should have primary provider");
384
385 let result = self.try_provider_with_timeout(primary, &operation).await;
386
387 if result.is_ok() {
388 return result;
389 }
390
391 let mut last_error = result.unwrap_err();
392
393 let num_providers = self.providers.len();
394 if num_providers > 1 {
395 info!("Primary provider failed, trying fallback provider(s)");
396 }
397
398 for (idx, provider) in providers.enumerate() {
400 let fallback_num = idx + 1;
401 if require_pubsub && !Self::supports_pubsub(provider) {
402 info!("Fallback provider {} doesn't support pubsub, skipping", fallback_num);
403 continue;
404 }
405 info!("Attempting fallback provider {}/{}", fallback_num, num_providers - 1);
406
407 match self.try_provider_with_timeout(provider, &operation).await {
408 Ok(value) => {
409 info!(provider_num = fallback_num, "Fallback provider succeeded");
410 return Ok(value);
411 }
412 Err(e) => {
413 error!(provider_num = fallback_num, err = %e, "Fallback provider failed");
414 last_error = e;
415 }
416 }
417 }
418
419 error!("All providers failed or timed out - returning the last providers attempt's error");
421 Err(last_error)
422 }
423
424 async fn try_provider_with_timeout<T, F, Fut>(
426 &self,
427 provider: &RootProvider<N>,
428 operation: F,
429 ) -> Result<T, Error>
430 where
431 F: Fn(RootProvider<N>) -> Fut,
432 Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
433 {
434 let retry_strategy = ExponentialBuilder::default()
435 .with_max_times(self.max_retries)
436 .with_min_delay(self.min_delay);
437
438 timeout(
439 self.max_timeout,
440 (|| operation(provider.clone()))
441 .retry(retry_strategy)
442 .notify(|err: &RpcError<TransportErrorKind>, dur: Duration| {
443 info!(error = %err, "RPC error retrying after {:?}", dur);
444 })
445 .sleep(tokio::time::sleep),
446 )
447 .await
448 .map_err(Error::from)?
449 .map_err(Error::from)
450 }
451
452 fn supports_pubsub(provider: &RootProvider<N>) -> bool {
454 provider.client().pubsub_frontend().is_some()
455 }
456}
457
458#[cfg(test)]
459mod tests {
460 use super::*;
461 use alloy::{
462 consensus::BlockHeader,
463 providers::{ProviderBuilder, WsConnect, ext::AnvilApi},
464 };
465 use alloy_node_bindings::Anvil;
466 use std::sync::atomic::{AtomicUsize, Ordering};
467 use tokio::time::sleep;
468
469 fn test_provider(timeout: u64, max_retries: usize, min_delay: u64) -> RobustProvider {
470 RobustProvider {
471 providers: vec![RootProvider::new_http("http://localhost:8545".parse().unwrap())],
472 max_timeout: Duration::from_millis(timeout),
473 max_retries,
474 min_delay: Duration::from_millis(min_delay),
475 }
476 }
477
478 #[tokio::test]
479 async fn test_retry_with_timeout_succeeds_on_first_attempt() {
480 let provider = test_provider(100, 3, 10);
481
482 let call_count = AtomicUsize::new(0);
483
484 let result = provider
485 .retry_with_total_timeout(
486 |_| async {
487 call_count.fetch_add(1, Ordering::SeqCst);
488 let count = call_count.load(Ordering::SeqCst);
489 Ok(count)
490 },
491 false,
492 )
493 .await;
494
495 assert!(matches!(result, Ok(1)));
496 }
497
498 #[tokio::test]
499 async fn test_retry_with_timeout_retries_on_error() {
500 let provider = test_provider(100, 3, 10);
501
502 let call_count = AtomicUsize::new(0);
503
504 let result = provider
505 .retry_with_total_timeout(
506 |_| async {
507 call_count.fetch_add(1, Ordering::SeqCst);
508 let count = call_count.load(Ordering::SeqCst);
509 match count {
510 3 => Ok(count),
511 _ => Err(TransportErrorKind::BackendGone.into()),
512 }
513 },
514 false,
515 )
516 .await;
517
518 assert!(matches!(result, Ok(3)));
519 }
520
521 #[tokio::test]
522 async fn test_retry_with_timeout_fails_after_max_retries() {
523 let provider = test_provider(100, 2, 10);
524
525 let call_count = AtomicUsize::new(0);
526
527 let result: Result<(), Error> = provider
528 .retry_with_total_timeout(
529 |_| async {
530 call_count.fetch_add(1, Ordering::SeqCst);
531 Err(TransportErrorKind::BackendGone.into())
532 },
533 false,
534 )
535 .await;
536
537 assert!(matches!(result, Err(Error::RpcError(_))));
538 assert_eq!(call_count.load(Ordering::SeqCst), 3);
539 }
540
541 #[tokio::test]
542 async fn test_retry_with_timeout_respects_max_timeout() {
543 let max_timeout = 50;
544 let provider = test_provider(max_timeout, 10, 1);
545
546 let result = provider
547 .retry_with_total_timeout(
548 move |_provider| async move {
549 sleep(Duration::from_millis(max_timeout + 10)).await;
550 Ok(42)
551 },
552 false,
553 )
554 .await;
555
556 assert!(matches!(result, Err(Error::Timeout)));
557 }
558
559 #[tokio::test]
560 async fn test_subscribe_fails_causes_backup_to_be_used() -> anyhow::Result<()> {
561 let anvil_1 = Anvil::new().try_spawn()?;
562
563 let ws_provider_1 =
564 ProviderBuilder::new().connect(anvil_1.ws_endpoint_url().as_str()).await?;
565
566 let anvil_2 = Anvil::new().try_spawn()?;
567
568 let ws_provider_2 =
569 ProviderBuilder::new().connect(anvil_2.ws_endpoint_url().as_str()).await?;
570
571 let robust = RobustProviderBuilder::fragile(ws_provider_1.clone())
572 .fallback(ws_provider_2.clone())
573 .max_timeout(Duration::from_secs(1))
574 .build()
575 .await?;
576
577 drop(anvil_1);
578
579 let mut subscription = robust.subscribe_blocks().await?;
580
581 ws_provider_2.anvil_mine(Some(2), None).await?;
582
583 assert_eq!(1, subscription.recv().await?.number());
584 assert_eq!(2, subscription.recv().await?.number());
585 assert!(subscription.is_empty());
586
587 Ok(())
588 }
589
590 #[tokio::test]
591 async fn test_subscribe_fails_when_all_providers_lack_pubsub() -> anyhow::Result<()> {
592 let anvil = Anvil::new().try_spawn()?;
593
594 let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url());
595
596 let robust = RobustProviderBuilder::new(http_provider.clone())
597 .fallback(http_provider)
598 .max_timeout(Duration::from_secs(5))
599 .min_delay(Duration::from_millis(100))
600 .build()
601 .await?;
602
603 let result = robust.subscribe_blocks().await.unwrap_err();
604
605 match result {
606 Error::RpcError(e) => {
607 assert!(matches!(
608 e.as_ref(),
609 RpcError::Transport(TransportErrorKind::PubsubUnavailable)
610 ));
611 }
612 other => panic!("Expected PubsubUnavailable error type, got: {other:?}"),
613 }
614
615 Ok(())
616 }
617
618 #[tokio::test]
619 async fn test_subscribe_succeeds_if_primary_provider_lacks_pubsub_but_fallback_supports_it()
620 -> anyhow::Result<()> {
621 let anvil = Anvil::new().try_spawn()?;
622
623 let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url());
624 let ws_provider = ProviderBuilder::new()
625 .connect_ws(WsConnect::new(anvil.ws_endpoint_url().as_str()))
626 .await?;
627
628 let robust = RobustProviderBuilder::fragile(http_provider)
629 .fallback(ws_provider)
630 .max_timeout(Duration::from_secs(5))
631 .build()
632 .await?;
633
634 let result = robust.subscribe_blocks().await;
635 assert!(result.is_ok());
636
637 Ok(())
638 }
639
640 #[tokio::test]
641 async fn test_ws_fails_http_fallback_returns_primary_error() -> anyhow::Result<()> {
642 let anvil_1 = Anvil::new().try_spawn()?;
643
644 let ws_provider =
645 ProviderBuilder::new().connect(anvil_1.ws_endpoint_url().as_str()).await?;
646
647 let anvil_2 = Anvil::new().try_spawn()?;
648 let http_provider = ProviderBuilder::new().connect_http(anvil_2.endpoint_url());
649
650 let robust = RobustProviderBuilder::fragile(ws_provider.clone())
651 .fallback(http_provider)
652 .max_timeout(Duration::from_millis(500))
653 .build()
654 .await?;
655
656 drop(anvil_1);
658
659 let err = robust.subscribe_blocks().await.unwrap_err();
660
661 match err {
664 Error::Timeout => {}
665 Error::RpcError(e) => {
666 assert!(matches!(e.as_ref(), RpcError::Transport(TransportErrorKind::BackendGone)));
667 }
668 Error::BlockNotFound(id) => panic!("Unexpected error type: BlockNotFound({id})"),
669 }
670
671 Ok(())
672 }
673}