rama_http/layer/retry/
mod.rs

1//! Middleware for retrying "failed" requests.
2
3use crate::Request;
4use crate::dep::http_body::Body as HttpBody;
5use crate::dep::http_body_util::BodyExt;
6use rama_core::error::BoxError;
7use rama_core::{Context, Service};
8use rama_utils::macros::define_inner_service_accessors;
9
10mod layer;
11mod policy;
12
13mod body;
14#[doc(inline)]
15pub use body::RetryBody;
16
17pub mod managed;
18pub use managed::ManagedPolicy;
19
20#[cfg(test)]
21mod tests;
22
23pub use self::layer::RetryLayer;
24pub use self::policy::{Policy, PolicyResult};
25
26/// Configure retrying requests of "failed" responses.
27///
28/// A [`Policy`] classifies what is a "failed" response.
29pub struct Retry<P, S> {
30    policy: P,
31    inner: S,
32}
33
34impl<P, S> std::fmt::Debug for Retry<P, S>
35where
36    P: std::fmt::Debug,
37    S: std::fmt::Debug,
38{
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        f.debug_struct("Retry")
41            .field("policy", &self.policy)
42            .field("inner", &self.inner)
43            .finish()
44    }
45}
46
47impl<P, S> Clone for Retry<P, S>
48where
49    P: Clone,
50    S: Clone,
51{
52    fn clone(&self) -> Self {
53        Retry {
54            policy: self.policy.clone(),
55            inner: self.inner.clone(),
56        }
57    }
58}
59
60// ===== impl Retry =====
61
62impl<P, S> Retry<P, S> {
63    /// Retry the inner service depending on this [`Policy`].
64    pub const fn new(policy: P, service: S) -> Self {
65        Retry {
66            policy,
67            inner: service,
68        }
69    }
70
71    define_inner_service_accessors!();
72}
73
74#[derive(Debug)]
75/// Error type for [`Retry`]
76pub struct RetryError {
77    kind: RetryErrorKind,
78    inner: Option<BoxError>,
79}
80
81#[derive(Debug)]
82enum RetryErrorKind {
83    BodyConsume,
84    Service,
85}
86
87impl std::fmt::Display for RetryError {
88    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
89        match &self.inner {
90            Some(inner) => write!(f, "{}: {}", self.kind, inner),
91            None => write!(f, "{}", self.kind),
92        }
93    }
94}
95
96impl std::fmt::Display for RetryErrorKind {
97    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
98        match self {
99            RetryErrorKind::BodyConsume => write!(f, "failed to consume body"),
100            RetryErrorKind::Service => write!(f, "service error"),
101        }
102    }
103}
104
105impl std::error::Error for RetryError {
106    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
107        self.inner.as_ref().and_then(|e| e.source())
108    }
109}
110
111impl<P, S, State, Body> Service<State, Request<Body>> for Retry<P, S>
112where
113    P: Policy<State, S::Response, S::Error>,
114    S: Service<State, Request<RetryBody>, Error: Into<BoxError>>,
115    State: Clone + Send + Sync + 'static,
116    Body: HttpBody<Data: Send + 'static, Error: Into<BoxError>> + Send + 'static,
117{
118    type Response = S::Response;
119    type Error = RetryError;
120
121    async fn serve(
122        &self,
123        ctx: Context<State>,
124        request: Request<Body>,
125    ) -> Result<Self::Response, Self::Error> {
126        let mut ctx = ctx;
127
128        // consume body so we can clone the request if desired
129        let (parts, body) = request.into_parts();
130        let body = body.collect().await.map_err(|e| RetryError {
131            kind: RetryErrorKind::BodyConsume,
132            inner: Some(e.into()),
133        })?;
134        let body = RetryBody::new(body.to_bytes());
135        let mut request = Request::from_parts(parts, body);
136
137        let mut cloned = self.policy.clone_input(&ctx, &request);
138
139        loop {
140            let resp = self.inner.serve(ctx, request).await;
141            match cloned.take() {
142                Some((cloned_ctx, cloned_req)) => {
143                    let (cloned_ctx, cloned_req) =
144                        match self.policy.retry(cloned_ctx, cloned_req, resp).await {
145                            PolicyResult::Abort(result) => {
146                                return result.map_err(|e| RetryError {
147                                    kind: RetryErrorKind::Service,
148                                    inner: Some(e.into()),
149                                });
150                            }
151                            PolicyResult::Retry { ctx, req } => (ctx, req),
152                        };
153
154                    cloned = self.policy.clone_input(&cloned_ctx, &cloned_req);
155                    ctx = cloned_ctx;
156                    request = cloned_req;
157                }
158                // no clone was made, so no possibility to retry
159                None => {
160                    return resp.map_err(|e| RetryError {
161                        kind: RetryErrorKind::Service,
162                        inner: Some(e.into()),
163                    });
164                }
165            }
166        }
167    }
168}
169
170#[cfg(test)]
171mod test {
172    use super::*;
173    use crate::{
174        BodyExtractExt, Response, StatusCode, layer::retry::managed::DoNotRetry,
175        service::web::response::IntoResponse,
176    };
177    use rama_core::{Context, Layer, service::service_fn};
178    use rama_utils::{backoff::ExponentialBackoff, rng::HasherRng};
179    use std::{
180        sync::{Arc, atomic::AtomicUsize},
181        time::Duration,
182    };
183
184    #[tokio::test]
185    async fn test_service_with_managed_retry() {
186        let backoff = ExponentialBackoff::new(
187            Duration::from_millis(1),
188            Duration::from_millis(5),
189            0.1,
190            HasherRng::default,
191        )
192        .unwrap();
193
194        #[derive(Debug, Clone)]
195        struct State {
196            retry_counter: Arc<AtomicUsize>,
197        }
198
199        async fn retry<E>(
200            ctx: Context<State>,
201            result: Result<Response, E>,
202        ) -> (Context<State>, Result<Response, E>, bool) {
203            if ctx.contains::<DoNotRetry>() {
204                panic!("unexpected retry: should be disabled");
205            }
206
207            match result {
208                Ok(ref res) => {
209                    if res.status().is_server_error() {
210                        ctx.state()
211                            .retry_counter
212                            .fetch_add(1, std::sync::atomic::Ordering::AcqRel);
213                        (ctx, result, true)
214                    } else {
215                        (ctx, result, false)
216                    }
217                }
218                Err(_) => {
219                    ctx.state()
220                        .retry_counter
221                        .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
222                    (ctx, result, true)
223                }
224            }
225        }
226
227        let retry_policy = ManagedPolicy::new(retry).with_backoff(backoff);
228
229        let service = RetryLayer::new(retry_policy).into_layer(service_fn(
230            async |_ctx, req: Request<RetryBody>| {
231                let txt = req.try_into_string().await.unwrap();
232                match txt.as_str() {
233                    "internal" => Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
234                    "error" => Err(rama_core::error::BoxError::from("custom error")),
235                    _ => Ok(txt.into_response()),
236                }
237            },
238        ));
239
240        fn request(s: &'static str) -> Request {
241            Request::builder().body(s.into()).unwrap()
242        }
243
244        fn ctx() -> Context<State> {
245            Context::with_state(State {
246                retry_counter: Arc::new(AtomicUsize::new(0)),
247            })
248        }
249
250        fn ctx_do_not_retry() -> Context<State> {
251            let mut ctx = ctx();
252            ctx.insert(DoNotRetry::default());
253            ctx
254        }
255
256        async fn assert_serve_ok<E: std::fmt::Debug>(
257            msg: &'static str,
258            input: &'static str,
259            output: &'static str,
260            ctx: Context<State>,
261            retried: bool,
262            service: &impl Service<State, Request, Response = Response, Error = E>,
263        ) {
264            let state = ctx.state_clone();
265
266            let fut = service.serve(ctx, request(input));
267            let res = fut.await.unwrap();
268
269            let body = res.try_into_string().await.unwrap();
270            assert_eq!(body, output, "{msg}");
271            if retried {
272                assert!(
273                    state
274                        .retry_counter
275                        .load(std::sync::atomic::Ordering::Acquire)
276                        > 0,
277                    "{msg}"
278                );
279            } else {
280                assert_eq!(
281                    state
282                        .retry_counter
283                        .load(std::sync::atomic::Ordering::Acquire),
284                    0,
285                    "{msg}"
286                );
287            }
288        }
289
290        async fn assert_serve_err<E: std::fmt::Debug>(
291            msg: &'static str,
292            input: &'static str,
293            ctx: Context<State>,
294            retried: bool,
295            service: &impl Service<State, Request, Response = Response, Error = E>,
296        ) {
297            let state = ctx.state_clone();
298
299            let fut = service.serve(ctx, request(input));
300            let res = fut.await;
301
302            assert!(res.is_err(), "{msg}");
303            if retried {
304                assert!(
305                    state
306                        .retry_counter
307                        .load(std::sync::atomic::Ordering::Acquire)
308                        > 0,
309                    "{msg}"
310                );
311            } else {
312                assert_eq!(
313                    state
314                        .retry_counter
315                        .load(std::sync::atomic::Ordering::Acquire),
316                    0,
317                    "{msg}"
318                )
319            }
320        }
321
322        assert_serve_ok(
323            "ok response should be aborted as response without retry",
324            "hello",
325            "hello",
326            ctx(),
327            false,
328            &service,
329        )
330        .await;
331        assert_serve_ok(
332            "internal will trigger 500 with a retry",
333            "internal",
334            "",
335            ctx(),
336            true,
337            &service,
338        )
339        .await;
340        assert_serve_err(
341            "error will trigger an actual non-http error with a retry",
342            "error",
343            ctx(),
344            true,
345            &service,
346        )
347        .await;
348
349        assert_serve_ok(
350            "normally internal will trigger a 500 with retry, but using DoNotRetry will disable retrying",
351            "internal",
352            "",
353            ctx_do_not_retry(),
354            false,
355            &service,
356        ).await;
357    }
358}