use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use reqwest::StatusCode;
use crate::query::RetryPolicy;
const HTTP_TOO_MANY_REQUESTS: StatusCode = StatusCode::TOO_MANY_REQUESTS;
pub(crate) async fn execute_retrying(
client: &reqwest::Client,
req: reqwest::Request,
retry: &RetryPolicy,
) -> reqwest::Result<reqwest::Response> {
let start = Instant::now();
for attempt in 0..=retry.max_retries {
let Some(clone) = req.try_clone() else {
return client.execute(req).await;
};
let resp = client.execute(clone).await?;
if resp.status() != HTTP_TOO_MANY_REQUESTS || attempt == retry.max_retries {
return Ok(resp);
}
let delay = backoff_delay(retry, attempt + 1, parse_retry_after(&resp));
if start.elapsed() + delay > retry.deadline {
return Ok(resp);
}
tokio::time::sleep(delay).await;
}
client.execute(req).await
}
pub(crate) fn backoff_delay(
retry: &RetryPolicy,
attempt: u32,
retry_after: Option<Duration>,
) -> Duration {
if let Some(ra) = retry_after {
return ra + ra.mul_f64(retry.jitter * jitter_fraction());
}
let factor = 2f64.powi(attempt.saturating_sub(1) as i32);
let base = retry.base_backoff.mul_f64(factor);
let with_jitter = base.mul_f64(1.0 + retry.jitter * jitter_fraction());
with_jitter.min(retry.max_backoff)
}
fn jitter_fraction() -> f64 {
match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(d) => (d.subsec_nanos() % 1_000) as f64 / 1_000.0,
Err(_) => 0.0,
}
}
pub(crate) fn parse_retry_after(resp: &reqwest::Response) -> Option<Duration> {
resp.headers()
.get(reqwest::header::RETRY_AFTER)
.and_then(|v| v.to_str().ok())
.and_then(retry_after_secs)
}
pub(crate) fn retry_after_secs(value: &str) -> Option<Duration> {
let secs = value.trim().parse::<f64>().ok()?;
Duration::try_from_secs_f64(secs).ok()
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn fast_retry(max_retries: u32) -> RetryPolicy {
RetryPolicy {
max_retries,
base_backoff: Duration::from_millis(1),
max_backoff: Duration::from_millis(5),
deadline: Duration::from_secs(30),
jitter: 0.0,
}
}
fn post_req(client: &reqwest::Client, url: &str) -> reqwest::Request {
client
.post(url)
.json(&json!({"k": "v"}))
.build()
.expect("request should build")
}
#[tokio::test]
async fn retries_two_429s_then_succeeds() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/thing"))
.respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "0"))
.up_to_n_times(2)
.mount(&server)
.await;
Mock::given(method("POST"))
.and(path("/thing"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({"ok": true})))
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/thing", server.uri());
let resp = execute_retrying(&client, post_req(&client, &url), &fast_retry(5))
.await
.expect("should succeed after retries");
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(server.received_requests().await.unwrap().len(), 3);
}
#[tokio::test]
async fn exhausts_after_max_retries() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/thing"))
.respond_with(ResponseTemplate::new(429))
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/thing", server.uri());
let resp = execute_retrying(&client, post_req(&client, &url), &fast_retry(2))
.await
.expect("should return the final 429, not a transport error");
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(server.received_requests().await.unwrap().len(), 3);
}
#[tokio::test]
async fn deadline_stops_retries_before_max_retries() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/thing"))
.respond_with(ResponseTemplate::new(429).insert_header("Retry-After", "100"))
.mount(&server)
.await;
let retry = RetryPolicy {
max_retries: 10,
base_backoff: Duration::from_millis(1),
max_backoff: Duration::from_secs(1),
deadline: Duration::from_millis(10),
jitter: 0.0,
};
let client = reqwest::Client::new();
let url = format!("{}/thing", server.uri());
let resp = execute_retrying(&client, post_req(&client, &url), &retry)
.await
.expect("should return the 429 after the deadline stops retries");
assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS);
assert_eq!(server.received_requests().await.unwrap().len(), 1);
}
#[tokio::test]
async fn non_429_is_returned_without_retry() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/thing"))
.respond_with(ResponseTemplate::new(400))
.mount(&server)
.await;
let client = reqwest::Client::new();
let url = format!("{}/thing", server.uri());
let resp = execute_retrying(&client, post_req(&client, &url), &fast_retry(5))
.await
.expect("should return the 400 without retrying");
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
assert_eq!(server.received_requests().await.unwrap().len(), 1);
}
#[test]
fn retry_after_secs_parses_and_rejects_malformed() {
assert_eq!(retry_after_secs("2"), Some(Duration::from_secs(2)));
assert_eq!(retry_after_secs(" 1.5 "), Some(Duration::from_secs_f64(1.5)));
assert_eq!(retry_after_secs("0"), Some(Duration::ZERO));
assert_eq!(retry_after_secs("inf"), None);
assert_eq!(retry_after_secs("nan"), None);
assert_eq!(retry_after_secs("1e30"), None);
assert_eq!(retry_after_secs("-5"), None);
assert_eq!(retry_after_secs("abc"), None);
assert_eq!(retry_after_secs(""), None);
}
#[test]
fn backoff_honors_retry_after_and_is_exponential() {
let retry = RetryPolicy {
base_backoff: Duration::from_secs(1),
max_backoff: Duration::from_secs(100),
jitter: 0.0,
..RetryPolicy::default()
};
assert_eq!(
backoff_delay(&retry, 1, Some(Duration::from_secs(7))),
Duration::from_secs(7)
);
assert_eq!(backoff_delay(&retry, 1, None), Duration::from_secs(1));
assert_eq!(backoff_delay(&retry, 2, None), Duration::from_secs(2));
assert_eq!(backoff_delay(&retry, 3, None), Duration::from_secs(4));
}
}