use super::{common::JsonRpcError, http::ClientError};
use crate::{errors::ProviderError, JsonRpcClient};
use async_trait::async_trait;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{
fmt::Debug,
sync::atomic::{AtomicU32, Ordering},
time::Duration,
};
use thiserror::Error;
use tracing::trace;
pub trait RetryPolicy<E>: Send + Sync + Debug {
fn should_retry(&self, error: &E) -> bool;
fn backoff_hint(&self, error: &E) -> Option<Duration>;
}
#[derive(Debug)]
pub struct RetryClient<T>
where
T: JsonRpcClient,
T::Error: crate::RpcError + Sync + Send + 'static,
{
inner: T,
requests_enqueued: AtomicU32,
policy: Box<dyn RetryPolicy<T::Error>>,
timeout_retries: u32,
rate_limit_retries: u32,
initial_backoff: Duration,
compute_units_per_second: u64,
}
impl<T> RetryClient<T>
where
T: JsonRpcClient,
T::Error: Sync + Send + 'static,
{
pub fn new(
inner: T,
policy: Box<dyn RetryPolicy<T::Error>>,
max_retry: u32,
initial_backoff: u64,
) -> Self {
RetryClientBuilder::default()
.initial_backoff(Duration::from_millis(initial_backoff))
.rate_limit_retries(max_retry)
.build(inner, policy)
}
pub fn set_compute_units(&mut self, cpus: u64) -> &mut Self {
self.compute_units_per_second = cpus;
self
}
}
#[derive(Debug, Clone, Eq, PartialEq)]
pub struct RetryClientBuilder {
timeout_retries: u32,
rate_limit_retries: u32,
initial_backoff: Duration,
compute_units_per_second: u64,
}
impl RetryClientBuilder {
pub fn timeout_retries(mut self, timeout_retries: u32) -> Self {
self.timeout_retries = timeout_retries;
self
}
pub fn rate_limit_retries(mut self, rate_limit_retries: u32) -> Self {
self.rate_limit_retries = rate_limit_retries;
self
}
pub fn compute_units_per_second(mut self, compute_units_per_second: u64) -> Self {
self.compute_units_per_second = compute_units_per_second;
self
}
pub fn initial_backoff(mut self, initial_backoff: Duration) -> Self {
self.initial_backoff = initial_backoff;
self
}
pub fn build<T>(self, client: T, policy: Box<dyn RetryPolicy<T::Error>>) -> RetryClient<T>
where
T: JsonRpcClient,
T::Error: Sync + Send + 'static,
{
let RetryClientBuilder {
timeout_retries,
rate_limit_retries,
initial_backoff,
compute_units_per_second,
} = self;
RetryClient {
inner: client,
requests_enqueued: AtomicU32::new(0),
policy,
timeout_retries,
rate_limit_retries,
initial_backoff,
compute_units_per_second,
}
}
}
impl Default for RetryClientBuilder {
fn default() -> Self {
Self {
timeout_retries: 3,
rate_limit_retries: 10,
initial_backoff: Duration::from_millis(1000),
compute_units_per_second: 330,
}
}
}
#[derive(Error, Debug)]
pub enum RetryClientError {
#[error(transparent)]
ProviderError(ProviderError),
TimeoutError,
#[error(transparent)]
SerdeJson(serde_json::Error),
}
impl crate::RpcError for RetryClientError {
fn as_error_response(&self) -> Option<&super::JsonRpcError> {
if let RetryClientError::ProviderError(err) = self {
err.as_error_response()
} else {
None
}
}
fn as_serde_error(&self) -> Option<&serde_json::Error> {
match self {
RetryClientError::ProviderError(e) => e.as_serde_error(),
RetryClientError::SerdeJson(e) => Some(e),
_ => None,
}
}
}
impl std::fmt::Display for RetryClientError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:?}")
}
}
impl From<RetryClientError> for ProviderError {
fn from(src: RetryClientError) -> Self {
match src {
RetryClientError::ProviderError(err) => err,
RetryClientError::TimeoutError => ProviderError::JsonRpcClientError(Box::new(src)),
RetryClientError::SerdeJson(err) => err.into(),
}
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl<T> JsonRpcClient for RetryClient<T>
where
T: JsonRpcClient + 'static,
T::Error: Sync + Send + 'static,
{
type Error = RetryClientError;
async fn request<A, R>(&self, method: &str, params: A) -> Result<R, Self::Error>
where
A: Debug + Serialize + Send + Sync,
R: DeserializeOwned + Send,
{
enum RetryParams<Params> {
Value(Params),
Zst(()),
}
let params = if std::mem::size_of::<A>() == 0 {
RetryParams::Zst(())
} else {
let params = serde_json::to_value(params).map_err(RetryClientError::SerdeJson)?;
RetryParams::Value(params)
};
let ahead_in_queue = self.requests_enqueued.fetch_add(1, Ordering::SeqCst) as u64;
let mut rate_limit_retry_number: u32 = 0;
let mut timeout_retries: u32 = 0;
loop {
let err;
{
let resp = match params {
RetryParams::Value(ref params) => self.inner.request(method, params).await,
RetryParams::Zst(unit) => self.inner.request(method, unit).await,
};
match resp {
Ok(ret) => {
self.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
return Ok(ret)
}
Err(err_) => err = err_,
}
}
let should_retry = self.policy.should_retry(&err);
if should_retry {
rate_limit_retry_number += 1;
if rate_limit_retry_number > self.rate_limit_retries {
trace!("request timed out after {} retries", self.rate_limit_retries);
return Err(RetryClientError::TimeoutError)
}
let current_queued_requests = self.requests_enqueued.load(Ordering::SeqCst) as u64;
let mut next_backoff = self.policy.backoff_hint(&err).unwrap_or_else(|| {
Duration::from_millis(self.initial_backoff.as_millis() as u64)
});
const AVG_COST: u64 = 17u64;
let seconds_to_wait_for_compute_budget = compute_unit_offset_in_secs(
AVG_COST,
self.compute_units_per_second,
current_queued_requests,
ahead_in_queue,
);
next_backoff += Duration::from_secs(seconds_to_wait_for_compute_budget);
trace!("retrying and backing off for {:?}", next_backoff);
#[cfg(target_arch = "wasm32")]
futures_timer::Delay::new(next_backoff).await;
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(next_backoff).await;
} else {
let err: ProviderError = err.into();
if timeout_retries < self.timeout_retries && maybe_connectivity(&err) {
timeout_retries += 1;
trace!(err = ?err, "retrying due to spurious network");
continue
}
trace!(err = ?err, "should not retry");
self.requests_enqueued.fetch_sub(1, Ordering::SeqCst);
return Err(RetryClientError::ProviderError(err))
}
}
}
}
#[derive(Debug, Default)]
pub struct HttpRateLimitRetryPolicy;
impl RetryPolicy<ClientError> for HttpRateLimitRetryPolicy {
fn should_retry(&self, error: &ClientError) -> bool {
fn should_retry_json_rpc_error(err: &JsonRpcError) -> bool {
let JsonRpcError { code, message, .. } = err;
if *code == 429 {
return true
}
if *code == -32005 {
return true
}
if *code == -32016 && message.contains("rate limit") {
return true
}
match message.as_str() {
"header not found" => true,
"daily request count exceeded, request rate limited" => true,
_ => false,
}
}
match error {
ClientError::ReqwestError(err) => {
err.status() == Some(http::StatusCode::TOO_MANY_REQUESTS)
}
ClientError::JsonRpcError(err) => should_retry_json_rpc_error(err),
ClientError::SerdeJson { text, .. } => {
#[derive(Deserialize)]
struct Resp {
error: JsonRpcError,
}
if let Ok(resp) = serde_json::from_str::<Resp>(text) {
return should_retry_json_rpc_error(&resp.error)
}
false
}
}
}
fn backoff_hint(&self, error: &ClientError) -> Option<Duration> {
if let ClientError::JsonRpcError(JsonRpcError { data, .. }) = error {
let data = data.as_ref()?;
let backoff_seconds = &data["rate"]["backoff_seconds"];
if let Some(seconds) = backoff_seconds.as_u64() {
return Some(Duration::from_secs(seconds))
}
if let Some(seconds) = backoff_seconds.as_f64() {
return Some(Duration::from_secs(seconds as u64 + 1))
}
}
None
}
}
fn compute_unit_offset_in_secs(
avg_cost: u64,
compute_units_per_second: u64,
current_queued_requests: u64,
ahead_in_queue: u64,
) -> u64 {
let request_capacity_per_second = compute_units_per_second.saturating_div(avg_cost);
if current_queued_requests > request_capacity_per_second {
current_queued_requests.min(ahead_in_queue).saturating_div(request_capacity_per_second)
} else {
0
}
}
fn maybe_connectivity(err: &ProviderError) -> bool {
if let ProviderError::HTTPError(reqwest_err) = err {
if reqwest_err.is_timeout() {
return true
}
#[cfg(not(target_arch = "wasm32"))]
if reqwest_err.is_connect() {
return true
}
if let Some(status) = reqwest_err.status() {
let code = status.as_u16();
if (500..600).contains(&code) {
return true
}
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
const AVG_COST: u64 = 17u64;
const COMPUTE_UNITS: u64 = 330u64;
fn compute_offset(current_queued_requests: u64, ahead_in_queue: u64) -> u64 {
compute_unit_offset_in_secs(
AVG_COST,
COMPUTE_UNITS,
current_queued_requests,
ahead_in_queue,
)
}
#[test]
fn can_measure_unit_offset_single_request() {
let current_queued_requests = 1;
let ahead_in_queue = 0;
let to_wait = compute_offset(current_queued_requests, ahead_in_queue);
assert_eq!(to_wait, 0);
let current_queued_requests = 19;
let ahead_in_queue = 18;
let to_wait = compute_offset(current_queued_requests, ahead_in_queue);
assert_eq!(to_wait, 0);
}
#[test]
fn can_measure_unit_offset_1x_over_budget() {
let current_queued_requests = 20;
let ahead_in_queue = 19;
let to_wait = compute_offset(current_queued_requests, ahead_in_queue);
assert_eq!(to_wait, 1);
}
#[test]
fn can_measure_unit_offset_2x_over_budget() {
let current_queued_requests = 49;
let ahead_in_queue = 48;
let to_wait = compute_offset(current_queued_requests, ahead_in_queue);
assert_eq!(to_wait, 2);
let current_queued_requests = 49;
let ahead_in_queue = 20;
let to_wait = compute_offset(current_queued_requests, ahead_in_queue);
assert_eq!(to_wait, 1);
}
#[test]
fn can_extract_backoff() {
let resp = r#"{"rate": {"allowed_rps": 1, "backoff_seconds": 30, "current_rps": 1.1}, "see": "https://infura.io/dashboard"}"#;
let err = ClientError::JsonRpcError(JsonRpcError {
code: 0,
message: "daily request count exceeded, request rate limited".to_string(),
data: Some(serde_json::from_str(resp).unwrap()),
});
let backoff = HttpRateLimitRetryPolicy.backoff_hint(&err).unwrap();
assert_eq!(backoff, Duration::from_secs(30));
let err = ClientError::JsonRpcError(JsonRpcError {
code: 0,
message: "daily request count exceeded, request rate limited".to_string(),
data: Some(serde_json::Value::String("blocked".to_string())),
});
let backoff = HttpRateLimitRetryPolicy.backoff_hint(&err);
assert!(backoff.is_none());
}
#[test]
fn test_alchemy_ip_rate_limit() {
let s = "{\"code\":-32016,\"message\":\"Your IP has exceeded its requests per second capacity. To increase your rate limits, please sign up for a free Alchemy account at https://www.alchemy.com/optimism.\"}";
let err: JsonRpcError = serde_json::from_str(s).unwrap();
let err = ClientError::JsonRpcError(err);
let should_retry = HttpRateLimitRetryPolicy.should_retry(&err);
assert!(should_retry);
}
#[test]
fn test_rate_limit_omitted_id() {
let s = r#"{"jsonrpc":"2.0","error":{"code":-32016,"message":"Your IP has exceeded its requests per second capacity. To increase your rate limits, please sign up for a free Alchemy account at https://www.alchemy.com/optimism."},"id":null}"#;
let err = ClientError::SerdeJson {
err: serde::de::Error::custom("unexpected notification over HTTP transport"),
text: s.to_string(),
};
let should_retry = HttpRateLimitRetryPolicy.should_retry(&err);
assert!(should_retry);
}
}