1use http::StatusCode;
2use reqwest::{blocking::Response, Result};
3use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
4use std::thread::sleep;
5use std::time::Duration;
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9pub enum RetryStrategy {
10 Automatic,
14 Always,
16}
17
18#[derive(Clone, Debug, PartialEq)]
20pub struct RetryConfig {
21 pub strategy: RetryStrategy,
23 pub max_retry_count: u8,
25 pub base_wait: Duration,
27 pub backoff_factor: f64,
30}
31
32#[derive(Debug)]
33pub(crate) struct Retrier {
34 config: RetryConfig,
35 is_first_request: AtomicBool,
36}
37
38impl Retrier {
39 pub fn new(config: RetryConfig) -> Self {
40 Self {
41 config,
42 is_first_request: AtomicBool::new(true),
43 }
44 }
45
46 fn should_retry(status: StatusCode) -> bool {
47 status.is_server_error()
48 || status == StatusCode::TOO_MANY_REQUESTS
49 || status == StatusCode::CONFLICT
50 }
51
52 pub fn with_retries(&self, send_request: impl Fn() -> Result<Response>) -> Result<Response> {
53 if self.is_first_request.swap(false, SeqCst)
54 && self.config.strategy == RetryStrategy::Automatic
55 {
56 return send_request();
57 }
58
59 for i_retry in 0..self.config.max_retry_count {
60 macro_rules! warn_and_sleep {
61 ($src:expr) => {{
62 let wait_factor = self.config.backoff_factor.powi(i_retry.into());
63 let duration = self.config.base_wait.mul_f64(wait_factor);
64 log::warn!("{} - retrying after {:?}.", $src, duration);
65 sleep(duration)
66 }};
67 }
68
69 match send_request() {
70 Ok(response) if Self::should_retry(response.status()) => {
71 warn_and_sleep!(format!("{} for {}", response.status(), response.url()))
72 }
73 Err(error) if error.is_timeout() || error.is_connect() || error.is_request() => {
74 warn_and_sleep!(error)
75 }
76 result => return result,
78 }
79 }
80
81 send_request()
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::{Retrier, RetryConfig, RetryStrategy};
89 use mockito::{mock, server_address};
90 use reqwest::blocking::{get, Client};
91 use std::thread::sleep;
92 use std::time::Duration;
93
94 #[test]
95 fn test_always_retry() {
96 let mut handler = Retrier::new(RetryConfig {
97 strategy: RetryStrategy::Always,
98 max_retry_count: 5,
99 base_wait: Duration::from_secs(0),
100 backoff_factor: 0.0,
101 });
102
103 let ok = mock("GET", "/").expect(1).create();
105 assert!(
106 handler
107 .with_retries(|| get(format!("http://{}", server_address())))
108 .unwrap()
109 .status()
110 == 200
111 );
112 ok.assert();
113
114 for i_retry in 0..10 {
116 let err = mock("GET", "/")
117 .with_status(500)
118 .expect((i_retry + 1).into())
119 .create();
120 handler.config.max_retry_count = i_retry;
121 assert!(
122 handler
123 .with_retries(|| get(format!("http://{}", server_address())))
124 .unwrap()
125 .status()
126 == 500
127 );
128 err.assert();
129 }
130 }
131
132 #[test]
133 fn test_automatic_retry() {
134 let mut handler = Retrier::new(RetryConfig {
135 strategy: RetryStrategy::Automatic,
136 max_retry_count: 5,
137 base_wait: Duration::from_secs(0),
138 backoff_factor: 0.0,
139 });
140
141 let err = mock("GET", "/").with_status(500).expect(1).create();
143 assert!(
144 handler
145 .with_retries(|| get(format!("http://{}", server_address())))
146 .unwrap()
147 .status()
148 == 500
149 );
150 err.assert();
151
152 let ok = mock("GET", "/").expect(1).create();
154 assert!(
155 handler
156 .with_retries(|| get(format!("http://{}", server_address())))
157 .unwrap()
158 .status()
159 == 200
160 );
161 ok.assert();
162
163 for i_retry in 0..10 {
165 let err = mock("GET", "/")
166 .with_status(500)
167 .expect((i_retry + 1).into())
168 .create();
169 handler.config.max_retry_count = i_retry;
170 assert!(
171 handler
172 .with_retries(|| get(format!("http://{}", server_address())))
173 .unwrap()
174 .status()
175 == 500
176 );
177 err.assert();
178 }
179 }
180
181 #[test]
182 fn test_timeout_retry() {
183 let handler = Retrier::new(RetryConfig {
184 strategy: RetryStrategy::Always,
185 max_retry_count: 1,
186 base_wait: Duration::from_secs(0),
187 backoff_factor: 0.0,
188 });
189
190 let timeout = mock("GET", "/")
192 .with_body_from_fn(|_| {
193 sleep(Duration::from_secs_f64(0.2));
194 Ok(())
195 })
196 .expect(2)
197 .create();
198 let client = Client::new();
199 assert!(handler
200 .with_retries(|| client
201 .get(format!("http://{}", server_address()))
202 .timeout(Duration::from_secs_f64(0.1))
203 .send()
204 .and_then(|r| {
205 let _ = r.text()?;
207 unreachable!()
208 }))
209 .unwrap_err()
210 .is_timeout());
211 timeout.assert();
212 }
213}