use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use alloy_json_rpc::{RequestPacket, ResponsePacket, RpcError};
use alloy_transport::{TransportError, TransportErrorKind};
use tower::Layer;
use tracing::{debug, warn};
const DEFAULT_MAX_RETRIES: u32 = 3;
const DEFAULT_BASE_DELAY_MS: u64 = 100;
const DEFAULT_MAX_DELAY_MS: u64 = 30_000;
#[derive(Clone, Debug)]
pub struct RetryLayer {
config: Arc<RetryConfig>,
}
#[derive(Clone, Debug)]
pub struct RetryConfig {
pub max_retries: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: DEFAULT_MAX_RETRIES,
base_delay: Duration::from_millis(DEFAULT_BASE_DELAY_MS),
max_delay: Duration::from_millis(DEFAULT_MAX_DELAY_MS),
}
}
}
impl RetryLayer {
pub fn new() -> Self {
Self {
config: Arc::new(RetryConfig::default()),
}
}
pub fn builder() -> RetryLayerBuilder {
RetryLayerBuilder::new()
}
pub fn with_max_retries(max_retries: u32) -> Self {
Self {
config: Arc::new(RetryConfig {
max_retries,
..Default::default()
}),
}
}
pub fn aggressive() -> Self {
Self {
config: Arc::new(RetryConfig {
max_retries: 5,
base_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(10),
}),
}
}
pub fn conservative() -> Self {
Self {
config: Arc::new(RetryConfig {
max_retries: 3,
base_delay: Duration::from_millis(500),
max_delay: Duration::from_secs(60),
}),
}
}
}
impl Default for RetryLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> Layer<S> for RetryLayer {
type Service = RetryService<S>;
fn layer(&self, service: S) -> Self::Service {
RetryService {
service,
config: self.config.clone(),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct RetryLayerBuilder {
config: RetryConfig,
}
impl RetryLayerBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn max_retries(mut self, max_retries: u32) -> Self {
self.config.max_retries = max_retries;
self
}
pub fn base_delay(mut self, delay: Duration) -> Self {
self.config.base_delay = delay;
self
}
pub fn max_delay(mut self, delay: Duration) -> Self {
self.config.max_delay = delay;
self
}
pub fn build(self) -> RetryLayer {
RetryLayer {
config: Arc::new(self.config),
}
}
}
#[derive(Clone, Debug)]
pub struct RetryService<S> {
service: S,
config: Arc<RetryConfig>,
}
impl<S> tower::Service<RequestPacket> for RetryService<S>
where
S: tower::Service<RequestPacket, Response = ResponsePacket, Error = TransportError>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = ResponsePacket;
type Error = TransportError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx)
}
fn call(&mut self, request: RequestPacket) -> Self::Future {
let service = self.service.clone();
let config = self.config.clone();
Box::pin(async move {
let mut attempt = 0u32;
loop {
let mut service_clone = service.clone();
match service_clone.call(request.clone()).await {
Ok(response) => {
if attempt > 0 {
debug!(attempt = attempt, "Request succeeded after retry");
}
return Ok(response);
}
Err(error) => {
if !is_retryable_error(&error) {
debug!(
error = %error,
"Non-retryable error, not retrying"
);
return Err(error);
}
if attempt >= config.max_retries {
warn!(
error = %error,
attempts = attempt + 1,
"Max retries exceeded"
);
return Err(error);
}
let delay = calculate_backoff(attempt, &config);
warn!(
error = %error,
attempt = attempt + 1,
max_retries = config.max_retries,
delay_ms = delay.as_millis(),
"Retryable error, backing off"
);
tokio::time::sleep(delay).await;
attempt += 1;
}
}
}
})
}
}
fn calculate_backoff(attempt: u32, config: &RetryConfig) -> Duration {
let multiplier = 2u64.saturating_pow(attempt);
let delay_ms = config
.base_delay
.as_millis()
.saturating_mul(multiplier as u128);
let capped_delay_ms = delay_ms.min(config.max_delay.as_millis()) as u64;
Duration::from_millis(capped_delay_ms)
}
fn is_retryable_error(error: &TransportError) -> bool {
match error {
RpcError::Transport(kind) => is_transport_kind_retryable(kind),
RpcError::SerError(_) => false,
RpcError::DeserError { .. } => true,
RpcError::ErrorResp(err) => err.is_retry_err(),
RpcError::NullResp => true,
_ => false,
}
}
fn is_transport_kind_retryable(kind: &TransportErrorKind) -> bool {
kind.is_retry_err()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_retry_layer_default() {
let layer = RetryLayer::new();
assert_eq!(layer.config.max_retries, DEFAULT_MAX_RETRIES);
assert_eq!(
layer.config.base_delay,
Duration::from_millis(DEFAULT_BASE_DELAY_MS)
);
assert_eq!(
layer.config.max_delay,
Duration::from_millis(DEFAULT_MAX_DELAY_MS)
);
}
#[test]
fn test_retry_layer_builder() {
let layer = RetryLayer::builder()
.max_retries(5)
.base_delay(Duration::from_millis(200))
.max_delay(Duration::from_secs(60))
.build();
assert_eq!(layer.config.max_retries, 5);
assert_eq!(layer.config.base_delay, Duration::from_millis(200));
assert_eq!(layer.config.max_delay, Duration::from_secs(60));
}
#[test]
fn test_retry_layer_with_max_retries() {
let layer = RetryLayer::with_max_retries(10);
assert_eq!(layer.config.max_retries, 10);
}
#[test]
fn test_retry_layer_aggressive() {
let layer = RetryLayer::aggressive();
assert_eq!(layer.config.max_retries, 5);
assert_eq!(layer.config.base_delay, Duration::from_millis(50));
assert_eq!(layer.config.max_delay, Duration::from_secs(10));
}
#[test]
fn test_retry_layer_conservative() {
let layer = RetryLayer::conservative();
assert_eq!(layer.config.max_retries, 3);
assert_eq!(layer.config.base_delay, Duration::from_millis(500));
assert_eq!(layer.config.max_delay, Duration::from_secs(60));
}
#[test]
fn test_calculate_backoff() {
let config = RetryConfig {
max_retries: 5,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(10),
};
assert_eq!(calculate_backoff(0, &config), Duration::from_millis(100));
assert_eq!(calculate_backoff(1, &config), Duration::from_millis(200));
assert_eq!(calculate_backoff(2, &config), Duration::from_millis(400));
assert_eq!(calculate_backoff(3, &config), Duration::from_millis(800));
}
#[test]
fn test_calculate_backoff_capped() {
let config = RetryConfig {
max_retries: 10,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(500),
};
assert_eq!(calculate_backoff(3, &config), Duration::from_millis(500));
assert_eq!(calculate_backoff(10, &config), Duration::from_millis(500));
}
#[test]
fn test_calculate_backoff_overflow_protection() {
let config = RetryConfig {
max_retries: 100,
base_delay: Duration::from_secs(1),
max_delay: Duration::from_secs(60),
};
assert_eq!(calculate_backoff(50, &config), Duration::from_secs(60));
}
}