gr/
backoff.rs

1use std::sync::Arc;
2
3use serde::Serialize;
4
5use crate::error::{AddContext, GRError};
6use crate::http::throttle::ThrottleStrategy;
7use crate::io::{HttpRunner, RateLimitHeader};
8use crate::log_error;
9use crate::{error, log_info, Result};
10use crate::{http::Request, io::HttpResponse, time::Seconds};
11
12/// ExponentialBackoff wraps an HttpRunner and retries requests with an
13/// exponential backoff retry mechanism.
14pub struct Backoff<'a, R> {
15    runner: &'a Arc<R>,
16    max_retries: u32,
17    num_retries: u32,
18    rate_limit_header: RateLimitHeader,
19    default_delay_wait: Seconds,
20    now: fn() -> Seconds,
21    backoff_strategy: Box<dyn BackOffStrategy>,
22    throttler: Box<dyn ThrottleStrategy>,
23}
24
25impl<'a, R> Backoff<'a, R> {
26    pub fn new(
27        runner: &'a Arc<R>,
28        max_retries: u32,
29        default_delay_wait: u64,
30        now: fn() -> Seconds,
31        strategy: Box<dyn BackOffStrategy>,
32        throttler_strategy: Box<dyn ThrottleStrategy>,
33    ) -> Self {
34        Backoff {
35            runner,
36            max_retries,
37            num_retries: 0,
38            rate_limit_header: RateLimitHeader::default(),
39            default_delay_wait: Seconds::new(default_delay_wait),
40            now,
41            backoff_strategy: strategy,
42            throttler: throttler_strategy,
43        }
44    }
45
46    fn log_backoff_enabled(&self) {
47        log_info!("Backoff enabled with {} max retries", self.max_retries);
48    }
49}
50
51impl<R: HttpRunner<Response = HttpResponse>> Backoff<'_, R> {
52    pub fn retry_on_error<T: Serialize>(
53        &mut self,
54        request: &mut Request<T>,
55    ) -> Result<HttpResponse> {
56        loop {
57            if self.num_retries > 0 {
58                log_info!(
59                    "Retrying request {} out of {}",
60                    self.num_retries,
61                    self.max_retries
62                );
63            }
64            match self.runner.run(request) {
65                Ok(response) => return Ok(response),
66                Err(err) => {
67                    if self.max_retries == 0 {
68                        return Err(err);
69                    }
70                    log_error!("Error: {}", err);
71                    // https://docs.github.com/en/rest/using-the-rest-api/rate-limits-for-the-rest-api?apiVersion=2022-11-28#exceeding-the-rate-limit
72                    match err.downcast_ref::<error::GRError>() {
73                        Some(error::GRError::RateLimitExceeded(headers)) => {
74                            self.rate_limit_header = *headers;
75                            self.num_retries += 1;
76                            if self.num_retries <= self.max_retries {
77                                let now = (self.now)();
78                                let mut base_wait_time = if self.rate_limit_header.reset > now {
79                                    self.rate_limit_header.reset - now
80                                } else {
81                                    self.default_delay_wait
82                                };
83                                if self.rate_limit_header.retry_after > Seconds::new(0) {
84                                    base_wait_time = self.rate_limit_header.retry_after;
85                                }
86                                self.log_backoff_enabled();
87                                self.throttler.throttle_for(
88                                    self.backoff_strategy
89                                        .wait_time(base_wait_time, self.num_retries)
90                                        .into(),
91                                );
92                                continue;
93                            }
94                        }
95                        Some(
96                            error::GRError::HttpTransportError(_)
97                            | error::GRError::RemoteServerError(_),
98                        ) => {
99                            self.num_retries += 1;
100                            if self.num_retries <= self.max_retries {
101                                self.log_backoff_enabled();
102                                self.throttler.throttle_for(
103                                    self.backoff_strategy
104                                        .wait_time(self.default_delay_wait, self.num_retries)
105                                        .into(),
106                                );
107                                continue;
108                            }
109                        }
110                        _ => {
111                            return Err(err);
112                        }
113                    }
114                    return Err(GRError::ExponentialBackoffMaxRetriesReached(format!(
115                        "Retried the request {} times",
116                        self.max_retries
117                    )))
118                    .err_context(err);
119                }
120            };
121        }
122    }
123}
124
125pub trait BackOffStrategy {
126    fn wait_time(&self, base_wait: Seconds, num_retries: u32) -> Seconds;
127}
128
129pub struct Exponential;
130
131impl BackOffStrategy for Exponential {
132    fn wait_time(&self, base_wait: Seconds, num_retries: u32) -> Seconds {
133        log_info!("Exponential backoff strategy enabled");
134        let wait_time = base_wait + 2u64.pow(num_retries).into();
135        log_info!("Waiting for {} seconds", wait_time);
136        wait_time
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use std::rc::Rc;
143
144    use crate::{
145        http::{self, Headers, Resource},
146        io::FlowControlHeaders,
147        test::utils::{MockRunner, MockThrottler},
148        time::Milliseconds,
149    };
150
151    use super::*;
152
153    fn ratelimited_with_headers(remaining: u32, reset: u64, retry_after: u64) -> HttpResponse {
154        let mut headers = Headers::new();
155        headers.set("x-ratelimit-remaining".to_string(), remaining.to_string());
156        headers.set("x-ratelimit-reset".to_string(), reset.to_string());
157        headers.set("retry-after".to_string(), retry_after.to_string());
158        let rate_limit_header =
159            RateLimitHeader::new(remaining, Seconds::new(reset), Seconds::new(retry_after));
160        let flow_control_headers =
161            FlowControlHeaders::new(Rc::new(None), Rc::new(Some(rate_limit_header)));
162        HttpResponse::builder()
163            .status(429)
164            .headers(headers)
165            .flow_control_headers(flow_control_headers)
166            .build()
167            .unwrap()
168    }
169
170    fn ratelimited_with_no_headers() -> HttpResponse {
171        HttpResponse::builder().status(429).build().unwrap()
172    }
173
174    fn response_ok() -> HttpResponse {
175        HttpResponse::builder().status(200).build().unwrap()
176    }
177
178    fn response_server_error() -> HttpResponse {
179        HttpResponse::builder().status(500).build().unwrap()
180    }
181
182    fn response_transport_error() -> HttpResponse {
183        // Could be a timeout, connection error, etc. Status code
184        // For testing purposes, the status code of -1 simulates a transport
185        // error from the mock http runner.
186        // TODO: Should move to enums instead at some point.
187        HttpResponse::builder().status(-1).build().unwrap()
188    }
189
190    fn now_mock() -> Seconds {
191        Seconds::new(1712814151)
192    }
193
194    #[test]
195    fn test_exponential_backoff_retries_and_succeeds() {
196        let reset = now_mock() + Seconds::new(60);
197        let responses = vec![
198            response_ok(),
199            ratelimited_with_no_headers(),
200            ratelimited_with_headers(10, *reset, 60),
201        ];
202        let client = Arc::new(MockRunner::new(responses));
203        let mut request: Request<()> = Request::builder()
204            .resource(Resource::new("http://localhost", None))
205            .method(http::Method::GET)
206            .build()
207            .unwrap();
208        let strategy = Box::new(Exponential);
209        let throttler = Rc::new(MockThrottler::new(None));
210        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
211        let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
212        backoff.retry_on_error(&mut request).unwrap();
213        assert_eq!(2, *throttler.throttled());
214    }
215
216    #[test]
217    fn test_exponential_backoff_retries_and_fails_after_max_retries_reached() {
218        let reset = now_mock() + Seconds::new(60);
219        let responses = vec![
220            response_ok(),
221            ratelimited_with_no_headers(),
222            ratelimited_with_headers(10, *reset, 60),
223        ];
224        let client = Arc::new(MockRunner::new(responses));
225        let mut request: Request<()> = Request::builder()
226            .resource(Resource::new("http://localhost", None))
227            .method(http::Method::GET)
228            .build()
229            .unwrap();
230        let strategy = Box::new(Exponential);
231        let throttler = Rc::new(MockThrottler::new(None));
232        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
233        let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
234        match backoff.retry_on_error(&mut request) {
235            Ok(_) => panic!("Expected max retries reached error"),
236            Err(err) => match err.downcast_ref::<error::GRError>() {
237                Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
238                    assert_eq!(1, *throttler.throttled());
239                    assert_eq!(
240                        Milliseconds::new(62000),
241                        *throttler.milliseconds_throttled()
242                    );
243                }
244                _ => panic!("Expected max retries reached error"),
245            },
246        }
247    }
248
249    #[test]
250    fn test_if_max_retries_is_zero_tries_once() {
251        let responses = vec![response_ok()];
252        let client = Arc::new(MockRunner::new(responses));
253        let mut request: Request<()> = Request::builder()
254            .resource(Resource::new("http://localhost", None))
255            .method(http::Method::GET)
256            .build()
257            .unwrap();
258        let strategy = Box::new(Exponential);
259        let throttler = Rc::new(MockThrottler::new(None));
260        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
261        let mut backoff = Backoff::new(&client, 0, 60, now_mock, strategy, bthrottler);
262        backoff.retry_on_error(&mut request).unwrap();
263        assert_eq!(0, *throttler.throttled());
264    }
265
266    #[test]
267    fn test_if_max_retries_is_zero_tries_once_and_fails() {
268        let responses = vec![ratelimited_with_no_headers()];
269        let client = Arc::new(MockRunner::new(responses));
270        let mut request: Request<()> = Request::builder()
271            .resource(Resource::new("http://localhost", None))
272            .method(http::Method::GET)
273            .build()
274            .unwrap();
275        let strategy = Box::new(Exponential);
276        let throttler = Rc::new(MockThrottler::new(None));
277        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
278        let mut backoff = Backoff::new(&client, 0, 60, now_mock, strategy, bthrottler);
279        match backoff.retry_on_error(&mut request) {
280            Ok(_) => panic!("Expected rate limit exceeded error"),
281            Err(err) => match err.downcast_ref::<error::GRError>() {
282                Some(error::GRError::RateLimitExceeded(_)) => {}
283                _ => panic!("Expected rate limit exceeded error"),
284            },
285        }
286        assert_eq!(0, *client.throttled());
287    }
288
289    #[test]
290    fn test_time_to_reset_is_zero() {
291        let responses = vec![
292            response_ok(),
293            ratelimited_with_no_headers(),
294            ratelimited_with_headers(10, 0, 0),
295        ];
296        let client = Arc::new(MockRunner::new(responses));
297        let mut request: Request<()> = Request::builder()
298            .resource(Resource::new("http://localhost", None))
299            .method(http::Method::GET)
300            .build()
301            .unwrap();
302        let strategy = Box::new(Exponential);
303        let throttler = Rc::new(MockThrottler::new(None));
304        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
305        let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
306        backoff.retry_on_error(&mut request).unwrap();
307        assert_eq!(2, *throttler.throttled());
308        // 60 secs base wait, 1st retry 2^1 = 2 => 62000 milliseconds
309        // 60 secs base wait, 2nd retry 2^2 = 4 => 64000 milliseconds
310        // Total wait 126000
311        assert_eq!(
312            Milliseconds::new(126000),
313            *throttler.milliseconds_throttled()
314        );
315    }
316
317    #[test]
318    fn test_retry_after_used_if_provided() {
319        let reset = now_mock() + Seconds::new(120);
320        let responses = vec![
321            response_ok(),
322            ratelimited_with_headers(10, 0, 65),
323            ratelimited_with_headers(10, *reset, 61),
324        ];
325        let client = Arc::new(MockRunner::new(responses));
326        let mut request: Request<()> = Request::builder()
327            .resource(Resource::new("http://localhost", None))
328            .method(http::Method::GET)
329            .build()
330            .unwrap();
331        let strategy = Box::new(Exponential);
332        let throttler = Rc::new(MockThrottler::new(None));
333        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
334        let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
335        backoff.retry_on_error(&mut request).unwrap();
336        assert_eq!(2, *throttler.throttled());
337        // 61 secs base wait, 1st retry 2^1 = 2 => 63000 milliseconds
338        // 65 secs base wait, 2nd retry 2^2 = 4 => 69000 milliseconds
339        // Total wait 132000
340        assert_eq!(
341            Milliseconds::new(132000),
342            *throttler.milliseconds_throttled()
343        );
344    }
345
346    #[test]
347    fn test_reset_time_future_and_no_retry_after() {
348        let reset_first = now_mock() + Seconds::new(120);
349        let reset_second = now_mock() + Seconds::new(61);
350        let responses = vec![
351            response_ok(),
352            ratelimited_with_headers(10, *reset_second, 0),
353            ratelimited_with_headers(10, *reset_first, 0),
354        ];
355        let client = Arc::new(MockRunner::new(responses));
356        let mut request: Request<()> = Request::builder()
357            .resource(Resource::new("http://localhost", None))
358            .method(http::Method::GET)
359            .build()
360            .unwrap();
361        let strategy = Box::new(Exponential);
362        let throttler = Rc::new(MockThrottler::new(None));
363        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
364        let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
365        backoff.retry_on_error(&mut request).unwrap();
366        assert_eq!(2, *throttler.throttled());
367        // 120 secs base wait, 1st retry 2^1 = 2 => 122000 milliseconds
368        // 61 secs base wait, 2nd retry 2^2 = 4 => 65000 milliseconds
369        // Total wait 187000
370        assert_eq!(
371            Milliseconds::new(187000),
372            *throttler.milliseconds_throttled()
373        );
374    }
375
376    #[test]
377    fn test_retries_on_server_500_error() {
378        let responses = vec![response_ok(), response_server_error()];
379        let client = Arc::new(MockRunner::new(responses));
380        let mut request: Request<()> = Request::builder()
381            .resource(Resource::new("http://localhost", None))
382            .method(http::Method::GET)
383            .build()
384            .unwrap();
385        let strategy = Box::new(Exponential);
386        let throttler = Rc::new(MockThrottler::new(None));
387        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
388        let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
389        backoff.retry_on_error(&mut request).unwrap();
390        assert_eq!(1, *throttler.throttled());
391        // Success on 2nd retry. Wait time of 1min + 2^1 = 2 => 62000 milliseconds
392        assert_eq!(
393            Milliseconds::new(62000),
394            *throttler.milliseconds_throttled()
395        );
396    }
397
398    #[test]
399    fn test_retries_on_server_500_error_and_fails_after_max_retries_reached() {
400        let responses = vec![response_server_error(), response_server_error()];
401        let client = Arc::new(MockRunner::new(responses));
402        let mut request: Request<()> = Request::builder()
403            .resource(Resource::new("http://localhost", None))
404            .method(http::Method::GET)
405            .build()
406            .unwrap();
407        let strategy = Box::new(Exponential);
408        let throttler = Rc::new(MockThrottler::new(None));
409        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
410        let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
411        match backoff.retry_on_error(&mut request) {
412            Ok(_) => panic!("Expected max retries reached error"),
413            Err(err) => match err.downcast_ref::<error::GRError>() {
414                Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
415                    assert_eq!(1, *throttler.throttled());
416                    assert_eq!(
417                        Milliseconds::new(62000),
418                        *throttler.milliseconds_throttled()
419                    );
420                }
421                _ => panic!("Expected max retries reached error"),
422            },
423        }
424    }
425
426    #[test]
427    fn test_retries_on_transport_error() {
428        let responses = vec![response_ok(), response_transport_error()];
429        let client = Arc::new(MockRunner::new(responses));
430        let mut request: Request<()> = Request::builder()
431            .resource(Resource::new("http://localhost", None))
432            .method(http::Method::GET)
433            .build()
434            .unwrap();
435        let strategy = Box::new(Exponential);
436        let throttler = Rc::new(MockThrottler::new(None));
437        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
438        let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
439        backoff.retry_on_error(&mut request).unwrap();
440        assert_eq!(1, *throttler.throttled());
441        // Success on 2nd retry. Wait time of 1min + 2^1 = 2 => 62000 milliseconds
442        assert_eq!(
443            Milliseconds::new(62000),
444            *throttler.milliseconds_throttled()
445        );
446    }
447
448    #[test]
449    fn test_retries_on_transport_error_and_fails_after_max_retries_reached() {
450        let responses = vec![response_transport_error(), response_transport_error()];
451        let client = Arc::new(MockRunner::new(responses));
452        let mut request: Request<()> = Request::builder()
453            .resource(Resource::new("http://localhost", None))
454            .method(http::Method::GET)
455            .build()
456            .unwrap();
457        let strategy = Box::new(Exponential);
458        let throttler = Rc::new(MockThrottler::new(None));
459        let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
460        let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
461        match backoff.retry_on_error(&mut request) {
462            Ok(_) => panic!("Expected max retries reached error"),
463            Err(err) => match err.downcast_ref::<error::GRError>() {
464                Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
465                    assert_eq!(1, *throttler.throttled());
466                    assert_eq!(
467                        Milliseconds::new(62000),
468                        *throttler.milliseconds_throttled()
469                    );
470                }
471                _ => panic!("Expected max retries reached error"),
472            },
473        }
474    }
475}