use std::future::Future;
use std::pin::Pin;
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
use bytes::Bytes;
use http::{Request, Response, StatusCode};
use http_body::Body as _;
use http_body::Frame;
use http_body_util::BodyExt;
use tower::Service;
use crate::http::{Body, BoxError, HttpService, full_body};
const DEFAULT_MAX_BACKOFF: Duration = Duration::from_secs(30);
const DEFAULT_BUDGET_WINDOW: Duration = Duration::from_secs(10);
const DEFAULT_BUDGET_MIN_RETRIES: u32 = 30;
struct BudgetState {
ratio: f64,
min_retries: u32,
window: Duration,
requests: u64,
retries: u64,
window_start: Instant,
}
impl BudgetState {
fn new(ratio: f64, min_retries: u32, window: Duration) -> Self {
Self {
ratio,
min_retries,
window,
requests: 0,
retries: 0,
window_start: Instant::now(),
}
}
fn maybe_reset_window(&mut self) {
if self.window_start.elapsed() >= self.window {
self.requests = 0;
self.retries = 0;
self.window_start = Instant::now();
}
}
fn record_request(&mut self) {
self.maybe_reset_window();
self.requests += 1;
}
fn allows_retry(&mut self) -> bool {
self.maybe_reset_window();
if self.retries < self.min_retries as u64 {
return true;
}
let total = self.requests + self.retries;
if total == 0 {
return true;
}
(self.retries as f64 / total as f64) < self.ratio
}
fn record_retry(&mut self) {
self.retries += 1;
}
}
type HeadersPolicy = Arc<dyn Fn(&http::response::Parts, u32) -> Option<Duration> + Send + Sync>;
type BodyPolicy = Arc<dyn Fn(&Response<Bytes>, u32) -> Option<Duration> + Send + Sync>;
type BufferedHttpService =
tower::buffer::Buffer<Request<Body>, <HttpService as Service<Request<Body>>>::Future>;
enum PolicyKind {
Headers(HeadersPolicy),
Body(BodyPolicy),
}
impl Clone for PolicyKind {
fn clone(&self) -> Self {
match self {
Self::Headers(f) => Self::Headers(f.clone()),
Self::Body(f) => Self::Body(f.clone()),
}
}
}
pub struct Retry {
statuses: Vec<StatusCode>,
max_retries: u32,
backoff: Duration,
max_backoff: Duration,
policy: Option<PolicyKind>,
max_replay_body_bytes: usize,
budget_ratio: Option<f64>,
budget_window: Duration,
budget_min_retries: u32,
budget_state: Option<Arc<Mutex<BudgetState>>>,
}
impl Clone for Retry {
fn clone(&self) -> Self {
Self {
statuses: self.statuses.clone(),
max_retries: self.max_retries,
backoff: self.backoff,
max_backoff: self.max_backoff,
policy: self.policy.clone(),
max_replay_body_bytes: self.max_replay_body_bytes,
budget_ratio: self.budget_ratio,
budget_window: self.budget_window,
budget_min_retries: self.budget_min_retries,
budget_state: self.budget_state.clone(),
}
}
}
impl Retry {
pub fn on_status<S: TryInto<StatusCode>>(status: S) -> Self
where
S::Error: std::fmt::Debug,
{
Self {
statuses: vec![status.try_into().expect("invalid status code")],
max_retries: 3,
backoff: Duration::from_secs(1),
max_backoff: DEFAULT_MAX_BACKOFF,
policy: None,
max_replay_body_bytes: 1024 * 1024,
budget_ratio: None,
budget_window: DEFAULT_BUDGET_WINDOW,
budget_min_retries: DEFAULT_BUDGET_MIN_RETRIES,
budget_state: None,
}
}
pub fn on_statuses<S>(statuses: impl IntoIterator<Item = S>) -> Self
where
S: TryInto<StatusCode>,
S::Error: std::fmt::Debug,
{
Self {
statuses: statuses
.into_iter()
.map(|s| s.try_into().expect("invalid status code"))
.collect(),
max_retries: 3,
backoff: Duration::from_secs(1),
max_backoff: DEFAULT_MAX_BACKOFF,
policy: None,
max_replay_body_bytes: 1024 * 1024,
budget_ratio: None,
budget_window: DEFAULT_BUDGET_WINDOW,
budget_min_retries: DEFAULT_BUDGET_MIN_RETRIES,
budget_state: None,
}
}
pub fn max_retries(mut self, n: u32) -> Self {
self.max_retries = n;
self
}
pub fn backoff(mut self, base: Duration) -> Self {
self.backoff = base;
self
}
pub fn max_backoff(mut self, max: Duration) -> Self {
self.max_backoff = max;
self
}
pub fn max_replay_body_bytes(mut self, bytes: usize) -> Self {
self.max_replay_body_bytes = bytes;
self
}
pub fn policy_headers<F>(mut self, f: F) -> Self
where
F: Fn(&http::response::Parts, u32) -> Option<Duration> + Send + Sync + 'static,
{
self.policy = Some(PolicyKind::Headers(Arc::new(f)));
self
}
pub fn policy<F>(mut self, f: F) -> Self
where
F: Fn(&Response<Bytes>, u32) -> Option<Duration> + Send + Sync + 'static,
{
self.policy = Some(PolicyKind::Body(Arc::new(f)));
self
}
pub fn budget(mut self, ratio: f64) -> Self {
self.budget_ratio = Some(ratio);
self.budget_state = Some(Arc::new(Mutex::new(BudgetState::new(
ratio,
self.budget_min_retries,
self.budget_window,
))));
self
}
pub fn budget_window(mut self, window: Duration) -> Self {
self.budget_window = window;
if let Some(ratio) = self.budget_ratio {
self.budget_state = Some(Arc::new(Mutex::new(BudgetState::new(
ratio,
self.budget_min_retries,
window,
))));
}
self
}
pub fn budget_min_retries(mut self, n: u32) -> Self {
self.budget_min_retries = n;
if let Some(ratio) = self.budget_ratio {
self.budget_state = Some(Arc::new(Mutex::new(BudgetState::new(
ratio,
n,
self.budget_window,
))));
}
self
}
}
impl Default for Retry {
fn default() -> Self {
Self {
statuses: vec![
StatusCode::TOO_MANY_REQUESTS,
StatusCode::BAD_GATEWAY,
StatusCode::SERVICE_UNAVAILABLE,
StatusCode::GATEWAY_TIMEOUT,
],
max_retries: 3,
backoff: Duration::from_secs(1),
max_backoff: DEFAULT_MAX_BACKOFF,
policy: None,
max_replay_body_bytes: 1024 * 1024,
budget_ratio: None,
budget_window: DEFAULT_BUDGET_WINDOW,
budget_min_retries: DEFAULT_BUDGET_MIN_RETRIES,
budget_state: None,
}
}
}
impl tower::Layer<HttpService> for Retry {
type Service = RetryService;
fn layer(&self, inner: HttpService) -> Self::Service {
RetryService {
inner: tower::buffer::Buffer::new(inner, 1024),
statuses: self.statuses.clone(),
max_retries: self.max_retries,
backoff: self.backoff,
max_backoff: self.max_backoff,
policy: self.policy.clone(),
max_replay_body_bytes: self.max_replay_body_bytes,
budget: self.budget_state.clone(),
}
}
}
pub struct RetryService {
inner: BufferedHttpService,
statuses: Vec<StatusCode>,
max_retries: u32,
backoff: Duration,
max_backoff: Duration,
policy: Option<PolicyKind>,
max_replay_body_bytes: usize,
budget: Option<Arc<Mutex<BudgetState>>>,
}
impl Service<Request<Body>> for RetryService {
type Response = Response<Body>;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<Response<Body>, BoxError>> + Send>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, req: Request<Body>) -> Self::Future {
let mut inner = self.inner.clone();
let statuses = self.statuses.clone();
let max_retries = self.max_retries;
let base_backoff = self.backoff;
let max_backoff = self.max_backoff;
let policy = self.policy.clone();
let max_replay_body_bytes = self.max_replay_body_bytes;
let budget = self.budget.clone();
Box::pin(async move {
let (parts, body) = req.into_parts();
let method = parts.method;
let uri = parts.uri;
let version = parts.version;
let headers = parts.headers;
let capture = Arc::new(Mutex::new(ReplayCapture::new(max_replay_body_bytes)));
let body_known_empty = body.size_hint().exact() == Some(0);
let mut first_body = Some(body);
let mut replay_bytes: Option<Bytes> = if body_known_empty {
Some(Bytes::new())
} else {
None
};
if let Some(ref budget) = budget {
budget.lock().unwrap().record_request();
}
for attempt in 0..=max_retries {
let mut builder = Request::builder()
.method(method.clone())
.uri(uri.clone())
.version(version);
*builder.headers_mut().unwrap() = headers.clone();
let req_body = if attempt == 0 {
let body = first_body.take().unwrap_or_else(crate::http::empty_body);
if body_known_empty {
body
} else {
RecordingBody::new(body, capture.clone()).boxed()
}
} else {
full_body(replay_bytes.clone().unwrap_or_default())
};
let req = builder.body(req_body).unwrap();
std::future::poll_fn(|cx| inner.poll_ready(cx)).await?;
let resp = inner.call(req).await?;
match &policy {
Some(PolicyKind::Body(f)) => {
let (resp_parts, resp_body) = resp.into_parts();
let resp_bytes = resp_body.collect().await?.to_bytes();
let buffered = Response::from_parts(resp_parts, resp_bytes);
if let Some(delay) = f(&buffered, attempt)
&& attempt < max_retries
{
if replay_bytes.is_none() {
replay_bytes = ReplayCapture::snapshot(&capture);
}
if replay_bytes.is_none() {
let (parts, bytes) = buffered.into_parts();
return Ok(Response::from_parts(parts, full_body(bytes)));
}
if let Some(ref budget) = budget {
let mut b = budget.lock().unwrap();
if !b.allows_retry() {
let (parts, bytes) = buffered.into_parts();
return Ok(Response::from_parts(parts, full_body(bytes)));
}
b.record_retry();
}
tracing::debug!(
status = %buffered.status(),
attempt = attempt + 1,
max = max_retries,
delay_ms = delay.as_millis() as u64,
"retrying request"
);
tokio::time::sleep(delay).await;
continue;
}
let (parts, bytes) = buffered.into_parts();
return Ok(Response::from_parts(parts, full_body(bytes)));
}
Some(PolicyKind::Headers(f)) => {
let (parts, body) = resp.into_parts();
if let Some(delay) = f(&parts, attempt)
&& attempt < max_retries
{
if replay_bytes.is_none() {
replay_bytes = ReplayCapture::snapshot(&capture);
}
if replay_bytes.is_none() {
return Ok(Response::from_parts(parts, body));
}
if let Some(ref budget) = budget {
let mut b = budget.lock().unwrap();
if !b.allows_retry() {
return Ok(Response::from_parts(parts, body));
}
b.record_retry();
}
tracing::debug!(
status = %parts.status,
attempt = attempt + 1,
max = max_retries,
delay_ms = delay.as_millis() as u64,
"retrying request"
);
tokio::time::sleep(delay).await;
continue;
}
return Ok(Response::from_parts(parts, body));
}
None => {
if attempt == max_retries || !statuses.contains(&resp.status()) {
return Ok(resp);
}
if replay_bytes.is_none() {
replay_bytes = ReplayCapture::snapshot(&capture);
}
if replay_bytes.is_none() {
return Ok(resp);
}
if let Some(ref budget) = budget {
let mut b = budget.lock().unwrap();
if !b.allows_retry() {
return Ok(resp);
}
b.record_retry();
}
let delay = retry_after_delay(&resp).unwrap_or_else(|| {
exponential_delay(base_backoff, max_backoff, attempt)
});
tracing::debug!(
status = %resp.status(),
attempt = attempt + 1,
max = max_retries,
delay_ms = delay.as_millis() as u64,
"retrying request"
);
tokio::time::sleep(delay).await;
}
}
}
unreachable!()
})
}
}
fn exponential_delay(base: Duration, max_backoff: Duration, attempt: u32) -> Duration {
let max_delay = base.saturating_mul(1 << attempt).min(max_backoff);
let jitter_nanos = rand::random_range(0..=max_delay.as_nanos() as u64);
Duration::from_nanos(jitter_nanos)
}
fn retry_after_delay(resp: &Response<Body>) -> Option<Duration> {
let header = resp.headers().get(http::header::RETRY_AFTER)?;
let value = header.to_str().ok()?;
let seconds: u64 = value.parse().ok()?;
Some(Duration::from_secs(seconds))
}
struct ReplayCapture {
bytes: Vec<u8>,
max_bytes: usize,
overflowed: bool,
complete: bool,
}
impl ReplayCapture {
fn new(max_bytes: usize) -> Self {
Self {
bytes: Vec::new(),
max_bytes,
overflowed: false,
complete: false,
}
}
fn record_chunk(&mut self, chunk: &[u8]) {
if self.overflowed {
return;
}
let new_len = self.bytes.len().saturating_add(chunk.len());
if new_len > self.max_bytes {
self.bytes.clear();
self.overflowed = true;
return;
}
self.bytes.extend_from_slice(chunk);
}
fn snapshot(capture: &Arc<Mutex<Self>>) -> Option<Bytes> {
let state = capture.lock().unwrap();
if state.complete && !state.overflowed {
Some(Bytes::copy_from_slice(&state.bytes))
} else {
None
}
}
}
struct RecordingBody {
inner: Body,
capture: Arc<Mutex<ReplayCapture>>,
}
impl RecordingBody {
fn new(inner: Body, capture: Arc<Mutex<ReplayCapture>>) -> Self {
Self { inner, capture }
}
}
impl http_body::Body for RecordingBody {
type Data = Bytes;
type Error = BoxError;
fn poll_frame(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Frame<Bytes>, Self::Error>>> {
match Pin::new(&mut self.inner).poll_frame(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(Ok(frame))) => {
if let Some(data) = frame.data_ref() {
self.capture.lock().unwrap().record_chunk(data.as_ref());
}
Poll::Ready(Some(Ok(frame)))
}
Poll::Ready(Some(Err(e))) => {
self.capture.lock().unwrap().overflowed = true;
Poll::Ready(Some(Err(e)))
}
Poll::Ready(None) => {
self.capture.lock().unwrap().complete = true;
Poll::Ready(None)
}
}
}
fn size_hint(&self) -> http_body::SizeHint {
self.inner.size_hint()
}
}
#[cfg(test)]
mod tests {
use super::*;
use http_body_util::BodyExt;
#[tokio::test]
async fn recording_body_marks_overflowed_on_stream_error() {
let error_body = http_body_util::StreamBody::new(futures_util::stream::iter(vec![
Ok(Frame::data(Bytes::from("partial"))),
Err(Box::<dyn std::error::Error + Send + Sync>::from(
"stream failed",
)),
]));
let capture = Arc::new(Mutex::new(ReplayCapture::new(1024)));
let mut recording = RecordingBody::new(error_body.boxed(), capture.clone());
let frame = recording.frame().await.unwrap().unwrap();
assert_eq!(frame.into_data().unwrap(), "partial");
let err = recording.frame().await.unwrap().unwrap_err();
assert_eq!(err.to_string(), "stream failed");
assert!(ReplayCapture::snapshot(&capture).is_none());
}
#[test]
fn budget_allows_when_under_ratio() {
let mut budget = BudgetState::new(0.5, 0, Duration::from_secs(60));
budget.requests = 10;
budget.retries = 3;
assert!(budget.allows_retry());
}
#[test]
fn budget_blocks_when_over_ratio() {
let mut budget = BudgetState::new(0.2, 0, Duration::from_secs(60));
budget.requests = 10;
budget.retries = 3;
assert!(!budget.allows_retry());
}
#[test]
fn budget_floor_allows_retries() {
let mut budget = BudgetState::new(0.0, 10, Duration::from_secs(60));
budget.requests = 100;
budget.retries = 0;
assert!(budget.allows_retry());
}
#[test]
fn budget_window_reset() {
let mut budget = BudgetState::new(0.2, 0, Duration::from_millis(1));
budget.requests = 10;
budget.retries = 10;
assert!(!budget.allows_retry());
std::thread::sleep(Duration::from_millis(2));
assert!(budget.allows_retry());
}
#[test]
fn budget_record_request_and_retry() {
let mut budget = BudgetState::new(0.5, 0, Duration::from_secs(60));
budget.record_request();
budget.record_request();
assert_eq!(budget.requests, 2);
assert_eq!(budget.retries, 0);
budget.record_retry();
assert_eq!(budget.retries, 1);
}
}