use std::time::Duration;
use http::{Request, Response, StatusCode};
use tower::{
BoxError,
retry::{
Policy,
backoff::{Backoff, ExponentialBackoff, ExponentialBackoffMaker, MakeBackoff},
},
util::rng::HasherRng,
};
use super::Body;
pub use tower::retry::backoff::InvalidBackoff;
#[derive(Clone)]
pub struct RetryPolicy {
backoff: ExponentialBackoff,
current_attempt: u32,
max_retries: u32,
}
impl RetryPolicy {
pub fn new(min_delay: Duration, max_delay: Duration, max_retries: u32) -> Result<Self, InvalidBackoff> {
let backoff =
ExponentialBackoffMaker::new(min_delay, max_delay, 2.0, HasherRng::new())?.make_backoff();
Ok(Self {
backoff,
current_attempt: 0,
max_retries,
})
}
fn is_retryable_status(status: StatusCode) -> bool {
matches!(
status,
StatusCode::TOO_MANY_REQUESTS | StatusCode::SERVICE_UNAVAILABLE | StatusCode::GATEWAY_TIMEOUT
)
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::new(Duration::from_millis(500), Duration::from_secs(5), 3)
.expect("default RetryPolicy parameters are valid")
}
}
impl<Res> Policy<Request<Body>, Response<Res>, BoxError> for RetryPolicy {
type Future = tokio::time::Sleep;
fn retry(
&mut self,
_req: &mut Request<Body>,
result: &mut Result<Response<Res>, BoxError>,
) -> Option<Self::Future> {
match result {
Ok(response)
if Self::is_retryable_status(response.status())
&& self.current_attempt < self.max_retries =>
{
self.current_attempt += 1;
Some(self.backoff.next_backoff())
}
_ => None,
}
}
fn clone_request(&mut self, req: &Request<Body>) -> Option<Request<Body>> {
let body = req.body().try_clone()?;
let mut builder = Request::builder()
.method(req.method().clone())
.uri(req.uri().clone())
.version(req.version());
if let Some(headers) = builder.headers_mut() {
headers.extend(req.headers().clone());
}
builder.body(body).ok().map(|mut new_req| {
*new_req.extensions_mut() = req.extensions().clone();
new_req
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_policy() {
let policy = RetryPolicy::default();
assert_eq!(policy.max_retries, 3);
}
#[test]
fn test_retryable_status() {
assert!(RetryPolicy::is_retryable_status(StatusCode::TOO_MANY_REQUESTS));
assert!(RetryPolicy::is_retryable_status(StatusCode::SERVICE_UNAVAILABLE));
assert!(RetryPolicy::is_retryable_status(StatusCode::GATEWAY_TIMEOUT));
assert!(!RetryPolicy::is_retryable_status(StatusCode::OK));
assert!(!RetryPolicy::is_retryable_status(StatusCode::BAD_REQUEST));
assert!(!RetryPolicy::is_retryable_status(
StatusCode::INTERNAL_SERVER_ERROR
));
assert!(!RetryPolicy::is_retryable_status(StatusCode::NOT_FOUND));
}
#[test]
fn test_invalid_backoff() {
let result = RetryPolicy::new(Duration::from_secs(10), Duration::from_secs(1), 3);
assert!(result.is_err());
}
}