use reqwest::{blocking::Response, Result};
use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
use std::thread::sleep;
use std::time::Duration;
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum RetryStrategy {
Automatic,
Always,
}
#[derive(Clone, Debug, PartialEq)]
pub struct RetryConfig {
pub strategy: RetryStrategy,
pub max_retry_count: u8,
pub base_wait: Duration,
pub backoff_factor: f64,
}
#[derive(Debug)]
pub(crate) struct Retrier {
config: RetryConfig,
is_first_request: AtomicBool,
}
impl Retrier {
pub fn new(config: RetryConfig) -> Self {
Self {
config,
is_first_request: AtomicBool::new(true),
}
}
pub fn with_retries(&self, send_request: impl Fn() -> Result<Response>) -> Result<Response> {
if self.is_first_request.swap(false, SeqCst)
&& self.config.strategy == RetryStrategy::Automatic
{
return send_request();
}
for i_retry in 0..self.config.max_retry_count {
macro_rules! warn_and_sleep {
($src:expr) => {{
let wait_factor = self.config.backoff_factor.powi(i_retry.into());
let duration = self.config.base_wait.mul_f64(wait_factor);
log::warn!("{} - retrying after {:?}.", $src, duration);
sleep(duration)
}};
};
match send_request() {
Ok(response) if response.status().is_server_error() => {
warn_and_sleep!(format!("{} for {}", response.status(), response.url()))
}
Err(error) if error.is_timeout() => warn_and_sleep!(error),
result => return result,
}
}
send_request()
}
}
#[cfg(test)]
mod tests {
use super::{Retrier, RetryConfig, RetryStrategy};
use mockito::{mock, server_address};
use reqwest::blocking::{get, Client};
use std::thread::sleep;
use std::time::Duration;
#[test]
fn test_always_retry() {
let mut handler = Retrier::new(RetryConfig {
strategy: RetryStrategy::Always,
max_retry_count: 5,
base_wait: Duration::from_secs(0),
backoff_factor: 0.0,
});
let ok = mock("GET", "/").expect(1).create();
assert!(
handler
.with_retries(|| get(&format!("http://{}", server_address())))
.unwrap()
.status()
== 200
);
ok.assert();
for i_retry in 0..10 {
let err = mock("GET", "/")
.with_status(500)
.expect((i_retry + 1).into())
.create();
handler.config.max_retry_count = i_retry;
assert!(
handler
.with_retries(|| get(&format!("http://{}", server_address())))
.unwrap()
.status()
== 500
);
err.assert();
}
}
#[test]
fn test_automatic_retry() {
let mut handler = Retrier::new(RetryConfig {
strategy: RetryStrategy::Automatic,
max_retry_count: 5,
base_wait: Duration::from_secs(0),
backoff_factor: 0.0,
});
let err = mock("GET", "/").with_status(500).expect(1).create();
assert!(
handler
.with_retries(|| get(&format!("http://{}", server_address())))
.unwrap()
.status()
== 500
);
err.assert();
let ok = mock("GET", "/").expect(1).create();
assert!(
handler
.with_retries(|| get(&format!("http://{}", server_address())))
.unwrap()
.status()
== 200
);
ok.assert();
for i_retry in 0..10 {
let err = mock("GET", "/")
.with_status(500)
.expect((i_retry + 1).into())
.create();
handler.config.max_retry_count = i_retry;
assert!(
handler
.with_retries(|| get(&format!("http://{}", server_address())))
.unwrap()
.status()
== 500
);
err.assert();
}
}
#[test]
fn test_timeout_retry() {
let handler = Retrier::new(RetryConfig {
strategy: RetryStrategy::Always,
max_retry_count: 1,
base_wait: Duration::from_secs(0),
backoff_factor: 0.0,
});
let timeout = mock("GET", "/")
.with_body_from_fn(|_| {
sleep(Duration::from_secs_f64(0.2));
Ok(())
})
.expect(2)
.create();
let client = Client::new();
assert!(handler
.with_retries(|| client
.get(&format!("http://{}", server_address()))
.timeout(Duration::from_secs_f64(0.1))
.send()
.and_then(|r| {
let _ = r.text()?;
unreachable!()
}))
.unwrap_err()
.is_timeout());
timeout.assert();
}
}