reinfer_client/
retry.rs

1use http::StatusCode;
2use reqwest::{blocking::Response, Result};
3use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
4use std::thread::sleep;
5use std::time::Duration;
6
7/// Strategy to use if retrying .
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9pub enum RetryStrategy {
10    /// The first request by the client will not be retried, but subsequent requests will.
11    /// This allows fast failure if the client cannot reach the API endpoint at all, but
12    /// helps to mitigate failure in long-running operations spanning multiple requests.
13    Automatic,
14    /// Always attempt to retry requests.
15    Always,
16}
17
18/// Configuration for the Reinfer client if retrying timeouts.
19#[derive(Clone, Debug, PartialEq)]
20pub struct RetryConfig {
21    /// Strategy for when to retry after a timeout
22    pub strategy: RetryStrategy,
23    /// Maximum number of retries to attempt.
24    pub max_retry_count: u8,
25    /// Amount of time to wait for first retry.
26    pub base_wait: Duration,
27    /// Amount of time to scale retry waits. The wait before retry N is an exponential backoff
28    /// using the formula `wait = retry_wait * (backoff_factor * N)`.
29    pub backoff_factor: f64,
30}
31
32#[derive(Debug)]
33pub(crate) struct Retrier {
34    config: RetryConfig,
35    is_first_request: AtomicBool,
36}
37
38impl Retrier {
39    pub fn new(config: RetryConfig) -> Self {
40        Self {
41            config,
42            is_first_request: AtomicBool::new(true),
43        }
44    }
45
46    fn should_retry(status: StatusCode) -> bool {
47        status.is_server_error() || status == StatusCode::TOO_MANY_REQUESTS
48    }
49
50    pub fn with_retries(&self, send_request: impl Fn() -> Result<Response>) -> Result<Response> {
51        if self.is_first_request.swap(false, SeqCst)
52            && self.config.strategy == RetryStrategy::Automatic
53        {
54            return send_request();
55        }
56
57        for i_retry in 0..self.config.max_retry_count {
58            macro_rules! warn_and_sleep {
59                ($src:expr) => {{
60                    let wait_factor = self.config.backoff_factor.powi(i_retry.into());
61                    let duration = self.config.base_wait.mul_f64(wait_factor);
62                    log::warn!("{} - retrying after {:?}.", $src, duration);
63                    sleep(duration)
64                }};
65            }
66
67            match send_request() {
68                Ok(response) if Self::should_retry(response.status()) => {
69                    warn_and_sleep!(format!("{} for {}", response.status(), response.url()))
70                }
71                Err(error) if error.is_timeout() || error.is_connect() || error.is_request() => {
72                    warn_and_sleep!(error)
73                }
74                // If anything else, just return it immediately
75                result => return result,
76            }
77        }
78
79        // On last retry don't handle the error, just propagate all errors.
80        send_request()
81    }
82}
83
84#[cfg(test)]
85mod tests {
86    use super::{Retrier, RetryConfig, RetryStrategy};
87    use mockito::{mock, server_address};
88    use reqwest::blocking::{get, Client};
89    use std::thread::sleep;
90    use std::time::Duration;
91
92    #[test]
93    fn test_always_retry() {
94        let mut handler = Retrier::new(RetryConfig {
95            strategy: RetryStrategy::Always,
96            max_retry_count: 5,
97            base_wait: Duration::from_secs(0),
98            backoff_factor: 0.0,
99        });
100
101        // Does not attempt to retry on success
102        let ok = mock("GET", "/").expect(1).create();
103        assert!(
104            handler
105                .with_retries(|| get(format!("http://{}", server_address())))
106                .unwrap()
107                .status()
108                == 200
109        );
110        ok.assert();
111
112        // Retries up to N times on timeout.
113        for i_retry in 0..10 {
114            let err = mock("GET", "/")
115                .with_status(500)
116                .expect((i_retry + 1).into())
117                .create();
118            handler.config.max_retry_count = i_retry;
119            assert!(
120                handler
121                    .with_retries(|| get(format!("http://{}", server_address())))
122                    .unwrap()
123                    .status()
124                    == 500
125            );
126            err.assert();
127        }
128    }
129
130    #[test]
131    fn test_automatic_retry() {
132        let mut handler = Retrier::new(RetryConfig {
133            strategy: RetryStrategy::Automatic,
134            max_retry_count: 5,
135            base_wait: Duration::from_secs(0),
136            backoff_factor: 0.0,
137        });
138
139        // Does not attempt to retry on failure of first request
140        let err = mock("GET", "/").with_status(500).expect(1).create();
141        assert!(
142            handler
143                .with_retries(|| get(format!("http://{}", server_address())))
144                .unwrap()
145                .status()
146                == 500
147        );
148        err.assert();
149
150        // Does not attempt to retry on success
151        let ok = mock("GET", "/").expect(1).create();
152        assert!(
153            handler
154                .with_retries(|| get(format!("http://{}", server_address())))
155                .unwrap()
156                .status()
157                == 200
158        );
159        ok.assert();
160
161        // Retries up to N times on timeout for non-first-requests.
162        for i_retry in 0..10 {
163            let err = mock("GET", "/")
164                .with_status(500)
165                .expect((i_retry + 1).into())
166                .create();
167            handler.config.max_retry_count = i_retry;
168            assert!(
169                handler
170                    .with_retries(|| get(format!("http://{}", server_address())))
171                    .unwrap()
172                    .status()
173                    == 500
174            );
175            err.assert();
176        }
177    }
178
179    #[test]
180    fn test_timeout_retry() {
181        let handler = Retrier::new(RetryConfig {
182            strategy: RetryStrategy::Always,
183            max_retry_count: 1,
184            base_wait: Duration::from_secs(0),
185            backoff_factor: 0.0,
186        });
187
188        // Should retry on the timeout
189        let timeout = mock("GET", "/")
190            .with_body_from_fn(|_| {
191                sleep(Duration::from_secs_f64(0.2));
192                Ok(())
193            })
194            .expect(2)
195            .create();
196        let client = Client::new();
197        assert!(handler
198            .with_retries(|| client
199                .get(format!("http://{}", server_address()))
200                .timeout(Duration::from_secs_f64(0.1))
201                .send()
202                .and_then(|r| {
203                    // This is a bit of a hack to force a timeout
204                    let _ = r.text()?;
205                    unreachable!()
206                }))
207            .unwrap_err()
208            .is_timeout());
209        timeout.assert();
210    }
211}