use crate::error::Error;
use reqwest::header::HeaderMap;
use reqwest::{RequestBuilder, Response, StatusCode};
use std::future::Future;
use std::time::Duration;
pub mod error;
pub use reqwest;
#[cfg(feature = "convenience")]
pub mod convenience;
#[derive(Debug)]
pub struct RetryResult<T> {
pub result: T,
pub status_code: StatusCode,
pub headers: HeaderMap,
}
pub enum RetryType {
Stop,
Retry,
RetryAfter(Duration),
}
pub async fn execute<T, F, G, H, I, FutG, FutI>(
make_builder: F,
check_done: G,
try_count: u8,
retry_duration: Duration,
get_jitter: H,
sleeper: I,
) -> Result<RetryResult<T>, Error>
where
F: Fn(u8) -> RequestBuilder,
G: Fn(Result<Response, reqwest::Error>) -> FutG,
H: Fn() -> Duration,
I: Fn(Duration) -> FutI,
FutG: Future<Output = Result<RetryResult<T>, RetryType>> + Send + 'static,
FutI: Future<Output = ()> + Send + 'static,
{
if try_count == 0 {
return Err(Error::NoTry);
}
for i in 0..try_count {
let builder = make_builder(i);
let response = builder.send().await;
let next_retry_duration = match check_done(response).await {
Ok(result) => return Ok(result),
Err(retry_type) => {
match retry_type {
RetryType::Retry => {
calc_retry_duration(retry_duration, get_jitter(), i as u32)
}
RetryType::RetryAfter(target_duration) => {
target_duration
}
RetryType::Stop => {
return Err(Error::Stop);
}
}
}
};
if i < try_count - 1 && next_retry_duration > Duration::ZERO {
sleeper(next_retry_duration).await;
}
}
Err(Error::TryOver)
}
fn calc_retry_duration(
retry_duration: Duration,
jitter_duration: Duration,
try_count: u32,
) -> Duration {
let retry_count = 2u64.pow(try_count) as u32;
retry_duration * retry_count + jitter_duration
}
#[cfg(test)]
mod tests {
use super::*;
fn make_builder_for_test(i: u8) -> RequestBuilder {
reqwest::Client::new()
.get("https://httpbin.org/get")
.header("Try-Count", i.to_string())
}
#[tokio::test]
async fn test_stop() {
match execute(
make_builder_for_test,
|_| async move { Err::<RetryResult<serde_json::Value>, RetryType>(RetryType::Stop) },
3,
Duration::from_secs(1),
|| Duration::from_millis(100),
|_| async move {},
)
.await
{
Err(Error::Stop) => {}
_ => {
panic!("Test failed: Expected TryOver error.");
}
}
}
#[tokio::test]
async fn test_over_try() {
match execute(
make_builder_for_test,
|_| async move { Err::<RetryResult<serde_json::Value>, RetryType>(RetryType::Retry) },
4,
Duration::from_secs(2),
|| Duration::from_millis(100),
|duration| async move { println!("Sleeping for {:?}", duration) },
)
.await
{
Err(Error::TryOver) => {}
_ => {
panic!("Test failed: Expected TryOver error.");
}
}
}
#[tokio::test]
async fn test_success() {
let check_done = |response: Result<Response, _>| async move {
let Ok(response) = response else {
return Err(RetryType::Retry); };
if !response.status().is_success() {
return Err(RetryType::Retry); }
let status = response.status();
let headers = response.headers().clone();
let Ok(json) = response.json::<serde_json::Value>().await else {
return Err(RetryType::Retry);
};
Ok(RetryResult {
result: json,
status_code: status,
headers,
})
};
match execute(
make_builder_for_test,
check_done,
3,
Duration::from_secs(1),
|| Duration::from_millis(100),
|_| async move {},
)
.await
{
Ok(res) => {
assert_eq!(res.status_code, StatusCode::OK);
}
Err(e) => {
panic!("Test failed: {:?}", e);
}
}
}
}