use crate::client::backoff::{Backoff, BackoffConfig};
use crate::PutPayload;
use futures::future::BoxFuture;
use reqwest::header::LOCATION;
use reqwest::{Client, Request, Response, StatusCode};
use snafu::Error as SnafuError;
use snafu::Snafu;
use std::time::{Duration, Instant};
use tracing::info;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Received redirect without LOCATION, this normally indicates an incorrectly configured region"))]
BareRedirect,
#[snafu(display("Client error with status {status}: {}", body.as_deref().unwrap_or("No Body")))]
Client {
status: StatusCode,
body: Option<String>,
},
#[snafu(display("Error after {retries} retries in {elapsed:?}, max_retries:{max_retries}, retry_timeout:{retry_timeout:?}, source:{source}"))]
Reqwest {
retries: usize,
max_retries: usize,
elapsed: Duration,
retry_timeout: Duration,
source: reqwest::Error,
},
}
impl Error {
pub fn status(&self) -> Option<StatusCode> {
match self {
Self::BareRedirect => None,
Self::Client { status, .. } => Some(*status),
Self::Reqwest { source, .. } => source.status(),
}
}
pub fn body(&self) -> Option<&str> {
match self {
Self::Client { body, .. } => body.as_deref(),
Self::BareRedirect => None,
Self::Reqwest { .. } => None,
}
}
pub fn error(self, store: &'static str, path: String) -> crate::Error {
match self.status() {
Some(StatusCode::NOT_FOUND) => crate::Error::NotFound {
path,
source: Box::new(self),
},
Some(StatusCode::NOT_MODIFIED) => crate::Error::NotModified {
path,
source: Box::new(self),
},
Some(StatusCode::PRECONDITION_FAILED) => crate::Error::Precondition {
path,
source: Box::new(self),
},
Some(StatusCode::CONFLICT) => crate::Error::AlreadyExists {
path,
source: Box::new(self),
},
_ => crate::Error::Generic {
store,
source: Box::new(self),
},
}
}
}
impl From<Error> for std::io::Error {
fn from(err: Error) -> Self {
use std::io::ErrorKind;
match &err {
Error::Client {
status: StatusCode::NOT_FOUND,
..
} => Self::new(ErrorKind::NotFound, err),
Error::Client {
status: StatusCode::BAD_REQUEST,
..
} => Self::new(ErrorKind::InvalidInput, err),
Error::Reqwest { source, .. } if source.is_timeout() => {
Self::new(ErrorKind::TimedOut, err)
}
Error::Reqwest { source, .. } if source.is_connect() => {
Self::new(ErrorKind::NotConnected, err)
}
_ => Self::new(ErrorKind::Other, err),
}
}
}
pub type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(Debug, Clone)]
pub struct RetryConfig {
pub backoff: BackoffConfig,
pub max_retries: usize,
pub retry_timeout: Duration,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
backoff: Default::default(),
max_retries: 10,
retry_timeout: Duration::from_secs(3 * 60),
}
}
}
pub struct RetryableRequest {
client: Client,
request: Request,
max_retries: usize,
retry_timeout: Duration,
backoff: Backoff,
sensitive: bool,
idempotent: Option<bool>,
payload: Option<PutPayload>,
}
impl RetryableRequest {
pub fn idempotent(self, idempotent: bool) -> Self {
Self {
idempotent: Some(idempotent),
..self
}
}
#[allow(unused)]
pub fn sensitive(self, sensitive: bool) -> Self {
Self { sensitive, ..self }
}
pub fn payload(self, payload: Option<PutPayload>) -> Self {
Self { payload, ..self }
}
pub async fn send(self) -> Result<Response> {
let max_retries = self.max_retries;
let retry_timeout = self.retry_timeout;
let mut retries = 0;
let now = Instant::now();
let mut backoff = self.backoff;
let is_idempotent = self
.idempotent
.unwrap_or_else(|| self.request.method().is_safe());
let sanitize_err = move |e: reqwest::Error| match self.sensitive {
true => e.without_url(),
false => e,
};
loop {
let mut request = self
.request
.try_clone()
.expect("request body must be cloneable");
if let Some(payload) = &self.payload {
*request.body_mut() = Some(payload.body());
}
match self.client.execute(request).await {
Ok(r) => match r.error_for_status_ref() {
Ok(_) if r.status().is_success() => return Ok(r),
Ok(r) if r.status() == StatusCode::NOT_MODIFIED => {
return Err(Error::Client {
body: None,
status: StatusCode::NOT_MODIFIED,
})
}
Ok(r) => {
let is_bare_redirect =
r.status().is_redirection() && !r.headers().contains_key(LOCATION);
return match is_bare_redirect {
true => Err(Error::BareRedirect),
false => Err(Error::Client {
body: None,
status: r.status(),
}),
};
}
Err(e) => {
let e = sanitize_err(e);
let status = r.status();
if retries == max_retries
|| now.elapsed() > retry_timeout
|| !status.is_server_error()
{
return Err(match status.is_client_error() {
true => match r.text().await {
Ok(body) => Error::Client {
body: Some(body).filter(|b| !b.is_empty()),
status,
},
Err(e) => Error::Reqwest {
retries,
max_retries,
elapsed: now.elapsed(),
retry_timeout,
source: e,
},
},
false => Error::Reqwest {
retries,
max_retries,
elapsed: now.elapsed(),
retry_timeout,
source: e,
},
});
}
let sleep = backoff.next();
retries += 1;
info!(
"Encountered server error, backing off for {} seconds, retry {} of {}: {}",
sleep.as_secs_f32(),
retries,
max_retries,
e,
);
tokio::time::sleep(sleep).await;
}
},
Err(e) => {
let e = sanitize_err(e);
let mut do_retry = false;
if e.is_connect()
|| e.is_body()
|| (e.is_request() && !e.is_timeout())
|| (is_idempotent && e.is_timeout())
{
do_retry = true
} else {
let mut source = e.source();
while let Some(e) = source {
if let Some(e) = e.downcast_ref::<hyper::Error>() {
do_retry = e.is_closed()
|| e.is_incomplete_message()
|| e.is_body_write_aborted()
|| (is_idempotent && e.is_timeout());
break;
}
if let Some(e) = e.downcast_ref::<std::io::Error>() {
if e.kind() == std::io::ErrorKind::TimedOut {
do_retry = is_idempotent;
} else {
do_retry = matches!(
e.kind(),
std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted
| std::io::ErrorKind::BrokenPipe
| std::io::ErrorKind::UnexpectedEof
);
}
break;
}
source = e.source();
}
}
if retries == max_retries || now.elapsed() > retry_timeout || !do_retry {
return Err(Error::Reqwest {
retries,
max_retries,
elapsed: now.elapsed(),
retry_timeout,
source: e,
});
}
let sleep = backoff.next();
retries += 1;
info!(
"Encountered transport error backing off for {} seconds, retry {} of {}: {}",
sleep.as_secs_f32(),
retries,
max_retries,
e,
);
tokio::time::sleep(sleep).await;
}
}
}
}
}
pub trait RetryExt {
fn retryable(self, config: &RetryConfig) -> RetryableRequest;
fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result<Response>>;
}
impl RetryExt for reqwest::RequestBuilder {
fn retryable(self, config: &RetryConfig) -> RetryableRequest {
let (client, request) = self.build_split();
let request = request.expect("request must be valid");
RetryableRequest {
client,
request,
max_retries: config.max_retries,
retry_timeout: config.retry_timeout,
backoff: Backoff::new(&config.backoff),
idempotent: None,
payload: None,
sensitive: false,
}
}
fn send_retry(self, config: &RetryConfig) -> BoxFuture<'static, Result<Response>> {
let request = self.retryable(config);
Box::pin(async move { request.send().await })
}
}
#[cfg(test)]
mod tests {
use crate::client::mock_server::MockServer;
use crate::client::retry::{Error, RetryExt};
use crate::RetryConfig;
use hyper::header::LOCATION;
use hyper::Response;
use reqwest::{Client, Method, StatusCode};
use std::time::Duration;
#[tokio::test]
async fn test_retry() {
let mock = MockServer::new().await;
let retry = RetryConfig {
backoff: Default::default(),
max_retries: 2,
retry_timeout: Duration::from_secs(1000),
};
let client = Client::builder()
.timeout(Duration::from_millis(100))
.build()
.unwrap();
let do_request = || client.request(Method::GET, mock.url()).send_retry(&retry);
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
mock.push(
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body("cupcakes".to_string())
.unwrap(),
);
let e = do_request().await.unwrap_err();
assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST);
assert_eq!(e.body(), Some("cupcakes"));
assert_eq!(
e.to_string(),
"Client error with status 400 Bad Request: cupcakes"
);
mock.push(
Response::builder()
.status(StatusCode::BAD_REQUEST)
.body(String::new())
.unwrap(),
);
let e = do_request().await.unwrap_err();
assert_eq!(e.status().unwrap(), StatusCode::BAD_REQUEST);
assert_eq!(e.body(), None);
assert_eq!(
e.to_string(),
"Client error with status 400 Bad Request: No Body"
);
mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body(String::new())
.unwrap(),
);
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
mock.push(
Response::builder()
.status(StatusCode::NO_CONTENT)
.body(String::new())
.unwrap(),
);
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::NO_CONTENT);
mock.push(
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, "/foo")
.body(String::new())
.unwrap(),
);
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(r.url().path(), "/foo");
mock.push(
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, "/bar")
.body(String::new())
.unwrap(),
);
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
assert_eq!(r.url().path(), "/bar");
for _ in 0..10 {
mock.push(
Response::builder()
.status(StatusCode::FOUND)
.header(LOCATION, "/bar")
.body(String::new())
.unwrap(),
);
}
let e = do_request().await.unwrap_err().to_string();
assert!(e.contains("error following redirect for url"), "{}", e);
mock.push(
Response::builder()
.status(StatusCode::FOUND)
.body(String::new())
.unwrap(),
);
let e = do_request().await.unwrap_err();
assert!(matches!(e, Error::BareRedirect));
assert_eq!(e.to_string(), "Received redirect without LOCATION, this normally indicates an incorrectly configured region");
for _ in 0..=retry.max_retries {
mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body("ignored".to_string())
.unwrap(),
);
}
let e = do_request().await.unwrap_err().to_string();
assert!(
e.contains("Error after 2 retries in") &&
e.contains("max_retries:2, retry_timeout:1000s, source:HTTP status server error (502 Bad Gateway) for url"),
"{e}"
);
mock.push_fn(|_| panic!());
let r = do_request().await.unwrap();
assert_eq!(r.status(), StatusCode::OK);
for _ in 0..=retry.max_retries {
mock.push_fn(|_| panic!());
}
let e = do_request().await.unwrap_err().to_string();
assert!(
e.contains("Error after 2 retries in")
&& e.contains(
"max_retries:2, retry_timeout:1000s, source:error sending request for url"
),
"{e}"
);
mock.push_async_fn(|_| async move {
tokio::time::sleep(Duration::from_secs(10)).await;
panic!()
});
do_request().await.unwrap();
mock.push_async_fn(|_| async move {
tokio::time::sleep(Duration::from_secs(10)).await;
panic!()
});
let res = client.request(Method::PUT, mock.url()).send_retry(&retry);
let e = res.await.unwrap_err().to_string();
assert!(
e.contains("Error after 0 retries in") && e.contains("error sending request for url"),
"{e}"
);
let url = format!("{}/SENSITIVE", mock.url());
for _ in 0..=retry.max_retries {
mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body("ignored".to_string())
.unwrap(),
);
}
let res = client.request(Method::GET, url).send_retry(&retry).await;
let err = res.unwrap_err().to_string();
assert!(err.contains("SENSITIVE"), "{err}");
let url = format!("{}/SENSITIVE", mock.url());
for _ in 0..=retry.max_retries {
mock.push(
Response::builder()
.status(StatusCode::BAD_GATEWAY)
.body("ignored".to_string())
.unwrap(),
);
}
let req = client
.request(Method::GET, &url)
.retryable(&retry)
.sensitive(true);
let err = req.send().await.unwrap_err().to_string();
assert!(!err.contains("SENSITIVE"), "{err}");
for _ in 0..=retry.max_retries {
mock.push_fn(|_| panic!());
}
let req = client
.request(Method::GET, &url)
.retryable(&retry)
.sensitive(true);
let err = req.send().await.unwrap_err().to_string();
assert!(!err.contains("SENSITIVE"), "{err}");
mock.shutdown().await
}
}