mx_core/resilience/
retry.rs1use std::future::Future;
4use std::time::Duration;
5
6use rand::Rng;
7use tokio::time::sleep;
8
9const DEFAULT_MAX_DELAY: Duration = Duration::from_secs(30);
10const DEFAULT_JITTER_FACTOR: f64 = 0.25;
11
12#[derive(Debug, Clone)]
14pub struct RetryPolicy {
15 pub max_attempts: u32,
16 pub base_delay: Duration,
17 pub max_delay: Duration,
18 pub jitter_factor: f64,
19 use_exponential_backoff: bool,
20 retryable_predicate: Option<fn(&str) -> bool>,
21}
22
23impl RetryPolicy {
24 pub fn new(max_attempts: u32, delay_ms: u64) -> Self {
26 Self {
27 max_attempts,
28 base_delay: Duration::from_millis(delay_ms),
29 max_delay: Duration::from_millis(delay_ms),
30 jitter_factor: 0.0,
31 use_exponential_backoff: false,
32 retryable_predicate: None,
33 }
34 }
35
36 pub fn with_exponential_backoff(
38 max_attempts: u32,
39 base_delay: Duration,
40 max_delay: Duration,
41 jitter_factor: f64,
42 ) -> Self {
43 Self {
44 max_attempts,
45 base_delay,
46 max_delay,
47 jitter_factor: jitter_factor.clamp(0.0, 1.0),
48 use_exponential_backoff: true,
49 retryable_predicate: None,
50 }
51 }
52
53 pub fn with_retryable_predicate(mut self, predicate: fn(&str) -> bool) -> Self {
55 self.retryable_predicate = Some(predicate);
56 self
57 }
58
59 pub fn next_delay(&self, attempt: u32) -> Duration {
68 if !self.use_exponential_backoff {
69 return self.base_delay;
70 }
71
72 let base_ms = self.base_delay.as_millis() as u64;
73 let max_ms = self.max_delay.as_millis() as u64;
74
75 let delay_ms = if attempt >= 64 {
76 max_ms
77 } else {
78 let multiplier = 1u64.checked_shl(attempt).unwrap_or(u64::MAX);
79 base_ms.saturating_mul(multiplier).min(max_ms)
80 };
81
82 if self.jitter_factor > 0.0 {
83 self.apply_jitter(Duration::from_millis(delay_ms))
84 } else {
85 Duration::from_millis(delay_ms)
86 }
87 }
88
89 fn apply_jitter(&self, delay: Duration) -> Duration {
90 let mut rng = rand::rng();
91 let jitter_range = self.jitter_factor * 2.0;
92 let jitter_offset = rng.random::<f64>() * jitter_range - self.jitter_factor;
93 let factor = 1.0 + jitter_offset;
94 let delay_ms = delay.as_millis() as f64;
95 let jittered_ms = (delay_ms * factor).max(1.0) as u64;
96 Duration::from_millis(jittered_ms)
97 }
98
99 fn is_retryable(&self, error_msg: &str) -> bool {
100 match self.retryable_predicate {
101 Some(pred) => pred(error_msg),
102 None => true,
103 }
104 }
105
106 pub async fn execute<F, Fut, T, E>(&self, mut operation: F) -> Result<T, E>
108 where
109 F: FnMut() -> Fut,
110 Fut: Future<Output = Result<T, E>>,
111 E: std::fmt::Display,
112 {
113 let mut last_error: Option<E> = None;
114
115 for attempt in 0..self.max_attempts {
116 match operation().await {
117 Ok(result) => return Ok(result),
118 Err(e) => {
119 let error_msg = e.to_string();
120 if !self.is_retryable(&error_msg) {
121 return Err(e);
122 }
123 last_error = Some(e);
124 if attempt < self.max_attempts - 1 {
125 let delay = self.next_delay(attempt);
126 sleep(delay).await;
127 }
128 }
129 }
130 }
131
132 Err(last_error.expect("at least one attempt must have been made"))
133 }
134}
135
136impl Default for RetryPolicy {
137 fn default() -> Self {
138 Self::with_exponential_backoff(
139 3,
140 Duration::from_millis(200),
141 DEFAULT_MAX_DELAY,
142 DEFAULT_JITTER_FACTOR,
143 )
144 }
145}
146
147#[cfg(test)]
148mod tests {
149 use super::*;
150
151 #[test]
152 fn test_fixed_delay() {
153 let policy = RetryPolicy::new(3, 200);
154 assert_eq!(policy.next_delay(0), Duration::from_millis(200));
155 assert_eq!(policy.next_delay(1), Duration::from_millis(200));
156 assert_eq!(policy.next_delay(2), Duration::from_millis(200));
157 }
158
159 #[test]
160 fn test_exponential_no_jitter() {
161 let policy = RetryPolicy::with_exponential_backoff(
162 5,
163 Duration::from_millis(100),
164 Duration::from_secs(30),
165 0.0,
166 );
167 assert_eq!(policy.next_delay(0), Duration::from_millis(100));
168 assert_eq!(policy.next_delay(1), Duration::from_millis(200));
169 assert_eq!(policy.next_delay(2), Duration::from_millis(400));
170 }
171
172 #[test]
173 fn test_capped_at_max_delay() {
174 let policy = RetryPolicy::with_exponential_backoff(
175 5,
176 Duration::from_millis(100),
177 Duration::from_secs(1),
178 0.0,
179 );
180 assert_eq!(policy.next_delay(20), Duration::from_secs(1));
181 }
182
183 #[test]
184 fn test_overflow_protection() {
185 let policy = RetryPolicy::with_exponential_backoff(
186 5,
187 Duration::from_secs(1),
188 Duration::from_secs(3600),
189 0.0,
190 );
191 let delay = policy.next_delay(100);
192 assert!(delay <= Duration::from_secs(3600));
193 }
194
195 #[tokio::test]
196 async fn test_execute_success() {
197 let policy = RetryPolicy::new(3, 10);
198 let result: Result<i32, String> = policy.execute(|| async { Ok(42) }).await;
199 assert_eq!(result.unwrap(), 42);
200 }
201
202 #[tokio::test]
203 async fn test_execute_retries_then_succeeds() {
204 use std::sync::Arc;
205 use std::sync::atomic::{AtomicU32, Ordering};
206
207 let attempts = Arc::new(AtomicU32::new(0));
208 let attempts_clone = Arc::clone(&attempts);
209
210 let policy = RetryPolicy::new(3, 1);
211 let result: Result<i32, String> = policy
212 .execute(|| {
213 let a = Arc::clone(&attempts_clone);
214 async move {
215 let count = a.fetch_add(1, Ordering::SeqCst);
216 if count < 2 {
217 Err("transient".to_string())
218 } else {
219 Ok(42)
220 }
221 }
222 })
223 .await;
224
225 assert_eq!(result.unwrap(), 42);
226 assert_eq!(attempts.load(Ordering::SeqCst), 3);
227 }
228}