axum_reverse_proxy/
retry.rs1use axum::body::Body;
2use bytes::{Bytes, BytesMut};
3use http::StatusCode;
4use http_body::{Body as HttpBody, Frame, SizeHint};
5use http_body_util::BodyExt;
6use std::convert::Infallible;
7use std::sync::{Arc, Mutex};
8use std::time::Duration;
9use std::{
10 future::Future,
11 pin::Pin,
12 task::{Context, Poll},
13};
14use tower::{Layer, Service};
15
16#[derive(Clone)]
17pub struct RetryLayer {
18 attempts: usize,
19 delay: Duration,
20}
21
22impl RetryLayer {
23 pub fn new(attempts: usize) -> Self {
24 let attempts = attempts.max(1);
25 Self {
26 attempts,
27 delay: Duration::from_millis(500),
28 }
29 }
30
31 pub fn with_delay(attempts: usize, delay: Duration) -> Self {
32 let attempts = attempts.max(1);
33 Self { attempts, delay }
34 }
35}
36
37#[derive(Clone)]
38pub struct Retry<S> {
39 inner: S,
40 attempts: usize,
41 delay: Duration,
42}
43
44struct BufferedBody<B> {
45 body: B,
46 buf: Arc<Mutex<BytesMut>>,
47}
48
49impl<B> BufferedBody<B> {
50 fn new(body: B, buf: Arc<Mutex<BytesMut>>) -> Self {
51 Self { body, buf }
52 }
53}
54
55impl<B> HttpBody for BufferedBody<B>
56where
57 B: HttpBody<Data = Bytes> + Unpin,
58{
59 type Data = Bytes;
60 type Error = B::Error;
61
62 fn poll_frame(
63 mut self: Pin<&mut Self>,
64 cx: &mut Context<'_>,
65 ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
66 match Pin::new(&mut self.body).poll_frame(cx) {
67 Poll::Ready(Some(Ok(frame))) => {
68 if let Some(data) = frame.data_ref() {
69 self.buf.lock().unwrap().extend_from_slice(data);
70 }
71 Poll::Ready(Some(Ok(frame)))
72 }
73 other => other,
74 }
75 }
76
77 fn is_end_stream(&self) -> bool {
78 self.body.is_end_stream()
79 }
80
81 fn size_hint(&self) -> SizeHint {
82 self.body.size_hint()
83 }
84}
85
86impl<S> Layer<S> for RetryLayer {
87 type Service = Retry<S>;
88
89 fn layer(&self, inner: S) -> Self::Service {
90 Retry {
91 inner,
92 attempts: self.attempts,
93 delay: self.delay,
94 }
95 }
96}
97
98impl<S> Service<axum::http::Request<Body>> for Retry<S>
99where
100 S: Service<
101 axum::http::Request<Body>,
102 Response = axum::http::Response<Body>,
103 Error = Infallible,
104 > + Clone
105 + Send
106 + 'static,
107 S::Future: Send + 'static,
108{
109 type Response = axum::http::Response<Body>;
110 type Error = Infallible;
111 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
112
113 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
114 self.inner.poll_ready(cx)
115 }
116
117 fn call(&mut self, req: axum::http::Request<Body>) -> Self::Future {
118 let mut inner = self.inner.clone();
119 let attempts = self.attempts;
120 let delay = self.delay;
121 Box::pin(async move {
122 let (parts, body) = req.into_parts();
123 let buf = Arc::new(Mutex::new(BytesMut::new()));
124 let wrapped = BufferedBody::new(body, buf.clone());
125 let attempt_req = axum::http::Request::from_parts(
126 parts.clone(),
127 Body::from_stream(wrapped.into_data_stream()),
128 );
129 let mut res = inner.call(attempt_req).await?;
130 if res.status() != StatusCode::BAD_GATEWAY || attempts == 1 {
131 return Ok(res);
132 }
133
134 for attempt in 1..attempts {
135 tokio::time::sleep(delay).await;
136 let bytes = buf.lock().unwrap().clone().freeze();
137 let req = axum::http::Request::from_parts(parts.clone(), Body::from(bytes.clone()));
138 res = inner.call(req).await?;
139 if res.status() != StatusCode::BAD_GATEWAY || attempt == attempts - 1 {
140 return Ok(res);
141 }
142 }
143 unreachable!();
144 })
145 }
146}