1use crate::Request;
4use crate::dep::http_body::Body as HttpBody;
5use crate::dep::http_body_util::BodyExt;
6use rama_core::error::BoxError;
7use rama_core::{Context, Service};
8use rama_utils::macros::define_inner_service_accessors;
9
10mod layer;
11mod policy;
12
13mod body;
14#[doc(inline)]
15pub use body::RetryBody;
16
17pub mod managed;
18pub use managed::ManagedPolicy;
19
20#[cfg(test)]
21mod tests;
22
23pub use self::layer::RetryLayer;
24pub use self::policy::{Policy, PolicyResult};
25
26pub struct Retry<P, S> {
30 policy: P,
31 inner: S,
32}
33
34impl<P, S> std::fmt::Debug for Retry<P, S>
35where
36 P: std::fmt::Debug,
37 S: std::fmt::Debug,
38{
39 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40 f.debug_struct("Retry")
41 .field("policy", &self.policy)
42 .field("inner", &self.inner)
43 .finish()
44 }
45}
46
47impl<P, S> Clone for Retry<P, S>
48where
49 P: Clone,
50 S: Clone,
51{
52 fn clone(&self) -> Self {
53 Retry {
54 policy: self.policy.clone(),
55 inner: self.inner.clone(),
56 }
57 }
58}
59
60impl<P, S> Retry<P, S> {
63 pub const fn new(policy: P, service: S) -> Self {
65 Retry {
66 policy,
67 inner: service,
68 }
69 }
70
71 define_inner_service_accessors!();
72}
73
74#[derive(Debug)]
75pub struct RetryError {
77 kind: RetryErrorKind,
78 inner: Option<BoxError>,
79}
80
81#[derive(Debug)]
82enum RetryErrorKind {
83 BodyConsume,
84 Service,
85}
86
87impl std::fmt::Display for RetryError {
88 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89 match &self.inner {
90 Some(inner) => write!(f, "{}: {}", self.kind, inner),
91 None => write!(f, "{}", self.kind),
92 }
93 }
94}
95
96impl std::fmt::Display for RetryErrorKind {
97 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98 match self {
99 RetryErrorKind::BodyConsume => write!(f, "failed to consume body"),
100 RetryErrorKind::Service => write!(f, "service error"),
101 }
102 }
103}
104
105impl std::error::Error for RetryError {
106 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
107 self.inner.as_ref().and_then(|e| e.source())
108 }
109}
110
111impl<P, S, State, Body> Service<State, Request<Body>> for Retry<P, S>
112where
113 P: Policy<State, S::Response, S::Error>,
114 S: Service<State, Request<RetryBody>, Error: Into<BoxError>>,
115 State: Clone + Send + Sync + 'static,
116 Body: HttpBody<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static,
117{
118 type Response = S::Response;
119 type Error = RetryError;
120
121 async fn serve(
122 &self,
123 ctx: Context<State>,
124 request: Request<Body>,
125 ) -> Result<Self::Response, Self::Error> {
126 let mut ctx = ctx;
127
128 let (parts, body) = request.into_parts();
130 let body = body.collect().await.map_err(|e| RetryError {
131 kind: RetryErrorKind::BodyConsume,
132 inner: Some(e.into()),
133 })?;
134 let body = RetryBody::new(body.to_bytes());
135 let mut request = Request::from_parts(parts, body);
136
137 let mut cloned = self.policy.clone_input(&ctx, &request);
138
139 loop {
140 let resp = self.inner.serve(ctx, request).await;
141 match cloned.take() {
142 Some((cloned_ctx, cloned_req)) => {
143 let (cloned_ctx, cloned_req) =
144 match self.policy.retry(cloned_ctx, cloned_req, resp).await {
145 PolicyResult::Abort(result) => {
146 return result.map_err(|e| RetryError {
147 kind: RetryErrorKind::Service,
148 inner: Some(e.into()),
149 });
150 }
151 PolicyResult::Retry { ctx, req } => (ctx, req),
152 };
153
154 cloned = self.policy.clone_input(&cloned_ctx, &cloned_req);
155 ctx = cloned_ctx;
156 request = cloned_req;
157 }
158 None => {
160 return resp.map_err(|e| RetryError {
161 kind: RetryErrorKind::Service,
162 inner: Some(e.into()),
163 });
164 }
165 }
166 }
167 }
168}
169
170#[cfg(test)]
171mod test {
172 use super::*;
173 use crate::{
174 BodyExtractExt, Response, StatusCode, layer::retry::managed::DoNotRetry,
175 service::web::response::IntoResponse,
176 };
177 use rama_core::{Context, Layer, service::service_fn};
178 use rama_utils::{backoff::ExponentialBackoff, rng::HasherRng};
179 use std::{
180 sync::{Arc, atomic::AtomicUsize},
181 time::Duration,
182 };
183
184 #[tokio::test]
185 async fn test_service_with_managed_retry() {
186 let backoff = ExponentialBackoff::new(
187 Duration::from_millis(1),
188 Duration::from_millis(5),
189 0.1,
190 HasherRng::default,
191 )
192 .unwrap();
193
194 #[derive(Debug, Clone)]
195 struct State {
196 retry_counter: Arc<AtomicUsize>,
197 }
198
199 async fn retry<E>(
200 ctx: Context<State>,
201 result: Result<Response, E>,
202 ) -> (Context<State>, Result<Response, E>, bool) {
203 if ctx.contains::<DoNotRetry>() {
204 panic!("unexpected retry: should be disabled");
205 }
206
207 match result {
208 Ok(ref res) => {
209 if res.status().is_server_error() {
210 ctx.state()
211 .retry_counter
212 .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
213 (ctx, result, true)
214 } else {
215 (ctx, result, false)
216 }
217 }
218 Err(_) => {
219 ctx.state()
220 .retry_counter
221 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
222 (ctx, result, true)
223 }
224 }
225 }
226
227 let retry_policy = ManagedPolicy::new(retry).with_backoff(backoff);
228
229 let service = RetryLayer::new(retry_policy).into_layer(service_fn(
230 async |_ctx, req: Request<RetryBody>| {
231 let txt = req.try_into_string().await.unwrap();
232 match txt.as_str() {
233 "internal" => Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
234 "error" => Err(rama_core::error::BoxError::from("custom error")),
235 _ => Ok(txt.into_response()),
236 }
237 },
238 ));
239
240 fn request(s: &'static str) -> Request {
241 Request::builder().body(s.into()).unwrap()
242 }
243
244 fn ctx() -> Context<State> {
245 Context::with_state(State {
246 retry_counter: Arc::new(AtomicUsize::new(0)),
247 })
248 }
249
250 fn ctx_do_not_retry() -> Context<State> {
251 let mut ctx = ctx();
252 ctx.insert(DoNotRetry::default());
253 ctx
254 }
255
256 async fn assert_serve_ok<E: std::fmt::Debug>(
257 msg: &'static str,
258 input: &'static str,
259 output: &'static str,
260 ctx: Context<State>,
261 retried: bool,
262 service: &impl Service<State, Request, Response = Response, Error = E>,
263 ) {
264 let state = ctx.state_clone();
265
266 let fut = service.serve(ctx, request(input));
267 let res = fut.await.unwrap();
268
269 let body = res.try_into_string().await.unwrap();
270 assert_eq!(body, output, "{msg}");
271 if retried {
272 assert!(
273 state
274 .retry_counter
275 .load(std::sync::atomic::Ordering::Acquire)
276 > 0,
277 "{msg}"
278 );
279 } else {
280 assert_eq!(
281 state
282 .retry_counter
283 .load(std::sync::atomic::Ordering::Acquire),
284 0,
285 "{msg}"
286 );
287 }
288 }
289
290 async fn assert_serve_err<E: std::fmt::Debug>(
291 msg: &'static str,
292 input: &'static str,
293 ctx: Context<State>,
294 retried: bool,
295 service: &impl Service<State, Request, Response = Response, Error = E>,
296 ) {
297 let state = ctx.state_clone();
298
299 let fut = service.serve(ctx, request(input));
300 let res = fut.await;
301
302 assert!(res.is_err(), "{msg}");
303 if retried {
304 assert!(
305 state
306 .retry_counter
307 .load(std::sync::atomic::Ordering::Acquire)
308 > 0,
309 "{msg}"
310 );
311 } else {
312 assert_eq!(
313 state
314 .retry_counter
315 .load(std::sync::atomic::Ordering::Acquire),
316 0,
317 "{msg}"
318 )
319 }
320 }
321
322 assert_serve_ok(
323 "ok response should be aborted as response without retry",
324 "hello",
325 "hello",
326 ctx(),
327 false,
328 &service,
329 )
330 .await;
331 assert_serve_ok(
332 "internal will trigger 500 with a retry",
333 "internal",
334 "",
335 ctx(),
336 true,
337 &service,
338 )
339 .await;
340 assert_serve_err(
341 "error will trigger an actual non-http error with a retry",
342 "error",
343 ctx(),
344 true,
345 &service,
346 )
347 .await;
348
349 assert_serve_ok(
350 "normally internal will trigger a 500 with retry, but using DoNotRetry will disable retrying",
351 "internal",
352 "",
353 ctx_do_not_retry(),
354 false,
355 &service,
356 ).await;
357 }
358}