use std::error::Error;
use std::time::{Duration, SystemTime, SystemTimeError};
use std::{io, iter};
use http::status::StatusCode;
use itertools::Itertools;
use reqwest::Response;
use reqwest_retry::policies::ExponentialBackoff;
use reqwest_retry::{
RetryPolicy, Retryable, RetryableStrategy, default_on_request_error, default_on_request_success,
};
use tracing::{debug, trace};
use url::Url;
use uv_redacted::DisplaySafeUrl;
use crate::WrappedReqwestError;
pub struct UvRetryableStrategy;
impl RetryableStrategy for UvRetryableStrategy {
fn handle(&self, res: &Result<Response, reqwest_middleware::Error>) -> Option<Retryable> {
let retryable = match res {
Ok(success) => default_on_request_success(success),
Err(err) => retryable_on_request_failure(err),
};
if retryable == Some(Retryable::Transient) {
match res {
Ok(response) => {
debug!("Transient request failure for: {}", response.url());
}
Err(err) => {
let context = iter::successors(err.source(), |&err| err.source())
.map(|err| format!(" Caused by: {err}"))
.join("\n");
debug!(
"Transient request failure for {}, retrying: {err}\n{context}",
err.url().map(Url::as_str).unwrap_or("unknown URL")
);
}
}
}
retryable
}
}
pub struct RetryState {
retry_policy: ExponentialBackoff,
start_time: SystemTime,
total_retries: u32,
url: DisplaySafeUrl,
}
impl RetryState {
pub fn start(retry_policy: ExponentialBackoff, url: impl Into<DisplaySafeUrl>) -> Self {
Self {
retry_policy,
start_time: SystemTime::now(),
total_retries: 0,
url: url.into(),
}
}
pub fn total_retries(&self) -> u32 {
self.total_retries
}
pub fn duration(&self) -> Result<Duration, SystemTimeError> {
self.start_time.elapsed()
}
#[must_use]
pub fn should_retry(
&mut self,
err: &(dyn Error + 'static),
error_retries: u32,
) -> Option<Duration> {
self.total_retries += error_retries;
match retryable_on_request_failure(err) {
Some(Retryable::Transient) => {
let now = SystemTime::now();
let retry_decision = self
.retry_policy
.should_retry(self.start_time, self.total_retries);
if let reqwest_retry::RetryDecision::Retry { execute_after } = retry_decision {
let duration = execute_after
.duration_since(now)
.unwrap_or_else(|_| Duration::default());
self.total_retries += 1;
return Some(duration);
}
None
}
Some(Retryable::Fatal) | None => None,
}
}
pub async fn sleep_backoff(&self, duration: Duration) {
debug!(
"Transient failure while handling response from {}; retrying after {:.1}s...",
self.url,
duration.as_secs_f32(),
);
tokio::time::sleep(duration).await;
}
}
pub fn retryable_on_request_failure(err: &(dyn Error + 'static)) -> Option<Retryable> {
if let Some((Some(status), Some(url))) = find_source::<WrappedReqwestError>(&err)
.map(|request_err| (request_err.status(), request_err.url()))
{
trace!(
"Considering retry of response HTTP {status} for {url}",
url = DisplaySafeUrl::from_url(url.clone())
);
} else {
trace!("Considering retry of error: {err:?}");
}
let mut has_known_error = false;
let mut current_source = Some(err);
while let Some(source) = current_source {
let reqwest_err = if let Some(reqwest_err) = source.downcast_ref::<reqwest::Error>() {
Some(reqwest_err)
} else if let Some(reqwest_err) = source
.downcast_ref::<WrappedReqwestError>()
.and_then(|err| err.inner())
{
Some(reqwest_err)
} else if let Some(reqwest_middleware::Error::Reqwest(reqwest_err)) =
source.downcast_ref::<reqwest_middleware::Error>()
{
Some(reqwest_err)
} else {
None
};
if let Some(reqwest_err) = reqwest_err {
has_known_error = true;
if default_on_request_error(reqwest_err) == Some(Retryable::Transient) {
trace!("Transient nested reqwest error");
return Some(Retryable::Transient);
}
if is_retryable_status_error(reqwest_err) {
trace!("Transient nested reqwest status code error");
return Some(Retryable::Transient);
}
trace!("Fatal nested reqwest error");
} else if source.downcast_ref::<h2::Error>().is_some() {
trace!("Transient nested h2 error");
return Some(Retryable::Transient);
} else if let Some(io_err) = source.downcast_ref::<io::Error>() {
has_known_error = true;
let retryable_io_err_kinds = [
io::ErrorKind::BrokenPipe,
io::ErrorKind::ConnectionAborted,
io::ErrorKind::ConnectionReset,
io::ErrorKind::InvalidData,
io::ErrorKind::TimedOut,
io::ErrorKind::UnexpectedEof,
];
if retryable_io_err_kinds.contains(&io_err.kind()) {
trace!("Transient IO error: `{}`", io_err.kind());
return Some(Retryable::Transient);
}
trace!(
"Fatal IO error `{}`, not a transient IO error kind",
io_err.kind()
);
}
current_source = source.source();
}
if !has_known_error {
trace!("Cannot retry error: neither an IO error nor a reqwest error");
}
None
}
pub trait RetriableError: std::error::Error + Sized + 'static {
fn should_try_next_url(&self) -> bool;
fn retries(&self) -> u32;
#[must_use]
fn into_retried(self, retries: u32, duration: Duration) -> Self;
}
fn is_retryable_status_error(reqwest_err: &reqwest::Error) -> bool {
let Some(status) = reqwest_err.status() else {
return false;
};
status.is_server_error()
|| status == StatusCode::REQUEST_TIMEOUT
|| status == StatusCode::TOO_MANY_REQUESTS
}
fn find_source<E: Error + 'static>(orig: &dyn Error) -> Option<&E> {
let mut cause = orig.source();
while let Some(err) = cause {
if let Some(typed) = err.downcast_ref() {
return Some(typed);
}
cause = err.source();
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::Result;
use insta::assert_debug_snapshot;
use reqwest::Client;
use reqwest_middleware::ClientWithMiddleware;
use wiremock::matchers::path;
use wiremock::{Mock, MockServer, ResponseTemplate};
use crate::{UvRetryableStrategy, retryable_on_request_failure};
#[tokio::test]
async fn retried_status_codes() -> Result<()> {
let server = MockServer::start().await;
let client = Client::default();
let middleware_client = ClientWithMiddleware::default();
let mut retried = Vec::new();
for status in 100..599 {
if StatusCode::from_u16(status)?.canonical_reason().is_none() && status != 420 {
continue;
}
Mock::given(path(format!("/{status}")))
.respond_with(ResponseTemplate::new(status))
.mount(&server)
.await;
let response = middleware_client
.get(format!("{}/{}", server.uri(), status))
.send()
.await;
let middleware_retry =
UvRetryableStrategy.handle(&response) == Some(Retryable::Transient);
let response = client
.get(format!("{}/{}", server.uri(), status))
.send()
.await?;
let uv_retry = match response.error_for_status() {
Ok(_) => false,
Err(err) => retryable_on_request_failure(&err) == Some(Retryable::Transient),
};
assert_eq!(middleware_retry, uv_retry);
if uv_retry {
retried.push(status);
}
}
assert_debug_snapshot!(retried, @"
[
100,
102,
103,
408,
429,
500,
501,
502,
503,
504,
505,
506,
507,
508,
510,
511,
]
");
Ok(())
}
}