rama_http/layer/retry/
managed.rs

1//! Managed retry [`Policy`].
2//!
3//! See [`ManagedPolicy`] for more details.
4//!
5//! [`Policy`]: super::Policy
6
7use super::{Policy, PolicyResult, RetryBody};
8use crate::{Request, Response};
9use rama_core::Context;
10use rama_utils::backoff::Backoff;
11
12#[derive(Debug, Clone, Default)]
13/// An [`Extensions`] value that can be added to the [`Context`]
14/// of a [`Request`] to signal that the request should not be retried.
15///
16/// This requires the [`ManagedPolicy`] to be used.
17///
18/// [`Extensions`]: rama_core::context::Extensions
19#[non_exhaustive]
20pub struct DoNotRetry;
21
22/// A managed retry [`Policy`],
23/// which allows for an easier interface to configure retrying requests.
24///
25/// [`DoNotRetry`] can be added to the [`Context`] of a [`Request`]
26/// to signal that the request should not be retried, regardless
27/// of the retry functionality defined.
28pub struct ManagedPolicy<B = Undefined, C = Undefined, R = Undefined> {
29    backoff: B,
30    clone: C,
31    retry: R,
32}
33
34impl<B, C, R, State, Response, Error> Policy<State, Response, Error> for ManagedPolicy<B, C, R>
35where
36    B: Backoff,
37    C: CloneInput<State>,
38    R: RetryRule<State, Response, Error>,
39    State: Clone + Send + Sync + 'static,
40    Response: Send + 'static,
41    Error: Send + 'static,
42{
43    async fn retry(
44        &self,
45        ctx: Context<State>,
46        req: Request<RetryBody>,
47        result: Result<Response, Error>,
48    ) -> PolicyResult<State, Response, Error> {
49        if ctx.get::<DoNotRetry>().is_some() {
50            // Custom extension to signal that the request should not be retried.
51            return PolicyResult::Abort(result);
52        }
53
54        let (ctx, result, retry) = self.retry.retry(ctx, result).await;
55        if retry && self.backoff.next_backoff().await {
56            PolicyResult::Retry { ctx, req }
57        } else {
58            self.backoff.reset().await;
59            PolicyResult::Abort(result)
60        }
61    }
62
63    fn clone_input(
64        &self,
65        ctx: &Context<State>,
66        req: &Request<RetryBody>,
67    ) -> Option<(Context<State>, Request<RetryBody>)> {
68        if ctx.get::<DoNotRetry>().is_some() {
69            None
70        } else {
71            self.clone.clone_input(ctx, req)
72        }
73    }
74}
75
76impl<B, C, R> std::fmt::Debug for ManagedPolicy<B, C, R>
77where
78    B: std::fmt::Debug,
79    C: std::fmt::Debug,
80    R: std::fmt::Debug,
81{
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        f.debug_struct("ManagedPolicy")
84            .field("backoff", &self.backoff)
85            .field("clone", &self.clone)
86            .field("retry", &self.retry)
87            .finish()
88    }
89}
90
91impl<B, C, R> Clone for ManagedPolicy<B, C, R>
92where
93    B: Clone,
94    C: Clone,
95    R: Clone,
96{
97    fn clone(&self) -> Self {
98        ManagedPolicy {
99            backoff: self.backoff.clone(),
100            clone: self.clone.clone(),
101            retry: self.retry.clone(),
102        }
103    }
104}
105
106impl Default for ManagedPolicy<Undefined, Undefined, Undefined> {
107    fn default() -> Self {
108        ManagedPolicy {
109            backoff: Undefined,
110            clone: Undefined,
111            retry: Undefined,
112        }
113    }
114}
115
116impl<F> ManagedPolicy<Undefined, Undefined, F> {
117    /// Create a new [`ManagedPolicy`] which uses the provided
118    /// function to determine if a request should be retried.
119    ///
120    /// The default cloning is used and no backoff is applied.
121    #[inline]
122    pub fn new(retry: F) -> Self {
123        ManagedPolicy::default().with_retry(retry)
124    }
125}
126
127impl<C, R> ManagedPolicy<Undefined, C, R> {
128    /// add a backoff to this [`ManagedPolicy`].
129    pub fn with_backoff<B>(self, backoff: B) -> ManagedPolicy<B, C, R> {
130        ManagedPolicy {
131            backoff,
132            clone: self.clone,
133            retry: self.retry,
134        }
135    }
136}
137
138impl<B, R> ManagedPolicy<B, Undefined, R> {
139    /// add a cloning function to this [`ManagedPolicy`].
140    /// to determine if a request should be cloned
141    pub fn with_clone<C>(self, clone: C) -> ManagedPolicy<B, C, R> {
142        ManagedPolicy {
143            backoff: self.backoff,
144            clone,
145            retry: self.retry,
146        }
147    }
148}
149
150impl<B, C> ManagedPolicy<B, C, Undefined> {
151    /// add a retry function to this [`ManagedPolicy`].
152    /// to determine if a request should be retried.
153    pub fn with_retry<R>(self, retry: R) -> ManagedPolicy<B, C, R> {
154        ManagedPolicy {
155            backoff: self.backoff,
156            clone: self.clone,
157            retry,
158        }
159    }
160}
161
162/// A trait that is used to umbrella-cover all possible
163/// implementation kinds for the retry rule functionality.
164pub trait RetryRule<S, R, E>: private::Sealed<(S, R, E)> + Send + Sync + 'static {
165    /// Check if the given result should be retried.
166    fn retry(
167        &self,
168        ctx: Context<S>,
169        result: Result<R, E>,
170    ) -> impl Future<Output = (Context<S>, Result<R, E>, bool)> + Send + '_;
171}
172
173impl<S, Body, E> RetryRule<S, Response<Body>, E> for Undefined
174where
175    S: Clone + Send + Sync + 'static,
176    E: std::fmt::Debug + Send + Sync + 'static,
177    Body: Send + 'static,
178{
179    async fn retry(
180        &self,
181        ctx: Context<S>,
182        result: Result<Response<Body>, E>,
183    ) -> (Context<S>, Result<Response<Body>, E>, bool) {
184        match &result {
185            Ok(response) => {
186                let status = response.status();
187                if status.is_server_error() {
188                    tracing::debug!(
189                        "retrying server error http status code: {status} ({})",
190                        status.as_u16()
191                    );
192                    (ctx, result, true)
193                } else {
194                    (ctx, result, false)
195                }
196            }
197            Err(error) => {
198                tracing::debug!("retrying error: {:?}", error);
199                (ctx, result, true)
200            }
201        }
202    }
203}
204
205impl<F, Fut, S, R, E> RetryRule<S, R, E> for F
206where
207    F: Fn(Context<S>, Result<R, E>) -> Fut + Send + Sync + 'static,
208    Fut: Future<Output = (Context<S>, Result<R, E>, bool)> + Send + 'static,
209    S: Clone + Send + Sync + 'static,
210    R: Send + 'static,
211    E: Send + Sync + 'static,
212{
213    async fn retry(
214        &self,
215        ctx: Context<S>,
216        result: Result<R, E>,
217    ) -> (Context<S>, Result<R, E>, bool) {
218        self(ctx, result).await
219    }
220}
221
222/// A trait that is used to umbrella-cover all possible
223/// implementation kinds for the cloning functionality.
224pub trait CloneInput<S>: private::Sealed<(S,)> + Send + Sync + 'static {
225    /// Clone the input request if necessary.
226    ///
227    /// See [`Policy::clone_input`] for more details.
228    ///
229    /// [`Policy::clone_input`]: super::Policy::clone_input
230    fn clone_input(
231        &self,
232        ctx: &Context<S>,
233        req: &Request<RetryBody>,
234    ) -> Option<(Context<S>, Request<RetryBody>)>;
235}
236
237impl<S: Clone> CloneInput<S> for Undefined {
238    fn clone_input(
239        &self,
240        ctx: &Context<S>,
241        req: &Request<RetryBody>,
242    ) -> Option<(Context<S>, Request<RetryBody>)> {
243        Some((ctx.clone(), req.clone()))
244    }
245}
246
247impl<F, S> CloneInput<S> for F
248where
249    F: Fn(&Context<S>, &Request<RetryBody>) -> Option<(Context<S>, Request<RetryBody>)>
250        + Send
251        + Sync
252        + 'static,
253{
254    fn clone_input(
255        &self,
256        ctx: &Context<S>,
257        req: &Request<RetryBody>,
258    ) -> Option<(Context<S>, Request<RetryBody>)> {
259        self(ctx, req)
260    }
261}
262
263#[derive(Debug, Clone)]
264#[non_exhaustive]
265/// A type to represent the undefined default type,
266/// which is used as the placeholder in the [`ManagedPolicy`],
267/// when the user does not provide a specific type.
268pub struct Undefined;
269
270impl std::fmt::Display for Undefined {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(f, "Undefined")
273    }
274}
275
276impl Backoff for Undefined {
277    async fn next_backoff(&self) -> bool {
278        true
279    }
280
281    async fn reset(&self) {}
282}
283
284mod private {
285    use super::*;
286
287    pub trait Sealed<S> {}
288
289    impl<S> Sealed<S> for Undefined {}
290    impl<F, S> Sealed<(S,)> for F where
291        F: Fn(&Context<S>, &Request<RetryBody>) -> Option<(Context<S>, Request<RetryBody>)>
292            + Send
293            + Sync
294            + 'static
295    {
296    }
297    impl<F, Fut, S, R, E> Sealed<(S, R, E)> for F
298    where
299        F: Fn(Context<S>, Result<R, E>) -> Fut + Send + Sync + 'static,
300        Fut: Future<Output = (Context<S>, Result<R, E>, bool)> + Send + 'static,
301    {
302    }
303}
304
305#[cfg(test)]
306mod tests {
307    use super::*;
308    use crate::{StatusCode, service::web::response::IntoResponse};
309    use rama_utils::{backoff::ExponentialBackoff, rng::HasherRng};
310    use std::time::Duration;
311
312    fn assert_clone_input_none(
313        ctx: &Context<()>,
314        req: &Request<RetryBody>,
315        policy: &impl Policy<(), Response, ()>,
316    ) {
317        assert!(policy.clone_input(ctx, req).is_none());
318    }
319
320    fn assert_clone_input_some(
321        ctx: &Context<()>,
322        req: &Request<RetryBody>,
323        policy: &impl Policy<(), Response, ()>,
324    ) {
325        assert!(policy.clone_input(ctx, req).is_some());
326    }
327
328    async fn assert_retry(
329        ctx: Context<()>,
330        req: Request<RetryBody>,
331        result: Result<Response, ()>,
332        policy: &impl Policy<(), Response, ()>,
333    ) {
334        match policy.retry(ctx, req, result).await {
335            PolicyResult::Retry { .. } => (),
336            PolicyResult::Abort(_) => panic!("expected retry"),
337        };
338    }
339
340    async fn assert_abort(
341        ctx: Context<()>,
342        req: Request<RetryBody>,
343        result: Result<Response, ()>,
344        policy: &impl Policy<(), Response, ()>,
345    ) {
346        match policy.retry(ctx, req, result).await {
347            PolicyResult::Retry { .. } => panic!("expected abort"),
348            PolicyResult::Abort(_) => (),
349        };
350    }
351
352    #[tokio::test]
353    async fn managed_policy_default() {
354        let request = Request::builder()
355            .method("GET")
356            .uri("http://example.com")
357            .body(RetryBody::empty())
358            .unwrap();
359
360        let policy = ManagedPolicy::default();
361
362        assert_clone_input_some(&Context::default(), &request, &policy);
363
364        // do not retry HTTP Ok
365        assert_abort(
366            Context::default(),
367            request.clone(),
368            Ok(StatusCode::OK.into_response()),
369            &policy,
370        )
371        .await;
372
373        // do retry HTTP InternalServerError
374        assert_retry(
375            Context::default(),
376            request.clone(),
377            Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
378            &policy,
379        )
380        .await;
381
382        // also retry any error case
383        assert_retry(Context::default(), request, Err(()), &policy).await;
384    }
385
386    #[tokio::test]
387    async fn managed_policy_default_do_not_retry() {
388        let req = Request::builder()
389            .method("GET")
390            .uri("http://example.com")
391            .body(RetryBody::empty())
392            .unwrap();
393
394        let policy = ManagedPolicy::default();
395
396        let mut ctx = Context::default();
397        ctx.insert(DoNotRetry);
398
399        assert_clone_input_none(&ctx, &req, &policy);
400
401        // do not retry HTTP Ok (.... Of course)
402        assert_abort(
403            ctx.clone(),
404            req.clone(),
405            Ok(StatusCode::OK.into_response()),
406            &policy,
407        )
408        .await;
409
410        // do not retry HTTP InternalServerError
411        assert_abort(
412            ctx.clone(),
413            req.clone(),
414            Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
415            &policy,
416        )
417        .await;
418
419        // also do not retry any error case
420        assert_abort(ctx, req, Err(()), &policy).await;
421    }
422
423    #[tokio::test]
424    async fn test_policy_custom_clone_fn() {
425        let req = Request::builder()
426            .method("GET")
427            .uri("http://example.com")
428            .body(RetryBody::empty())
429            .unwrap();
430
431        fn clone_fn<S>(
432            _: &Context<S>,
433            _: &Request<RetryBody>,
434        ) -> Option<(Context<S>, Request<RetryBody>)> {
435            None
436        }
437
438        let policy = ManagedPolicy::default().with_clone(clone_fn);
439
440        assert_clone_input_none(&Context::default(), &req, &policy);
441
442        // retry should still be the default
443        assert_abort(
444            Context::default(),
445            req,
446            Ok(StatusCode::OK.into_response()),
447            &policy,
448        )
449        .await;
450    }
451
452    #[tokio::test]
453    async fn test_policy_custom_retry_fn() {
454        let req = Request::builder()
455            .method("GET")
456            .uri("http://example.com")
457            .body(RetryBody::empty())
458            .unwrap();
459
460        async fn retry_fn<S, R, E>(
461            ctx: Context<S>,
462            result: Result<R, E>,
463        ) -> (Context<S>, Result<R, E>, bool) {
464            match result {
465                Ok(_) => (ctx, result, false),
466                Err(_) => (ctx, result, true),
467            }
468        }
469
470        let policy = ManagedPolicy::new(retry_fn);
471
472        // default clone should be used
473        assert_clone_input_some(&Context::default(), &req, &policy);
474
475        // retry should be the custom one
476        assert_abort(
477            Context::default(),
478            req.clone(),
479            Ok(StatusCode::OK.into_response()),
480            &policy,
481        )
482        .await;
483        assert_abort(
484            Context::default(),
485            req.clone(),
486            Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
487            &policy,
488        )
489        .await;
490        assert_retry(Context::default(), req, Err(()), &policy).await;
491    }
492
493    #[tokio::test]
494    async fn test_policy_fully_custom() {
495        let req = Request::builder()
496            .method("GET")
497            .uri("http://example.com")
498            .body(RetryBody::empty())
499            .unwrap();
500
501        fn clone_fn<S>(
502            _: &Context<S>,
503            _: &Request<RetryBody>,
504        ) -> Option<(Context<S>, Request<RetryBody>)> {
505            None
506        }
507
508        async fn retry_fn<S, R, E>(
509            ctx: Context<S>,
510            result: Result<R, E>,
511        ) -> (Context<S>, Result<R, E>, bool) {
512            match result {
513                Ok(_) => (ctx, result, false),
514                Err(_) => (ctx, result, true),
515            }
516        }
517
518        let backoff = ExponentialBackoff::new(
519            Duration::from_millis(1),
520            Duration::from_millis(5),
521            0.1,
522            HasherRng::default,
523        )
524        .unwrap();
525
526        let policy = ManagedPolicy::default()
527            .with_backoff(backoff)
528            .with_clone(clone_fn)
529            .with_retry(retry_fn);
530
531        assert_clone_input_none(&Context::default(), &req, &policy);
532
533        // retry should be the custom one
534        assert_abort(
535            Context::default(),
536            req.clone(),
537            Ok(StatusCode::OK.into_response()),
538            &policy,
539        )
540        .await;
541        assert_abort(
542            Context::default(),
543            req.clone(),
544            Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
545            &policy,
546        )
547        .await;
548        assert_retry(Context::default(), req, Err(()), &policy).await;
549    }
550}