use crate::config::{ExponentialBackoff, RetryConfig, RetryTrigger};
use crate::error::HttpError;
use crate::response::{ResponseBody, parse_retry_after};
use bytes::Bytes;
use http::{HeaderValue, Request, Response};
use http_body_util::{BodyExt, Full};
use rand::RngExt;
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::Duration;
use tower::{Layer, Service, ServiceExt};
pub const RETRY_ATTEMPT_HEADER: &str = "X-Retry-Attempt";
#[derive(Clone)]
pub struct RetryLayer {
config: RetryConfig,
total_timeout: Option<Duration>,
}
impl RetryLayer {
#[must_use]
pub fn new(config: RetryConfig) -> Self {
Self {
config,
total_timeout: None,
}
}
#[must_use]
pub fn with_total_timeout(config: RetryConfig, total_timeout: Option<Duration>) -> Self {
Self {
config,
total_timeout,
}
}
}
impl<S> Layer<S> for RetryLayer {
type Service = RetryService<S>;
fn layer(&self, inner: S) -> Self::Service {
RetryService {
inner,
config: self.config.clone(),
total_timeout: self.total_timeout,
}
}
}
#[derive(Clone)]
pub struct RetryService<S> {
inner: S,
config: RetryConfig,
total_timeout: Option<Duration>,
}
impl<S> Service<Request<Full<Bytes>>> for RetryService<S>
where
S: Service<Request<Full<Bytes>>, Response = Response<ResponseBody>, Error = HttpError>
+ Clone
+ Send
+ 'static,
S::Future: Send,
{
type Response = S::Response;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
let clone = self.inner.clone();
let inner = std::mem::replace(&mut self.inner, clone);
let config = self.config.clone();
let total_timeout = self.total_timeout;
let (parts, body_bytes) = req.into_parts();
let http_version = parts.version;
let extensions = parts.extensions.clone();
let has_idempotency_key = config
.idempotency_key_header
.as_ref()
.is_some_and(|name| parts.headers.contains_key(name));
let parts = std::sync::Arc::new(parts);
Box::pin(async move {
let method = parts.method.clone();
let url_host = parts
.uri
.authority()
.map(ToString::to_string)
.or_else(|| parts.uri.host().map(ToOwned::to_owned))
.unwrap_or_else(|| "unknown".to_owned());
let request_id = parts
.headers
.get("x-request-id")
.or_else(|| parts.headers.get("x-correlation-id"))
.and_then(|v| v.to_str().ok())
.map(String::from);
let deadline_info = total_timeout.map(|t| (tokio::time::Instant::now() + t, t));
let mut attempt = 0usize;
loop {
if let Some((deadline, timeout_duration)) = deadline_info
&& tokio::time::Instant::now() >= deadline
{
return Err(HttpError::DeadlineExceeded(timeout_duration));
}
let mut req = Request::from_parts((*parts).clone(), body_bytes.clone());
*req.version_mut() = http_version;
*req.extensions_mut() = extensions.clone();
if attempt > 0 {
if let Ok(value) = HeaderValue::try_from(attempt.to_string()) {
req.headers_mut().insert(RETRY_ATTEMPT_HEADER, value);
}
}
let mut svc = inner.clone();
svc.ready().await?;
match svc.call(req).await {
Ok(resp) => {
let status_code = resp.status().as_u16();
let trigger = RetryTrigger::Status(status_code);
if config.max_retries > 0
&& attempt < config.max_retries
&& config.should_retry(trigger, &method, has_idempotency_key)
{
let retry_after = parse_retry_after(resp.headers())
.map(|d| d.min(config.backoff.max));
let backoff_duration = if config.ignore_retry_after {
calculate_backoff(&config.backoff, attempt)
} else {
retry_after
.unwrap_or_else(|| calculate_backoff(&config.backoff, attempt))
};
let drain_limit = config.retry_response_drain_limit;
let should_drain = if config.skip_drain_on_retry {
tracing::trace!("Skipping drain: skip_drain_on_retry enabled");
false
} else if let Some(content_length) = resp
.headers()
.get(http::header::CONTENT_LENGTH)
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
{
if content_length > drain_limit as u64 {
tracing::debug!(
content_length,
drain_limit,
"Skipping drain: Content-Length exceeds limit"
);
false
} else {
true
}
} else {
true
};
if should_drain
&& let Err(e) = drain_response_body(resp, drain_limit).await
{
tracing::debug!(
error = %e,
"Failed to drain response body before retry; connection may not be reused"
);
}
let effective_backoff =
if let Some((deadline, timeout_duration)) = deadline_info {
let remaining = deadline
.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(HttpError::DeadlineExceeded(timeout_duration));
}
backoff_duration.min(remaining)
} else {
backoff_duration
};
tracing::debug!(
retry = attempt + 1,
max_retries = config.max_retries,
status = status_code,
trigger = ?trigger,
method = %method,
host = %url_host,
request_id = ?request_id,
backoff_ms = effective_backoff.as_millis(),
retry_after_used = retry_after.is_some() && !config.ignore_retry_after,
"Retrying request after status code"
);
tokio::time::sleep(effective_backoff).await;
attempt += 1;
continue;
}
return Ok(resp);
}
Err(err) => {
if config.max_retries == 0 || attempt >= config.max_retries {
return Err(err);
}
let trigger = get_retry_trigger(&err);
if !config.should_retry(trigger, &method, has_idempotency_key) {
return Err(err);
}
let backoff_duration = calculate_backoff(&config.backoff, attempt);
let effective_backoff =
if let Some((deadline, timeout_duration)) = deadline_info {
let remaining =
deadline.saturating_duration_since(tokio::time::Instant::now());
if remaining.is_zero() {
return Err(HttpError::DeadlineExceeded(timeout_duration));
}
backoff_duration.min(remaining)
} else {
backoff_duration
};
tracing::debug!(
retry = attempt + 1,
max_retries = config.max_retries,
error = %err,
trigger = ?trigger,
method = %method,
host = %url_host,
request_id = ?request_id,
backoff_ms = effective_backoff.as_millis(),
"Retrying request after error"
);
tokio::time::sleep(effective_backoff).await;
attempt += 1;
}
}
}
})
}
}
async fn drain_response_body(
response: Response<ResponseBody>,
limit: usize,
) -> Result<(), HttpError> {
let (_parts, body) = response.into_parts();
let mut body = std::pin::pin!(body);
let mut drained = 0usize;
while let Some(frame) = body.frame().await {
let frame = frame.map_err(HttpError::Transport)?;
if let Some(chunk) = frame.data_ref() {
drained += chunk.len();
if drained >= limit {
break;
}
}
}
Ok(())
}
fn get_retry_trigger(err: &HttpError) -> RetryTrigger {
match err {
HttpError::Transport(_) => RetryTrigger::TransportError,
HttpError::Timeout(_) => RetryTrigger::Timeout,
_ => RetryTrigger::NonRetryable,
}
}
pub fn calculate_backoff(backoff: &ExponentialBackoff, attempt: usize) -> Duration {
const MAX_BACKOFF_SECS: f64 = 86400.0;
let attempt_i32 = i32::try_from(attempt).unwrap_or(i32::MAX);
let multiplier = if backoff.multiplier.is_finite() && backoff.multiplier >= 0.0 {
backoff.multiplier
} else {
1.0
};
let initial_secs = backoff.initial.as_secs_f64();
let initial_secs = if initial_secs.is_finite() && initial_secs >= 0.0 {
initial_secs
} else {
0.0
};
let max_secs = backoff.max.as_secs_f64();
let max_secs = if max_secs.is_finite() && max_secs >= 0.0 {
max_secs.min(MAX_BACKOFF_SECS)
} else {
MAX_BACKOFF_SECS
};
let base_duration = initial_secs * multiplier.powi(attempt_i32);
let clamped = if base_duration.is_finite() {
base_duration.min(max_secs).max(0.0)
} else {
max_secs
};
let duration = Duration::from_secs_f64(clamped);
let duration = if backoff.jitter {
let mut rng = rand::rng();
let jitter_factor = rng.random_range(0.0..=0.25);
let jitter = duration.mul_f64(jitter_factor);
duration + jitter
} else {
duration
};
let max_duration = Duration::from_secs_f64(max_secs);
duration.min(max_duration)
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use crate::config::IDEMPOTENCY_KEY_HEADER;
use bytes::Bytes;
use http::{Method, Request, Response, StatusCode};
use http_body_util::Full;
fn make_response_body(data: &[u8]) -> ResponseBody {
let body = Full::new(Bytes::from(data.to_vec()));
body.map_err(|e| -> Box<dyn std::error::Error + Send + Sync> { Box::new(e) })
.boxed()
}
#[tokio::test]
async fn test_retry_layer_successful_request() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct CountingService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for CountingService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
let response = Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap();
Ok(response)
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = CountingService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig::default();
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 1); }
#[tokio::test]
async fn test_retry_layer_post_not_retried_on_5xx() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct ServerErrorService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for ServerErrorService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(make_response_body(b"Internal Server Error"))
.unwrap())
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = ServerErrorService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(*call_count.lock().unwrap(), 1); }
#[tokio::test]
async fn test_retry_layer_get_retried_on_5xx() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct FailThenSucceedService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 3 {
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(make_response_body(b"Internal Server Error"))
.unwrap())
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = FailThenSucceedService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), StatusCode::OK);
assert_eq!(*call_count.lock().unwrap(), 3); }
#[tokio::test]
async fn test_retry_layer_always_retries_429() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RateLimitThenSucceedService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for RateLimitThenSucceedService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 2 {
Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.body(make_response_body(b"Rate limited"))
.unwrap())
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = RateLimitThenSucceedService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), StatusCode::OK);
assert_eq!(*call_count.lock().unwrap(), 2); }
#[tokio::test]
async fn test_retry_layer_retries_transport_errors() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct FailThenSucceedService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 3 {
Err(HttpError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
))))
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = FailThenSucceedService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 3); }
#[tokio::test]
async fn test_retry_layer_post_not_retried_on_transport_error() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct TransportErrorService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for TransportErrorService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
Err(HttpError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
))))
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = TransportErrorService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_err()); assert_eq!(*call_count.lock().unwrap(), 1); }
#[tokio::test]
async fn test_retry_layer_post_with_idempotency_key_retried() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct FailThenSucceedService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for FailThenSucceedService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 3 {
Err(HttpError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
))))
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = FailThenSucceedService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.header(IDEMPOTENCY_KEY_HEADER, "unique-key-123")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok()); assert_eq!(*call_count.lock().unwrap(), 3); }
#[tokio::test]
async fn test_retry_layer_does_not_retry_json_errors() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct JsonErrorService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for JsonErrorService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
let err: serde_json::Error =
serde_json::from_str::<serde_json::Value>("invalid").unwrap_err();
Err(HttpError::Json(err))
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = JsonErrorService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig::default();
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_err());
assert_eq!(*call_count.lock().unwrap(), 1); }
#[test]
fn test_calculate_backoff_no_jitter() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: 2.0,
jitter: false,
};
let backoff0 = calculate_backoff(&backoff, 0);
assert_eq!(backoff0, Duration::from_millis(100));
let backoff1 = calculate_backoff(&backoff, 1);
assert_eq!(backoff1, Duration::from_millis(200));
let backoff2 = calculate_backoff(&backoff, 2);
assert_eq!(backoff2, Duration::from_millis(400));
let backoff_capped = calculate_backoff(&backoff, 10);
assert_eq!(backoff_capped, Duration::from_secs(10));
}
#[test]
fn test_calculate_backoff_with_jitter() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: 2.0,
jitter: true,
};
let backoff0 = calculate_backoff(&backoff, 0);
assert!(backoff0 >= Duration::from_millis(100));
assert!(backoff0 <= Duration::from_millis(125));
}
#[test]
fn test_calculate_backoff_with_nan_multiplier() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: f64::NAN,
jitter: false,
};
let result = calculate_backoff(&backoff, 0);
assert_eq!(result, Duration::from_millis(100));
let result1 = calculate_backoff(&backoff, 1);
assert_eq!(result1, Duration::from_millis(100));
}
#[test]
fn test_calculate_backoff_with_infinity_multiplier() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: f64::INFINITY,
jitter: false,
};
let result = calculate_backoff(&backoff, 0);
assert_eq!(result, Duration::from_millis(100));
}
#[test]
fn test_calculate_backoff_with_negative_multiplier() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: -2.0,
jitter: false,
};
let result = calculate_backoff(&backoff, 0);
assert_eq!(result, Duration::from_millis(100));
}
#[test]
fn test_calculate_backoff_with_huge_attempt() {
let backoff = ExponentialBackoff {
initial: Duration::from_millis(100),
max: Duration::from_secs(10),
multiplier: 2.0,
jitter: false,
};
let result = calculate_backoff(&backoff, usize::MAX);
assert_eq!(result, Duration::from_secs(10));
}
#[tokio::test]
async fn test_retry_layer_uses_retry_after_header() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RetryAfterService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for RetryAfterService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 2 {
Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(http::header::RETRY_AFTER, "0")
.body(make_response_body(b"Rate limited"))
.unwrap())
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = RetryAfterService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff {
initial: Duration::from_secs(10), jitter: false,
..ExponentialBackoff::default()
},
ignore_retry_after: false, ..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let start = std::time::Instant::now();
let result = retry_service.call(req).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 2);
assert!(
elapsed < Duration::from_secs(1),
"Expected quick retry using Retry-After, but took {elapsed:?}",
);
}
#[tokio::test]
async fn test_retry_layer_ignores_retry_after_when_configured() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct RetryAfterService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for RetryAfterService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 2 {
Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(http::header::RETRY_AFTER, "10")
.body(make_response_body(b"Rate limited"))
.unwrap())
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = RetryAfterService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(), ignore_retry_after: true, ..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let start = std::time::Instant::now();
let result = retry_service.call(req).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 2);
assert!(
elapsed < Duration::from_secs(1),
"Expected quick retry using backoff policy (1ms), but took {elapsed:?}",
);
}
#[tokio::test]
async fn test_retry_after_clamped_to_backoff_max() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct LargeRetryAfterService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for LargeRetryAfterService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
if *c < 2 {
Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
.header(http::header::RETRY_AFTER, "3600")
.body(make_response_body(b"Rate limited"))
.unwrap())
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = LargeRetryAfterService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff {
initial: Duration::from_millis(1),
max: Duration::from_millis(50), jitter: false,
..ExponentialBackoff::default()
},
ignore_retry_after: false, ..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::POST)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let start = std::time::Instant::now();
let result = retry_service.call(req).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 2);
assert!(
elapsed < Duration::from_secs(1),
"Retry-After should be clamped to backoff.max (50ms), but took {elapsed:?}",
);
}
#[tokio::test]
async fn test_retry_attempt_header_added_on_retry() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct HeaderCapturingService {
call_count: Arc<Mutex<usize>>,
captured_headers: Arc<Mutex<Vec<Option<String>>>>,
}
impl Service<Request<Full<Bytes>>> for HeaderCapturingService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
let captured_headers = self.captured_headers.clone();
let retry_header = req
.headers()
.get(RETRY_ATTEMPT_HEADER)
.map(|v| v.to_str().unwrap_or("invalid").to_owned());
Box::pin(async move {
let mut c = count.lock().unwrap();
*c += 1;
captured_headers.lock().unwrap().push(retry_header);
if *c < 3 {
Err(HttpError::Transport(Box::new(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
))))
} else {
Ok(Response::builder()
.status(StatusCode::OK)
.body(make_response_body(b""))
.unwrap())
}
})
}
}
let call_count = Arc::new(Mutex::new(0));
let captured_headers = Arc::new(Mutex::new(Vec::new()));
let service = HeaderCapturingService {
call_count: call_count.clone(),
captured_headers: captured_headers.clone(),
};
let retry_config = RetryConfig {
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
assert_eq!(*call_count.lock().unwrap(), 3);
let headers = captured_headers.lock().unwrap();
assert_eq!(headers.len(), 3);
assert_eq!(headers[0], None);
assert_eq!(headers[1], Some("1".to_owned()));
assert_eq!(headers[2], Some("2".to_owned()));
}
#[tokio::test]
async fn test_retry_layer_exhausted_returns_ok_with_status() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct AlwaysFailService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for AlwaysFailService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
Ok(Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(make_response_body(b"error"))
.unwrap())
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = AlwaysFailService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
max_retries: 2,
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(*call_count.lock().unwrap(), 3);
}
#[tokio::test]
async fn test_retry_layer_non_retryable_status_passes_through() {
use std::sync::{Arc, Mutex};
#[derive(Clone)]
struct NotFoundService {
call_count: Arc<Mutex<usize>>,
}
impl Service<Request<Full<Bytes>>> for NotFoundService {
type Response = Response<ResponseBody>;
type Error = HttpError;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: Request<Full<Bytes>>) -> Self::Future {
let count = self.call_count.clone();
Box::pin(async move {
*count.lock().unwrap() += 1;
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(make_response_body(b"not found"))
.unwrap())
})
}
}
let call_count = Arc::new(Mutex::new(0));
let service = NotFoundService {
call_count: call_count.clone(),
};
let retry_config = RetryConfig {
max_retries: 3,
backoff: ExponentialBackoff::fast(),
..RetryConfig::default()
};
let layer = RetryLayer::new(retry_config);
let mut retry_service = layer.layer(service);
let req = Request::builder()
.method(Method::GET)
.uri("http://example.com")
.body(Full::new(Bytes::new()))
.unwrap();
let result = retry_service.call(req).await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status(), StatusCode::NOT_FOUND);
assert_eq!(*call_count.lock().unwrap(), 1);
}
}