use std::{borrow::Cow, fmt::Debug, future::Future, time::Duration};
use backon::{ExponentialBuilder, Retryable};
use serde_json::value::RawValue;
use tokio::time::timeout;
use super::errors::{FailoverError, is_retryable_error};
use alloy::{
consensus::{BlockHeader, TrieAccount},
eips::{BlockId, BlockNumberOrTag, eip1559::Eip1559Estimation},
network::{BlockResponse, Ethereum, Network},
primitives::{
Address, B256, BlockHash, BlockNumber, Bytes, StorageKey, StorageValue, TxHash, U256,
},
providers::{
PendingTransactionBuilder, Provider, RootProvider,
utils::{
EIP1559_FEE_ESTIMATION_PAST_BLOCKS, EIP1559_FEE_ESTIMATION_REWARD_PERCENTILE,
Eip1559Estimator,
},
},
rpc::{
client::PollerBuilder,
json_rpc::{RpcRecv, RpcSend},
types::{
AccessListResult, AccountInfo, Bundle, EIP1186AccountProofResponse, EthCallResponse,
FeeHistory, FillTransaction, Filter, Log, SyncStatus,
erc4337::TransactionConditional,
simulate::{SimulatePayload, SimulatedBlock},
},
},
transports::{RpcError, TransportErrorKind},
};
use crate::{
Error, block_not_found_doc,
robust_provider::{RobustSubscription, subscription::SubscriptionBackend},
};
#[derive(Clone, Debug)]
pub struct RobustProvider<N: Network = Ethereum> {
pub(crate) primary_provider: RootProvider<N>,
pub(crate) fallback_providers: Vec<RootProvider<N>>,
pub(crate) call_timeout: Duration,
pub(crate) subscription_timeout: Duration,
pub(crate) max_retries: usize,
pub(crate) min_delay: Duration,
pub(crate) reconnect_interval: Duration,
pub(crate) subscription_buffer_capacity: usize,
#[cfg(feature = "http-subscription")]
pub(crate) poll_interval: Duration,
#[cfg(feature = "http-subscription")]
pub(crate) allow_http_subscriptions: bool,
}
impl<N: Network> RobustProvider<N> {
#[must_use]
pub fn primary(&self) -> &RootProvider<N> {
&self.primary_provider
}
#[must_use]
pub fn fallback_providers(&self) -> &[RootProvider<N>] {
&self.fallback_providers
}
#[must_use]
pub fn call_timeout(&self) -> Duration {
self.call_timeout
}
#[must_use]
pub fn max_retries(&self) -> usize {
self.max_retries
}
#[must_use]
pub fn min_delay(&self) -> Duration {
self.min_delay
}
robust_rpc!(fn get_accounts() -> Vec<Address>);
robust_rpc!(fn get_blob_base_fee() -> u128);
robust_rpc!(
doc_args = [(tx, "The transaction request to simulate.")]
@clone [tx]
fn call(tx: N::TransactionRequest) -> Bytes
);
robust_rpc!(
doc_args = [(bundles, "A slice of transaction bundles to execute.")]
fn call_many(bundles: &[Bundle]) -> Vec<Vec<EthCallResponse>>
);
robust_rpc!(fn get_chain_id() -> u64);
robust_rpc!(fn get_net_version() -> u64);
robust_rpc!(
doc_args = [(data, "The data to hash.")]
fn get_sha3(data: &[u8]) -> B256
);
robust_rpc!(
doc_args = [(request, "The transaction request to create an access list for.")]
fn create_access_list(request: &N::TransactionRequest) -> AccessListResult
);
robust_rpc!(
doc_args = [(tx, "The transaction request to estimate gas for.")]
@clone [tx]
fn estimate_gas(tx: N::TransactionRequest) -> u64
);
robust_rpc!(
fn estimate_eip1559_fees() -> Eip1559Estimation
);
robust_rpc!(
doc_args = [
(block_count, "The number of blocks to include in the fee history."),
(last_block, "The last block to include in the fee history."),
(reward_percentiles, "A list of percentiles to compute reward values for.")
]
fn get_fee_history(block_count: u64, last_block: BlockNumberOrTag, reward_percentiles: &[f64]) -> FeeHistory
);
robust_rpc!(fn get_gas_price() -> u128);
robust_rpc!(fn get_max_priority_fee_per_gas() -> u128);
robust_rpc!(
doc_args = [(address, "The address for which to get the account info.")]
fn get_account_info(address: Address) -> AccountInfo
);
robust_rpc!(
doc_args = [(address, "The address to get the account for.")]
fn get_account(address: Address) -> TrieAccount
);
robust_rpc!(
doc_args = [(address, "The address to get the balance for.")]
fn get_balance(address: Address) -> U256
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(number, "The block number or tag.")]
fn get_block_by_number(number: BlockNumberOrTag) -> N::BlockResponse; or BlockNotFound
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(id, "The block identifier.")]
fn get_block(id: BlockId) -> N::BlockResponse; or BlockNotFound
);
robust_rpc!(fn get_block_number() -> BlockNumber);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(block_id, "The block identifier to fetch the block number for.")]
fn get_block_number_by_id(block_id: BlockId) -> BlockNumber; or BlockNotFound
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(hash, "The block hash.")]
fn get_block_by_hash(hash: BlockHash) -> N::BlockResponse; or BlockNotFound
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(block, "The block identifier (hash, number, or tag).")]
fn get_block_receipts(block: BlockId) -> Vec<N::ReceiptResponse>; or BlockNotFound
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(hash, "The block hash.")]
fn get_block_transaction_count_by_hash(hash: BlockHash) -> u64; or BlockNotFound
);
robust_rpc!(
doc_include_error = [block_not_found_doc!()]
doc_args = [(block_number, "The block number or tag.")]
fn get_block_transaction_count_by_number(block_number: BlockNumberOrTag) -> u64; or BlockNotFound
);
robust_rpc!(
doc_args = [(filter, "The log filter.")]
fn get_logs(filter: &Filter) -> Vec<Log>
);
robust_rpc!(
doc_args = [(address, "The address to get the code for.")]
fn get_code_at(address: Address) -> Bytes
);
robust_rpc!(
doc_args = [(filter_id, "The filter ID to fetch logs for.")]
fn get_filter_logs(filter_id: U256) -> Vec<Log>
);
robust_rpc!(
doc_args = [(filter_id, "The filter ID to get changes for.")]
fn get_filter_changes<R: RpcRecv>(filter_id: U256) -> Vec<R>
);
robust_rpc!(
doc_args = [(filter, "The filter to create.")]
fn new_filter(filter: &Filter) -> U256
);
robust_rpc!(fn new_block_filter() -> U256);
robust_rpc!(
doc_args = [(full, "Whether to include full transaction objects.")]
fn new_pending_transactions_filter(full: bool) -> U256
);
robust_rpc!(
doc_args = [
(tx, "The transaction request to sign.")
]
@clone [tx]
fn sign_transaction(tx: N::TransactionRequest) -> Bytes
);
robust_rpc!(
doc_args = [
(tx, "The transaction request to fill.")
]
@clone [tx]
fn fill_transaction(tx: N::TransactionRequest) -> FillTransaction<N::TxEnvelope>
where N::TxEnvelope: RpcRecv
);
robust_rpc!(
doc_args = [
(address, "The address of the account."),
(keys, "A vector of storage keys to include in the proof.")
]
@clone [keys]
fn get_proof(address: Address, keys: Vec<StorageKey>) -> EIP1186AccountProofResponse
);
robust_rpc!(
doc_args = [
(address, "The address of the storage."),
(key, "The position in the storage.")
]
fn get_storage_at(address: Address, key: U256) -> StorageValue
);
robust_rpc!(
doc_args = [
(block_hash, "The hash of the block."),
(index, "The transaction index position.")
]
fn get_transaction_by_block_hash_and_index(block_hash: B256, index: usize) -> Option<N::TransactionResponse>
);
robust_rpc!(
doc_args = [
(block_number, "The block number or tag."),
(index, "The transaction index position.")
]
fn get_transaction_by_block_number_and_index(block_number: BlockNumberOrTag, index: usize) -> Option<N::TransactionResponse>
);
robust_rpc!(
doc_args = [(hash, "The transaction hash.")]
fn get_transaction_by_hash(hash: TxHash) -> Option<N::TransactionResponse>
);
robust_rpc!(
doc_args = [(hash, "The transaction hash.")]
fn get_raw_transaction_by_hash(hash: TxHash) -> Option<Bytes>
);
robust_rpc!(
doc_args = [(address, "The address to get the transaction count for.")]
fn get_transaction_count(address: Address) -> u64
);
robust_rpc!(
doc_args = [(hash, "The transaction hash.")]
fn get_transaction_receipt(hash: TxHash) -> Option<N::ReceiptResponse>
);
robust_rpc!(
doc_args = [(block, "The block identifier (hash or number).")]
fn get_uncle_count(block: BlockId) -> u64
);
robust_rpc!(
doc_args = [
(tag, "The block identifier (hash or number)."),
(idx, "The uncle index position.")
]
fn get_uncle(tag: BlockId, idx: u64) -> Option<N::BlockResponse>
);
robust_rpc!(
doc_args = [(request, "The simulation request")]
fn simulate(request: &SimulatePayload) -> Vec<SimulatedBlock<N::BlockResponse>>
);
robust_rpc!(fn syncing() -> SyncStatus);
robust_rpc!(
doc_args = [(id, "The filter ID to uninstall.")]
fn uninstall_filter(id: U256) -> bool
);
robust_rpc!(
doc_args = [(encoded_tx, "The RLP-encoded signed transaction bytes")]
fn send_raw_transaction(encoded_tx: &[u8]) -> PendingTransactionBuilder<N>
);
robust_rpc!(
doc_args = [(encoded_tx, "The RLP-encoded signed transaction bytes")]
fn send_raw_transaction_sync(encoded_tx: &[u8]) -> N::ReceiptResponse
);
robust_rpc!(
doc_args = [
(encoded_tx, "The RLP-encoded signed transaction bytes"),
(conditional, "The transaction conditional to apply")
]
@clone [conditional]
fn send_raw_transaction_conditional(encoded_tx: &[u8], conditional: TransactionConditional) -> PendingTransactionBuilder<N>
);
robust_rpc!(
doc_args = [(tx, "The transaction request to send")]
@clone [tx]
fn send_transaction(tx: N::TransactionRequest) -> PendingTransactionBuilder<N>
);
robust_rpc!(
doc_args = [
(tx, "The signed transaction envelope to send.")
]
@clone [tx]
fn send_tx_envelope(tx: N::TxEnvelope) -> PendingTransactionBuilder<N>
where N::TxEnvelope: Clone
);
robust_rpc!(
doc_args = [(tx, "The transaction request to send synchronously")]
@clone [tx]
fn send_transaction_sync(tx: N::TransactionRequest) -> N::ReceiptResponse
);
robust_rpc!(
doc_args = [
(sender, "The sender address"),
(nonce, "The nonce of the transaction")
]
fn get_transaction_by_sender_nonce(sender: Address, nonce: u64) -> Option<N::TransactionResponse>
);
robust_rpc!(
doc_args = [
(block_hash, "The hash of the block"),
(index, "The transaction index position")
]
fn get_raw_transaction_by_block_hash_and_index(block_hash: B256, index: usize) -> Option<Bytes>
);
robust_rpc!(
doc_args = [
(block_number, "The block number or tag"),
(index, "The transaction index position")
]
fn get_raw_transaction_by_block_number_and_index(block_number: BlockNumberOrTag, index: usize) -> Option<Bytes>
);
robust_rpc!(
doc_alias = "web3_client_version"
fn get_client_version() -> String
);
pub async fn estimate_eip1559_fees_with(
&self,
estimator: Eip1559Estimator,
) -> Result<Eip1559Estimation, Error> {
let fee_history = self
.get_fee_history(
EIP1559_FEE_ESTIMATION_PAST_BLOCKS,
BlockNumberOrTag::Latest,
&[EIP1559_FEE_ESTIMATION_REWARD_PERCENTILE],
)
.await?;
let base_fee_per_gas = match fee_history.latest_block_base_fee() {
Some(base_fee) if base_fee != 0 => base_fee,
_ => {
self.get_block_by_number(BlockNumberOrTag::Latest)
.await
.map_err(|e| {
if matches!(e, Error::BlockNotFound) {
RpcError::NullResp.into()
} else {
e
}
})?
.header()
.as_ref()
.base_fee_per_gas()
.ok_or(RpcError::UnsupportedFeature("eip1559"))?
.into()
}
};
Ok(estimator.estimate(base_fee_per_gas, &fee_history.reward.unwrap_or_default()))
}
robust_rpc!(
@clone [method, params]
fn raw_request<P, R>(method: Cow<'static, str>, params: P) -> R
where P: RpcSend, R: RpcRecv
);
robust_rpc!(
@clone [method]
fn raw_request_dyn(method: Cow<'static, str>, params: &RawValue) -> Box<RawValue>
);
pub async fn subscribe_blocks(&self) -> Result<RobustSubscription<N>, Error> {
let subscription: SubscriptionBackend<N> = self
.try_operation_with_failover(move |provider| async move {
#[cfg(feature = "http-subscription")]
{
let not_pubsub = provider.client().pubsub_frontend().is_none();
if not_pubsub && self.allow_http_subscriptions {
return provider.watch_blocks().await.map(|builder| {
builder
.with_poll_interval(self.poll_interval)
.with_channel_size(self.subscription_buffer_capacity)
.into()
});
}
}
provider
.subscribe_blocks()
.channel_size(self.subscription_buffer_capacity)
.await
.map(Into::<SubscriptionBackend<N>>::into)
})
.await?;
Ok(RobustSubscription::new(subscription, self.clone()))
}
robust_rpc!(fn watch_blocks() -> PollerBuilder<(U256,), Vec<BlockHash>>);
pub async fn try_operation_with_failover<T: Debug, F, Fut>(
&self,
operation: F,
) -> Result<T, FailoverError>
where
F: Fn(RootProvider<N>) -> Fut,
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
{
let primary = self.primary();
match self.try_provider_with_timeout(primary, &operation).await {
Ok(value) => Ok(value),
Err(last_error) => self
.try_fallback_providers_from(&operation, last_error, 0)
.await
.map(|(value, _)| value),
}
}
pub(crate) async fn try_fallback_providers_from<T: Debug, F, Fut>(
&self,
operation: F,
mut last_error: FailoverError,
start_index: usize,
) -> Result<(T, usize), FailoverError>
where
F: Fn(RootProvider<N>) -> Fut,
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
{
let fallback_providers = self.fallback_providers();
debug!(
start_index = start_index,
total_fallbacks = fallback_providers.len(),
"Primary provider failed, attempting fallback providers"
);
let fallback_iter = fallback_providers.iter().enumerate().skip(start_index);
for (fallback_idx, provider) in fallback_iter {
trace!(
fallback_index = fallback_idx,
total_fallbacks = fallback_providers.len(),
"Attempting fallback provider"
);
match self.try_provider_with_timeout(provider, &operation).await {
Ok(value) => {
info!(
fallback_index = fallback_idx,
total_fallbacks = fallback_providers.len(),
"Switched to fallback provider"
);
return Ok((value, fallback_idx));
}
Err(e) => {
warn!(
fallback_index = fallback_idx,
error = %e,
"Fallback provider failed"
);
last_error = e;
}
}
}
error!(attempted_providers = fallback_providers.len() + 1, "All providers exhausted");
Err(last_error)
}
pub(crate) fn try_provider_with_timeout<T, F, Fut>(
&self,
provider: &RootProvider<N>,
operation: F,
) -> impl Future<Output = Result<T, FailoverError>>
where
F: Fn(RootProvider<N>) -> Fut,
Fut: Future<Output = Result<T, RpcError<TransportErrorKind>>>,
{
let retry_strategy = ExponentialBuilder::default()
.with_max_times(self.max_retries())
.with_min_delay(self.min_delay());
let call_timeout = self.call_timeout();
let provider = provider.clone();
async move {
timeout(
call_timeout,
(|| operation(provider.clone()))
.retry(retry_strategy)
.when(|e| match e {
RpcError::ErrorResp(err_resp) if err_resp.is_retry_err() => true,
RpcError::ErrorResp(err_resp) => {
is_retryable_error(err_resp.code, err_resp.message.as_ref())
}
RpcError::Transport(tr_err) => tr_err.is_retry_err(),
_ => true,
})
.sleep(tokio::time::sleep),
)
.await
.map_err(FailoverError::from)?
.map_err(FailoverError::from)
}
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use alloy::{
node_bindings::Anvil,
providers::{ProviderBuilder, WsConnect},
};
use tokio::time::sleep;
use crate::{
DEFAULT_RECONNECT_INTERVAL, DEFAULT_SUBSCRIPTION_BUFFER_CAPACITY,
DEFAULT_SUBSCRIPTION_TIMEOUT, Error, RobustProvider, RobustProviderBuilder,
};
use super::*;
fn test_provider(timeout: u64, max_retries: usize, min_delay: u64) -> RobustProvider {
RobustProvider {
primary_provider: RootProvider::new_http("http://localhost:8545".parse().unwrap()),
fallback_providers: vec![],
call_timeout: Duration::from_millis(timeout),
subscription_timeout: DEFAULT_SUBSCRIPTION_TIMEOUT,
max_retries,
min_delay: Duration::from_millis(min_delay),
reconnect_interval: DEFAULT_RECONNECT_INTERVAL,
subscription_buffer_capacity: DEFAULT_SUBSCRIPTION_BUFFER_CAPACITY,
#[cfg(feature = "http-subscription")]
poll_interval: crate::DEFAULT_POLL_INTERVAL,
#[cfg(feature = "http-subscription")]
allow_http_subscriptions: false,
}
}
#[tokio::test]
async fn test_retry_with_timeout_succeeds_on_first_attempt() {
let provider = test_provider(100, 3, 10);
let call_count = AtomicUsize::new(0);
let result = provider
.try_operation_with_failover(|_| async {
call_count.fetch_add(1, Ordering::SeqCst);
let count = call_count.load(Ordering::SeqCst);
Ok(count)
})
.await;
assert!(matches!(result, Ok(1)));
}
#[tokio::test]
async fn test_retry_with_timeout_retries_on_error() {
let provider = test_provider(100, 3, 10);
let call_count = AtomicUsize::new(0);
let result = provider
.try_operation_with_failover(|_| async {
call_count.fetch_add(1, Ordering::SeqCst);
let count = call_count.load(Ordering::SeqCst);
match count {
3 => Ok(count),
_ => Err(TransportErrorKind::custom_str("429 Too Many Requests")),
}
})
.await;
assert!(matches!(result, Ok(3)));
}
#[tokio::test]
async fn test_retry_with_timeout_fails_after_max_retries() {
let provider = test_provider(100, 2, 10);
let call_count = AtomicUsize::new(0);
let result: Result<(), FailoverError> = provider
.try_operation_with_failover(|_| async {
call_count.fetch_add(1, Ordering::SeqCst);
Err(TransportErrorKind::custom_str("429 Too Many Requests"))
})
.await;
assert!(matches!(result, Err(FailoverError::RpcError(_))));
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_with_timeout_respects_call_timeout() {
let call_timeout = 50;
let provider = test_provider(call_timeout, 10, 1);
let result = provider
.try_operation_with_failover(move |_provider| async move {
sleep(Duration::from_millis(call_timeout + 10)).await;
Ok(42)
})
.await;
assert!(matches!(result, Err(FailoverError::Timeout)));
}
#[tokio::test]
async fn test_subscribe_fails_when_all_providers_lack_pubsub() -> anyhow::Result<()> {
let anvil = Anvil::new().try_spawn()?;
let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url());
let robust = RobustProviderBuilder::new(http_provider.clone())
.fallback(http_provider)
.call_timeout(Duration::from_secs(5))
.min_delay(Duration::from_millis(100))
.build()
.await?;
let err = robust.subscribe_blocks().await.unwrap_err();
assert!(
matches!(
err,
Error::RpcError(RpcError::Transport(TransportErrorKind::PubsubUnavailable))
),
"expected TransportErrorKind::PubsubUnavailable error type, got: {err:?}"
);
Ok(())
}
#[tokio::test]
async fn test_subscribe_succeeds_if_primary_provider_lacks_pubsub_but_fallback_supports_it()
-> anyhow::Result<()> {
let anvil = Anvil::new().try_spawn()?;
let http_provider = ProviderBuilder::new().connect_http(anvil.endpoint_url());
let ws_provider = ProviderBuilder::new()
.connect_ws(WsConnect::new(anvil.ws_endpoint_url().as_str()))
.await?;
let robust = RobustProviderBuilder::fragile(http_provider)
.fallback(ws_provider)
.call_timeout(Duration::from_secs(5))
.build()
.await?;
let result = robust.subscribe_blocks().await;
assert!(result.is_ok());
Ok(())
}
}