use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context as TaskContext, Poll};
use std::time::Duration;
use tower::{Layer, Service};
pub type RetryPredicate = Arc<dyn Fn(&http::Response<bytes::Bytes>) -> bool + Send + Sync>;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub max_retries: u32,
pub initial_delay: Duration,
pub max_delay: Duration,
pub backoff_multiplier: f64,
pub jitter: f64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: 3,
initial_delay: Duration::from_millis(100),
max_delay: Duration::from_secs(30),
backoff_multiplier: 2.0,
jitter: 0.1,
}
}
}
impl RetryConfig {
pub fn exponential_backoff(max_retries: u32) -> Self {
Self {
max_retries,
..Default::default()
}
}
pub fn delay(&self, attempt: u32) -> Duration {
let exponent = self
.backoff_multiplier
.powi(attempt.saturating_sub(1) as i32);
let delay_millis = self.initial_delay.as_millis() as f64 * exponent;
let delay = Duration::from_millis(delay_millis.min(u64::MAX as f64) as u64);
delay.min(self.max_delay)
}
}
pub type RetryResult<T, E> = Result<T, RetryError<E>>;
#[derive(Debug, Clone)]
pub enum RetryError<E> {
Exhausted {
last_error: E,
attempts: u32,
},
Cancelled,
Error(E),
}
impl<E> RetryError<E> {
pub fn last_error(&self) -> Option<&E> {
match self {
RetryError::Exhausted { last_error, .. } => Some(last_error),
RetryError::Error(e) => Some(e),
RetryError::Cancelled => None,
}
}
}
impl<E: std::fmt::Display> std::fmt::Display for RetryError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
RetryError::Exhausted { attempts, .. } => {
write!(f, "Retry exhausted after {} attempts", attempts)
}
RetryError::Cancelled => write!(f, "Retry cancelled"),
RetryError::Error(e) => write!(f, "Retry error: {}", e),
}
}
}
impl<E: std::fmt::Debug + std::fmt::Display> std::error::Error for RetryError<E> {}
pub async fn retry<F, Fut, T, E>(config: RetryConfig, mut operation: F) -> RetryResult<T, E>
where
F: FnMut() -> Fut,
Fut: std::future::Future<Output = Result<T, E>>,
E: Clone,
{
use rand::RngExt;
let mut last_error: Option<E> = None;
for attempt in 0..=config.max_retries {
let result = operation().await;
match result {
Ok(value) => return Ok(value),
Err(e) => {
last_error = Some(e.clone());
if attempt < config.max_retries {
let delay = config.delay(attempt + 1);
let jitter_range = delay.as_millis() as f64 * config.jitter;
let jitter_millis = rand::rng().random_range(-jitter_range..jitter_range);
let delay_with_jitter =
delay.saturating_add(Duration::from_millis(jitter_millis.abs() as u64));
tokio::time::sleep(delay_with_jitter).await;
}
}
}
}
Err(RetryError::Exhausted {
last_error: last_error.expect("last_error must be set after retries"),
attempts: config.max_retries,
})
}
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub jitter: bool,
}
impl Default for RetryPolicy {
fn default() -> Self {
Self {
max_attempts: 3,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(30),
jitter: true,
}
}
}
impl RetryPolicy {
pub fn backoff(&self, attempt_index: u32) -> Duration {
let factor = 2u64.saturating_pow(attempt_index);
let base_ms = self.initial_backoff.as_millis() as u64;
let computed_ms = base_ms.saturating_mul(factor);
let cap_ms = self.max_backoff.as_millis() as u64;
let capped_ms = computed_ms.min(cap_ms);
if self.jitter {
use rand::RngExt;
let jitter_ms = rand::rng().random_range(0..=(capped_ms / 4).max(1));
Duration::from_millis(capped_ms.saturating_add(jitter_ms))
} else {
Duration::from_millis(capped_ms)
}
}
}
fn default_should_retry<B>(resp: &http::Response<B>) -> bool {
let status = resp.status().as_u16();
if status >= 500 {
return true;
}
if (400..500).contains(&status) {
return !matches!(status, 400 | 401 | 403 | 404);
}
false
}
#[derive(Clone)]
pub struct RetryLayer {
policy: RetryPolicy,
should_retry: RetryPredicate,
}
impl std::fmt::Debug for RetryLayer {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RetryLayer")
.field("policy", &self.policy)
.finish()
}
}
impl RetryLayer {
pub fn new(policy: RetryPolicy) -> Self {
Self {
policy,
should_retry: Arc::new(default_should_retry),
}
}
pub fn with_predicate<F>(mut self, f: F) -> Self
where
F: Fn(&http::Response<bytes::Bytes>) -> bool + Send + Sync + 'static,
{
self.should_retry = Arc::new(f) as RetryPredicate;
self
}
}
impl<S> Layer<S> for RetryLayer {
type Service = RetryService<S>;
fn layer(&self, inner: S) -> Self::Service {
RetryService {
inner,
policy: self.policy.clone(),
should_retry: Arc::clone(&self.should_retry),
}
}
}
#[derive(Clone)]
pub struct RetryService<S> {
inner: S,
policy: RetryPolicy,
should_retry: RetryPredicate,
}
impl<S, B> Service<http::Request<B>> for RetryService<S>
where
S: Service<http::Request<B>> + Clone + Send + 'static,
S::Response: Into<http::Response<bytes::Bytes>> + Send + 'static,
S::Error: Send + 'static,
S::Future: Send + 'static,
B: Clone + Send + 'static,
{
type Response = http::Response<bytes::Bytes>;
type Error = S::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: http::Request<B>) -> Self::Future {
let policy = self.policy.clone();
let should_retry = Arc::clone(&self.should_retry);
let inner = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, inner);
Box::pin(async move {
let mut attempt = 0u32;
loop {
let req_clone = http::Request::builder()
.method(req.method().clone())
.uri(req.uri().clone())
.version(req.version())
.body(req.body().clone())
.unwrap_or_else(|_| {
let mut r = http::Request::new(req.body().clone());
*r.method_mut() = req.method().clone();
*r.uri_mut() = req.uri().clone();
*r.version_mut() = req.version();
r
});
let (mut parts, body) = req_clone.into_parts();
for (name, value) in req.headers() {
parts.headers.insert(name.clone(), value.clone());
}
let req_attempt = http::Request::from_parts(parts, body);
let mut svc = if attempt == 0 {
inner.clone()
} else {
inner.clone()
};
match svc.call(req_attempt).await {
Err(e) => {
attempt += 1;
if attempt >= policy.max_attempts {
return Err(e);
}
let backoff = policy.backoff(attempt - 1);
tokio::time::sleep(backoff).await;
}
Ok(resp) => {
let resp: http::Response<bytes::Bytes> = resp.into();
if (should_retry)(&resp) {
attempt += 1;
if attempt >= policy.max_attempts {
return Ok(resp);
}
let backoff = policy.backoff(attempt - 1);
tokio::time::sleep(backoff).await;
} else {
return Ok(resp);
}
}
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tower::{ServiceBuilder, ServiceExt};
#[test]
fn test_retry_config_default() {
let config = RetryConfig::default();
assert_eq!(config.max_retries, 3);
assert_eq!(config.initial_delay, Duration::from_millis(100));
assert_eq!(config.backoff_multiplier, 2.0);
}
#[test]
fn test_retry_config_delay() {
let config = RetryConfig {
initial_delay: Duration::from_millis(100),
backoff_multiplier: 2.0,
max_delay: Duration::from_secs(10),
..Default::default()
};
assert_eq!(config.delay(1), Duration::from_millis(100));
assert_eq!(config.delay(2), Duration::from_millis(200));
assert_eq!(config.delay(3), Duration::from_millis(400));
assert_eq!(config.delay(4), Duration::from_millis(800));
}
#[tokio::test]
async fn test_retry_success_first_attempt() {
let config = RetryConfig::default();
let result: Result<i32, RetryError<&str>> = retry(config, || async { Ok(42) }).await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), 42);
}
#[tokio::test]
async fn test_retry_success_after_retries() {
let config = RetryConfig {
max_retries: 3,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let attempts = AtomicUsize::new(0);
let result = retry(config, || {
let count = attempts.fetch_add(1, Ordering::SeqCst);
async move {
if count >= 2 {
Ok("success")
} else {
Err("fail")
}
}
})
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap(), "success");
assert!(attempts.load(Ordering::SeqCst) >= 3);
}
#[tokio::test]
async fn test_retry_exhausted() {
let config = RetryConfig {
max_retries: 2,
initial_delay: Duration::from_millis(10),
..Default::default()
};
let result = retry(config, || async { Err::<&str, _>("always fails") }).await;
match result {
Err(RetryError::Exhausted { attempts, .. }) => assert_eq!(attempts, 2),
_ => panic!("Expected Exhausted error"),
}
}
#[test]
fn test_retry_policy_backoff_doubles() {
let policy = RetryPolicy {
max_attempts: 5,
initial_backoff: Duration::from_millis(100),
max_backoff: Duration::from_secs(60),
jitter: false, };
assert_eq!(policy.backoff(0), Duration::from_millis(100));
assert_eq!(policy.backoff(1), Duration::from_millis(200));
assert_eq!(policy.backoff(2), Duration::from_millis(400));
assert_eq!(policy.backoff(3), Duration::from_millis(800));
}
#[test]
fn test_retry_policy_backoff_capped() {
let policy = RetryPolicy {
max_attempts: 10,
initial_backoff: Duration::from_millis(1000),
max_backoff: Duration::from_millis(2000),
jitter: false,
};
assert_eq!(policy.backoff(0), Duration::from_millis(1000));
assert_eq!(policy.backoff(1), Duration::from_millis(2000));
assert_eq!(policy.backoff(2), Duration::from_millis(2000)); }
fn make_request(path: &str) -> http::Request<String> {
http::Request::builder()
.method("GET")
.uri(path)
.body(String::new())
.unwrap()
}
#[tokio::test]
async fn test_retry_layer_succeeds_first_try() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
cc.fetch_add(1, Ordering::SeqCst);
async {
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(200)
.body(bytes::Bytes::new())
.unwrap(),
)
}
});
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
jitter: false,
};
let mut svc = ServiceBuilder::new()
.layer(RetryLayer::new(policy))
.service(inner);
let resp = svc
.ready()
.await
.unwrap()
.call(make_request("/"))
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 1); }
#[tokio::test]
async fn test_retry_layer_retries_on_error() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
let n = cc.fetch_add(1, Ordering::SeqCst);
async move {
if n < 2 {
Err("transient")
} else {
Ok(http::Response::builder()
.status(200)
.body(bytes::Bytes::new())
.unwrap())
}
}
});
let policy = RetryPolicy {
max_attempts: 5,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
jitter: false,
};
tokio::time::pause();
let mut svc = ServiceBuilder::new()
.layer(RetryLayer::new(policy))
.service(inner);
let resp = svc
.ready()
.await
.unwrap()
.call(make_request("/"))
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_layer_exhausts_attempts() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
cc.fetch_add(1, Ordering::SeqCst);
async { Err::<http::Response<bytes::Bytes>, _>("always") }
});
let policy = RetryPolicy {
max_attempts: 3,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
jitter: false,
};
tokio::time::pause();
let mut svc = ServiceBuilder::new()
.layer(RetryLayer::new(policy))
.service(inner);
let result = svc.ready().await.unwrap().call(make_request("/")).await;
assert!(result.is_err());
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_retry_layer_5xx_triggers_retry() {
let call_count = Arc::new(AtomicUsize::new(0));
let cc = Arc::clone(&call_count);
let inner = tower::service_fn(move |_req: http::Request<String>| {
let n = cc.fetch_add(1, Ordering::SeqCst);
async move {
let status = if n < 2 { 503 } else { 200 };
Ok::<_, std::convert::Infallible>(
http::Response::builder()
.status(status)
.body(bytes::Bytes::new())
.unwrap(),
)
}
});
let policy = RetryPolicy {
max_attempts: 5,
initial_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(10),
jitter: false,
};
tokio::time::pause();
let mut svc = ServiceBuilder::new()
.layer(RetryLayer::new(policy))
.service(inner);
let resp = svc
.ready()
.await
.unwrap()
.call(make_request("/"))
.await
.unwrap();
assert_eq!(resp.status(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
}