cognite/
retry.rs

1// This file is adapted from reqwest-retry, which was a bit too opinionated for our use.
2// https://github.com/TrueLayer/reqwest-middleware
3
4use async_trait::async_trait;
5use http::Extensions;
6use rand::{rng, Rng};
7use reqwest::{Request, Response, StatusCode};
8use reqwest_middleware::{Middleware, Next, Result};
9use std::time::Duration;
10
11/// Middleware for retrying requests.
12pub struct CustomRetryMiddleware {
13    max_retries: u32,
14    max_delay_ms: u64,
15    initial_delay_ms: u64,
16}
17
18#[async_trait]
19impl Middleware for CustomRetryMiddleware {
20    async fn handle(
21        &self,
22        req: Request,
23        extensions: &mut Extensions,
24        next: Next<'_>,
25    ) -> Result<Response> {
26        self.execute_with_retry(req, next, extensions).await
27    }
28}
29
30impl CustomRetryMiddleware {
31    /// Create a new retry middleware instance.
32    pub fn new(max_retries: u32, max_delay_ms: u64, initial_delay_ms: u64) -> Self {
33        Self {
34            max_retries: max_retries.min(10),
35            max_delay_ms,
36            initial_delay_ms,
37        }
38    }
39
40    async fn execute_with_retry<'a>(
41        &'a self,
42        req: Request,
43        next: Next<'a>,
44        ext: &'a mut Extensions,
45    ) -> Result<Response> {
46        let mut n_past_retries = 0;
47        let mut last_req_401 = false;
48        loop {
49            let duplicate_request = match req.try_clone() {
50                Some(x) => x,
51                None => return next.run(req, ext).await,
52            };
53
54            let result = next.clone().run(duplicate_request, ext).await;
55
56            // Check if the error can be retried.
57            break match Retryable::from_reqwest_response(&result) {
58                Some(retryable)
59                    if (retryable == Retryable::Transient
60                        || retryable == Retryable::Unauthorized && !last_req_401)
61                        && n_past_retries < self.max_retries =>
62                {
63                    last_req_401 = retryable == Retryable::Unauthorized;
64                    // If the response failed and the error type was transient
65                    // we can safely try to retry the request.
66                    let mut retry_delay = self.initial_delay_ms * 2u64.pow(n_past_retries);
67                    if retry_delay > self.max_delay_ms {
68                        retry_delay = self.max_delay_ms;
69                    }
70                    // Jitter so we land between initial * 2 ** attempt * 3/4 and initial * 2 ** attempt * 5/4
71                    retry_delay = retry_delay / 4 * 3 + rng().random_range(0..=(retry_delay / 2));
72                    futures_timer::Delay::new(Duration::from_millis(retry_delay)).await;
73                    n_past_retries += 1;
74                    continue;
75                }
76                Some(_) | None => result,
77            };
78        }
79    }
80}
81
82#[derive(PartialEq, Eq)]
83pub(crate) enum Retryable {
84    /// The failure was due to something that might resolve in the future.
85    Transient,
86    /// Unresolvable error.
87    Fatal,
88    /// Unauthorized. This is _maybe_ resolvable, if the last request wasn't also a 401.
89    Unauthorized,
90}
91
92impl Retryable {
93    /// Try to map a `reqwest` response into `Retryable`.
94    ///
95    /// Returns `None` if the response object does not contain any errors.
96    ///
97    /// # Arguments
98    ///
99    /// * `res` - Request response.
100    pub fn from_reqwest_response(
101        res: &reqwest_middleware::Result<reqwest::Response>,
102    ) -> Option<Self> {
103        match res {
104            Ok(success) => {
105                let status = success.status();
106                if status.is_success() {
107                    None
108                } else if status == StatusCode::UNAUTHORIZED {
109                    Some(Retryable::Unauthorized)
110                } else if status.is_server_error()
111                    || status == StatusCode::REQUEST_TIMEOUT
112                    || status == StatusCode::TOO_MANY_REQUESTS
113                    || success
114                        .headers()
115                        .get("cdf-is-auto-retryable")
116                        .and_then(|v| v.to_str().ok())
117                        .is_some_and(|v| v == "true")
118                {
119                    Some(Retryable::Transient)
120                } else {
121                    Some(Retryable::Fatal)
122                }
123            }
124            Err(error) => match error {
125                reqwest_middleware::Error::Middleware(_) => Some(Retryable::Fatal),
126                reqwest_middleware::Error::Reqwest(error) => {
127                    if error.is_timeout() || error.is_connect() {
128                        Some(Retryable::Transient)
129                    } else if error.is_body()
130                        || error.is_decode()
131                        || error.is_builder()
132                        || error.is_redirect()
133                        || error.is_request()
134                    {
135                        Some(Retryable::Fatal)
136                    } else {
137                        // We omit checking if error.is_status() since we check that already.
138                        // However, if Response::error_for_status is used the status will still
139                        // remain in the response object.
140                        None
141                    }
142                }
143            },
144        }
145    }
146}