use std::error::Error as StdError;
use std::fmt;
use std::io;
use std::ops::ControlFlow;
use std::time::Duration;
#[derive(Debug, Clone, Copy)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub base_delay: Duration,
pub max_delay: Duration,
}
impl RetryPolicy {
pub const UPLOAD: RetryPolicy = RetryPolicy {
max_attempts: 10,
base_delay: Duration::from_millis(50),
max_delay: Duration::from_secs(30),
};
pub fn delay_for(&self, next_attempt: u32) -> Duration {
let exp = next_attempt.saturating_sub(2);
let mult = 1u64.checked_shl(exp).unwrap_or(u64::MAX);
let ms = (self.base_delay.as_millis() as u64).saturating_mul(mult);
std::cmp::min(Duration::from_millis(ms), self.max_delay)
}
}
pub fn retry_sync<T, E, F>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
where
F: FnMut(u32) -> Result<T, ControlFlow<E, E>>,
{
let max = policy.max_attempts.max(1);
let mut attempt: u32 = 1;
loop {
if attempt > 1 {
std::thread::sleep(policy.delay_for(attempt));
}
match op(attempt) {
Ok(v) => return Ok(v),
Err(ControlFlow::Break(e)) => return Err(e),
Err(ControlFlow::Continue(e)) => {
if attempt >= max {
return Err(e);
}
}
}
attempt += 1;
}
}
pub async fn retry_async<T, E, F, Fut>(policy: &RetryPolicy, mut op: F) -> Result<T, E>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Result<T, ControlFlow<E, E>>>,
{
let max = policy.max_attempts.max(1);
let mut attempt: u32 = 1;
loop {
if attempt > 1 {
tokio::time::sleep(policy.delay_for(attempt)).await;
}
match op(attempt).await {
Ok(v) => return Ok(v),
Err(ControlFlow::Break(e)) => return Err(e),
Err(ControlFlow::Continue(e)) => {
if attempt >= max {
return Err(e);
}
}
}
attempt += 1;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SuccessClass {
Strict,
AllowRedirects,
}
pub fn retry_http_blocking<F, M>(
label: &str,
policy: &RetryPolicy,
success_class: SuccessClass,
mut send: F,
error_msg: M,
) -> anyhow::Result<(reqwest::StatusCode, String)>
where
F: FnMut(u32) -> Result<reqwest::blocking::Response, reqwest::Error>,
M: Fn(reqwest::StatusCode, &str) -> String,
{
use anyhow::Context as _;
retry_sync(policy, |attempt| {
match send(attempt) {
Ok(resp) => {
let status = resp.status();
let succeeded = match success_class {
SuccessClass::Strict => status.is_success(),
SuccessClass::AllowRedirects => status.is_success() || status.is_redirection(),
};
let body = resp
.text()
.unwrap_or_else(|e| format!("<failed to read body: {e}>"));
if succeeded {
Ok((status, body))
} else {
let msg = error_msg(status, &body);
let inner = anyhow::anyhow!("{msg}");
let wrapped = anyhow::Error::new(HttpError::new(
std::io::Error::other(inner.to_string()),
status.as_u16(),
))
.context(inner);
if is_retriable(wrapped.as_ref()) {
Err(ControlFlow::Continue(wrapped))
} else {
Err(ControlFlow::Break(wrapped))
}
}
}
Err(e) => {
let err = anyhow::Error::new(HttpError::from_response(e, None))
.context(format!("{label}: HTTP transport error"));
if is_retriable(err.as_ref()) {
Err(ControlFlow::Continue(err))
} else {
Err(ControlFlow::Break(err))
}
}
}
})
.with_context(|| format!("{label}: exhausted retry attempts"))
}
pub async fn retry_http_async<F, Fut, M>(
label: &str,
policy: &RetryPolicy,
success_class: SuccessClass,
mut send: F,
error_msg: M,
) -> anyhow::Result<reqwest::Response>
where
F: FnMut(u32) -> Fut,
Fut: std::future::Future<Output = Result<reqwest::Response, reqwest::Error>>,
M: Fn(reqwest::StatusCode, &str) -> String,
{
use anyhow::Context as _;
retry_async(policy, |attempt| {
let fut = send(attempt);
let error_msg = &error_msg;
async move {
match fut.await {
Ok(resp) => {
let status = resp.status();
let succeeded = match success_class {
SuccessClass::Strict => status.is_success(),
SuccessClass::AllowRedirects => {
status.is_success() || status.is_redirection()
}
};
if succeeded {
Ok(resp)
} else {
let body = resp
.text()
.await
.unwrap_or_else(|e| format!("<failed to read body: {e}>"));
let msg = error_msg(status, &body);
let inner = anyhow::anyhow!("{msg}");
let wrapped = anyhow::Error::new(HttpError::new(
std::io::Error::other(inner.to_string()),
status.as_u16(),
))
.context(inner);
if is_retriable(wrapped.as_ref()) {
Err(ControlFlow::Continue(wrapped))
} else {
Err(ControlFlow::Break(wrapped))
}
}
}
Err(e) => {
let err = anyhow::Error::new(HttpError::from_response(e, None))
.context(format!("{label}: HTTP transport error"));
if is_retriable(err.as_ref()) {
Err(ControlFlow::Continue(err))
} else {
Err(ControlFlow::Break(err))
}
}
}
}
})
.await
.with_context(|| format!("{label}: exhausted retry attempts"))
}
pub fn classify_http_sync(
result: reqwest::Result<reqwest::blocking::Response>,
) -> Result<reqwest::blocking::Response, ControlFlow<anyhow::Error, anyhow::Error>> {
use anyhow::anyhow;
match result {
Ok(resp) => {
let status = resp.status();
if status.is_success() || status.is_redirection() {
Ok(resp)
} else if status.is_server_error() {
Err(ControlFlow::Continue(anyhow!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("server error")
)))
} else {
Err(ControlFlow::Break(anyhow!(
"HTTP {} {}",
status.as_u16(),
status.canonical_reason().unwrap_or("client error")
)))
}
}
Err(e) => Err(ControlFlow::Continue(anyhow!(e))),
}
}
#[derive(Debug)]
pub struct HttpError {
source: Box<dyn StdError + Send + Sync + 'static>,
pub status: u16,
}
impl HttpError {
pub fn new<E>(source: E, status: u16) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self {
source: Box::new(source),
status,
}
}
pub fn from_response<E>(err: E, resp: Option<&reqwest::Response>) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self::new(err, resp.map(|r| r.status().as_u16()).unwrap_or(0))
}
}
impl fmt::Display for HttpError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.source, f)
}
}
impl StdError for HttpError {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&*self.source)
}
}
#[derive(Debug)]
pub struct Retriable(Box<dyn StdError + Send + Sync + 'static>);
impl Retriable {
pub fn new<E>(source: E) -> Self
where
E: StdError + Send + Sync + 'static,
{
Self(Box::new(source))
}
}
impl fmt::Display for Retriable {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Display::fmt(&self.0, f)
}
}
impl StdError for Retriable {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&*self.0)
}
}
pub fn is_network_error(err: &(dyn StdError + 'static)) -> bool {
let mut cur: Option<&(dyn StdError + 'static)> = Some(err);
while let Some(e) = cur {
if let Some(io_err) = e.downcast_ref::<io::Error>() {
match io_err.kind() {
io::ErrorKind::UnexpectedEof
| io::ErrorKind::TimedOut
| io::ErrorKind::ConnectionRefused
| io::ErrorKind::ConnectionReset
| io::ErrorKind::ConnectionAborted
| io::ErrorKind::BrokenPipe => return true,
_ => {}
}
let m = io_err.to_string().to_lowercase();
if m == "eof" || m == "unexpected eof" {
return true;
}
}
let s = e.to_string().to_lowercase();
if NETWORK_ERROR_NEEDLES.iter().any(|n| s.contains(n)) {
return true;
}
cur = e.source();
}
false
}
const NETWORK_ERROR_NEEDLES: &[&str] = &[
"connection reset",
"network is unreachable",
"connection closed",
"connection refused",
"tls handshake timeout",
"i/o timeout",
"broken pipe",
"timeout awaiting response headers",
"context deadline exceeded",
"operation timed out",
"the network connection was aborted",
"an existing connection was forcibly closed",
"dns error",
"failed to lookup address",
"no such host is known",
];
pub fn is_retriable(err: &(dyn StdError + 'static)) -> bool {
let mut cur: Option<&(dyn StdError + 'static)> = Some(err);
while let Some(e) = cur {
if e.is::<Retriable>() {
return true;
}
if let Some(http) = e.downcast_ref::<HttpError>()
&& (http.status >= 500 || http.status == 429)
{
return true;
}
cur = e.source();
}
is_network_error(err)
}
pub fn is_retriable_opt(err: Option<&(dyn StdError + 'static)>) -> bool {
err.is_some_and(is_retriable)
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
fn fast_policy() -> RetryPolicy {
RetryPolicy {
max_attempts: 4,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(5),
}
}
#[test]
fn delay_progression_caps_at_max() {
let p = RetryPolicy {
max_attempts: 10,
base_delay: Duration::from_millis(100),
max_delay: Duration::from_millis(500),
};
assert_eq!(p.delay_for(2), Duration::from_millis(100));
assert_eq!(p.delay_for(3), Duration::from_millis(200));
assert_eq!(p.delay_for(4), Duration::from_millis(400));
assert_eq!(p.delay_for(5), Duration::from_millis(500)); assert_eq!(p.delay_for(8), Duration::from_millis(500)); }
#[test]
fn sync_succeeds_on_first_attempt() {
let calls = AtomicU32::new(0);
let result: Result<&str, ()> = retry_sync(&fast_policy(), |_| {
calls.fetch_add(1, Ordering::SeqCst);
Ok("ok")
});
assert_eq!(result, Ok("ok"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn sync_retries_until_success() {
let calls = AtomicU32::new(0);
let result: Result<u32, &str> = retry_sync(&fast_policy(), |attempt| {
calls.fetch_add(1, Ordering::SeqCst);
if attempt < 3 {
Err(ControlFlow::Continue("transient"))
} else {
Ok(attempt)
}
});
assert_eq!(result, Ok(3));
assert_eq!(calls.load(Ordering::SeqCst), 3);
}
#[test]
fn sync_break_stops_immediately() {
let calls = AtomicU32::new(0);
let result: Result<(), &str> = retry_sync(&fast_policy(), |_| {
calls.fetch_add(1, Ordering::SeqCst);
Err(ControlFlow::Break("fatal"))
});
assert_eq!(result, Err("fatal"));
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn sync_returns_last_error_after_exhaustion() {
let calls = AtomicU32::new(0);
let result: Result<(), String> = retry_sync(&fast_policy(), |attempt| {
calls.fetch_add(1, Ordering::SeqCst);
Err(ControlFlow::Continue(format!("fail {attempt}")))
});
assert_eq!(result, Err("fail 4".to_string()));
assert_eq!(calls.load(Ordering::SeqCst), 4);
}
#[tokio::test]
async fn async_retries_until_success() {
let calls = std::sync::Arc::new(AtomicU32::new(0));
let calls_inner = calls.clone();
let result: Result<u32, &str> = retry_async(&fast_policy(), move |attempt| {
let c = calls_inner.clone();
async move {
c.fetch_add(1, Ordering::SeqCst);
if attempt < 2 {
Err(ControlFlow::Continue("transient"))
} else {
Ok(attempt)
}
}
})
.await;
assert_eq!(result, Ok(2));
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[derive(Debug)]
struct StrErr(&'static str);
impl fmt::Display for StrErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.0)
}
}
impl StdError for StrErr {}
#[derive(Debug)]
struct OwnedErr(String);
impl fmt::Display for OwnedErr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.0)
}
}
impl StdError for OwnedErr {}
#[test]
fn network_error_substrings_match() {
for s in [
"connection reset by peer",
"network is unreachable",
"connection closed unexpectedly",
"connection refused",
"tls handshake timeout",
"i/o timeout",
"CONNECTION RESET",
"TLS Handshake Timeout",
"write: broken pipe",
"net/http: timeout awaiting response headers",
"context deadline exceeded",
"client error (Connect): dns error: failed to lookup address information: Name or service not known",
"dns error: nodename nor servname provided, or not known",
"dns error: No such host is known. (os error 11001)",
] {
let e = OwnedErr(s.to_string());
assert!(is_network_error(&e), "expected network error: {s:?}");
}
}
#[test]
fn network_error_io_eof_kinds() {
let e = io::Error::from(io::ErrorKind::UnexpectedEof);
assert!(is_network_error(&e));
let e2 = io::Error::other("EOF");
assert!(is_network_error(&e2));
}
#[test]
fn is_network_error_classifies_io_timedout() {
let e = io::Error::from(io::ErrorKind::TimedOut);
assert!(is_network_error(&e));
assert!(is_retriable(&e));
}
#[test]
fn is_network_error_classifies_io_connection_refused() {
let e = io::Error::from(io::ErrorKind::ConnectionRefused);
assert!(is_network_error(&e));
assert!(is_retriable(&e));
}
#[test]
fn is_network_error_classifies_io_connection_reset() {
let e = io::Error::from(io::ErrorKind::ConnectionReset);
assert!(is_network_error(&e));
assert!(is_retriable(&e));
}
#[test]
fn is_network_error_classifies_io_connection_aborted() {
let e = io::Error::from(io::ErrorKind::ConnectionAborted);
assert!(is_network_error(&e));
assert!(is_retriable(&e));
}
#[test]
fn is_network_error_classifies_io_broken_pipe() {
let e = io::Error::from(io::ErrorKind::BrokenPipe);
assert!(is_network_error(&e));
assert!(is_retriable(&e));
}
#[test]
fn is_network_error_classifies_operation_timed_out_substring() {
let other_kind = io::Error::other("operation timed out");
assert!(is_network_error(&other_kind));
assert!(is_retriable(&other_kind));
let kind_only = io::Error::from(io::ErrorKind::TimedOut);
assert!(is_network_error(&kind_only));
assert!(is_retriable(&kind_only));
}
#[test]
fn network_error_wrapped_unexpected_eof() {
#[derive(Debug)]
struct Wrap(io::Error);
impl fmt::Display for Wrap {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "read failed")
}
}
impl StdError for Wrap {
fn source(&self) -> Option<&(dyn StdError + 'static)> {
Some(&self.0)
}
}
let inner = io::Error::from(io::ErrorKind::UnexpectedEof);
let outer = Wrap(inner);
assert!(is_network_error(&outer));
}
#[test]
fn network_error_non_network_strings_reject() {
for s in [
"file not found",
"permission denied",
"dial tcp: lookup example.com: no such host",
"",
] {
let e = OwnedErr(s.to_string());
assert!(!is_network_error(&e), "expected NOT network error: {s:?}");
}
}
#[test]
fn retriable_opt_nil_passthrough() {
assert!(!is_retriable_opt(None));
}
#[test]
fn http_error_500_retriable() {
let e = HttpError::new(StrErr("internal server error"), 500);
assert!(is_retriable(&e));
}
#[test]
fn http_error_502_503_retriable() {
for s in [502u16, 503] {
let e = HttpError::new(StrErr("bad gateway"), s);
assert!(is_retriable(&e), "status {s} should be retriable");
}
}
#[test]
fn http_error_429_retriable() {
let e = HttpError::new(StrErr("rate limited"), 429);
assert!(is_retriable(&e));
}
#[test]
fn http_error_4xx_not_retriable() {
for s in [400u16, 401, 403, 404, 422] {
let e = HttpError::new(StrErr("client err"), s);
assert!(!is_retriable(&e), "status {s} should NOT be retriable");
}
}
#[test]
fn http_error_zero_status_routes_via_message() {
let net = HttpError::new(StrErr("connection reset"), 0);
assert!(is_retriable(&net));
let non_net = HttpError::new(StrErr("dial failed"), 0);
assert!(!is_retriable(&non_net));
}
#[test]
fn http_error_unwrap_chain_visible() {
let inner = StrErr("inner");
let e = HttpError::new(inner, 503);
assert!(e.source().is_some());
}
#[test]
fn from_response_nil_resp_yields_status_zero() {
let inner = io::Error::other("connect: dial tcp");
let e = HttpError::from_response(inner, None);
assert_eq!(e.status, 0);
}
#[test]
fn from_response_unwrap_chain_visible() {
let inner = io::Error::other("connection reset by peer");
let e = HttpError::from_response(inner, None);
assert!(
e.source().is_some(),
"inner error must be reachable via source()"
);
assert!(is_retriable(&e));
}
#[test]
fn retriable_wrapper_is_retriable() {
let e = Retriable::new(StrErr("retry me"));
assert!(is_retriable(&e));
}
#[test]
fn retriable_wrapper_overrides_4xx() {
let inner = HttpError::new(StrErr("exists"), 422);
let outer = Retriable::new(inner);
assert!(is_retriable(&outer));
}
#[test]
fn retriable_wrapper_unwrap_chain_visible() {
let inner = StrErr("inner");
let e = Retriable::new(inner);
assert!(e.source().is_some());
}
#[test]
fn plain_error_not_retriable() {
let e = StrErr("something");
assert!(!is_retriable(&e));
}
#[test]
fn anyhow_error_threadable() {
let e: anyhow::Error = anyhow::anyhow!("connection refused");
assert!(is_retriable(e.as_ref()));
let e2: anyhow::Error = anyhow::anyhow!("permission denied");
assert!(!is_retriable(e2.as_ref()));
}
#[test]
fn is_retriable_chain_walks_to_http_error() {
let inner = HttpError::new(StrErr("bad gateway"), 503);
let wrapped: anyhow::Error = anyhow::Error::new(inner).context("publish failed");
assert!(is_retriable(wrapped.as_ref()));
}
#[test]
fn classifier_5xx_via_anyhow_chain_uses_as_ref() {
let wrapped: anyhow::Error =
anyhow::Error::new(HttpError::new(std::io::Error::other("503"), 503))
.context("publish");
assert!(
is_retriable(wrapped.as_ref()),
"5xx HttpError reached via as_ref() must classify retriable"
);
}
#[test]
fn classifier_root_cause_walks_past_http_error_drift_guard() {
let wrapped: anyhow::Error =
anyhow::Error::new(HttpError::new(std::io::Error::other("503"), 503))
.context("publish");
assert!(
!is_retriable(wrapped.root_cause()),
"root_cause() walks past HttpError; 5xx must NOT be detected via the leaf"
);
}
#[test]
fn classifier_429_via_anyhow_chain_uses_as_ref() {
let wrapped: anyhow::Error =
anyhow::Error::new(HttpError::new(std::io::Error::other("429"), 429))
.context("publish");
assert!(is_retriable(wrapped.as_ref()));
assert!(!is_retriable(wrapped.root_cause()));
}
fn spawn_oneshot_http_responder(
responses: Vec<&'static str>,
) -> (std::net::SocketAddr, std::sync::Arc<AtomicU32>) {
use std::io::{Read, Write};
use std::net::TcpListener;
let listener = TcpListener::bind("127.0.0.1:0").expect("bind ephemeral port");
let addr = listener.local_addr().expect("local_addr");
let counter = std::sync::Arc::new(AtomicU32::new(0));
let counter_inner = counter.clone();
std::thread::spawn(move || {
for (i, resp) in responses.iter().enumerate() {
let (mut stream, _) = match listener.accept() {
Ok(pair) => pair,
Err(_) => return, };
counter_inner.fetch_add(1, Ordering::SeqCst);
let mut buf = [0u8; 8192];
let _ = stream.set_read_timeout(Some(Duration::from_millis(500)));
let _ = stream.read(&mut buf);
let _ = stream.write_all(resp.as_bytes());
let _ = stream.flush();
let _ = stream.shutdown(std::net::Shutdown::Both);
if i == responses.len() - 1 {
break;
}
}
});
(addr, counter)
}
#[test]
fn retry_http_blocking_success_returns_first_attempt() {
let (addr, calls) =
spawn_oneshot_http_responder(vec!["HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"]);
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_blocking(
"test",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|_, _| String::from("should not be called on success"),
);
let (status, body) = result.expect("success");
assert_eq!(status.as_u16(), 200);
assert_eq!(body, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1, "single attempt");
}
#[test]
fn retry_http_blocking_retries_5xx_then_succeeds() {
let (addr, calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
]);
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_blocking(
"test",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|status, body| format!("{status}: {body}"),
);
let (status, _) = result.expect("eventually succeeds");
assert_eq!(status.as_u16(), 200);
assert_eq!(calls.load(Ordering::SeqCst), 2, "one retry then success");
}
#[test]
fn retry_http_blocking_4xx_fast_fails_no_retry() {
let (addr, calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found",
]);
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 5,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_blocking(
"myscope",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|status, body| format!("custom error: {status} body={body}"),
);
let err = result.expect_err("4xx must fast-fail");
let chain = format!("{err:#}");
assert!(
chain.contains("custom error"),
"error formatter must be invoked on non-success; got: {chain}"
);
assert!(chain.contains("404"), "status must be in chain: {chain}");
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"4xx must NOT retry (only one connection accepted)"
);
}
#[test]
fn retry_http_blocking_redirect_class_alters_success_predicate() {
let (addr, _calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 307 Temporary Redirect\r\nLocation: /next\r\nContent-Length: 0\r\n\r\n",
]);
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(2))
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_blocking(
"test",
&policy,
SuccessClass::AllowRedirects,
|_| client.get(format!("http://{addr}/")).send(),
|_, _| String::from("should not be called on 3xx with AllowRedirects"),
);
let (status, _) = result.expect("3xx is success under AllowRedirects");
assert_eq!(status.as_u16(), 307);
}
#[tokio::test]
async fn retry_http_async_success_returns_first_attempt() {
let (addr, calls) =
spawn_oneshot_http_responder(vec!["HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok"]);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_async(
"test",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|_, _| String::from("should not be called on success"),
)
.await;
let resp = result.expect("success");
assert_eq!(resp.status().as_u16(), 200);
let body = resp.text().await.expect("body");
assert_eq!(body, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1, "single attempt");
}
#[tokio::test]
async fn retry_http_async_retries_5xx_then_succeeds() {
let (addr, calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 503 Service Unavailable\r\nContent-Length: 0\r\n\r\n",
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
]);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_async(
"test",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|status, body| format!("{status}: {body}"),
)
.await;
let resp = result.expect("eventually succeeds");
assert_eq!(resp.status().as_u16(), 200);
assert_eq!(calls.load(Ordering::SeqCst), 2, "one retry then success");
}
#[tokio::test]
async fn retry_http_async_4xx_fast_fails_no_retry() {
let (addr, calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 404 Not Found\r\nContent-Length: 9\r\n\r\nnot found",
]);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 5,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_async(
"myscope",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|status, body| format!("custom error: {status} body={body}"),
)
.await;
let err = result.expect_err("4xx must fast-fail");
let chain = format!("{err:#}");
assert!(
chain.contains("custom error"),
"error formatter must be invoked on non-success; got: {chain}"
);
assert!(chain.contains("404"), "status must be in chain: {chain}");
assert_eq!(
calls.load(Ordering::SeqCst),
1,
"4xx must NOT retry (only one connection accepted)"
);
}
#[tokio::test]
async fn retry_http_async_429_retries_then_succeeds() {
let (addr, calls) = spawn_oneshot_http_responder(vec![
"HTTP/1.1 429 Too Many Requests\r\nContent-Length: 0\r\n\r\n",
"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nok",
]);
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_async(
"test",
&policy,
SuccessClass::Strict,
|_| client.get(format!("http://{addr}/")).send(),
|status, body| format!("{status}: {body}"),
)
.await;
let resp = result.expect("429 retried then success");
assert_eq!(resp.status().as_u16(), 200);
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
const TRANSPORT_FAIL_URL: &str = "http://nonexistent.invalid/";
#[test]
fn retry_http_blocking_transport_error_retries_then_fails() {
let attempts = std::sync::Arc::new(AtomicU32::new(0));
let attempts_inner = attempts.clone();
let client = reqwest::blocking::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_blocking(
"test-transport",
&policy,
SuccessClass::Strict,
|_| {
attempts_inner.fetch_add(1, Ordering::SeqCst);
client.get(TRANSPORT_FAIL_URL).send()
},
|_, _| String::from("non-success branch should not be reached"),
);
let err = result.expect_err("transport error must surface as Err");
let chain = format!("{err:#}");
assert!(
attempts.load(Ordering::SeqCst) > 1,
"transport error must be retried; got {} attempts; chain={chain}",
attempts.load(Ordering::SeqCst)
);
assert!(
chain.contains("test-transport"),
"label must surface in error chain; got: {chain}"
);
}
#[tokio::test]
async fn retry_http_async_transport_error_retries_then_fails() {
let attempts = std::sync::Arc::new(AtomicU32::new(0));
let attempts_inner = attempts.clone();
let client = reqwest::Client::builder()
.timeout(Duration::from_millis(500))
.build()
.expect("client");
let policy = RetryPolicy {
max_attempts: 3,
base_delay: Duration::from_millis(1),
max_delay: Duration::from_millis(2),
};
let result = retry_http_async(
"test-transport-async",
&policy,
SuccessClass::Strict,
|_| {
attempts_inner.fetch_add(1, Ordering::SeqCst);
client.get(TRANSPORT_FAIL_URL).send()
},
|_, _| String::from("non-success branch should not be reached"),
)
.await;
let err = result.expect_err("transport error must surface as Err");
assert!(
attempts.load(Ordering::SeqCst) > 1,
"transport error must be retried; got {} attempts",
attempts.load(Ordering::SeqCst)
);
let chain = format!("{err:#}");
assert!(
chain.contains("test-transport-async"),
"label must surface in error chain; got: {chain}"
);
}
}