Skip to main content

reinfer_client/
retry.rs

1use http::StatusCode;
2use reqwest::{blocking::Response, Result};
3use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
4use std::thread::sleep;
5use std::time::Duration;
6
7/// Strategy to use if retrying .
8#[derive(Copy, Clone, Debug, PartialEq, Eq)]
9pub enum RetryStrategy {
10    /// The first request by the client will not be retried, but subsequent requests will.
11    /// This allows fast failure if the client cannot reach the API endpoint at all, but
12    /// helps to mitigate failure in long-running operations spanning multiple requests.
13    Automatic,
14    /// Always attempt to retry requests.
15    Always,
16}
17
18/// Configuration for the Reinfer client if retrying timeouts.
19#[derive(Clone, Debug, PartialEq)]
20pub struct RetryConfig {
21    /// Strategy for when to retry after a timeout
22    pub strategy: RetryStrategy,
23    /// Maximum number of retries to attempt.
24    pub max_retry_count: u8,
25    /// Amount of time to wait for first retry.
26    pub base_wait: Duration,
27    /// Amount of time to scale retry waits. The wait before retry N is an exponential backoff
28    /// using the formula `wait = retry_wait * (backoff_factor * N)`.
29    pub backoff_factor: f64,
30}
31
32#[derive(Debug)]
33pub(crate) struct Retrier {
34    config: RetryConfig,
35    is_first_request: AtomicBool,
36}
37
38impl Retrier {
39    pub fn new(config: RetryConfig) -> Self {
40        Self {
41            config,
42            is_first_request: AtomicBool::new(true),
43        }
44    }
45
46    fn should_retry(status: StatusCode) -> bool {
47        status.is_server_error()
48            || status == StatusCode::TOO_MANY_REQUESTS
49            || status == StatusCode::CONFLICT
50    }
51
52    pub fn with_retries(&self, send_request: impl Fn() -> Result<Response>) -> Result<Response> {
53        if self.is_first_request.swap(false, SeqCst)
54            && self.config.strategy == RetryStrategy::Automatic
55        {
56            return send_request();
57        }
58
59        for i_retry in 0..self.config.max_retry_count {
60            macro_rules! warn_and_sleep {
61                ($src:expr) => {{
62                    let wait_factor = self.config.backoff_factor.powi(i_retry.into());
63                    let duration = self.config.base_wait.mul_f64(wait_factor);
64                    log::warn!("{} - retrying after {:?}.", $src, duration);
65                    sleep(duration)
66                }};
67            }
68
69            match send_request() {
70                Ok(response) if Self::should_retry(response.status()) => {
71                    warn_and_sleep!(format!("{} for {}", response.status(), response.url()))
72                }
73                Err(error) if error.is_timeout() || error.is_connect() || error.is_request() => {
74                    warn_and_sleep!(error)
75                }
76                // If anything else, just return it immediately
77                result => return result,
78            }
79        }
80
81        // On last retry don't handle the error, just propagate all errors.
82        send_request()
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::{Retrier, RetryConfig, RetryStrategy};
89    use mockito::{mock, server_address};
90    use reqwest::blocking::{get, Client};
91    use std::thread::sleep;
92    use std::time::Duration;
93
94    #[test]
95    fn test_always_retry() {
96        let mut handler = Retrier::new(RetryConfig {
97            strategy: RetryStrategy::Always,
98            max_retry_count: 5,
99            base_wait: Duration::from_secs(0),
100            backoff_factor: 0.0,
101        });
102
103        // Does not attempt to retry on success
104        let ok = mock("GET", "/").expect(1).create();
105        assert!(
106            handler
107                .with_retries(|| get(format!("http://{}", server_address())))
108                .unwrap()
109                .status()
110                == 200
111        );
112        ok.assert();
113
114        // Retries up to N times on timeout.
115        for i_retry in 0..10 {
116            let err = mock("GET", "/")
117                .with_status(500)
118                .expect((i_retry + 1).into())
119                .create();
120            handler.config.max_retry_count = i_retry;
121            assert!(
122                handler
123                    .with_retries(|| get(format!("http://{}", server_address())))
124                    .unwrap()
125                    .status()
126                    == 500
127            );
128            err.assert();
129        }
130    }
131
132    #[test]
133    fn test_automatic_retry() {
134        let mut handler = Retrier::new(RetryConfig {
135            strategy: RetryStrategy::Automatic,
136            max_retry_count: 5,
137            base_wait: Duration::from_secs(0),
138            backoff_factor: 0.0,
139        });
140
141        // Does not attempt to retry on failure of first request
142        let err = mock("GET", "/").with_status(500).expect(1).create();
143        assert!(
144            handler
145                .with_retries(|| get(format!("http://{}", server_address())))
146                .unwrap()
147                .status()
148                == 500
149        );
150        err.assert();
151
152        // Does not attempt to retry on success
153        let ok = mock("GET", "/").expect(1).create();
154        assert!(
155            handler
156                .with_retries(|| get(format!("http://{}", server_address())))
157                .unwrap()
158                .status()
159                == 200
160        );
161        ok.assert();
162
163        // Retries up to N times on timeout for non-first-requests.
164        for i_retry in 0..10 {
165            let err = mock("GET", "/")
166                .with_status(500)
167                .expect((i_retry + 1).into())
168                .create();
169            handler.config.max_retry_count = i_retry;
170            assert!(
171                handler
172                    .with_retries(|| get(format!("http://{}", server_address())))
173                    .unwrap()
174                    .status()
175                    == 500
176            );
177            err.assert();
178        }
179    }
180
181    #[test]
182    fn test_timeout_retry() {
183        let handler = Retrier::new(RetryConfig {
184            strategy: RetryStrategy::Always,
185            max_retry_count: 1,
186            base_wait: Duration::from_secs(0),
187            backoff_factor: 0.0,
188        });
189
190        // Should retry on the timeout
191        let timeout = mock("GET", "/")
192            .with_body_from_fn(|_| {
193                sleep(Duration::from_secs_f64(0.2));
194                Ok(())
195            })
196            .expect(2)
197            .create();
198        let client = Client::new();
199        assert!(handler
200            .with_retries(|| client
201                .get(format!("http://{}", server_address()))
202                .timeout(Duration::from_secs_f64(0.1))
203                .send()
204                .and_then(|r| {
205                    // This is a bit of a hack to force a timeout
206                    let _ = r.text()?;
207                    unreachable!()
208                }))
209            .unwrap_err()
210            .is_timeout());
211        timeout.assert();
212    }
213}