gitlab/api/
retry.rs

1// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
2// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
3// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
4// option. This file may not be copied, modified, or distributed
5// except according to those terms.
6
7//! Retry client wrapper
8//!
9//! This module provides a `Client` implementation which can wrap other `ApiClient` instances in
10//! order to retry requests with an exponential backoff. Only service errors (those in the `5xx`
11//! range) are retried and all others are passed through as final statuses.
12
13use std::error::Error as StdError;
14use std::iter;
15use std::thread;
16use std::time::Duration;
17
18use bytes::Bytes;
19use http::Response;
20use url::Url;
21
22use derive_builder::Builder;
23use thiserror::Error;
24
25use crate::api;
26
27/// Parameters for retrying queries with an exponential backoff.
28#[derive(Debug, Builder, Clone)]
29pub struct Backoff {
30    /// The maximum number of times to backoff.
31    ///
32    /// Defaults to `5`.
33    #[builder(default = "5")]
34    limit: usize,
35    /// How long to wait after the first failure.
36    ///
37    /// Defaults to 1 second.
38    #[builder(default = "Duration::from_secs(1)")]
39    init: Duration,
40    /// The scale parameter for timeouts after each subsequent failure.
41    ///
42    /// Defaults to `2.0`.
43    #[builder(default = "2.0")]
44    scale: f64,
45}
46
47fn should_backoff<E>(err: &api::ApiError<E>) -> bool
48where
49    E: StdError + Send + Sync + 'static,
50{
51    if let api::ApiError::GitlabService {
52        status, ..
53    } = err
54    {
55        status.is_server_error()
56    } else {
57        false
58    }
59}
60
61impl Backoff {
62    /// Create a builder for retry backoff parameters.
63    pub fn builder() -> BackoffBuilder {
64        BackoffBuilder::default()
65    }
66
67    fn retry<F, E>(&self, mut tryf: F) -> Result<Response<Bytes>, api::ApiError<Error<E>>>
68    where
69        F: FnMut() -> Result<Response<Bytes>, api::ApiError<E>>,
70        E: StdError + Send + Sync + 'static,
71    {
72        iter::repeat(())
73            .take(self.limit)
74            .scan(self.init, |timeout, _| {
75                match tryf() {
76                    Ok(rsp) => {
77                        if rsp.status().is_server_error() {
78                            thread::sleep(*timeout);
79                            *timeout = timeout.mul_f64(self.scale);
80                            Some(None)
81                        } else {
82                            Some(Some(Ok(rsp)))
83                        }
84                    },
85                    Err(err) => {
86                        if should_backoff(&err) {
87                            thread::sleep(*timeout);
88                            *timeout = timeout.mul_f64(self.scale);
89                            Some(None)
90                        } else {
91                            Some(Some(Err(err.map_client(Error::inner))))
92                        }
93                    },
94                }
95            })
96            .flatten()
97            .next()
98            .unwrap_or_else(|| Err(api::ApiError::client(Error::backoff())))
99    }
100}
101
102impl Default for Backoff {
103    fn default() -> Self {
104        Self::builder().build().unwrap()
105    }
106}
107
108/// An error from a client even after retrying multiple times.
109#[derive(Debug, Error)]
110#[non_exhaustive]
111pub enum Error<E>
112where
113    E: StdError + Send + Sync + 'static,
114{
115    /// The request failed after multiple attempts.
116    #[error("exponential backoff expired")]
117    Backoff {},
118    /// An error occurred within the client.
119    #[error("{}", source)]
120    Inner {
121        /// The source of the error.
122        #[from]
123        source: E,
124    },
125}
126
127impl<E> Error<E>
128where
129    E: StdError + Send + Sync + 'static,
130{
131    fn backoff() -> Self {
132        Self::Backoff {}
133    }
134
135    fn inner(source: E) -> Self {
136        Self::Inner {
137            source,
138        }
139    }
140}
141
142/// A wrapper around a client to perform exponential backoff while retrying errors.
143///
144/// ## Notes
145///
146/// Currently, the wrapping is not 100% compatible, however the gaps should not be common. Of note
147/// is that the HTTP version is 1.1 since there is not a way to query the version from an existing
148/// builder. Also, all requested extensions are ignored since they cannot be cloned reliably. In
149/// the future, requests with extensions will be passed through, but without any backoff support.
150pub struct Client<C> {
151    client: C,
152    backoff: Backoff,
153}
154
155impl<C> Client<C> {
156    /// Create a client which retries in the face of service errors with an exponential backoff.
157    pub fn new(client: C, backoff: Backoff) -> Self {
158        Self {
159            client,
160            backoff,
161        }
162    }
163}
164
165impl<C> api::RestClient for Client<C>
166where
167    C: api::RestClient,
168{
169    type Error = Error<C::Error>;
170
171    fn rest_endpoint(&self, endpoint: &str) -> Result<Url, api::ApiError<Self::Error>> {
172        self.client
173            .rest_endpoint(endpoint)
174            .map_err(|e| e.map_client(Error::inner))
175    }
176
177    fn instance_endpoint(&self, endpoint: &str) -> Result<Url, api::ApiError<Self::Error>> {
178        self.client
179            .instance_endpoint(endpoint)
180            .map_err(|e| e.map_client(Error::inner))
181    }
182}
183
184impl<C> api::Client for Client<C>
185where
186    C: api::Client,
187{
188    fn rest(
189        &self,
190        request: http::request::Builder,
191        body: Vec<u8>,
192    ) -> Result<Response<Bytes>, api::ApiError<Self::Error>> {
193        self.backoff.retry(|| {
194            let mut builder = http::request::Request::builder();
195            if let Some(method) = request.method_ref() {
196                builder = builder.method(method);
197            }
198            if let Some(uri) = request.uri_ref() {
199                builder = builder.uri(uri);
200            }
201            if let Some(version) = request.version_ref() {
202                builder = builder.version(*version);
203            }
204            if let Some(headers) = request.headers_ref() {
205                for (key, value) in headers.iter() {
206                    builder = builder.header(key, value);
207                }
208            }
209            // Ignore extensions for now. Can be handled once this is released:
210            // https://github.com/hyperium/http/pull/497
211
212            self.client.rest(builder, body.clone())
213        })
214    }
215}
216
217#[cfg(test)]
218mod test {
219    use http::{Response, StatusCode};
220    use serde::Deserialize;
221    use serde_json::json;
222    use thiserror::Error;
223
224    use crate::api::endpoint_prelude::*;
225    use crate::api::{self, retry, ApiError, Query};
226    use crate::test::client::{ExpectedUrl, SingleTestClient};
227
228    #[derive(Debug, Error)]
229    #[error("bogus")]
230    struct BogusError {}
231
232    #[test]
233    fn backoff_first_success() {
234        let backoff = retry::Backoff::default();
235        let mut call_count = 0;
236        let body: &'static [u8] = b"";
237        backoff
238            .retry::<_, BogusError>(|| {
239                call_count += 1;
240                Ok(Response::builder()
241                    .status(StatusCode::OK)
242                    .body(body.into())
243                    .unwrap())
244            })
245            .unwrap();
246        assert_eq!(call_count, 1);
247    }
248
249    #[test]
250    fn backoff_second_success() {
251        let backoff = retry::Backoff::default();
252        let mut call_count = 0;
253        let mut did_err = false;
254        let body: &'static [u8] = b"";
255        backoff
256            .retry::<_, BogusError>(|| {
257                call_count += 1;
258                if did_err {
259                    Ok(Response::builder()
260                        .status(StatusCode::OK)
261                        .body(body.into())
262                        .unwrap())
263                } else {
264                    did_err = true;
265                    Ok(Response::builder()
266                        .status(StatusCode::SERVICE_UNAVAILABLE)
267                        .body(body.into())
268                        .unwrap())
269                }
270            })
271            .unwrap();
272        assert_eq!(call_count, 2);
273    }
274
275    #[test]
276    fn backoff_second_success_gitlab_service_err() {
277        let backoff = retry::Backoff::default();
278        let mut call_count = 0;
279        let mut did_err = false;
280        let body: &'static [u8] = b"";
281        backoff
282            .retry::<_, BogusError>(|| {
283                call_count += 1;
284                if did_err {
285                    Ok(Response::builder()
286                        .status(StatusCode::OK)
287                        .body(body.into())
288                        .unwrap())
289                } else {
290                    did_err = true;
291                    Err(api::ApiError::GitlabService {
292                        status: StatusCode::INTERNAL_SERVER_ERROR,
293                        data: Vec::default(),
294                    })
295                }
296            })
297            .unwrap();
298        assert_eq!(call_count, 2);
299    }
300
301    #[test]
302    fn backoff_no_success() {
303        let backoff = retry::Backoff::builder().limit(3).build().unwrap();
304        let mut call_count = 0;
305        let body: &'static [u8] = b"";
306        let err = backoff
307            .retry::<_, BogusError>(|| {
308                call_count += 1;
309                Ok(Response::builder()
310                    .status(StatusCode::SERVICE_UNAVAILABLE)
311                    .body(body.into())
312                    .unwrap())
313            })
314            .unwrap_err();
315        assert_eq!(call_count, backoff.limit);
316        if let api::ApiError::Client {
317            source: retry::Error::Backoff {},
318        } = err
319        {
320        } else {
321            panic!("unexpected error: {}", err);
322        }
323    }
324
325    #[test]
326    fn backoff_no_success_gitlab_service_err() {
327        let backoff = retry::Backoff::builder().limit(3).build().unwrap();
328        let mut call_count = 0;
329        let err = backoff
330            .retry::<_, BogusError>(|| {
331                call_count += 1;
332                Err(api::ApiError::GitlabService {
333                    status: StatusCode::INTERNAL_SERVER_ERROR,
334                    data: Vec::default(),
335                })
336            })
337            .unwrap_err();
338        assert_eq!(call_count, backoff.limit);
339        if let api::ApiError::Client {
340            source: retry::Error::Backoff {},
341        } = err
342        {
343        } else {
344            panic!("unexpected error: {}", err);
345        }
346    }
347
348    struct Dummy;
349
350    impl Endpoint for Dummy {
351        fn method(&self) -> Method {
352            Method::GET
353        }
354
355        fn endpoint(&self) -> Cow<'static, str> {
356            "dummy".into()
357        }
358    }
359
360    #[derive(Debug, Deserialize)]
361    struct DummyResult {
362        value: u8,
363    }
364
365    #[test]
366    fn retry_client_ok() {
367        let endpoint = ExpectedUrl::builder().endpoint("dummy").build().unwrap();
368        let client = SingleTestClient::new_json(
369            endpoint,
370            &json!({
371                "value": 0,
372            }),
373        );
374        let backoff = retry::Backoff::default();
375        let client = retry::Client::new(client, backoff);
376
377        let res: DummyResult = Dummy.query(&client).unwrap();
378        assert_eq!(res.value, 0);
379    }
380
381    #[test]
382    fn retry_client_backoff_err() {
383        let endpoint = ExpectedUrl::builder()
384            .endpoint("dummy")
385            .status(StatusCode::NOT_FOUND)
386            .build()
387            .unwrap();
388        let client = SingleTestClient::new_json(
389            endpoint,
390            &json!({
391                "message": "dummy error message",
392            }),
393        );
394        let backoff = retry::Backoff::default();
395        let client = retry::Client::new(client, backoff);
396
397        let res: Result<DummyResult, _> = Dummy.query(&client);
398        let err = res.unwrap_err();
399        if let ApiError::GitlabWithStatus {
400            status,
401            msg,
402        } = err
403        {
404            assert_eq!(status, StatusCode::NOT_FOUND);
405            assert_eq!(msg, "dummy error message");
406        } else {
407            panic!("unexpected error: {}", err);
408        }
409    }
410
411    #[test]
412    fn retry_client_other_err() {
413        let endpoint = ExpectedUrl::builder()
414            .endpoint("dummy")
415            .status(StatusCode::IM_A_TEAPOT)
416            .build()
417            .unwrap();
418        let return_obj = json!({
419            "blah": "dummy error message",
420        });
421        let client = SingleTestClient::new_json(endpoint, &return_obj);
422        let backoff = retry::Backoff::default();
423        let client = retry::Client::new(client, backoff);
424
425        let res: Result<DummyResult, _> = Dummy.query(&client);
426        let err = res.unwrap_err();
427        if let ApiError::GitlabUnrecognizedWithStatus {
428            status,
429            obj,
430        } = err
431        {
432            assert_eq!(status, StatusCode::IM_A_TEAPOT);
433            assert_eq!(obj, return_obj);
434        } else {
435            panic!("unexpected error: {}", err);
436        }
437    }
438
439    #[test]
440    fn retry_client_retry_timeout() {
441        let endpoint = ExpectedUrl::builder()
442            .endpoint("dummy")
443            .status(StatusCode::SERVICE_UNAVAILABLE)
444            .build()
445            .unwrap();
446        let client = SingleTestClient::new_raw(endpoint, "");
447        let backoff = retry::Backoff::builder().limit(3).build().unwrap();
448        let client = retry::Client::new(client, backoff);
449
450        let res: Result<DummyResult, _> = Dummy.query(&client);
451        let err = res.unwrap_err();
452        if let ApiError::Client {
453            source: retry::Error::Backoff {},
454        } = err
455        {
456            // expected
457        } else {
458            panic!("unexpected error: {}", err);
459        }
460    }
461}