1use std::sync::Arc;
2
3use serde::Serialize;
4
5use crate::error::{AddContext, GRError};
6use crate::http::throttle::ThrottleStrategy;
7use crate::io::{HttpRunner, RateLimitHeader};
8use crate::log_error;
9use crate::{error, log_info, Result};
10use crate::{http::Request, io::HttpResponse, time::Seconds};
11
12pub struct Backoff<'a, R> {
15 runner: &'a Arc<R>,
16 max_retries: u32,
17 num_retries: u32,
18 rate_limit_header: RateLimitHeader,
19 default_delay_wait: Seconds,
20 now: fn() -> Seconds,
21 backoff_strategy: Box<dyn BackOffStrategy>,
22 throttler: Box<dyn ThrottleStrategy>,
23}
24
25impl<'a, R> Backoff<'a, R> {
26 pub fn new(
27 runner: &'a Arc<R>,
28 max_retries: u32,
29 default_delay_wait: u64,
30 now: fn() -> Seconds,
31 strategy: Box<dyn BackOffStrategy>,
32 throttler_strategy: Box<dyn ThrottleStrategy>,
33 ) -> Self {
34 Backoff {
35 runner,
36 max_retries,
37 num_retries: 0,
38 rate_limit_header: RateLimitHeader::default(),
39 default_delay_wait: Seconds::new(default_delay_wait),
40 now,
41 backoff_strategy: strategy,
42 throttler: throttler_strategy,
43 }
44 }
45
46 fn log_backoff_enabled(&self) {
47 log_info!("Backoff enabled with {} max retries", self.max_retries);
48 }
49}
50
51impl<R: HttpRunner<Response = HttpResponse>> Backoff<'_, R> {
52 pub fn retry_on_error<T: Serialize>(
53 &mut self,
54 request: &mut Request<T>,
55 ) -> Result<HttpResponse> {
56 loop {
57 if self.num_retries > 0 {
58 log_info!(
59 "Retrying request {} out of {}",
60 self.num_retries,
61 self.max_retries
62 );
63 }
64 match self.runner.run(request) {
65 Ok(response) => return Ok(response),
66 Err(err) => {
67 if self.max_retries == 0 {
68 return Err(err);
69 }
70 log_error!("Error: {}", err);
71 match err.downcast_ref::<error::GRError>() {
73 Some(error::GRError::RateLimitExceeded(headers)) => {
74 self.rate_limit_header = *headers;
75 self.num_retries += 1;
76 if self.num_retries <= self.max_retries {
77 let now = (self.now)();
78 let mut base_wait_time = if self.rate_limit_header.reset > now {
79 self.rate_limit_header.reset - now
80 } else {
81 self.default_delay_wait
82 };
83 if self.rate_limit_header.retry_after > Seconds::new(0) {
84 base_wait_time = self.rate_limit_header.retry_after;
85 }
86 self.log_backoff_enabled();
87 self.throttler.throttle_for(
88 self.backoff_strategy
89 .wait_time(base_wait_time, self.num_retries)
90 .into(),
91 );
92 continue;
93 }
94 }
95 Some(
96 error::GRError::HttpTransportError(_)
97 | error::GRError::RemoteServerError(_),
98 ) => {
99 self.num_retries += 1;
100 if self.num_retries <= self.max_retries {
101 self.log_backoff_enabled();
102 self.throttler.throttle_for(
103 self.backoff_strategy
104 .wait_time(self.default_delay_wait, self.num_retries)
105 .into(),
106 );
107 continue;
108 }
109 }
110 _ => {
111 return Err(err);
112 }
113 }
114 return Err(GRError::ExponentialBackoffMaxRetriesReached(format!(
115 "Retried the request {} times",
116 self.max_retries
117 )))
118 .err_context(err);
119 }
120 };
121 }
122 }
123}
124
125pub trait BackOffStrategy {
126 fn wait_time(&self, base_wait: Seconds, num_retries: u32) -> Seconds;
127}
128
129pub struct Exponential;
130
131impl BackOffStrategy for Exponential {
132 fn wait_time(&self, base_wait: Seconds, num_retries: u32) -> Seconds {
133 log_info!("Exponential backoff strategy enabled");
134 let wait_time = base_wait + 2u64.pow(num_retries).into();
135 log_info!("Waiting for {} seconds", wait_time);
136 wait_time
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use std::rc::Rc;
143
144 use crate::{
145 http::{self, Headers, Resource},
146 io::FlowControlHeaders,
147 test::utils::{MockRunner, MockThrottler},
148 time::Milliseconds,
149 };
150
151 use super::*;
152
153 fn ratelimited_with_headers(remaining: u32, reset: u64, retry_after: u64) -> HttpResponse {
154 let mut headers = Headers::new();
155 headers.set("x-ratelimit-remaining".to_string(), remaining.to_string());
156 headers.set("x-ratelimit-reset".to_string(), reset.to_string());
157 headers.set("retry-after".to_string(), retry_after.to_string());
158 let rate_limit_header =
159 RateLimitHeader::new(remaining, Seconds::new(reset), Seconds::new(retry_after));
160 let flow_control_headers =
161 FlowControlHeaders::new(Rc::new(None), Rc::new(Some(rate_limit_header)));
162 HttpResponse::builder()
163 .status(429)
164 .headers(headers)
165 .flow_control_headers(flow_control_headers)
166 .build()
167 .unwrap()
168 }
169
170 fn ratelimited_with_no_headers() -> HttpResponse {
171 HttpResponse::builder().status(429).build().unwrap()
172 }
173
174 fn response_ok() -> HttpResponse {
175 HttpResponse::builder().status(200).build().unwrap()
176 }
177
178 fn response_server_error() -> HttpResponse {
179 HttpResponse::builder().status(500).build().unwrap()
180 }
181
182 fn response_transport_error() -> HttpResponse {
183 HttpResponse::builder().status(-1).build().unwrap()
188 }
189
190 fn now_mock() -> Seconds {
191 Seconds::new(1712814151)
192 }
193
194 #[test]
195 fn test_exponential_backoff_retries_and_succeeds() {
196 let reset = now_mock() + Seconds::new(60);
197 let responses = vec![
198 response_ok(),
199 ratelimited_with_no_headers(),
200 ratelimited_with_headers(10, *reset, 60),
201 ];
202 let client = Arc::new(MockRunner::new(responses));
203 let mut request: Request<()> = Request::builder()
204 .resource(Resource::new("http://localhost", None))
205 .method(http::Method::GET)
206 .build()
207 .unwrap();
208 let strategy = Box::new(Exponential);
209 let throttler = Rc::new(MockThrottler::new(None));
210 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
211 let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
212 backoff.retry_on_error(&mut request).unwrap();
213 assert_eq!(2, *throttler.throttled());
214 }
215
216 #[test]
217 fn test_exponential_backoff_retries_and_fails_after_max_retries_reached() {
218 let reset = now_mock() + Seconds::new(60);
219 let responses = vec![
220 response_ok(),
221 ratelimited_with_no_headers(),
222 ratelimited_with_headers(10, *reset, 60),
223 ];
224 let client = Arc::new(MockRunner::new(responses));
225 let mut request: Request<()> = Request::builder()
226 .resource(Resource::new("http://localhost", None))
227 .method(http::Method::GET)
228 .build()
229 .unwrap();
230 let strategy = Box::new(Exponential);
231 let throttler = Rc::new(MockThrottler::new(None));
232 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
233 let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
234 match backoff.retry_on_error(&mut request) {
235 Ok(_) => panic!("Expected max retries reached error"),
236 Err(err) => match err.downcast_ref::<error::GRError>() {
237 Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
238 assert_eq!(1, *throttler.throttled());
239 assert_eq!(
240 Milliseconds::new(62000),
241 *throttler.milliseconds_throttled()
242 );
243 }
244 _ => panic!("Expected max retries reached error"),
245 },
246 }
247 }
248
249 #[test]
250 fn test_if_max_retries_is_zero_tries_once() {
251 let responses = vec![response_ok()];
252 let client = Arc::new(MockRunner::new(responses));
253 let mut request: Request<()> = Request::builder()
254 .resource(Resource::new("http://localhost", None))
255 .method(http::Method::GET)
256 .build()
257 .unwrap();
258 let strategy = Box::new(Exponential);
259 let throttler = Rc::new(MockThrottler::new(None));
260 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
261 let mut backoff = Backoff::new(&client, 0, 60, now_mock, strategy, bthrottler);
262 backoff.retry_on_error(&mut request).unwrap();
263 assert_eq!(0, *throttler.throttled());
264 }
265
266 #[test]
267 fn test_if_max_retries_is_zero_tries_once_and_fails() {
268 let responses = vec![ratelimited_with_no_headers()];
269 let client = Arc::new(MockRunner::new(responses));
270 let mut request: Request<()> = Request::builder()
271 .resource(Resource::new("http://localhost", None))
272 .method(http::Method::GET)
273 .build()
274 .unwrap();
275 let strategy = Box::new(Exponential);
276 let throttler = Rc::new(MockThrottler::new(None));
277 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
278 let mut backoff = Backoff::new(&client, 0, 60, now_mock, strategy, bthrottler);
279 match backoff.retry_on_error(&mut request) {
280 Ok(_) => panic!("Expected rate limit exceeded error"),
281 Err(err) => match err.downcast_ref::<error::GRError>() {
282 Some(error::GRError::RateLimitExceeded(_)) => {}
283 _ => panic!("Expected rate limit exceeded error"),
284 },
285 }
286 assert_eq!(0, *client.throttled());
287 }
288
289 #[test]
290 fn test_time_to_reset_is_zero() {
291 let responses = vec![
292 response_ok(),
293 ratelimited_with_no_headers(),
294 ratelimited_with_headers(10, 0, 0),
295 ];
296 let client = Arc::new(MockRunner::new(responses));
297 let mut request: Request<()> = Request::builder()
298 .resource(Resource::new("http://localhost", None))
299 .method(http::Method::GET)
300 .build()
301 .unwrap();
302 let strategy = Box::new(Exponential);
303 let throttler = Rc::new(MockThrottler::new(None));
304 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
305 let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
306 backoff.retry_on_error(&mut request).unwrap();
307 assert_eq!(2, *throttler.throttled());
308 assert_eq!(
312 Milliseconds::new(126000),
313 *throttler.milliseconds_throttled()
314 );
315 }
316
317 #[test]
318 fn test_retry_after_used_if_provided() {
319 let reset = now_mock() + Seconds::new(120);
320 let responses = vec![
321 response_ok(),
322 ratelimited_with_headers(10, 0, 65),
323 ratelimited_with_headers(10, *reset, 61),
324 ];
325 let client = Arc::new(MockRunner::new(responses));
326 let mut request: Request<()> = Request::builder()
327 .resource(Resource::new("http://localhost", None))
328 .method(http::Method::GET)
329 .build()
330 .unwrap();
331 let strategy = Box::new(Exponential);
332 let throttler = Rc::new(MockThrottler::new(None));
333 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
334 let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
335 backoff.retry_on_error(&mut request).unwrap();
336 assert_eq!(2, *throttler.throttled());
337 assert_eq!(
341 Milliseconds::new(132000),
342 *throttler.milliseconds_throttled()
343 );
344 }
345
346 #[test]
347 fn test_reset_time_future_and_no_retry_after() {
348 let reset_first = now_mock() + Seconds::new(120);
349 let reset_second = now_mock() + Seconds::new(61);
350 let responses = vec![
351 response_ok(),
352 ratelimited_with_headers(10, *reset_second, 0),
353 ratelimited_with_headers(10, *reset_first, 0),
354 ];
355 let client = Arc::new(MockRunner::new(responses));
356 let mut request: Request<()> = Request::builder()
357 .resource(Resource::new("http://localhost", None))
358 .method(http::Method::GET)
359 .build()
360 .unwrap();
361 let strategy = Box::new(Exponential);
362 let throttler = Rc::new(MockThrottler::new(None));
363 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
364 let mut backoff = Backoff::new(&client, 3, 60, now_mock, strategy, bthrottler);
365 backoff.retry_on_error(&mut request).unwrap();
366 assert_eq!(2, *throttler.throttled());
367 assert_eq!(
371 Milliseconds::new(187000),
372 *throttler.milliseconds_throttled()
373 );
374 }
375
376 #[test]
377 fn test_retries_on_server_500_error() {
378 let responses = vec![response_ok(), response_server_error()];
379 let client = Arc::new(MockRunner::new(responses));
380 let mut request: Request<()> = Request::builder()
381 .resource(Resource::new("http://localhost", None))
382 .method(http::Method::GET)
383 .build()
384 .unwrap();
385 let strategy = Box::new(Exponential);
386 let throttler = Rc::new(MockThrottler::new(None));
387 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
388 let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
389 backoff.retry_on_error(&mut request).unwrap();
390 assert_eq!(1, *throttler.throttled());
391 assert_eq!(
393 Milliseconds::new(62000),
394 *throttler.milliseconds_throttled()
395 );
396 }
397
398 #[test]
399 fn test_retries_on_server_500_error_and_fails_after_max_retries_reached() {
400 let responses = vec![response_server_error(), response_server_error()];
401 let client = Arc::new(MockRunner::new(responses));
402 let mut request: Request<()> = Request::builder()
403 .resource(Resource::new("http://localhost", None))
404 .method(http::Method::GET)
405 .build()
406 .unwrap();
407 let strategy = Box::new(Exponential);
408 let throttler = Rc::new(MockThrottler::new(None));
409 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
410 let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
411 match backoff.retry_on_error(&mut request) {
412 Ok(_) => panic!("Expected max retries reached error"),
413 Err(err) => match err.downcast_ref::<error::GRError>() {
414 Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
415 assert_eq!(1, *throttler.throttled());
416 assert_eq!(
417 Milliseconds::new(62000),
418 *throttler.milliseconds_throttled()
419 );
420 }
421 _ => panic!("Expected max retries reached error"),
422 },
423 }
424 }
425
426 #[test]
427 fn test_retries_on_transport_error() {
428 let responses = vec![response_ok(), response_transport_error()];
429 let client = Arc::new(MockRunner::new(responses));
430 let mut request: Request<()> = Request::builder()
431 .resource(Resource::new("http://localhost", None))
432 .method(http::Method::GET)
433 .build()
434 .unwrap();
435 let strategy = Box::new(Exponential);
436 let throttler = Rc::new(MockThrottler::new(None));
437 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
438 let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
439 backoff.retry_on_error(&mut request).unwrap();
440 assert_eq!(1, *throttler.throttled());
441 assert_eq!(
443 Milliseconds::new(62000),
444 *throttler.milliseconds_throttled()
445 );
446 }
447
448 #[test]
449 fn test_retries_on_transport_error_and_fails_after_max_retries_reached() {
450 let responses = vec![response_transport_error(), response_transport_error()];
451 let client = Arc::new(MockRunner::new(responses));
452 let mut request: Request<()> = Request::builder()
453 .resource(Resource::new("http://localhost", None))
454 .method(http::Method::GET)
455 .build()
456 .unwrap();
457 let strategy = Box::new(Exponential);
458 let throttler = Rc::new(MockThrottler::new(None));
459 let bthrottler: Box<dyn ThrottleStrategy> = Box::new(Rc::clone(&throttler));
460 let mut backoff = Backoff::new(&client, 1, 60, now_mock, strategy, bthrottler);
461 match backoff.retry_on_error(&mut request) {
462 Ok(_) => panic!("Expected max retries reached error"),
463 Err(err) => match err.downcast_ref::<error::GRError>() {
464 Some(error::GRError::ExponentialBackoffMaxRetriesReached(_)) => {
465 assert_eq!(1, *throttler.throttled());
466 assert_eq!(
467 Milliseconds::new(62000),
468 *throttler.milliseconds_throttled()
469 );
470 }
471 _ => panic!("Expected max retries reached error"),
472 },
473 }
474 }
475}