use crate::http::HttpErrorResponse;
use backoff::{backoff::Backoff, future::retry_notify, ExponentialBackoff, Notify};
use bytes::Bytes;
use futures::Future;
use http::HeaderMap;
use reqwest::StatusCode;
use std::{error::Error as StdError, time::Duration};
use tracing::{debug, warn};
fn find_io_error(original_error: &reqwest::Error) -> Option<&std::io::Error> {
let mut cause = original_error.source();
while let Some(err) = cause {
if let Some(typed) = err.downcast_ref() {
return Some(typed);
}
cause = err.source();
}
None
}
pub fn http_request_exponential_backoff() -> ExponentialBackoff {
ExponentialBackoff {
initial_interval: Duration::from_secs(1),
max_interval: Duration::from_secs(30),
multiplier: 2.0,
max_elapsed_time: Some(Duration::from_secs(600)),
..Default::default()
}
}
#[derive(Clone, Debug)]
pub struct HttpResponse {
status: StatusCode,
headers: HeaderMap,
body: Bytes,
}
impl HttpResponse {
pub fn status(&self) -> StatusCode {
self.status
}
pub fn headers(&self) -> &HeaderMap {
&self.headers
}
pub fn body(&self) -> &Bytes {
&self.body
}
}
pub struct NoopNotify;
impl<E> Notify<E> for NoopNotify {
fn notify(&mut self, _: E, _: Duration) {}
}
#[allow(clippy::result_large_err)]
pub async fn retry_http_request<ResultFuture>(
backoff: impl Backoff,
request_fn: impl Fn() -> ResultFuture,
) -> Result<HttpResponse, Result<HttpErrorResponse, reqwest::Error>>
where
ResultFuture: Future<Output = Result<reqwest::Response, reqwest::Error>>,
{
retry_http_request_notify(backoff, NoopNotify, request_fn).await
}
#[allow(clippy::result_large_err)]
pub async fn retry_http_request_notify<ResultFuture>(
backoff: impl Backoff,
notify: impl Notify<Result<HttpErrorResponse, reqwest::Error>>,
request_fn: impl Fn() -> ResultFuture,
) -> Result<HttpResponse, Result<HttpErrorResponse, reqwest::Error>>
where
ResultFuture: Future<Output = Result<reqwest::Response, reqwest::Error>>,
{
fn check_reqwest_result<T>(
rslt: Result<T, reqwest::Error>,
) -> Result<T, backoff::Error<Result<HttpErrorResponse, reqwest::Error>>> {
rslt.map_err(|err| {
if err.is_timeout() || err.is_connect() {
warn!(?err, "Encountered retryable network error");
return backoff::Error::transient(Err(err));
}
if let Some(io_error) = find_io_error(&err) {
if let std::io::ErrorKind::ConnectionRefused
| std::io::ErrorKind::ConnectionReset
| std::io::ErrorKind::ConnectionAborted = io_error.kind()
{
warn!(?err, "Encountered retryable network error");
return backoff::Error::transient(Err(err));
}
}
debug!("Encountered non-retryable network error");
backoff::Error::permanent(Err(err))
})
}
retry_notify(
backoff,
|| async {
let response = check_reqwest_result(request_fn().await)?;
let status = response.status();
if status.is_server_error() || status.is_client_error() {
if is_retryable_http_status(status) {
warn!(?response, "Encountered retryable HTTP error");
return Err(backoff::Error::transient(Ok(
HttpErrorResponse::from_response(response).await,
)));
} else {
warn!(?response, "Encountered non-retryable HTTP error");
return Err(backoff::Error::permanent(Ok(
HttpErrorResponse::from_response(response).await,
)));
}
}
let headers = response.headers().clone();
let body = check_reqwest_result(response.bytes().await)?;
Ok(HttpResponse {
status,
headers,
body,
})
},
notify,
)
.await
}
pub fn is_retryable_http_status(status: StatusCode) -> bool {
(status.is_server_error() && status != StatusCode::NOT_IMPLEMENTED)
|| status == StatusCode::TOO_MANY_REQUESTS
}
pub fn is_retryable_http_client_error(error: &reqwest::Error) -> bool {
error.is_timeout() || error.is_connect() || error.is_request() || error.is_body()
}
#[cfg(feature = "test-util")]
#[cfg_attr(docsrs, doc(cfg(feature = "test-util")))]
pub mod test_util {
use backoff::{backoff::Backoff, ExponentialBackoff};
use std::time::Duration;
pub fn test_http_request_exponential_backoff() -> ExponentialBackoff {
ExponentialBackoff {
initial_interval: Duration::from_nanos(1),
max_interval: Duration::from_nanos(30),
multiplier: 2.0,
max_elapsed_time: Some(Duration::from_millis(100)),
..Default::default()
}
}
#[derive(Clone)]
pub struct LimitedRetryer {
retries: u64,
max_retries: u64,
}
impl LimitedRetryer {
pub fn new(max_retries: u64) -> Self {
Self {
retries: 0,
max_retries,
}
}
}
impl Backoff for LimitedRetryer {
fn next_backoff(&mut self) -> Option<Duration> {
if self.retries >= self.max_retries {
return None;
}
self.retries += 1;
Some(Duration::ZERO)
}
fn reset(&mut self) {
self.retries = 0
}
}
}
#[cfg(test)]
mod tests {
use crate::{
retries::{retry_http_request, retry_http_request_notify, test_util::LimitedRetryer},
test_util::install_test_trace_subscriber,
};
use backoff::Notify;
use reqwest::StatusCode;
use std::time::Duration;
use tokio::net::TcpListener;
use url::Url;
#[derive(Default)]
struct NotifyCounter {
count: u64,
}
impl<E> Notify<E> for &mut NotifyCounter {
fn notify(&mut self, _: E, _: Duration) {
self.count += 1;
}
}
#[tokio::test]
async fn http_retry_client_error() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let mock_404 = server
.mock("GET", "/")
.with_status(StatusCode::NOT_FOUND.as_u16().into())
.with_header("some-header", "some-value")
.with_body("some-body") .expect(1)
.create_async()
.await;
let http_client = reqwest::Client::builder().build().unwrap();
let mut notify = NotifyCounter::default();
let response = retry_http_request_notify(LimitedRetryer::new(10), &mut notify, || async {
http_client.get(server.url()).send().await
})
.await
.unwrap_err()
.unwrap();
assert_eq!(notify.count, 0);
assert_eq!(response.status(), StatusCode::NOT_FOUND);
mock_404.assert_async().await;
}
#[tokio::test]
async fn http_retry_server_error() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let mock_500 = server
.mock("GET", "/")
.with_status(StatusCode::INTERNAL_SERVER_ERROR.as_u16().into())
.with_header("some-header", "some-value")
.with_body("some-body")
.expect_at_least(2)
.create_async()
.await;
let http_client = reqwest::Client::builder().build().unwrap();
let mut notify = NotifyCounter::default();
let response = retry_http_request_notify(LimitedRetryer::new(10), &mut notify, || async {
http_client.get(server.url()).send().await
})
.await
.unwrap_err()
.unwrap();
assert_eq!(notify.count, 10);
assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
mock_500.assert_async().await;
}
#[tokio::test]
async fn http_retry_server_error_unimplemented() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let mock_501 = server
.mock("GET", "/")
.with_status(StatusCode::NOT_IMPLEMENTED.as_u16().into())
.expect(1)
.create_async()
.await;
let http_client = reqwest::Client::builder().build().unwrap();
let mut notify = NotifyCounter::default();
let response = retry_http_request_notify(LimitedRetryer::new(10), &mut notify, || async {
http_client.get(server.url()).send().await
})
.await
.unwrap_err()
.unwrap();
assert_eq!(notify.count, 0);
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
mock_501.assert_async().await;
}
#[tokio::test]
async fn http_retry_server_error_eventually_succeeds() {
install_test_trace_subscriber();
let mut server = mockito::Server::new_async().await;
let mock_500 = server
.mock("GET", "/")
.with_status(500)
.expect_at_least(2)
.create_async()
.await;
let mock_200 = server
.mock("GET", "/")
.with_status(200)
.expect(1)
.create_async()
.await;
let http_client = reqwest::Client::builder().build().unwrap();
let mut notify = NotifyCounter::default();
retry_http_request_notify(LimitedRetryer::new(10), &mut notify, || async {
http_client.get(server.url()).send().await
})
.await
.unwrap();
assert_eq!(notify.count, 2);
mock_200.assert_async().await;
mock_500.assert_async().await;
}
#[tokio::test]
async fn http_retry_timeout() {
install_test_trace_subscriber();
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound_port = tcp_listener.local_addr().unwrap().port();
let listener_task = tokio::spawn(async move {
loop {
let (_socket, _) = tcp_listener.accept().await.unwrap();
tokio::time::sleep(Duration::from_secs(10)).await;
}
});
let url = Url::parse(&format!("http://127.0.0.1:{bound_port}")).unwrap();
let http_client = reqwest::Client::builder()
.timeout(Duration::from_nanos(1))
.build()
.unwrap();
let err = retry_http_request(LimitedRetryer::new(0), || async {
http_client.get(url.clone()).send().await
})
.await
.unwrap_err()
.unwrap_err();
assert!(err.is_timeout(), "error = {err}");
listener_task.abort();
assert!(listener_task.await.unwrap_err().is_cancelled());
}
#[tokio::test]
async fn http_retry_connection_reset() {
install_test_trace_subscriber();
let tcp_listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let bound_port = tcp_listener.local_addr().unwrap().port();
let listener_task = tokio::spawn(async move {
loop {
let (socket, _) = tcp_listener.accept().await.unwrap();
loop {
socket.readable().await.unwrap();
let mut buf = [0u8; 1];
match socket.try_read(&mut buf) {
Ok(1) => break,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
continue;
}
val => panic!("unexpected result from try_read {val:?}"),
}
}
drop(socket);
}
});
let url = Url::parse(&format!("http://127.0.0.1:{bound_port}")).unwrap();
let http_client = reqwest::Client::builder().build().unwrap();
retry_http_request(LimitedRetryer::new(0), || async {
http_client.get(url.clone()).send().await
})
.await
.unwrap_err()
.unwrap_err();
listener_task.abort();
assert!(listener_task.await.unwrap_err().is_cancelled());
}
}