1use super::{Policy, PolicyResult, RetryBody};
8use crate::{Request, Response};
9use rama_core::Context;
10use rama_utils::backoff::Backoff;
11
12#[derive(Debug, Clone, Default)]
13#[non_exhaustive]
20pub struct DoNotRetry;
21
22pub struct ManagedPolicy<B = Undefined, C = Undefined, R = Undefined> {
29 backoff: B,
30 clone: C,
31 retry: R,
32}
33
34impl<B, C, R, State, Response, Error> Policy<State, Response, Error> for ManagedPolicy<B, C, R>
35where
36 B: Backoff,
37 C: CloneInput<State>,
38 R: RetryRule<State, Response, Error>,
39 State: Clone + Send + Sync + 'static,
40 Response: Send + 'static,
41 Error: Send + 'static,
42{
43 async fn retry(
44 &self,
45 ctx: Context<State>,
46 req: Request<RetryBody>,
47 result: Result<Response, Error>,
48 ) -> PolicyResult<State, Response, Error> {
49 if ctx.get::<DoNotRetry>().is_some() {
50 return PolicyResult::Abort(result);
52 }
53
54 let (ctx, result, retry) = self.retry.retry(ctx, result).await;
55 if retry && self.backoff.next_backoff().await {
56 PolicyResult::Retry { ctx, req }
57 } else {
58 self.backoff.reset().await;
59 PolicyResult::Abort(result)
60 }
61 }
62
63 fn clone_input(
64 &self,
65 ctx: &Context<State>,
66 req: &Request<RetryBody>,
67 ) -> Option<(Context<State>, Request<RetryBody>)> {
68 if ctx.get::<DoNotRetry>().is_some() {
69 None
70 } else {
71 self.clone.clone_input(ctx, req)
72 }
73 }
74}
75
76impl<B, C, R> std::fmt::Debug for ManagedPolicy<B, C, R>
77where
78 B: std::fmt::Debug,
79 C: std::fmt::Debug,
80 R: std::fmt::Debug,
81{
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 f.debug_struct("ManagedPolicy")
84 .field("backoff", &self.backoff)
85 .field("clone", &self.clone)
86 .field("retry", &self.retry)
87 .finish()
88 }
89}
90
91impl<B, C, R> Clone for ManagedPolicy<B, C, R>
92where
93 B: Clone,
94 C: Clone,
95 R: Clone,
96{
97 fn clone(&self) -> Self {
98 ManagedPolicy {
99 backoff: self.backoff.clone(),
100 clone: self.clone.clone(),
101 retry: self.retry.clone(),
102 }
103 }
104}
105
106impl Default for ManagedPolicy<Undefined, Undefined, Undefined> {
107 fn default() -> Self {
108 ManagedPolicy {
109 backoff: Undefined,
110 clone: Undefined,
111 retry: Undefined,
112 }
113 }
114}
115
116impl<F> ManagedPolicy<Undefined, Undefined, F> {
117 #[inline]
122 pub fn new(retry: F) -> Self {
123 ManagedPolicy::default().with_retry(retry)
124 }
125}
126
127impl<C, R> ManagedPolicy<Undefined, C, R> {
128 pub fn with_backoff<B>(self, backoff: B) -> ManagedPolicy<B, C, R> {
130 ManagedPolicy {
131 backoff,
132 clone: self.clone,
133 retry: self.retry,
134 }
135 }
136}
137
138impl<B, R> ManagedPolicy<B, Undefined, R> {
139 pub fn with_clone<C>(self, clone: C) -> ManagedPolicy<B, C, R> {
142 ManagedPolicy {
143 backoff: self.backoff,
144 clone,
145 retry: self.retry,
146 }
147 }
148}
149
150impl<B, C> ManagedPolicy<B, C, Undefined> {
151 pub fn with_retry<R>(self, retry: R) -> ManagedPolicy<B, C, R> {
154 ManagedPolicy {
155 backoff: self.backoff,
156 clone: self.clone,
157 retry,
158 }
159 }
160}
161
162pub trait RetryRule<S, R, E>: private::Sealed<(S, R, E)> + Send + Sync + 'static {
165 fn retry(
167 &self,
168 ctx: Context<S>,
169 result: Result<R, E>,
170 ) -> impl Future<Output = (Context<S>, Result<R, E>, bool)> + Send + '_;
171}
172
173impl<S, Body, E> RetryRule<S, Response<Body>, E> for Undefined
174where
175 S: Clone + Send + Sync + 'static,
176 E: std::fmt::Debug + Send + Sync + 'static,
177 Body: Send + 'static,
178{
179 async fn retry(
180 &self,
181 ctx: Context<S>,
182 result: Result<Response<Body>, E>,
183 ) -> (Context<S>, Result<Response<Body>, E>, bool) {
184 match &result {
185 Ok(response) => {
186 let status = response.status();
187 if status.is_server_error() {
188 tracing::debug!(
189 "retrying server error http status code: {status} ({})",
190 status.as_u16()
191 );
192 (ctx, result, true)
193 } else {
194 (ctx, result, false)
195 }
196 }
197 Err(error) => {
198 tracing::debug!("retrying error: {:?}", error);
199 (ctx, result, true)
200 }
201 }
202 }
203}
204
205impl<F, Fut, S, R, E> RetryRule<S, R, E> for F
206where
207 F: Fn(Context<S>, Result<R, E>) -> Fut + Send + Sync + 'static,
208 Fut: Future<Output = (Context<S>, Result<R, E>, bool)> + Send + 'static,
209 S: Clone + Send + Sync + 'static,
210 R: Send + 'static,
211 E: Send + Sync + 'static,
212{
213 async fn retry(
214 &self,
215 ctx: Context<S>,
216 result: Result<R, E>,
217 ) -> (Context<S>, Result<R, E>, bool) {
218 self(ctx, result).await
219 }
220}
221
222pub trait CloneInput<S>: private::Sealed<(S,)> + Send + Sync + 'static {
225 fn clone_input(
231 &self,
232 ctx: &Context<S>,
233 req: &Request<RetryBody>,
234 ) -> Option<(Context<S>, Request<RetryBody>)>;
235}
236
237impl<S: Clone> CloneInput<S> for Undefined {
238 fn clone_input(
239 &self,
240 ctx: &Context<S>,
241 req: &Request<RetryBody>,
242 ) -> Option<(Context<S>, Request<RetryBody>)> {
243 Some((ctx.clone(), req.clone()))
244 }
245}
246
247impl<F, S> CloneInput<S> for F
248where
249 F: Fn(&Context<S>, &Request<RetryBody>) -> Option<(Context<S>, Request<RetryBody>)>
250 + Send
251 + Sync
252 + 'static,
253{
254 fn clone_input(
255 &self,
256 ctx: &Context<S>,
257 req: &Request<RetryBody>,
258 ) -> Option<(Context<S>, Request<RetryBody>)> {
259 self(ctx, req)
260 }
261}
262
263#[derive(Debug, Clone)]
264#[non_exhaustive]
265pub struct Undefined;
269
270impl std::fmt::Display for Undefined {
271 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272 write!(f, "Undefined")
273 }
274}
275
276impl Backoff for Undefined {
277 async fn next_backoff(&self) -> bool {
278 true
279 }
280
281 async fn reset(&self) {}
282}
283
284mod private {
285 use super::*;
286
287 pub trait Sealed<S> {}
288
289 impl<S> Sealed<S> for Undefined {}
290 impl<F, S> Sealed<(S,)> for F where
291 F: Fn(&Context<S>, &Request<RetryBody>) -> Option<(Context<S>, Request<RetryBody>)>
292 + Send
293 + Sync
294 + 'static
295 {
296 }
297 impl<F, Fut, S, R, E> Sealed<(S, R, E)> for F
298 where
299 F: Fn(Context<S>, Result<R, E>) -> Fut + Send + Sync + 'static,
300 Fut: Future<Output = (Context<S>, Result<R, E>, bool)> + Send + 'static,
301 {
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308 use crate::{StatusCode, service::web::response::IntoResponse};
309 use rama_utils::{backoff::ExponentialBackoff, rng::HasherRng};
310 use std::time::Duration;
311
312 fn assert_clone_input_none(
313 ctx: &Context<()>,
314 req: &Request<RetryBody>,
315 policy: &impl Policy<(), Response, ()>,
316 ) {
317 assert!(policy.clone_input(ctx, req).is_none());
318 }
319
320 fn assert_clone_input_some(
321 ctx: &Context<()>,
322 req: &Request<RetryBody>,
323 policy: &impl Policy<(), Response, ()>,
324 ) {
325 assert!(policy.clone_input(ctx, req).is_some());
326 }
327
328 async fn assert_retry(
329 ctx: Context<()>,
330 req: Request<RetryBody>,
331 result: Result<Response, ()>,
332 policy: &impl Policy<(), Response, ()>,
333 ) {
334 match policy.retry(ctx, req, result).await {
335 PolicyResult::Retry { .. } => (),
336 PolicyResult::Abort(_) => panic!("expected retry"),
337 };
338 }
339
340 async fn assert_abort(
341 ctx: Context<()>,
342 req: Request<RetryBody>,
343 result: Result<Response, ()>,
344 policy: &impl Policy<(), Response, ()>,
345 ) {
346 match policy.retry(ctx, req, result).await {
347 PolicyResult::Retry { .. } => panic!("expected abort"),
348 PolicyResult::Abort(_) => (),
349 };
350 }
351
352 #[tokio::test]
353 async fn managed_policy_default() {
354 let request = Request::builder()
355 .method("GET")
356 .uri("http://example.com")
357 .body(RetryBody::empty())
358 .unwrap();
359
360 let policy = ManagedPolicy::default();
361
362 assert_clone_input_some(&Context::default(), &request, &policy);
363
364 assert_abort(
366 Context::default(),
367 request.clone(),
368 Ok(StatusCode::OK.into_response()),
369 &policy,
370 )
371 .await;
372
373 assert_retry(
375 Context::default(),
376 request.clone(),
377 Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
378 &policy,
379 )
380 .await;
381
382 assert_retry(Context::default(), request, Err(()), &policy).await;
384 }
385
386 #[tokio::test]
387 async fn managed_policy_default_do_not_retry() {
388 let req = Request::builder()
389 .method("GET")
390 .uri("http://example.com")
391 .body(RetryBody::empty())
392 .unwrap();
393
394 let policy = ManagedPolicy::default();
395
396 let mut ctx = Context::default();
397 ctx.insert(DoNotRetry);
398
399 assert_clone_input_none(&ctx, &req, &policy);
400
401 assert_abort(
403 ctx.clone(),
404 req.clone(),
405 Ok(StatusCode::OK.into_response()),
406 &policy,
407 )
408 .await;
409
410 assert_abort(
412 ctx.clone(),
413 req.clone(),
414 Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
415 &policy,
416 )
417 .await;
418
419 assert_abort(ctx, req, Err(()), &policy).await;
421 }
422
423 #[tokio::test]
424 async fn test_policy_custom_clone_fn() {
425 let req = Request::builder()
426 .method("GET")
427 .uri("http://example.com")
428 .body(RetryBody::empty())
429 .unwrap();
430
431 fn clone_fn<S>(
432 _: &Context<S>,
433 _: &Request<RetryBody>,
434 ) -> Option<(Context<S>, Request<RetryBody>)> {
435 None
436 }
437
438 let policy = ManagedPolicy::default().with_clone(clone_fn);
439
440 assert_clone_input_none(&Context::default(), &req, &policy);
441
442 assert_abort(
444 Context::default(),
445 req,
446 Ok(StatusCode::OK.into_response()),
447 &policy,
448 )
449 .await;
450 }
451
452 #[tokio::test]
453 async fn test_policy_custom_retry_fn() {
454 let req = Request::builder()
455 .method("GET")
456 .uri("http://example.com")
457 .body(RetryBody::empty())
458 .unwrap();
459
460 async fn retry_fn<S, R, E>(
461 ctx: Context<S>,
462 result: Result<R, E>,
463 ) -> (Context<S>, Result<R, E>, bool) {
464 match result {
465 Ok(_) => (ctx, result, false),
466 Err(_) => (ctx, result, true),
467 }
468 }
469
470 let policy = ManagedPolicy::new(retry_fn);
471
472 assert_clone_input_some(&Context::default(), &req, &policy);
474
475 assert_abort(
477 Context::default(),
478 req.clone(),
479 Ok(StatusCode::OK.into_response()),
480 &policy,
481 )
482 .await;
483 assert_abort(
484 Context::default(),
485 req.clone(),
486 Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
487 &policy,
488 )
489 .await;
490 assert_retry(Context::default(), req, Err(()), &policy).await;
491 }
492
493 #[tokio::test]
494 async fn test_policy_fully_custom() {
495 let req = Request::builder()
496 .method("GET")
497 .uri("http://example.com")
498 .body(RetryBody::empty())
499 .unwrap();
500
501 fn clone_fn<S>(
502 _: &Context<S>,
503 _: &Request<RetryBody>,
504 ) -> Option<(Context<S>, Request<RetryBody>)> {
505 None
506 }
507
508 async fn retry_fn<S, R, E>(
509 ctx: Context<S>,
510 result: Result<R, E>,
511 ) -> (Context<S>, Result<R, E>, bool) {
512 match result {
513 Ok(_) => (ctx, result, false),
514 Err(_) => (ctx, result, true),
515 }
516 }
517
518 let backoff = ExponentialBackoff::new(
519 Duration::from_millis(1),
520 Duration::from_millis(5),
521 0.1,
522 HasherRng::default,
523 )
524 .unwrap();
525
526 let policy = ManagedPolicy::default()
527 .with_backoff(backoff)
528 .with_clone(clone_fn)
529 .with_retry(retry_fn);
530
531 assert_clone_input_none(&Context::default(), &req, &policy);
532
533 assert_abort(
535 Context::default(),
536 req.clone(),
537 Ok(StatusCode::OK.into_response()),
538 &policy,
539 )
540 .await;
541 assert_abort(
542 Context::default(),
543 req.clone(),
544 Ok(StatusCode::INTERNAL_SERVER_ERROR.into_response()),
545 &policy,
546 )
547 .await;
548 assert_retry(Context::default(), req, Err(()), &policy).await;
549 }
550}