use std::time::Duration;
use anyhow::anyhow;
use axum::http::Extensions;
use reqwest::{Client, Request, Response, StatusCode};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware, Error, Middleware, Next, Result};
use reqwest_retry::{
DefaultRetryableStrategy, RetryTransientMiddleware, RetryableStrategy,
policies::ExponentialBackoff,
};
use tracing::debug;
pub fn build_rate_limited_client(base: Client, max_retries: u32) -> ClientWithMiddleware {
let backoff_policy = ExponentialBackoff::builder()
.retry_bounds(Duration::from_secs(1), Duration::from_secs(60))
.build_with_max_retries(max_retries);
ClientBuilder::new(base)
.with(RetryAfterMiddleware)
.with(RetryTransientMiddleware::new_with_policy_and_strategy(
backoff_policy,
No429Strategy,
))
.build()
}
struct RetryAfterMiddleware;
#[async_trait::async_trait]
impl Middleware for RetryAfterMiddleware {
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
let duplicate = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not clonable. Are you passing a streaming body?"
))
})?;
let response = next.clone().run(duplicate, extensions).await?;
if response.status() != StatusCode::TOO_MANY_REQUESTS {
return Ok(response);
}
let wait = response
.headers()
.get("Retry-After")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(1);
debug!(
wait_secs = wait,
"429 Too Many Requests — sleeping before retry"
);
tokio::time::sleep(Duration::from_secs(wait)).await;
let retry_req = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!("Request object is not clonable on retry attempt"))
})?;
next.run(retry_req, extensions).await
}
}
struct No429Strategy;
impl RetryableStrategy for No429Strategy {
fn handle(
&self,
res: &std::result::Result<reqwest::Response, reqwest_middleware::Error>,
) -> Option<reqwest_retry::Retryable> {
match res {
Ok(r) if r.status() == StatusCode::TOO_MANY_REQUESTS => {
None
}
other => DefaultRetryableStrategy.handle(other),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{
Arc,
atomic::{AtomicU32, Ordering},
};
#[tokio::test]
async fn honours_retry_after_seconds() {
let server = wiremock::MockServer::start().await;
let attempts = Arc::new(AtomicU32::new(0));
{
let attempts = attempts.clone();
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/x"))
.respond_with(move |_: &wiremock::Request| {
let n = attempts.fetch_add(1, Ordering::SeqCst);
if n == 0 {
wiremock::ResponseTemplate::new(429).insert_header("Retry-After", "1")
} else {
wiremock::ResponseTemplate::new(200).set_body_string("ok")
}
})
.mount(&server)
.await;
}
let client = build_rate_limited_client(reqwest::Client::new(), 3);
let start = std::time::Instant::now();
let resp = client
.get(format!("{}/x", server.uri()))
.send()
.await
.unwrap();
let elapsed = start.elapsed();
assert_eq!(resp.status(), 200);
assert!(
elapsed >= Duration::from_secs(1),
"expected ≥1 s elapsed, got {elapsed:?}"
);
assert_eq!(attempts.load(Ordering::SeqCst), 2, "expected 2 attempts");
}
#[tokio::test]
async fn fails_after_max_retries_on_5xx() {
let server = wiremock::MockServer::start().await;
let attempts = Arc::new(AtomicU32::new(0));
{
let attempts = attempts.clone();
wiremock::Mock::given(wiremock::matchers::method("GET"))
.and(wiremock::matchers::path("/y"))
.respond_with(move |_: &wiremock::Request| {
attempts.fetch_add(1, Ordering::SeqCst);
wiremock::ResponseTemplate::new(503)
})
.mount(&server)
.await;
}
let max_retries: u32 = 2;
let client = build_rate_limited_client(reqwest::Client::new(), max_retries);
let resp = client
.get(format!("{}/y", server.uri()))
.send()
.await
.unwrap();
assert_eq!(resp.status(), 503);
let total = attempts.load(Ordering::SeqCst);
assert_eq!(
total,
max_retries + 1,
"expected {} attempts, got {total}",
max_retries + 1
);
}
}