use std::fmt::Display;
use std::time::Duration;
use reqwest::{Client, RequestBuilder, Response, StatusCode};
use tokio::time::sleep;
use tracing::{error, info, warn};
#[derive(Debug, Clone)]
pub struct RetryConfig {
max_retries: u32,
timeout: Duration,
base_backoff: Duration,
}
impl RetryConfig {
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(5);
const DEFAULT_MAX_RETRIES: u32 = 3;
const DEFAULT_BASE_BACKOFF: Duration = Duration::from_millis(500);
pub fn new(max_retries: u32, timeout: Duration, base_backoff: Duration) -> Self {
Self {
max_retries,
timeout,
base_backoff,
}
}
pub fn max_retries(&self) -> u32 {
self.max_retries
}
pub fn timeout(&self) -> Duration {
self.timeout
}
pub fn base_backoff(&self) -> Duration {
self.base_backoff
}
pub fn build_client(&self) -> Client {
Client::builder()
.timeout(self.timeout)
.build()
.expect("failed to build HTTP client")
}
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_retries: Self::DEFAULT_MAX_RETRIES,
timeout: Self::DEFAULT_TIMEOUT,
base_backoff: Self::DEFAULT_BASE_BACKOFF,
}
}
}
pub type SuccessPredicate = fn(&Response) -> bool;
pub fn is_success_2xx(response: &Response) -> bool {
response.status().is_success()
}
pub fn is_accepted_202(response: &Response) -> bool {
response.status() == StatusCode::ACCEPTED
}
pub async fn deliver_with_retry(
config: &RetryConfig,
build_request: impl Fn() -> RequestBuilder,
is_success: SuccessPredicate,
subscriber_name: &str,
context: &(impl Display + ?Sized),
) {
for attempt in 0..config.max_retries {
let result = build_request().send().await;
match result {
Ok(resp) if is_success(&resp) => {
info!(
subscriber = subscriber_name,
context = %context,
"delivery succeeded"
);
return;
}
Ok(resp) => {
let status = resp.status();
log_retry_or_fail(
config,
attempt,
subscriber_name,
context,
&format!("HTTP {status}"),
);
}
Err(err) => {
log_retry_or_fail(config, attempt, subscriber_name, context, &err.to_string());
}
}
if attempt + 1 < config.max_retries {
let delay = config.base_backoff * 2u32.pow(attempt);
sleep(delay).await;
}
}
}
fn log_retry_or_fail(
config: &RetryConfig,
attempt: u32,
subscriber_name: &str,
context: &(impl Display + ?Sized),
err_msg: &str,
) {
let remaining = config.max_retries - attempt - 1;
if remaining > 0 {
warn!(
subscriber = subscriber_name,
context = %context,
attempt = attempt + 1,
remaining,
error = %err_msg,
"delivery failed, retrying"
);
} else {
error!(
subscriber = subscriber_name,
context = %context,
error = %err_msg,
"delivery failed after all retries"
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config_values() {
let config = RetryConfig::default();
assert_eq!(config.max_retries(), 3);
assert_eq!(config.timeout(), Duration::from_secs(5));
assert_eq!(config.base_backoff(), Duration::from_millis(500));
}
#[test]
fn custom_config_values() {
let config = RetryConfig::new(5, Duration::from_secs(10), Duration::from_secs(1));
assert_eq!(config.max_retries(), 5);
assert_eq!(config.timeout(), Duration::from_secs(10));
assert_eq!(config.base_backoff(), Duration::from_secs(1));
}
#[test]
fn build_client_succeeds() {
let config = RetryConfig::default();
let _client = config.build_client();
}
use axum::http::Response as HttpResponse;
#[test]
fn is_success_2xx_predicate() {
let response = HttpResponse::builder().status(200).body("").unwrap();
let reqwest_resp = Response::from(response);
assert!(is_success_2xx(&reqwest_resp));
}
#[test]
fn is_success_2xx_rejects_4xx() {
let response = HttpResponse::builder().status(400).body("").unwrap();
let reqwest_resp = Response::from(response);
assert!(!is_success_2xx(&reqwest_resp));
}
#[test]
fn is_accepted_202_predicate() {
let response = HttpResponse::builder().status(202).body("").unwrap();
let reqwest_resp = Response::from(response);
assert!(is_accepted_202(&reqwest_resp));
}
#[test]
fn is_accepted_202_rejects_200() {
let response = HttpResponse::builder().status(200).body("").unwrap();
let reqwest_resp = Response::from(response);
assert!(!is_accepted_202(&reqwest_resp));
}
#[tokio::test]
async fn deliver_succeeds_on_first_try() {
use axum::Router;
use axum::http::StatusCode;
use axum::routing::post;
use tokio::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = Router::new().route("/", post(|| async { StatusCode::OK }));
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let config = RetryConfig::default();
let client = config.build_client();
let url = format!("http://{}", addr);
deliver_with_retry(
&config,
|| client.post(&url).body("{}"),
is_success_2xx,
"test",
&url,
)
.await;
}
#[tokio::test]
async fn deliver_retries_on_server_error() {
use axum::Router;
use axum::http::StatusCode;
use axum::routing::post;
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use tokio::net::TcpListener;
let call_count = Arc::new(AtomicU32::new(0));
let count = call_count.clone();
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let app = Router::new().route(
"/",
post(move || {
let count = count.clone();
async move {
count.fetch_add(1, Ordering::SeqCst);
StatusCode::INTERNAL_SERVER_ERROR
}
}),
);
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
let config = RetryConfig::new(3, Duration::from_secs(5), Duration::from_millis(10));
let client = config.build_client();
let url = format!("http://{}", addr);
deliver_with_retry(
&config,
|| client.post(&url).body("{}"),
is_success_2xx,
"test",
&url,
)
.await;
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
}