axum_reverse_proxy/
retry.rs

1use axum::body::Body;
2use bytes::{Bytes, BytesMut};
3use http::StatusCode;
4use http_body::{Body as HttpBody, Frame, SizeHint};
5use http_body_util::BodyExt;
6use std::convert::Infallible;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use std::{
10    future::Future,
11    pin::Pin,
12    task::{Context, Poll},
13};
14use tower::{Layer, Service};
15
16#[derive(Clone)]
17pub struct RetryLayer {
18    attempts: usize,
19    delay: Duration,
20}
21
22impl RetryLayer {
23    pub fn new(attempts: usize) -> Self {
24        let attempts = attempts.max(1);
25        Self {
26            attempts,
27            delay: Duration::from_millis(500),
28        }
29    }
30
31    pub fn with_delay(attempts: usize, delay: Duration) -> Self {
32        let attempts = attempts.max(1);
33        Self { attempts, delay }
34    }
35}
36
37#[derive(Clone)]
38pub struct Retry<S> {
39    inner: S,
40    attempts: usize,
41    delay: Duration,
42}
43
44struct BufferedBody<B> {
45    body: B,
46    buf: Arc<Mutex<BytesMut>>,
47}
48
49impl<B> BufferedBody<B> {
50    fn new(body: B, buf: Arc<Mutex<BytesMut>>) -> Self {
51        Self { body, buf }
52    }
53}
54
55impl<B> HttpBody for BufferedBody<B>
56where
57    B: HttpBody<Data = Bytes> + Unpin,
58{
59    type Data = Bytes;
60    type Error = B::Error;
61
62    fn poll_frame(
63        mut self: Pin<&mut Self>,
64        cx: &mut Context<'_>,
65    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66        match Pin::new(&mut self.body).poll_frame(cx) {
67            Poll::Ready(Some(Ok(frame))) => {
68                if let Some(data) = frame.data_ref() {
69                    self.buf.lock().unwrap().extend_from_slice(data);
70                }
71                Poll::Ready(Some(Ok(frame)))
72            }
73            other => other,
74        }
75    }
76
77    fn is_end_stream(&self) -> bool {
78        self.body.is_end_stream()
79    }
80
81    fn size_hint(&self) -> SizeHint {
82        self.body.size_hint()
83    }
84}
85
86impl<S> Layer<S> for RetryLayer {
87    type Service = Retry<S>;
88
89    fn layer(&self, inner: S) -> Self::Service {
90        Retry {
91            inner,
92            attempts: self.attempts,
93            delay: self.delay,
94        }
95    }
96}
97
98impl<S> Service<axum::http::Request<Body>> for Retry<S>
99where
100    S: Service<
101            axum::http::Request<Body>,
102            Response = axum::http::Response<Body>,
103            Error = Infallible,
104        > + Clone
105        + Send
106        + 'static,
107    S::Future: Send + 'static,
108{
109    type Response = axum::http::Response<Body>;
110    type Error = Infallible;
111    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
112
113    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114        self.inner.poll_ready(cx)
115    }
116
117    fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
118        let mut inner = self.inner.clone();
119        let attempts = self.attempts;
120        let delay = self.delay;
121        Box::pin(async move {
122            let (parts, body) = req.into_parts();
123            let buf = Arc::new(Mutex::new(BytesMut::new()));
124            let wrapped = BufferedBody::new(body, buf.clone());
125            let attempt_req = axum::http::Request::from_parts(
126                parts.clone(),
127                Body::from_stream(wrapped.into_data_stream()),
128            );
129            let mut res = inner.call(attempt_req).await?;
130            if res.status() != StatusCode::BAD_GATEWAY || attempts == 1 {
131                return Ok(res);
132            }
133
134            for attempt in 1..attempts {
135                tokio::time::sleep(delay).await;
136                let bytes = buf.lock().unwrap().clone().freeze();
137                let req = axum::http::Request::from_parts(parts.clone(), Body::from(bytes.clone()));
138                res = inner.call(req).await?;
139                if res.status() != StatusCode::BAD_GATEWAY || attempt == attempts - 1 {
140                    return Ok(res);
141                }
142            }
143            unreachable!();
144        })
145    }
146}