use std::panic::RefUnwindSafe;
use std::time::Duration;
use futures::StreamExt;
use pact_models::interaction::Interaction;
use reqwest::RequestBuilder;
use tokio::time::sleep;
use tracing::{trace, warn};
fn is_retryable(status: reqwest::StatusCode) -> bool {
status.is_server_error()
|| status == reqwest::StatusCode::TOO_MANY_REQUESTS
|| status == reqwest::StatusCode::REQUEST_TIMEOUT
}
pub(crate) fn compute_retry_delay(
status: reqwest::StatusCode,
retry_after: Option<std::time::Duration>,
attempt: u32,
) -> Duration {
if status == reqwest::StatusCode::TOO_MANY_REQUESTS {
if let Some(base) = retry_after {
let secs = base.as_secs();
let extra = (secs / 5).min(60);
return Duration::from_secs(secs + extra);
}
}
Duration::from_millis(500 * 2_u64.pow(attempt.saturating_sub(1)))
}
fn parse_retry_after(response: &reqwest::Response) -> Option<std::time::Duration> {
let header_value = response
.headers()
.get(reqwest::header::RETRY_AFTER)?
.to_str()
.ok()?;
if let Ok(secs) = header_value.trim().parse::<u64>() {
return Some(std::time::Duration::from_secs(secs));
}
if let Ok(system_time) = httpdate::parse_http_date(header_value) {
let delay = system_time
.duration_since(std::time::SystemTime::now())
.unwrap_or_default();
return Some(delay);
}
None
}
pub(crate) async fn with_retries(retries: u8, request: RequestBuilder) -> Result<reqwest::Response, reqwest::Error> {
match &request.try_clone() {
None => {
warn!("with_retries: Could not retry the request as it is not cloneable");
request.send().await
}
Some(_) => {
if retries == 0 {
return request.send().await;
}
futures::stream::iter((1..=retries).step_by(1))
.fold((None::<Result<reqwest::Response, reqwest::Error>>, request.try_clone()), |(response, request), attempt| {
async move {
match request {
Some(request_builder) => match response {
None => {
let next = request_builder.try_clone();
(Some(request_builder.send().await), next)
},
Some(response) => {
trace!("with_retries: attempt {}/{} is {:?}", attempt, retries, response);
match response {
Ok(ref res) => if is_retryable(res.status()) {
match request_builder.try_clone() {
None => (Some(response), None),
Some(rb) => {
let delay = compute_retry_delay(res.status(), parse_retry_after(res), attempt as u32);
sleep(delay).await;
(Some(request_builder.send().await), Some(rb))
}
}
} else {
(Some(response), None)
},
Err(ref err) => if err.is_status() {
if err.status().map(is_retryable).unwrap_or(false) {
match request_builder.try_clone() {
None => (Some(response), None),
Some(rb) => {
sleep(Duration::from_millis(10_u64.pow(attempt as u32))).await;
(Some(request_builder.send().await), Some(rb))
}
}
} else {
(Some(response), None)
}
} else {
(Some(response), None)
}
}
}
}
None => (response, None)
}
}
}).await.0.unwrap()
}
}
}
pub(crate) fn as_safe_ref(interaction: &dyn Interaction) -> Box<dyn Interaction + Send + Sync + RefUnwindSafe> {
if let Some(v4) = interaction.as_v4_sync_message() {
Box::new(v4)
} else if let Some(v4) = interaction.as_v4_async_message() {
Box::new(v4)
} else {
let v4 = interaction.as_v4_http().unwrap();
Box::new(v4)
}
}
#[cfg(test)]
mod tests {
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
use axum::response::IntoResponse;
use axum::routing::get;
use axum::Router;
use tokio::net::TcpListener;
use super::{compute_retry_delay, is_retryable, with_retries};
#[test]
fn compute_retry_delay_429_without_retry_after_uses_exponential_backoff() {
let delay = compute_retry_delay(reqwest::StatusCode::TOO_MANY_REQUESTS, None, 2);
assert_eq!(delay, Duration::from_millis(1000));
}
#[test]
fn compute_retry_delay_starts_at_500ms_on_first_attempt() {
let delay = compute_retry_delay(reqwest::StatusCode::INTERNAL_SERVER_ERROR, None, 1);
assert_eq!(delay, Duration::from_millis(500));
}
#[test]
fn compute_retry_delay_429_with_retry_after_10_adds_20_percent() {
let delay = compute_retry_delay(reqwest::StatusCode::TOO_MANY_REQUESTS, Some(Duration::from_secs(10)), 1);
assert_eq!(delay, Duration::from_secs(12));
}
#[test]
fn compute_retry_delay_429_with_retry_after_400_caps_extra_at_60() {
let delay = compute_retry_delay(reqwest::StatusCode::TOO_MANY_REQUESTS, Some(Duration::from_secs(400)), 1);
assert_eq!(delay, Duration::from_secs(460));
}
#[test]
fn compute_retry_delay_429_with_retry_after_300_boundary_case() {
let delay = compute_retry_delay(reqwest::StatusCode::TOO_MANY_REQUESTS, Some(Duration::from_secs(300)), 1);
assert_eq!(delay, Duration::from_secs(360));
}
#[test]
fn compute_retry_delay_5xx_ignores_retry_after_uses_exponential_backoff() {
let delay = compute_retry_delay(reqwest::StatusCode::INTERNAL_SERVER_ERROR, Some(Duration::from_secs(99999)), 3);
assert_eq!(delay, Duration::from_millis(2000));
}
#[test]
fn is_retryable_returns_true_for_500() {
assert!(is_retryable(reqwest::StatusCode::INTERNAL_SERVER_ERROR));
}
#[test]
fn is_retryable_returns_true_for_429() {
assert!(is_retryable(reqwest::StatusCode::TOO_MANY_REQUESTS));
}
#[test]
fn is_retryable_returns_true_for_408() {
assert!(is_retryable(reqwest::StatusCode::REQUEST_TIMEOUT));
}
#[test]
fn is_retryable_returns_false_for_404() {
assert!(!is_retryable(reqwest::StatusCode::NOT_FOUND));
}
#[test]
fn is_retryable_returns_false_for_200() {
assert!(!is_retryable(reqwest::StatusCode::OK));
}
async fn spawn_server<F>(handler: F) -> (String, Arc<AtomicUsize>)
where
F: Fn(usize) -> axum::response::Response + Send + Sync + 'static,
{
let counter = Arc::new(AtomicUsize::new(0));
let counter_clone = counter.clone();
let handler = Arc::new(handler);
type HandlerState = (Arc<AtomicUsize>, Arc<dyn Fn(usize) -> axum::response::Response + Send + Sync>);
let app = Router::new()
.route(
"/",
get(|axum::extract::State(state): axum::extract::State<HandlerState>| async move {
let (ctr, h) = state;
let n = ctr.fetch_add(1, Ordering::SeqCst);
h(n)
}),
)
.with_state((counter_clone, handler as Arc<dyn Fn(usize) -> axum::response::Response + Send + Sync>));
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
(format!("http://{}/", addr), counter)
}
#[tokio::test]
async fn with_retries_retries_429_responses() {
let (url, call_count) = spawn_server(|n| {
if n < 2 {
axum::http::StatusCode::TOO_MANY_REQUESTS.into_response()
} else {
axum::http::StatusCode::OK.into_response()
}
})
.await;
let client = reqwest::Client::new();
let request = client.get(&url);
let result = with_retries(3, request).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn with_retries_does_not_retry_404_responses() {
let (url, call_count) = spawn_server(|_| axum::http::StatusCode::NOT_FOUND.into_response()).await;
let client = reqwest::Client::new();
let request = client.get(&url);
let result = with_retries(3, request).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), 404);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn with_retries_zero_retries_sends_exactly_once_without_panic() {
let (url, call_count) =
spawn_server(|_| axum::http::StatusCode::TOO_MANY_REQUESTS.into_response()).await;
let client = reqwest::Client::new();
let request = client.get(&url);
let result = with_retries(0, request).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), 429);
assert_eq!(call_count.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn http_date_retry_after_is_parsed() {
let (url, call_count) = spawn_server(|n| {
if n < 1 {
(
axum::http::StatusCode::TOO_MANY_REQUESTS,
[(axum::http::header::RETRY_AFTER, "Thu, 01 Jan 1970 00:00:00 GMT")],
)
.into_response()
} else {
axum::http::StatusCode::OK.into_response()
}
})
.await;
let client = reqwest::Client::new();
let request = client.get(&url);
let result = with_retries(3, request).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().status(), 200);
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
}