1use async_trait::async_trait;
5use http::Extensions;
6use rand::{rng, Rng};
7use reqwest::{Request, Response, StatusCode};
8use reqwest_middleware::{Middleware, Next, Result};
9use std::time::Duration;
10
11pub struct CustomRetryMiddleware {
13 max_retries: u32,
14 max_delay_ms: u64,
15 initial_delay_ms: u64,
16}
17
18#[async_trait]
19impl Middleware for CustomRetryMiddleware {
20 async fn handle(
21 &self,
22 req: Request,
23 extensions: &mut Extensions,
24 next: Next<'_>,
25 ) -> Result<Response> {
26 self.execute_with_retry(req, next, extensions).await
27 }
28}
29
30impl CustomRetryMiddleware {
31 pub fn new(max_retries: u32, max_delay_ms: u64, initial_delay_ms: u64) -> Self {
33 Self {
34 max_retries: max_retries.min(10),
35 max_delay_ms,
36 initial_delay_ms,
37 }
38 }
39
40 async fn execute_with_retry<'a>(
41 &'a self,
42 req: Request,
43 next: Next<'a>,
44 ext: &'a mut Extensions,
45 ) -> Result<Response> {
46 let mut n_past_retries = 0;
47 let mut last_req_401 = false;
48 loop {
49 let duplicate_request = match req.try_clone() {
50 Some(x) => x,
51 None => return next.run(req, ext).await,
52 };
53
54 let result = next.clone().run(duplicate_request, ext).await;
55
56 break match Retryable::from_reqwest_response(&result) {
58 Some(retryable)
59 if (retryable == Retryable::Transient
60 || retryable == Retryable::Unauthorized && !last_req_401)
61 && n_past_retries < self.max_retries =>
62 {
63 last_req_401 = retryable == Retryable::Unauthorized;
64 let mut retry_delay = self.initial_delay_ms * 2u64.pow(n_past_retries);
67 if retry_delay > self.max_delay_ms {
68 retry_delay = self.max_delay_ms;
69 }
70 retry_delay = retry_delay / 4 * 3 + rng().random_range(0..=(retry_delay / 2));
72 futures_timer::Delay::new(Duration::from_millis(retry_delay)).await;
73 n_past_retries += 1;
74 continue;
75 }
76 Some(_) | None => result,
77 };
78 }
79 }
80}
81
82#[derive(PartialEq, Eq)]
83pub(crate) enum Retryable {
84 Transient,
86 Fatal,
88 Unauthorized,
90}
91
92impl Retryable {
93 pub fn from_reqwest_response(
101 res: &reqwest_middleware::Result<reqwest::Response>,
102 ) -> Option<Self> {
103 match res {
104 Ok(success) => {
105 let status = success.status();
106 if status.is_success() {
107 None
108 } else if status == StatusCode::UNAUTHORIZED {
109 Some(Retryable::Unauthorized)
110 } else if status.is_server_error()
111 || status == StatusCode::REQUEST_TIMEOUT
112 || status == StatusCode::TOO_MANY_REQUESTS
113 || success
114 .headers()
115 .get("cdf-is-auto-retryable")
116 .and_then(|v| v.to_str().ok())
117 .is_some_and(|v| v == "true")
118 {
119 Some(Retryable::Transient)
120 } else {
121 Some(Retryable::Fatal)
122 }
123 }
124 Err(error) => match error {
125 reqwest_middleware::Error::Middleware(_) => Some(Retryable::Fatal),
126 reqwest_middleware::Error::Reqwest(error) => {
127 if error.is_timeout() || error.is_connect() {
128 Some(Retryable::Transient)
129 } else if error.is_body()
130 || error.is_decode()
131 || error.is_builder()
132 || error.is_redirect()
133 || error.is_request()
134 {
135 Some(Retryable::Fatal)
136 } else {
137 None
141 }
142 }
143 },
144 }
145 }
146}