use std::time::{Duration, SystemTime};
use crate::retryable_strategy::RetryableStrategy;
use crate::{retryable::Retryable, retryable_strategy::DefaultRetryableStrategy, RetryError};
use anyhow::anyhow;
use http::Extensions;
use reqwest::{Request, Response};
use reqwest_middleware::{Error, Middleware, Next, Result};
use retry_policies::RetryPolicy;
#[doc(hidden)]
#[cfg(feature = "tracing")]
macro_rules! log_retry {
($level:expr, $($args:tt)*) => {{
match $level {
::tracing::Level::TRACE => ::tracing::trace!($($args)*),
::tracing::Level::DEBUG => ::tracing::debug!($($args)*),
::tracing::Level::INFO => ::tracing::info!($($args)*),
::tracing::Level::WARN => ::tracing::warn!($($args)*),
::tracing::Level::ERROR => ::tracing::error!($($args)*),
}
}};
}
pub struct RetryTransientMiddleware<
T: RetryPolicy + Send + Sync + 'static,
R: RetryableStrategy + Send + Sync + 'static = DefaultRetryableStrategy,
> {
retry_policy: T,
retryable_strategy: R,
#[cfg(feature = "tracing")]
retry_log_level: tracing::Level,
}
impl<T: RetryPolicy + Send + Sync> RetryTransientMiddleware<T, DefaultRetryableStrategy> {
pub fn new_with_policy(retry_policy: T) -> Self {
Self::new_with_policy_and_strategy(retry_policy, DefaultRetryableStrategy)
}
#[cfg(feature = "tracing")]
pub fn with_retry_log_level(mut self, level: tracing::Level) -> Self {
self.retry_log_level = level;
self
}
}
impl<T, R> RetryTransientMiddleware<T, R>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync,
{
pub fn new_with_policy_and_strategy(retry_policy: T, retryable_strategy: R) -> Self {
Self {
retry_policy,
retryable_strategy,
#[cfg(feature = "tracing")]
retry_log_level: tracing::Level::WARN,
}
}
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait::async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait::async_trait(?Send))]
impl<T, R> Middleware for RetryTransientMiddleware<T, R>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync + 'static,
{
async fn handle(
&self,
req: Request,
extensions: &mut Extensions,
next: Next<'_>,
) -> Result<Response> {
self.execute_with_retry(req, next, extensions).await
}
}
impl<T, R> RetryTransientMiddleware<T, R>
where
T: RetryPolicy + Send + Sync,
R: RetryableStrategy + Send + Sync,
{
async fn execute_with_retry<'a>(
&'a self,
req: Request,
next: Next<'a>,
ext: &'a mut Extensions,
) -> Result<Response> {
let mut n_past_retries = 0;
let start_time = SystemTime::now();
loop {
let duplicate_request = req.try_clone().ok_or_else(|| {
Error::Middleware(anyhow!(
"Request object is not cloneable. Are you passing a streaming body?"
.to_string()
))
})?;
let result = next.clone().run(duplicate_request, ext).await;
if let Some(Retryable::Transient) = self.retryable_strategy.handle(&result) {
let retry_decision = self.retry_policy.should_retry(start_time, n_past_retries);
if let retry_policies::RetryDecision::Retry { execute_after } = retry_decision {
let duration = execute_after
.duration_since(SystemTime::now())
.unwrap_or_else(|_| Duration::default());
#[cfg(feature = "tracing")]
log_retry!(
self.retry_log_level,
"Retry attempt #{}. Sleeping {:?} before the next attempt",
n_past_retries,
duration
);
#[cfg(not(target_arch = "wasm32"))]
tokio::time::sleep(duration).await;
#[cfg(target_arch = "wasm32")]
wasmtimer::tokio::sleep(duration).await;
n_past_retries += 1;
continue;
}
};
break if n_past_retries > 0 {
result.map_err(|err| {
Error::Middleware(
RetryError::WithRetries {
retries: n_past_retries,
err,
}
.into(),
)
})
} else {
result.map_err(|err| Error::Middleware(RetryError::Error(err).into()))
};
}
}
}