1use http::StatusCode;
2use reqwest::{blocking::Response, Result};
3use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
4use std::thread::sleep;
5use std::time::Duration;
6
7#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9pub enum RetryStrategy {
10 Automatic,
14 Always,
16}
17
18#[derive(Clone, Debug, PartialEq)]
20pub struct RetryConfig {
21 pub strategy: RetryStrategy,
23 pub max_retry_count: u8,
25 pub base_wait: Duration,
27 pub backoff_factor: f64,
30}
31
32#[derive(Debug)]
33pub(crate) struct Retrier {
34 config: RetryConfig,
35 is_first_request: AtomicBool,
36}
37
38impl Retrier {
39 pub fn new(config: RetryConfig) -> Self {
40 Self {
41 config,
42 is_first_request: AtomicBool::new(true),
43 }
44 }
45
46 fn should_retry(status: StatusCode) -> bool {
47 status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS
48 }
49
50 pub fn with_retries(&self, send_request: impl Fn() -> Result<Response>) -> Result<Response> {
51 if self.is_first_request.swap(false, SeqCst)
52 && self.config.strategy == RetryStrategy::Automatic
53 {
54 return send_request();
55 }
56
57 for i_retry in 0..self.config.max_retry_count {
58 macro_rules! warn_and_sleep {
59 ($src:expr) => {{
60 let wait_factor = self.config.backoff_factor.powi(i_retry.into());
61 let duration = self.config.base_wait.mul_f64(wait_factor);
62 log::warn!("{} - retrying after {:?}.", $src, duration);
63 sleep(duration)
64 }};
65 }
66
67 match send_request() {
68 Ok(response) if Self::should_retry(response.status()) => {
69 warn_and_sleep!(format!("{} for {}", response.status(), response.url()))
70 }
71 Err(error) if error.is_timeout() || error.is_connect() || error.is_request() => {
72 warn_and_sleep!(error)
73 }
74 result => return result,
76 }
77 }
78
79 send_request()
81 }
82}
83
84#[cfg(test)]
85mod tests {
86 use super::{Retrier, RetryConfig, RetryStrategy};
87 use mockito::{mock, server_address};
88 use reqwest::blocking::{get, Client};
89 use std::thread::sleep;
90 use std::time::Duration;
91
92 #[test]
93 fn test_always_retry() {
94 let mut handler = Retrier::new(RetryConfig {
95 strategy: RetryStrategy::Always,
96 max_retry_count: 5,
97 base_wait: Duration::from_secs(0),
98 backoff_factor: 0.0,
99 });
100
101 let ok = mock("GET", "/").expect(1).create();
103 assert!(
104 handler
105 .with_retries(|| get(format!("http://{}", server_address())))
106 .unwrap()
107 .status()
108 == 200
109 );
110 ok.assert();
111
112 for i_retry in 0..10 {
114 let err = mock("GET", "/")
115 .with_status(500)
116 .expect((i_retry + 1).into())
117 .create();
118 handler.config.max_retry_count = i_retry;
119 assert!(
120 handler
121 .with_retries(|| get(format!("http://{}", server_address())))
122 .unwrap()
123 .status()
124 == 500
125 );
126 err.assert();
127 }
128 }
129
130 #[test]
131 fn test_automatic_retry() {
132 let mut handler = Retrier::new(RetryConfig {
133 strategy: RetryStrategy::Automatic,
134 max_retry_count: 5,
135 base_wait: Duration::from_secs(0),
136 backoff_factor: 0.0,
137 });
138
139 let err = mock("GET", "/").with_status(500).expect(1).create();
141 assert!(
142 handler
143 .with_retries(|| get(format!("http://{}", server_address())))
144 .unwrap()
145 .status()
146 == 500
147 );
148 err.assert();
149
150 let ok = mock("GET", "/").expect(1).create();
152 assert!(
153 handler
154 .with_retries(|| get(format!("http://{}", server_address())))
155 .unwrap()
156 .status()
157 == 200
158 );
159 ok.assert();
160
161 for i_retry in 0..10 {
163 let err = mock("GET", "/")
164 .with_status(500)
165 .expect((i_retry + 1).into())
166 .create();
167 handler.config.max_retry_count = i_retry;
168 assert!(
169 handler
170 .with_retries(|| get(format!("http://{}", server_address())))
171 .unwrap()
172 .status()
173 == 500
174 );
175 err.assert();
176 }
177 }
178
179 #[test]
180 fn test_timeout_retry() {
181 let handler = Retrier::new(RetryConfig {
182 strategy: RetryStrategy::Always,
183 max_retry_count: 1,
184 base_wait: Duration::from_secs(0),
185 backoff_factor: 0.0,
186 });
187
188 let timeout = mock("GET", "/")
190 .with_body_from_fn(|_| {
191 sleep(Duration::from_secs_f64(0.2));
192 Ok(())
193 })
194 .expect(2)
195 .create();
196 let client = Client::new();
197 assert!(handler
198 .with_retries(|| client
199 .get(format!("http://{}", server_address()))
200 .timeout(Duration::from_secs_f64(0.1))
201 .send()
202 .and_then(|r| {
203 let _ = r.text()?;
205 unreachable!()
206 }))
207 .unwrap_err()
208 .is_timeout());
209 timeout.assert();
210 }
211}