quickwit_aws/
retry.rs

1// Copyright (C) 2021 Quickwit, Inc.
2//
3// Quickwit is offered under the AGPL v3.0 and as commercial software.
4// For commercial licensing, contact us at hello@quickwit.io.
5//
6// AGPL:
7// This program is free software: you can redistribute it and/or modify
8// it under the terms of the GNU Affero General Public License as
9// published by the Free Software Foundation, either version 3 of the
10// License, or (at your option) any later version.
11//
12// This program is distributed in the hope that it will be useful,
13// but WITHOUT ANY WARRANTY; without even the implied warranty of
14// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
15// GNU Affero General Public License for more details.
16//
17// You should have received a copy of the GNU Affero General Public License
18// along with this program. If not, see <http://www.gnu.org/licenses/>.
19
20use std::fmt::Display;
21use std::time::Duration;
22
23use futures::Future;
24use rand::Rng;
25use tracing::{debug, warn};
26
27const DEFAULT_MAX_RETRY_ATTEMPTS: usize = 30;
28const DEFAULT_BASE_DELAY: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 250 });
29const DEFAULT_MAX_DELAY: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 20_000 });
30
31pub trait Retryable {
32    fn is_retryable(&self) -> bool {
33        false
34    }
35}
36
37#[derive(Debug, PartialEq, Eq)]
38pub enum Retry<E> {
39    Transient(E),
40    Permanent(E),
41}
42
43impl<E> Retry<E> {
44    pub fn into_inner(self) -> E {
45        match self {
46            Self::Transient(e) => e,
47            Self::Permanent(e) => e,
48        }
49    }
50}
51
52impl<E: Display> Display for Retry<E> {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            Retry::Transient(e) => {
56                write!(f, "Transient({})", e)
57            }
58            Retry::Permanent(e) => {
59                write!(f, "Permanent({})", e)
60            }
61        }
62    }
63}
64
65impl<E> Retryable for Retry<E> {
66    fn is_retryable(&self) -> bool {
67        match self {
68            Retry::Transient(_) => true,
69            Retry::Permanent(_) => false,
70        }
71    }
72}
73
74#[derive(Clone)]
75pub struct RetryParams {
76    pub base_delay: Duration,
77    pub max_delay: Duration,
78    pub max_attempts: usize,
79}
80
81impl Default for RetryParams {
82    fn default() -> Self {
83        Self {
84            base_delay: DEFAULT_BASE_DELAY,
85            max_delay: DEFAULT_MAX_DELAY,
86            max_attempts: DEFAULT_MAX_RETRY_ATTEMPTS,
87        }
88    }
89}
90
91/// Retry with exponential backoff and full jitter. Implementation and default values originate from
92/// the Java SDK. See also: <https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/>.
93pub async fn retry<F, U, E, Fut>(retry_params: &RetryParams, f: F) -> Result<U, E>
94where
95    F: Fn() -> Fut,
96    Fut: Future<Output = Result<U, E>>,
97    E: Retryable + Display + 'static,
98{
99    let mut attempt_count = 0;
100
101    loop {
102        let response = f().await;
103
104        attempt_count += 1;
105
106        match response {
107            Ok(response) => {
108                return Ok(response);
109            }
110            Err(error) => {
111                if !error.is_retryable() {
112                    return Err(error);
113                }
114                if attempt_count >= retry_params.max_attempts {
115                    warn!(
116                        attempt_count = %attempt_count,
117                        "Request failed"
118                    );
119                    return Err(error);
120                }
121
122                let ceiling_ms = (retry_params.base_delay.as_millis() as u64
123                    * 2u64.pow(attempt_count as u32))
124                .min(retry_params.max_delay.as_millis() as u64);
125                let delay_ms = rand::thread_rng().gen_range(0..ceiling_ms);
126                debug!(
127                    attempt_count = %attempt_count,
128                    delay_ms = %delay_ms,
129                    error = %error,
130                    "Request failed, retrying"
131                );
132
133                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
134            }
135        }
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use std::sync::RwLock;
142
143    use futures::future::ready;
144
145    use super::{retry, Retry, RetryParams};
146
147    async fn simulate_retries<T>(values: Vec<Result<T, Retry<usize>>>) -> Result<T, Retry<usize>> {
148        let values_it = RwLock::new(values.into_iter());
149        retry(&RetryParams::default(), || {
150            ready(values_it.write().unwrap().next().unwrap())
151        })
152        .await
153    }
154
155    #[tokio::test]
156    async fn test_retry_accepts_ok() {
157        assert_eq!(simulate_retries(vec![Ok(())]).await, Ok(()));
158    }
159
160    #[tokio::test]
161    async fn test_retry_does_retry() {
162        assert_eq!(
163            simulate_retries(vec![Err(Retry::Transient(1)), Ok(())]).await,
164            Ok(())
165        );
166    }
167
168    #[tokio::test]
169    async fn test_retry_stops_retrying_on_non_retryable_error() {
170        assert_eq!(
171            simulate_retries(vec![Err(Retry::Permanent(1)), Ok(())]).await,
172            Err(Retry::Permanent(1))
173        );
174    }
175
176    #[tokio::test]
177    async fn test_retry_retries_up_at_most_attempts_times() {
178        let retry_sequence: Vec<_> = (0..30)
179            .map(|retry_id| Err(Retry::Transient(retry_id)))
180            .chain(Some(Ok(())))
181            .collect();
182        assert_eq!(
183            simulate_retries(retry_sequence).await,
184            Err(Retry::Transient(29))
185        );
186    }
187
188    #[tokio::test]
189    async fn test_retry_retries_up_to_max_attempts_times() {
190        let retry_sequence: Vec<_> = (0..29)
191            .map(|retry_id| Err(Retry::Transient(retry_id)))
192            .chain(Some(Ok(())))
193            .collect();
194        assert_eq!(simulate_retries(retry_sequence).await, Ok(()));
195    }
196}