reqwest_retry_after/
lib.rs

1//! # reqwest-retry-after
2//!
3//! `reqwest-retry-after` is a library that adds support for the `Retry-After` header
4//! in [`reqwest`], using [`reqwest_middleware`].
5//!
6//! ## Usage
7//!
8//! Pass [`RetryAfterMiddleware`] to the [`ClientWithMiddleware`] builder.
9//!
10//! ```
11//! use reqwest_middleware::{ClientBuilder, ClientWithMiddleware};
12//! use reqwest_retry_after::RetryAfterMiddleware;
13//!
14//! let client = ClientBuilder::new(reqwest::Client::new())
15//!     .with(RetryAfterMiddleware::new())
16//!     .build();
17//! ```
18#![warn(missing_docs)]
19#![warn(rustdoc::missing_doc_code_examples)]
20
21use std::{
22    collections::HashMap,
23    time::{Duration, SystemTime},
24};
25
26use http::{header::RETRY_AFTER, Extensions};
27use reqwest::Url;
28use reqwest_middleware::{
29    reqwest::{Request, Response},
30    Middleware, Next, Result,
31};
32use time::{format_description::well_known::Rfc2822, OffsetDateTime};
33use tokio::sync::RwLock;
34
35/// The `RetryAfterMiddleware` is a [`Middleware`] that adds support for the `Retry-After`
36/// header in [`reqwest`].
37pub struct RetryAfterMiddleware {
38    retry_after: RwLock<HashMap<Url, SystemTime>>,
39}
40
41impl RetryAfterMiddleware {
42    /// Creates a new `RetryAfterMiddleware`.
43    pub fn new() -> Self {
44        Self {
45            retry_after: RwLock::new(HashMap::new()),
46        }
47    }
48}
49
50impl Default for RetryAfterMiddleware {
51    fn default() -> Self {
52        Self::new()
53    }
54}
55
56fn parse_retry_value(val: &str) -> Option<SystemTime> {
57    if let Ok(secs) = val.parse::<u64>() {
58        return Some(SystemTime::now() + Duration::from_secs(secs));
59    }
60    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
61        return Some(date.into());
62    }
63    None
64}
65
66#[async_trait::async_trait]
67impl Middleware for RetryAfterMiddleware {
68    async fn handle(
69        &self,
70        req: Request,
71        extensions: &mut Extensions,
72        next: Next<'_>,
73    ) -> Result<Response> {
74        let url = req.url().clone();
75
76        if let Some(timestamp) = self.retry_after.read().await.get(&url) {
77            let now = SystemTime::now();
78
79            if let Ok(duration) = timestamp.duration_since(now) {
80                tokio::time::sleep(duration).await;
81            }
82        }
83
84        let res = next.run(req, extensions).await;
85
86        if let Ok(res) = &res {
87            match res.headers().get(RETRY_AFTER) {
88                Some(retry_after) => {
89                    if let Ok(val) = retry_after.to_str() {
90                        if let Some(timestamp) = parse_retry_value(val) {
91                            self.retry_after
92                                .write()
93                                .await
94                                .insert(url.clone(), timestamp);
95                        }
96                    }
97                }
98                _ => {
99                    self.retry_after.write().await.remove(&url);
100                }
101            }
102        }
103        res
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use std::{
110        str::FromStr,
111        sync::Arc,
112        time::{Duration, SystemTime},
113    };
114
115    use httpmock::{Method::GET, MockServer};
116    use reqwest::Url;
117    use reqwest_middleware::ClientBuilder;
118    use time::{format_description::well_known::Rfc2822, OffsetDateTime};
119
120    use crate::RetryAfterMiddleware;
121
122    #[tokio::test]
123    async fn test() {
124        // create
125        let test_duration = Duration::from_secs(2);
126        let middleware = Arc::new(RetryAfterMiddleware::new());
127
128        // build client with middleware
129        let client = ClientBuilder::new(reqwest::Client::new())
130            .with_arc(middleware.clone())
131            .build();
132
133        test_empty_retry_after(&middleware).await;
134
135        // create mock server
136        let server = MockServer::start();
137        let pre_ra_mock = server.mock(|when, then| {
138            when.method(GET).path("/").header("RA", "true");
139            then.status(200)
140                .header("Retry-After", test_duration.as_secs().to_string())
141                .body("");
142        });
143        let post_ra_mock = server.mock(|when, then| {
144            when.method(GET).path("/");
145            then.status(200).body("");
146        });
147        let normal_mock = server.mock(|when, then| {
148            when.method(GET).path("/normal");
149            then.status(200).body("");
150        });
151
152        let url = Url::from_str(&server.url("/")).unwrap();
153
154        // hit URL; get RA value and store it
155        let pre_test = SystemTime::now();
156        client
157            .get(url.clone())
158            .header("RA", "true")
159            .send()
160            .await
161            .unwrap();
162        pre_ra_mock.assert_async().await;
163        test_valid_retry_after(&middleware, &url, pre_test, test_duration).await;
164
165        // hit other URL, which should return instantly
166        let normal = Url::from_str(&server.url("/normal")).unwrap();
167        let before_normal = SystemTime::now();
168        client.get(normal.clone()).send().await.unwrap();
169        normal_mock.assert_async().await;
170        assert!(
171            SystemTime::now()
172                .duration_since(before_normal)
173                .unwrap()
174                .as_secs_f64()
175                <= 0.2
176        );
177        test_absent_retry_after(&middleware, &normal).await;
178
179        // hit URL with stored RA
180        client.get(url.clone()).send().await.unwrap();
181        post_ra_mock.assert_async().await;
182
183        // this should have (1) slept and (2) cleared the stored RA afterward
184        let post_test = SystemTime::now();
185        assert!(post_test.duration_since(pre_test).unwrap() >= test_duration);
186        test_empty_retry_after(&middleware).await;
187    }
188
189    #[tokio::test]
190    async fn test_rfc2822() {
191        let mut test_duration = Duration::from_secs(2);
192
193        // Build server and client
194        let server = MockServer::start();
195        let middleware = Arc::new(RetryAfterMiddleware::new());
196        let client = ClientBuilder::new(reqwest::Client::new())
197            .with_arc(middleware.clone())
198            .build();
199
200        // Conversion to RFC 2822 floors the duration, so compensate with ceiling function.
201        let begin =
202            OffsetDateTime::now_utc().replace_nanosecond(0).unwrap() + Duration::from_secs(1);
203        let ra = begin + test_duration;
204        test_duration = (begin - ra).unsigned_abs();
205
206        let ra_mock = server.mock(|when, then| {
207            when.method(GET).path("/").header("RA", "true");
208            then.status(200)
209                .header("Retry-After", ra.format(&Rfc2822).unwrap())
210                .body("");
211        });
212        let no_ra_mock = server.mock(|when, then| {
213            when.method(GET).path("/");
214            then.status(200).body("");
215        });
216
217        // hit URL; store RA value
218        let url = Url::from_str(&server.url("/")).unwrap();
219        client
220            .get(url.clone())
221            .header("RA", "true")
222            .send()
223            .await
224            .unwrap();
225        test_valid_retry_after(&middleware, &url, SystemTime::now(), test_duration).await;
226        ra_mock.assert_async().await;
227
228        // hit URL with stored RA
229        client.get(url.clone()).send().await.unwrap();
230        no_ra_mock.assert_async().await;
231
232        // this should have (1) slept and (2) cleared the stored RA afterward
233        let duration = SystemTime::now().duration_since(begin.into()).unwrap();
234        assert!(duration >= test_duration);
235        test_empty_retry_after(&middleware).await;
236    }
237
238    async fn test_valid_retry_after(
239        middleware: &Arc<RetryAfterMiddleware>,
240        url: &Url,
241        now: SystemTime,
242        test_duration: Duration,
243    ) {
244        let time = middleware
245            .retry_after
246            .read()
247            .await
248            .get(url)
249            .cloned()
250            .unwrap();
251        let duration = time.duration_since(now).unwrap();
252        assert!(duration >= test_duration);
253    }
254
255    async fn test_absent_retry_after(middleware: &Arc<RetryAfterMiddleware>, url: &Url) {
256        assert!(middleware.retry_after.read().await.get(url).is_none());
257    }
258
259    async fn test_empty_retry_after(middleware: &Arc<RetryAfterMiddleware>) {
260        assert!(middleware.retry_after.read().await.is_empty());
261    }
262}