use std::time::{Duration, Instant};
use http::{Request, Response, StatusCode};
use tower::{
BoxError,
retry::{
Policy,
backoff::{Backoff, ExponentialBackoff, ExponentialBackoffMaker, MakeBackoff},
},
util::rng::HasherRng,
};
use super::Body;
pub use tower::retry::backoff::InvalidBackoff;
#[derive(Clone)]
pub struct RetryPolicy {
backoff: ExponentialBackoff,
current_attempt: u32,
max_retries: u32,
server_aware: bool,
}
impl RetryPolicy {
pub fn new(
min_delay: Duration,
max_delay: Duration,
max_retries: u32,
server_retry: bool,
) -> Result<Self, InvalidBackoff> {
let backoff =
ExponentialBackoffMaker::new(min_delay, max_delay, 2.0, HasherRng::new())?.make_backoff();
Ok(Self {
backoff,
current_attempt: 0,
max_retries,
server_aware: server_retry,
})
}
pub fn server_retry() -> Self {
Self::new(Duration::from_millis(5), Duration::from_secs(1000), 15, true)
.expect("default server RetryPolicy parameters are valid")
}
fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT
)
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::new(Duration::from_millis(500), Duration::from_secs(5), 3, false)
.expect("default RetryPolicy parameters are valid")
}
}
impl<Res> Policy<Request<Body>, Response<Res>, BoxError> for RetryPolicy {
type Future = tokio::time::Sleep;
fn retry(
&mut self,
_req: &mut Request<Body>,
result: &mut Result<Response<Res>, BoxError>,
) -> Option<Self::Future> {
match result {
Ok(response)
if Self::is_retryable_status(response.status())
&& self.current_attempt < self.max_retries =>
{
self.current_attempt += 1;
let backoff = self.backoff.next_backoff();
if self.server_aware
&& let Some(retry_after) = response.headers().get("Retry-After")
&& let Some(retry_after) = retry_after.to_str().ok()
&& let Some(retry_after) = retry_after.parse::<u64>().ok()
{
let server_delay = Duration::from_secs(retry_after);
let retry_after = Instant::now() + server_delay;
if backoff.deadline().le(&retry_after.into()) {
return Some(tokio::time::sleep(server_delay));
}
}
Some(backoff)
}
_ => None,
}
}
fn clone_request(&mut self, req: &Request<Body>) -> Option<Request<Body>> {
let body = req.body().try_clone()?;
let mut builder = Request::builder()
.method(req.method().clone())
.uri(req.uri().clone())
.version(req.version());
if let Some(headers) = builder.headers_mut() {
headers.extend(req.headers().clone());
}
builder.body(body).ok().map(|mut new_req| {
*new_req.extensions_mut() = req.extensions().clone();
new_req
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
}
#[test]
fn test_retryable_status() {
assert!(RetryPolicy::is_retryable_status(StatusCode::TOO_MANY_REQUESTS));
assert!(RetryPolicy::is_retryable_status(StatusCode::SERVICE_UNAVAILABLE));
assert!(RetryPolicy::is_retryable_status(StatusCode::GATEWAY_TIMEOUT));
assert!(!RetryPolicy::is_retryable_status(StatusCode::OK));
assert!(!RetryPolicy::is_retryable_status(StatusCode::BAD_REQUEST));
assert!(!RetryPolicy::is_retryable_status(
StatusCode::INTERNAL_SERVER_ERROR
));
assert!(!RetryPolicy::is_retryable_status(StatusCode::NOT_FOUND));
}
#[test]
fn test_invalid_backoff() {
let result = RetryPolicy::new(Duration::from_secs(10), Duration::from_secs(1), 3, false);
assert!(result.is_err());
}
#[test]
fn test_compile_server_retry() {
let policy = RetryPolicy::server_retry();
assert!(policy.server_aware);
assert_eq!(policy.max_retries, 15);
}
}