use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
use bytes::Bytes;
use reqwest::{Error as ReqwestError, Response, StatusCode};
use tokio::sync::Mutex;
use tokio_retry::RetryIf;
use tokio_retry::strategy::{ExponentialBackoff, jitter};
use tracing::{error, info};
use xet_runtime::core::xet_config;
use super::adaptive_concurrency::ConnectionPermit;
use crate::common::http_client::request_id_from_response;
use crate::error::{ClientError, Result};
#[derive(Debug)]
pub enum RetryableReqwestError {
FatalError(ClientError),
RetryableError(ClientError),
}
struct ConnectionPermitInfo {
permit: Option<ConnectionPermit>,
transfer_size_if_known: Option<u64>,
}
pub struct RetryWrapper {
max_attempts: usize,
base_delay: Duration,
no_retry_on_429: bool,
retry_on_403: bool,
expected_416: bool,
log_errors_as_info: bool,
api_tag: &'static str,
connection_permit: Option<Mutex<ConnectionPermitInfo>>,
}
impl RetryWrapper {
pub fn new(api_tag: &'static str) -> Self {
Self {
max_attempts: xet_config().client.retry_max_attempts,
base_delay: xet_config().client.retry_base_delay,
no_retry_on_429: false,
retry_on_403: false,
expected_416: false,
log_errors_as_info: false,
api_tag,
connection_permit: None,
}
}
pub fn with_max_attempts(mut self, attempts: usize) -> Self {
self.max_attempts = attempts;
self
}
pub fn with_base_delay(mut self, delay: Duration) -> Self {
self.base_delay = delay;
self
}
pub fn with_429_no_retry(mut self) -> Self {
self.no_retry_on_429 = true;
self
}
pub fn with_retry_on_403(mut self) -> Self {
self.retry_on_403 = true;
self
}
pub fn with_expected_416(mut self) -> Self {
self.expected_416 = true;
self
}
pub fn log_errors_as_info(mut self) -> Self {
self.log_errors_as_info = true;
self
}
pub fn with_connection_permit(mut self, permit: ConnectionPermit, transfer_size_if_known: Option<u64>) -> Self {
self.connection_permit = Some(Mutex::new(ConnectionPermitInfo {
permit: Some(permit),
transfer_size_if_known,
}));
self
}
fn process_error_response(&self, try_idx: usize, err: reqwest_middleware::Error) -> RetryableReqwestError {
let api = &self.api_tag;
let process_error = |txt, log_as_info, err: reqwest_middleware::Error| {
let msg = {
if try_idx > 0 {
format!("{txt}: {api} api call failed (retry {try_idx}): {err}")
} else {
format!("{txt}: {api} api call failed: {err}")
}
};
info!(api = self.api_tag, "Connection attempt {}/{}", try_idx + 1, self.max_attempts);
if self.log_errors_as_info || log_as_info {
info!("{msg}");
} else {
error!("{msg}");
}
ClientError::from(err)
};
match on_request_failure(&err) {
Some(Retryable::Fatal) => {
let cas_err = process_error("Fatal Client Error", false, err);
RetryableReqwestError::FatalError(cas_err)
},
Some(Retryable::Transient) => {
let cas_err = process_error("Retryable Client Error", true, err);
RetryableReqwestError::RetryableError(cas_err)
},
None => {
let cas_err = process_error("Unknown Client Error", true, err);
RetryableReqwestError::FatalError(cas_err)
},
}
}
fn process_ok_response(
&self,
try_idx: usize,
resp: Response,
) -> std::result::Result<Response, RetryableReqwestError> {
let request_id = request_id_from_response(&resp).to_owned();
let retry_str = if try_idx == 0 {
String::default()
} else {
format!(", retry {try_idx}")
};
let api = &self.api_tag;
let process_error = |context, err: ReqwestError, log_as_info| {
if self.log_errors_as_info || log_as_info {
info!("{context}: {api:?} api call failed (request id {request_id}{retry_str}): {err}");
} else {
error!("{context}: {api:?} api call failed (request id {request_id}{retry_str}): {err}");
}
ClientError::from(err)
};
let retriability = default_on_request_success(&resp);
match (resp.error_for_status(), retriability) {
(Err(e), Some(Retryable::Fatal)) => {
if e.status() == Some(StatusCode::FORBIDDEN) && self.retry_on_403 {
let cas_err = process_error("Retry on 403 (Forbidden) enabled)", e, true);
Err(RetryableReqwestError::RetryableError(cas_err))
} else if e.status() == Some(StatusCode::RANGE_NOT_SATISFIABLE) && self.expected_416 {
let cas_err = process_error("Reached end of reconstruction 416 (Range Not Satisfiable)", e, true);
Err(RetryableReqwestError::FatalError(cas_err))
} else {
let cas_err = process_error("Fatal Error", e, false);
Err(RetryableReqwestError::FatalError(cas_err))
}
},
(Err(e), Some(Retryable::Transient)) => {
if e.status() == Some(StatusCode::TOO_MANY_REQUESTS) && self.no_retry_on_429 {
let cas_err = process_error("Too Many Requests (retry on 429 disabled)", e, false);
Err(RetryableReqwestError::FatalError(cas_err))
} else if e.status() == Some(StatusCode::NOT_IMPLEMENTED) {
let cas_err = process_error("Not Implemented", e, true);
Err(RetryableReqwestError::FatalError(cas_err))
} else {
let cas_err = process_error("Retryable Error", e, true);
Err(RetryableReqwestError::RetryableError(cas_err))
}
},
(Err(e), None) => {
let cas_err = process_error("Unknown Error", e, false);
Err(RetryableReqwestError::FatalError(cas_err))
},
(Ok(result), _) => {
info!("Request Success: {api} api call succeeded (request id {request_id}{retry_str}).");
Ok(result)
},
}
}
pub async fn run_and_process<T, ReqFut, ReqFn, ProcFut, ProcFn>(
self,
make_request: ReqFn,
process_fn: ProcFn,
) -> Result<T>
where
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
ProcFn: Fn(Response) -> ProcFut + Send + 'static,
ProcFut: Future<Output = std::result::Result<T, RetryableReqwestError>> + 'static,
{
let strategy = ExponentialBackoff::from_millis(self.base_delay.as_millis().min(u64::MAX as u128) as u64)
.map(jitter)
.take(self.max_attempts);
info!(
max_attempts = self.max_attempts,
base_delay=?self.base_delay,
no_retry_on_429=self.no_retry_on_429,
"Retry strategy",
);
let self_ = Arc::new(self);
let try_count = AtomicUsize::new(0);
let retry_info = Arc::new((make_request, process_fn, try_count, self_.clone()));
let result = RetryIf::spawn(
strategy,
move || {
let retry_info = retry_info.clone();
async move {
let (make_request, process_fn, try_count, self_) = retry_info.as_ref();
if let Some(p) = &self_.connection_permit
&& let Some(p) = p.lock().await.permit.as_ref()
{
p.transfer_starting().await;
}
let resp_result = make_request().await;
let try_idx = try_count.fetch_add(1, Ordering::Relaxed);
let checked_result = match resp_result {
Err(e) => Err(self_.process_error_response(try_idx, e)),
Ok(resp) => self_.process_ok_response(try_idx, resp),
};
let (reply_bytes, processing_result) = match checked_result {
Ok(ok_response) => {
let reply_bytes = ok_response.content_length().unwrap_or(0);
let prosess_fn_result = process_fn(ok_response).await;
(reply_bytes, prosess_fn_result)
},
Err(e) => (0, Err(e)),
};
if let Some(permit_holder) = &self_.connection_permit {
let mut permit_info = permit_holder.lock().await;
match &processing_result {
Ok(_) => {
if let Some(permit) = permit_info.permit.take() {
permit
.report_completion(
permit_info.transfer_size_if_known.unwrap_or(reply_bytes),
true,
)
.await;
}
},
Err(RetryableReqwestError::FatalError(_)) => {
if let Some(permit) = permit_info.permit.take() {
permit.report_completion(0, false).await;
}
},
Err(RetryableReqwestError::RetryableError(_)) => {
if let Some(permit) = permit_info.permit.as_ref() {
permit.report_retryable_failure().await;
}
},
}
}
processing_result
}
},
|err: &RetryableReqwestError| matches!(err, RetryableReqwestError::RetryableError(_)),
)
.await;
match result {
Ok(r) => Ok(r),
Err(RetryableReqwestError::FatalError(e)) => {
Err(e)
},
Err(RetryableReqwestError::RetryableError(e)) => {
if let Some(permit_holder) = &self_.connection_permit {
let mut permit_info = permit_holder.lock().await;
if let Some(permit) = permit_info.permit.take() {
permit.report_completion(0, false).await;
}
}
if self_.log_errors_as_info {
info!("No more retries; aborting: {e}");
} else {
error!("No more retries; aborting: {e}");
}
Err(e)
},
}
}
pub async fn run_and_extract_json<JsonDest, ReqFn, ReqFut>(self, make_request: ReqFn) -> Result<JsonDest>
where
JsonDest: for<'de> serde::Deserialize<'de>,
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
{
self.run_and_process(make_request, |resp: Response| {
async move {
let r: std::result::Result<JsonDest, reqwest::Error> = resp.json().await;
match r {
Ok(v) => Ok(v),
Err(e) => {
#[cfg(not(target_arch = "wasm32"))]
let is_connect = e.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if is_connect || e.is_decode() || e.is_body() || e.is_timeout() {
Err(RetryableReqwestError::RetryableError(e.into()))
} else {
Err(RetryableReqwestError::FatalError(e.into()))
}
},
}
}
})
.await
}
pub async fn run_and_extract_bytes<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Bytes>
where
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
{
self.run_and_process(make_request, |resp: Response| {
async move {
let r: std::result::Result<Bytes, reqwest::Error> = resp.bytes().await;
match r {
Ok(v) => Ok(v),
Err(e) => {
#[cfg(not(target_arch = "wasm32"))]
let is_connect = e.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if is_connect || e.is_decode() || e.is_body() || e.is_timeout() {
Err(RetryableReqwestError::RetryableError(e.into()))
} else {
Err(RetryableReqwestError::FatalError(e.into()))
}
},
}
}
})
.await
}
pub async fn run_and_extract_custom<ReqFn, ReqFut, Parse, ParseFut, Dest>(
self,
make_request: ReqFn,
parse: Parse,
) -> Result<Dest>
where
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
Parse: Fn(Response) -> ParseFut + Send + Sync + 'static,
ParseFut: std::future::Future<Output = std::result::Result<Dest, RetryableReqwestError>> + 'static,
{
self.run_and_process(make_request, parse).await
}
pub async fn run<ReqFut, ReqFn>(self, make_request: ReqFn) -> Result<Response>
where
ReqFn: Fn() -> ReqFut + Send + Sync + 'static,
ReqFut: std::future::Future<Output = std::result::Result<Response, reqwest_middleware::Error>> + 'static,
{
self.run_and_process(make_request, |resp| async move { Ok(resp) }).await
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Retryable {
Fatal,
Transient,
}
pub fn default_on_request_success(success: &Response) -> Option<Retryable> {
let status = success.status();
if status.is_server_error() || status == StatusCode::REQUEST_TIMEOUT || status == StatusCode::TOO_MANY_REQUESTS {
Some(Retryable::Transient)
} else if status.is_success() {
None
} else {
Some(Retryable::Fatal)
}
}
pub fn on_request_failure(error: &reqwest_middleware::Error) -> Option<Retryable> {
let reqwest_middleware::Error::Reqwest(error) = error else {
return Some(Retryable::Fatal);
};
#[cfg(not(target_arch = "wasm32"))]
let is_connect = error.is_connect();
#[cfg(target_arch = "wasm32")]
let is_connect = false;
if error.is_timeout() || is_connect {
Some(Retryable::Transient)
} else if error.is_body() || error.is_decode() || error.is_builder() || error.is_redirect() {
Some(Retryable::Fatal)
} else if error.is_request() {
#[cfg(not(target_arch = "wasm32"))]
if let Some(hyper_error) = get_source_error_type::<hyper::Error>(&error) {
let is_io_error = get_source_error_type::<std::io::Error>(hyper_error).is_some();
if hyper_error.is_incomplete_message() || hyper_error.is_canceled() || is_io_error {
Some(Retryable::Transient)
} else {
Some(Retryable::Fatal)
}
} else {
Some(Retryable::Fatal)
}
#[cfg(target_arch = "wasm32")]
Some(Retryable::Fatal)
} else {
None
}
}
#[cfg(not(target_arch = "wasm32"))]
fn get_source_error_type<T: std::error::Error + 'static>(err: &dyn std::error::Error) -> Option<&T> {
let mut source = err.source();
while let Some(err) = source {
if let Some(err) = err.downcast_ref::<T>() {
return Some(err);
}
source = err.source();
}
None
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicU32, Ordering};
use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
use serde::{Deserialize, Serialize};
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
use super::*;
fn connection_wrapper(api: &'static str) -> RetryWrapper {
RetryWrapper::new(api)
.with_base_delay(Duration::from_millis(5))
.with_max_attempts(3)
}
fn make_client() -> ClientWithMiddleware {
ClientBuilder::new(reqwest::Client::new()).build()
}
async fn check_success_first_try(server: &MockServer) {
let _guard = Mock::given(method("GET"))
.and(path("/success"))
.respond_with(ResponseTemplate::new(200))
.expect(1)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_success_first_try")
.run(move || {
let url = format!("{}/success", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_ok());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
async fn check_retry_then_success(server: &MockServer) {
let _guard1 = Mock::given(method("GET"))
.and(path("/flaky"))
.respond_with(ResponseTemplate::new(500))
.up_to_n_times(2)
.mount_as_scoped(server)
.await;
let _guard2 = Mock::given(method("GET"))
.and(path("/flaky"))
.respond_with(ResponseTemplate::new(200).set_body_string("Recovered"))
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_retry_then_success")
.run(move || {
let url = format!("{}/flaky", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(url).send()
})
.await;
assert!(result.is_ok());
assert_eq!(&result.unwrap().bytes().await.unwrap()[..], b"Recovered");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
async fn check_retry_limit_exceeded(server: &MockServer) {
let _guard = Mock::given(method("GET"))
.and(path("/fail"))
.respond_with(ResponseTemplate::new(500))
.expect(4) .mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_retry_limit_exceeded")
.with_max_attempts(3)
.run(move || {
let url = format!("{}/fail", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 4);
}
async fn check_non_retryable_status(server: &MockServer) {
let _guard = Mock::given(method("GET"))
.and(path("/bad_request"))
.respond_with(ResponseTemplate::new(400))
.expect(1)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_non_retryable_status")
.run(move || {
let url = format!("{}/bad_request", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
async fn check_429_retry_if_specified(server: &MockServer) {
let _guard = Mock::given(method("GET"))
.and(path("/rate_limit"))
.respond_with(ResponseTemplate::new(429))
.expect(4)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_429_retry_if_specified")
.with_max_attempts(3)
.run(move || {
let url = format!("{}/rate_limit", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 4);
}
async fn check_429_no_retry(server: &MockServer) {
let _guard = Mock::given(method("GET"))
.and(path("/rate_limit_no_retry"))
.respond_with(ResponseTemplate::new(429))
.expect(1)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let result = connection_wrapper("check_429_no_retry")
.with_max_attempts(3)
.with_429_no_retry()
.run(move || {
let url = format!("{}/rate_limit_no_retry", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct JsonData {
text: String,
number: u64,
}
async fn check_json_reserialization(server: &MockServer) {
let data = JsonData {
text: "test".into(),
number: 42,
};
let _guard = Mock::given(method("GET"))
.and(path("/json"))
.respond_with(ResponseTemplate::new(StatusCode::OK).set_body_json(&data))
.expect(1)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let ret_data: JsonData = connection_wrapper("check_json_reserialization")
.run_and_extract_json(move || {
let url = format!("{}/json", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await
.unwrap();
assert_eq!(ret_data, data);
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
async fn check_json_unexpected_eof_retry(server: &MockServer) {
let data = JsonData {
text: "test".into(),
number: 42,
};
let json_data = serde_json::to_string(&data).unwrap();
let _guard1 = Mock::given(method("GET"))
.and(path("/json_flaky"))
.respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data[..json_data.len() - 5]))
.up_to_n_times(1)
.mount_as_scoped(server)
.await;
let _guard2 = Mock::given(method("GET"))
.and(path("/json_flaky"))
.respond_with(ResponseTemplate::new(StatusCode::OK).set_body_string(&json_data))
.expect(1)
.mount_as_scoped(server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let server_uri = server.uri();
let ret_data: JsonData = connection_wrapper("check_json_unexpected_eof_retry")
.run_and_extract_json(move || {
let url = format!("{}/json_flaky", server_uri);
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await
.unwrap();
assert_eq!(ret_data, data);
assert_eq!(counter.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn test_retry_wrapper() {
let server = MockServer::start().await;
check_success_first_try(&server).await;
check_retry_then_success(&server).await;
check_retry_limit_exceeded(&server).await;
check_non_retryable_status(&server).await;
check_429_retry_if_specified(&server).await;
check_429_no_retry(&server).await;
check_json_reserialization(&server).await;
check_json_unexpected_eof_retry(&server).await;
}
#[tokio::test]
async fn test_403_no_retry_by_default() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/forbidden"))
.respond_with(ResponseTemplate::new(403))
.expect(1)
.mount(&server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let result = connection_wrapper("test_403_no_retry_by_default")
.with_max_attempts(3)
.run(move || {
let url = format!("{}/forbidden", server.uri());
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn test_403_retry_when_enabled() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/forbidden"))
.respond_with(ResponseTemplate::new(403))
.expect(4)
.mount(&server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let result = connection_wrapper("test_403_retry_when_enabled")
.with_max_attempts(3)
.with_retry_on_403()
.run(move || {
let url = format!("{}/forbidden", server.uri());
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_err());
assert_eq!(counter.load(Ordering::SeqCst), 4);
}
#[tokio::test]
async fn test_403_retry_then_success() {
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/forbidden_then_ok"))
.respond_with(ResponseTemplate::new(403))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path("/forbidden_then_ok"))
.respond_with(ResponseTemplate::new(200).set_body_string("Success"))
.mount(&server)
.await;
let client = make_client();
let counter = Arc::new(AtomicU32::new(0));
let counter_ = counter.clone();
let result = connection_wrapper("test_403_retry_then_success")
.with_max_attempts(3)
.with_retry_on_403()
.run(move || {
let url = format!("{}/forbidden_then_ok", server.uri());
counter_.fetch_add(1, Ordering::Relaxed);
client.clone().get(&url).send()
})
.await;
assert!(result.is_ok());
assert_eq!(&result.unwrap().bytes().await.unwrap()[..], b"Success");
assert_eq!(counter.load(Ordering::SeqCst), 3);
}
}