reqwest_retry_after/
lib.rs1#![warn(missing_docs)]
19#![warn(rustdoc::missing_doc_code_examples)]
20
21use std::{
22 collections::HashMap,
23 time::{Duration, SystemTime},
24};
25
26use http::{header::RETRY_AFTER, Extensions};
27use reqwest::Url;
28use reqwest_middleware::{
29 reqwest::{Request, Response},
30 Middleware, Next, Result,
31};
32use time::{format_description::well_known::Rfc2822, OffsetDateTime};
33use tokio::sync::RwLock;
34
35pub struct RetryAfterMiddleware {
38 retry_after: RwLock<HashMap<Url, SystemTime>>,
39}
40
41impl RetryAfterMiddleware {
42 pub fn new() -> Self {
44 Self {
45 retry_after: RwLock::new(HashMap::new()),
46 }
47 }
48}
49
50impl Default for RetryAfterMiddleware {
51 fn default() -> Self {
52 Self::new()
53 }
54}
55
56fn parse_retry_value(val: &str) -> Option<SystemTime> {
57 if let Ok(secs) = val.parse::<u64>() {
58 return Some(SystemTime::now() + Duration::from_secs(secs));
59 }
60 if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
61 return Some(date.into());
62 }
63 None
64}
65
66#[async_trait::async_trait]
67impl Middleware for RetryAfterMiddleware {
68 async fn handle(
69 &self,
70 req: Request,
71 extensions: &mut Extensions,
72 next: Next<'_>,
73 ) -> Result<Response> {
74 let url = req.url().clone();
75
76 if let Some(timestamp) = self.retry_after.read().await.get(&url) {
77 let now = SystemTime::now();
78
79 if let Ok(duration) = timestamp.duration_since(now) {
80 tokio::time::sleep(duration).await;
81 }
82 }
83
84 let res = next.run(req, extensions).await;
85
86 if let Ok(res) = &res {
87 match res.headers().get(RETRY_AFTER) {
88 Some(retry_after) => {
89 if let Ok(val) = retry_after.to_str() {
90 if let Some(timestamp) = parse_retry_value(val) {
91 self.retry_after
92 .write()
93 .await
94 .insert(url.clone(), timestamp);
95 }
96 }
97 }
98 _ => {
99 self.retry_after.write().await.remove(&url);
100 }
101 }
102 }
103 res
104 }
105}
106
107#[cfg(test)]
108mod test {
109 use std::{
110 str::FromStr,
111 sync::Arc,
112 time::{Duration, SystemTime},
113 };
114
115 use httpmock::{Method::GET, MockServer};
116 use reqwest::Url;
117 use reqwest_middleware::ClientBuilder;
118 use time::{format_description::well_known::Rfc2822, OffsetDateTime};
119
120 use crate::RetryAfterMiddleware;
121
122 #[tokio::test]
123 async fn test() {
124 let test_duration = Duration::from_secs(2);
126 let middleware = Arc::new(RetryAfterMiddleware::new());
127
128 let client = ClientBuilder::new(reqwest::Client::new())
130 .with_arc(middleware.clone())
131 .build();
132
133 test_empty_retry_after(&middleware).await;
134
135 let server = MockServer::start();
137 let pre_ra_mock = server.mock(|when, then| {
138 when.method(GET).path("/").header("RA", "true");
139 then.status(200)
140 .header("Retry-After", test_duration.as_secs().to_string())
141 .body("");
142 });
143 let post_ra_mock = server.mock(|when, then| {
144 when.method(GET).path("/");
145 then.status(200).body("");
146 });
147 let normal_mock = server.mock(|when, then| {
148 when.method(GET).path("/normal");
149 then.status(200).body("");
150 });
151
152 let url = Url::from_str(&server.url("/")).unwrap();
153
154 let pre_test = SystemTime::now();
156 client
157 .get(url.clone())
158 .header("RA", "true")
159 .send()
160 .await
161 .unwrap();
162 pre_ra_mock.assert_async().await;
163 test_valid_retry_after(&middleware, &url, pre_test, test_duration).await;
164
165 let normal = Url::from_str(&server.url("/normal")).unwrap();
167 let before_normal = SystemTime::now();
168 client.get(normal.clone()).send().await.unwrap();
169 normal_mock.assert_async().await;
170 assert!(
171 SystemTime::now()
172 .duration_since(before_normal)
173 .unwrap()
174 .as_secs_f64()
175 <= 0.2
176 );
177 test_absent_retry_after(&middleware, &normal).await;
178
179 client.get(url.clone()).send().await.unwrap();
181 post_ra_mock.assert_async().await;
182
183 let post_test = SystemTime::now();
185 assert!(post_test.duration_since(pre_test).unwrap() >= test_duration);
186 test_empty_retry_after(&middleware).await;
187 }
188
189 #[tokio::test]
190 async fn test_rfc2822() {
191 let mut test_duration = Duration::from_secs(2);
192
193 let server = MockServer::start();
195 let middleware = Arc::new(RetryAfterMiddleware::new());
196 let client = ClientBuilder::new(reqwest::Client::new())
197 .with_arc(middleware.clone())
198 .build();
199
200 let begin =
202 OffsetDateTime::now_utc().replace_nanosecond(0).unwrap() + Duration::from_secs(1);
203 let ra = begin + test_duration;
204 test_duration = (begin - ra).unsigned_abs();
205
206 let ra_mock = server.mock(|when, then| {
207 when.method(GET).path("/").header("RA", "true");
208 then.status(200)
209 .header("Retry-After", ra.format(&Rfc2822).unwrap())
210 .body("");
211 });
212 let no_ra_mock = server.mock(|when, then| {
213 when.method(GET).path("/");
214 then.status(200).body("");
215 });
216
217 let url = Url::from_str(&server.url("/")).unwrap();
219 client
220 .get(url.clone())
221 .header("RA", "true")
222 .send()
223 .await
224 .unwrap();
225 test_valid_retry_after(&middleware, &url, SystemTime::now(), test_duration).await;
226 ra_mock.assert_async().await;
227
228 client.get(url.clone()).send().await.unwrap();
230 no_ra_mock.assert_async().await;
231
232 let duration = SystemTime::now().duration_since(begin.into()).unwrap();
234 assert!(duration >= test_duration);
235 test_empty_retry_after(&middleware).await;
236 }
237
238 async fn test_valid_retry_after(
239 middleware: &Arc<RetryAfterMiddleware>,
240 url: &Url,
241 now: SystemTime,
242 test_duration: Duration,
243 ) {
244 let time = middleware
245 .retry_after
246 .read()
247 .await
248 .get(url)
249 .cloned()
250 .unwrap();
251 let duration = time.duration_since(now).unwrap();
252 assert!(duration >= test_duration);
253 }
254
255 async fn test_absent_retry_after(middleware: &Arc<RetryAfterMiddleware>, url: &Url) {
256 assert!(middleware.retry_after.read().await.get(url).is_none());
257 }
258
259 async fn test_empty_retry_after(middleware: &Arc<RetryAfterMiddleware>) {
260 assert!(middleware.retry_after.read().await.is_empty());
261 }
262}