1use std::error::Error as StdError;
14use std::iter;
15use std::thread;
16use std::time::Duration;
17
18use bytes::Bytes;
19use http::Response;
20use url::Url;
21
22use derive_builder::Builder;
23use thiserror::Error;
24
25use crate::api;
26
27#[derive(Debug, Builder, Clone)]
29pub struct Backoff {
30 #[builder(default = "5")]
34 limit: usize,
35 #[builder(default = "Duration::from_secs(1)")]
39 init: Duration,
40 #[builder(default = "2.0")]
44 scale: f64,
45}
46
47fn should_backoff<E>(err: &api::ApiError<E>) -> bool
48where
49 E: StdError + Send + Sync + 'static,
50{
51 if let api::ApiError::GitlabService {
52 status, ..
53 } = err
54 {
55 status.is_server_error()
56 } else {
57 false
58 }
59}
60
61impl Backoff {
62 pub fn builder() -> BackoffBuilder {
64 BackoffBuilder::default()
65 }
66
67 fn retry<F, E>(&self, mut tryf: F) -> Result<Response<Bytes>, api::ApiError<Error<E>>>
68 where
69 F: FnMut() -> Result<Response<Bytes>, api::ApiError<E>>,
70 E: StdError + Send + Sync + 'static,
71 {
72 iter::repeat(())
73 .take(self.limit)
74 .scan(self.init, |timeout, _| {
75 match tryf() {
76 Ok(rsp) => {
77 if rsp.status().is_server_error() {
78 thread::sleep(*timeout);
79 *timeout = timeout.mul_f64(self.scale);
80 Some(None)
81 } else {
82 Some(Some(Ok(rsp)))
83 }
84 },
85 Err(err) => {
86 if should_backoff(&err) {
87 thread::sleep(*timeout);
88 *timeout = timeout.mul_f64(self.scale);
89 Some(None)
90 } else {
91 Some(Some(Err(err.map_client(Error::inner))))
92 }
93 },
94 }
95 })
96 .flatten()
97 .next()
98 .unwrap_or_else(|| Err(api::ApiError::client(Error::backoff())))
99 }
100}
101
102impl Default for Backoff {
103 fn default() -> Self {
104 Self::builder().build().unwrap()
105 }
106}
107
108#[derive(Debug, Error)]
110#[non_exhaustive]
111pub enum Error<E>
112where
113 E: StdError + Send + Sync + 'static,
114{
115 #[error("exponential backoff expired")]
117 Backoff {},
118 #[error("{}", source)]
120 Inner {
121 #[from]
123 source: E,
124 },
125}
126
127impl<E> Error<E>
128where
129 E: StdError + Send + Sync + 'static,
130{
131 fn backoff() -> Self {
132 Self::Backoff {}
133 }
134
135 fn inner(source: E) -> Self {
136 Self::Inner {
137 source,
138 }
139 }
140}
141
142pub struct Client<C> {
151 client: C,
152 backoff: Backoff,
153}
154
155impl<C> Client<C> {
156 pub fn new(client: C, backoff: Backoff) -> Self {
158 Self {
159 client,
160 backoff,
161 }
162 }
163}
164
165impl<C> api::RestClient for Client<C>
166where
167 C: api::RestClient,
168{
169 type Error = Error<C::Error>;
170
171 fn rest_endpoint(&self, endpoint: &str) -> Result<Url, api::ApiError<Self::Error>> {
172 self.client
173 .rest_endpoint(endpoint)
174 .map_err(|e| e.map_client(Error::inner))
175 }
176
177 fn instance_endpoint(&self, endpoint: &str) -> Result<Url, api::ApiError<Self::Error>> {
178 self.client
179 .instance_endpoint(endpoint)
180 .map_err(|e| e.map_client(Error::inner))
181 }
182}
183
184impl<C> api::Client for Client<C>
185where
186 C: api::Client,
187{
188 fn rest(
189 &self,
190 request: http::request::Builder,
191 body: Vec<u8>,
192 ) -> Result<Response<Bytes>, api::ApiError<Self::Error>> {
193 self.backoff.retry(|| {
194 let mut builder = http::request::Request::builder();
195 if let Some(method) = request.method_ref() {
196 builder = builder.method(method);
197 }
198 if let Some(uri) = request.uri_ref() {
199 builder = builder.uri(uri);
200 }
201 if let Some(version) = request.version_ref() {
202 builder = builder.version(*version);
203 }
204 if let Some(headers) = request.headers_ref() {
205 for (key, value) in headers.iter() {
206 builder = builder.header(key, value);
207 }
208 }
209 self.client.rest(builder, body.clone())
213 })
214 }
215}
216
217#[cfg(test)]
218mod test {
219 use http::{Response, StatusCode};
220 use serde::Deserialize;
221 use serde_json::json;
222 use thiserror::Error;
223
224 use crate::api::endpoint_prelude::*;
225 use crate::api::{self, retry, ApiError, Query};
226 use crate::test::client::{ExpectedUrl, SingleTestClient};
227
228 #[derive(Debug, Error)]
229 #[error("bogus")]
230 struct BogusError {}
231
232 #[test]
233 fn backoff_first_success() {
234 let backoff = retry::Backoff::default();
235 let mut call_count = 0;
236 let body: &'static [u8] = b"";
237 backoff
238 .retry::<_, BogusError>(|| {
239 call_count += 1;
240 Ok(Response::builder()
241 .status(StatusCode::OK)
242 .body(body.into())
243 .unwrap())
244 })
245 .unwrap();
246 assert_eq!(call_count, 1);
247 }
248
249 #[test]
250 fn backoff_second_success() {
251 let backoff = retry::Backoff::default();
252 let mut call_count = 0;
253 let mut did_err = false;
254 let body: &'static [u8] = b"";
255 backoff
256 .retry::<_, BogusError>(|| {
257 call_count += 1;
258 if did_err {
259 Ok(Response::builder()
260 .status(StatusCode::OK)
261 .body(body.into())
262 .unwrap())
263 } else {
264 did_err = true;
265 Ok(Response::builder()
266 .status(StatusCode::SERVICE_UNAVAILABLE)
267 .body(body.into())
268 .unwrap())
269 }
270 })
271 .unwrap();
272 assert_eq!(call_count, 2);
273 }
274
275 #[test]
276 fn backoff_second_success_gitlab_service_err() {
277 let backoff = retry::Backoff::default();
278 let mut call_count = 0;
279 let mut did_err = false;
280 let body: &'static [u8] = b"";
281 backoff
282 .retry::<_, BogusError>(|| {
283 call_count += 1;
284 if did_err {
285 Ok(Response::builder()
286 .status(StatusCode::OK)
287 .body(body.into())
288 .unwrap())
289 } else {
290 did_err = true;
291 Err(api::ApiError::GitlabService {
292 status: StatusCode::INTERNAL_SERVER_ERROR,
293 data: Vec::default(),
294 })
295 }
296 })
297 .unwrap();
298 assert_eq!(call_count, 2);
299 }
300
301 #[test]
302 fn backoff_no_success() {
303 let backoff = retry::Backoff::builder().limit(3).build().unwrap();
304 let mut call_count = 0;
305 let body: &'static [u8] = b"";
306 let err = backoff
307 .retry::<_, BogusError>(|| {
308 call_count += 1;
309 Ok(Response::builder()
310 .status(StatusCode::SERVICE_UNAVAILABLE)
311 .body(body.into())
312 .unwrap())
313 })
314 .unwrap_err();
315 assert_eq!(call_count, backoff.limit);
316 if let api::ApiError::Client {
317 source: retry::Error::Backoff {},
318 } = err
319 {
320 } else {
321 panic!("unexpected error: {}", err);
322 }
323 }
324
325 #[test]
326 fn backoff_no_success_gitlab_service_err() {
327 let backoff = retry::Backoff::builder().limit(3).build().unwrap();
328 let mut call_count = 0;
329 let err = backoff
330 .retry::<_, BogusError>(|| {
331 call_count += 1;
332 Err(api::ApiError::GitlabService {
333 status: StatusCode::INTERNAL_SERVER_ERROR,
334 data: Vec::default(),
335 })
336 })
337 .unwrap_err();
338 assert_eq!(call_count, backoff.limit);
339 if let api::ApiError::Client {
340 source: retry::Error::Backoff {},
341 } = err
342 {
343 } else {
344 panic!("unexpected error: {}", err);
345 }
346 }
347
348 struct Dummy;
349
350 impl Endpoint for Dummy {
351 fn method(&self) -> Method {
352 Method::GET
353 }
354
355 fn endpoint(&self) -> Cow<'static, str> {
356 "dummy".into()
357 }
358 }
359
360 #[derive(Debug, Deserialize)]
361 struct DummyResult {
362 value: u8,
363 }
364
365 #[test]
366 fn retry_client_ok() {
367 let endpoint = ExpectedUrl::builder().endpoint("dummy").build().unwrap();
368 let client = SingleTestClient::new_json(
369 endpoint,
370 &json!({
371 "value": 0,
372 }),
373 );
374 let backoff = retry::Backoff::default();
375 let client = retry::Client::new(client, backoff);
376
377 let res: DummyResult = Dummy.query(&client).unwrap();
378 assert_eq!(res.value, 0);
379 }
380
381 #[test]
382 fn retry_client_backoff_err() {
383 let endpoint = ExpectedUrl::builder()
384 .endpoint("dummy")
385 .status(StatusCode::NOT_FOUND)
386 .build()
387 .unwrap();
388 let client = SingleTestClient::new_json(
389 endpoint,
390 &json!({
391 "message": "dummy error message",
392 }),
393 );
394 let backoff = retry::Backoff::default();
395 let client = retry::Client::new(client, backoff);
396
397 let res: Result<DummyResult, _> = Dummy.query(&client);
398 let err = res.unwrap_err();
399 if let ApiError::GitlabWithStatus {
400 status,
401 msg,
402 } = err
403 {
404 assert_eq!(status, StatusCode::NOT_FOUND);
405 assert_eq!(msg, "dummy error message");
406 } else {
407 panic!("unexpected error: {}", err);
408 }
409 }
410
411 #[test]
412 fn retry_client_other_err() {
413 let endpoint = ExpectedUrl::builder()
414 .endpoint("dummy")
415 .status(StatusCode::IM_A_TEAPOT)
416 .build()
417 .unwrap();
418 let return_obj = json!({
419 "blah": "dummy error message",
420 });
421 let client = SingleTestClient::new_json(endpoint, &return_obj);
422 let backoff = retry::Backoff::default();
423 let client = retry::Client::new(client, backoff);
424
425 let res: Result<DummyResult, _> = Dummy.query(&client);
426 let err = res.unwrap_err();
427 if let ApiError::GitlabUnrecognizedWithStatus {
428 status,
429 obj,
430 } = err
431 {
432 assert_eq!(status, StatusCode::IM_A_TEAPOT);
433 assert_eq!(obj, return_obj);
434 } else {
435 panic!("unexpected error: {}", err);
436 }
437 }
438
439 #[test]
440 fn retry_client_retry_timeout() {
441 let endpoint = ExpectedUrl::builder()
442 .endpoint("dummy")
443 .status(StatusCode::SERVICE_UNAVAILABLE)
444 .build()
445 .unwrap();
446 let client = SingleTestClient::new_raw(endpoint, "");
447 let backoff = retry::Backoff::builder().limit(3).build().unwrap();
448 let client = retry::Client::new(client, backoff);
449
450 let res: Result<DummyResult, _> = Dummy.query(&client);
451 let err = res.unwrap_err();
452 if let ApiError::Client {
453 source: retry::Error::Backoff {},
454 } = err
455 {
456 } else {
458 panic!("unexpected error: {}", err);
459 }
460 }
461}