use std::future::Future;
use std::pin::Pin;
use std::time::Duration;
use tower::retry::Policy;
use crate::error::BbmError;
pub const DEFAULT_MAX_ATTEMPTS: u32 = 3;
#[derive(Debug, Clone)]
pub struct RetryPolicy {
pub max_attempts: u32,
pub current_attempt: u32,
}
impl RetryPolicy {
pub fn new(max_attempts: u32) -> Self {
Self {
max_attempts,
current_attempt: 0,
}
}
pub fn backoff_duration(&self) -> Duration {
let base = Duration::from_millis(200);
let factor = 2u64.saturating_pow(self.current_attempt.saturating_sub(1));
base * factor as u32
}
pub fn is_retryable(err: &BbmError) -> bool {
match err {
BbmError::Http(e) => {
if e.is_timeout() || e.is_connect() {
return true;
}
if let Some(status) = e.status() {
return status.is_server_error();
}
false
}
BbmError::Io(_) => true,
BbmError::Api(_) | BbmError::Json(_) | BbmError::TestFailed(_) => false,
}
}
}
impl Default for RetryPolicy {
fn default() -> Self {
Self::new(DEFAULT_MAX_ATTEMPTS)
}
}
impl<Req, Res> Policy<Req, Res, BbmError> for RetryPolicy
where
Req: Clone,
{
type Future = Pin<Box<dyn Future<Output = ()> + Send>>;
fn retry(
&mut self,
_req: &mut Req,
result: &mut Result<Res, BbmError>,
) -> Option<Self::Future> {
match result {
Ok(_) => None,
Err(err) => {
if !Self::is_retryable(err) {
return None;
}
if self.current_attempt >= self.max_attempts {
return None;
}
self.current_attempt += 1;
let backoff = self.backoff_duration();
log::warn!(
"Retryable error (attempt {}/{}), retrying after {:?}: {}",
self.current_attempt,
self.max_attempts,
backoff,
err,
);
Some(Box::pin(async move {
tokio::time::sleep(backoff).await;
}))
}
}
}
fn clone_request(&mut self, req: &Req) -> Option<Req> {
Some(req.clone())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn retryable_errors() {
let io_err = BbmError::Io(std::io::Error::new(
std::io::ErrorKind::ConnectionReset,
"connection reset",
));
assert!(RetryPolicy::is_retryable(&io_err));
let api_err = BbmError::Api("bad request".into());
assert!(!RetryPolicy::is_retryable(&api_err));
let test_err = BbmError::TestFailed("test failed".into());
assert!(!RetryPolicy::is_retryable(&test_err));
let json_err: BbmError = serde_json::from_str::<serde_json::Value>("not json")
.unwrap_err()
.into();
assert!(!RetryPolicy::is_retryable(&json_err));
let client = reqwest::Client::new();
let req_err = client.get("not a url").build().unwrap_err();
let http_err = BbmError::Http(req_err);
assert!(!RetryPolicy::is_retryable(&http_err));
}
#[test]
fn backoff_durations() {
let p0 = RetryPolicy {
max_attempts: 3,
current_attempt: 0,
};
assert_eq!(p0.backoff_duration(), Duration::from_millis(200));
let p1 = RetryPolicy {
max_attempts: 3,
current_attempt: 1,
};
assert_eq!(p1.backoff_duration(), Duration::from_millis(200));
let p2 = RetryPolicy {
max_attempts: 3,
current_attempt: 2,
};
assert_eq!(p2.backoff_duration(), Duration::from_millis(400));
let p3 = RetryPolicy {
max_attempts: 3,
current_attempt: 3,
};
assert_eq!(p3.backoff_duration(), Duration::from_millis(800));
}
}