1use std::fmt::Display;
21use std::time::Duration;
22
23use futures::Future;
24use rand::Rng;
25use tracing::{debug, warn};
26
27const DEFAULT_MAX_RETRY_ATTEMPTS: usize = 30;
28const DEFAULT_BASE_DELAY: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 250 });
29const DEFAULT_MAX_DELAY: Duration = Duration::from_millis(if cfg!(test) { 1 } else { 20_000 });
30
31pub trait Retryable {
32 fn is_retryable(&self) -> bool {
33 false
34 }
35}
36
37#[derive(Debug, PartialEq, Eq)]
38pub enum Retry<E> {
39 Transient(E),
40 Permanent(E),
41}
42
43impl<E> Retry<E> {
44 pub fn into_inner(self) -> E {
45 match self {
46 Self::Transient(e) => e,
47 Self::Permanent(e) => e,
48 }
49 }
50}
51
52impl<E: Display> Display for Retry<E> {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 Retry::Transient(e) => {
56 write!(f, "Transient({})", e)
57 }
58 Retry::Permanent(e) => {
59 write!(f, "Permanent({})", e)
60 }
61 }
62 }
63}
64
65impl<E> Retryable for Retry<E> {
66 fn is_retryable(&self) -> bool {
67 match self {
68 Retry::Transient(_) => true,
69 Retry::Permanent(_) => false,
70 }
71 }
72}
73
74#[derive(Clone)]
75pub struct RetryParams {
76 pub base_delay: Duration,
77 pub max_delay: Duration,
78 pub max_attempts: usize,
79}
80
81impl Default for RetryParams {
82 fn default() -> Self {
83 Self {
84 base_delay: DEFAULT_BASE_DELAY,
85 max_delay: DEFAULT_MAX_DELAY,
86 max_attempts: DEFAULT_MAX_RETRY_ATTEMPTS,
87 }
88 }
89}
90
91pub async fn retry<F, U, E, Fut>(retry_params: &RetryParams, f: F) -> Result<U, E>
94where
95 F: Fn() -> Fut,
96 Fut: Future<Output = Result<U, E>>,
97 E: Retryable + Display + 'static,
98{
99 let mut attempt_count = 0;
100
101 loop {
102 let response = f().await;
103
104 attempt_count += 1;
105
106 match response {
107 Ok(response) => {
108 return Ok(response);
109 }
110 Err(error) => {
111 if !error.is_retryable() {
112 return Err(error);
113 }
114 if attempt_count >= retry_params.max_attempts {
115 warn!(
116 attempt_count = %attempt_count,
117 "Request failed"
118 );
119 return Err(error);
120 }
121
122 let ceiling_ms = (retry_params.base_delay.as_millis() as u64
123 * 2u64.pow(attempt_count as u32))
124 .min(retry_params.max_delay.as_millis() as u64);
125 let delay_ms = rand::thread_rng().gen_range(0..ceiling_ms);
126 debug!(
127 attempt_count = %attempt_count,
128 delay_ms = %delay_ms,
129 error = %error,
130 "Request failed, retrying"
131 );
132
133 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
134 }
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use std::sync::RwLock;
142
143 use futures::future::ready;
144
145 use super::{retry, Retry, RetryParams};
146
147 async fn simulate_retries<T>(values: Vec<Result<T, Retry<usize>>>) -> Result<T, Retry<usize>> {
148 let values_it = RwLock::new(values.into_iter());
149 retry(&RetryParams::default(), || {
150 ready(values_it.write().unwrap().next().unwrap())
151 })
152 .await
153 }
154
155 #[tokio::test]
156 async fn test_retry_accepts_ok() {
157 assert_eq!(simulate_retries(vec![Ok(())]).await, Ok(()));
158 }
159
160 #[tokio::test]
161 async fn test_retry_does_retry() {
162 assert_eq!(
163 simulate_retries(vec![Err(Retry::Transient(1)), Ok(())]).await,
164 Ok(())
165 );
166 }
167
168 #[tokio::test]
169 async fn test_retry_stops_retrying_on_non_retryable_error() {
170 assert_eq!(
171 simulate_retries(vec![Err(Retry::Permanent(1)), Ok(())]).await,
172 Err(Retry::Permanent(1))
173 );
174 }
175
176 #[tokio::test]
177 async fn test_retry_retries_up_at_most_attempts_times() {
178 let retry_sequence: Vec<_> = (0..30)
179 .map(|retry_id| Err(Retry::Transient(retry_id)))
180 .chain(Some(Ok(())))
181 .collect();
182 assert_eq!(
183 simulate_retries(retry_sequence).await,
184 Err(Retry::Transient(29))
185 );
186 }
187
188 #[tokio::test]
189 async fn test_retry_retries_up_to_max_attempts_times() {
190 let retry_sequence: Vec<_> = (0..29)
191 .map(|retry_id| Err(Retry::Transient(retry_id)))
192 .chain(Some(Ok(())))
193 .collect();
194 assert_eq!(simulate_retries(retry_sequence).await, Ok(()));
195 }
196}