asyn_retry_policy/lib.rs
1//! A small crate providing an async retry policy with exponential backoff and jitter.
2//!
3//! Example: programmatic and macro usage (predicate-aware)
4//!
5//! ```no_run
6//! use asyn_retry_policy::RetryPolicy;
7//! use std::time::Duration;
8//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
9//!
10//! // Programmatic usage with a predicate that inspects the error type (`String` here)
11//! #[tokio::main]
12//! async fn main() {
13//! // predicate gets an `&E`, so when `E = String` it's `&String`.
14//! fn is_retryable(e: &String) -> bool { e == "temporary" }
15//!
16//! let mut policy = RetryPolicy::default();
17//! policy.attempts = 5;
18//! policy.jitter = false;
19//!
20//! let tries = Arc::new(AtomicU8::new(0));
21//! let res = policy.retry(
22//! {
23//! let tries = tries.clone();
24//! move || {
25//! let tries = tries.clone();
26//! async move {
27//! let prev = tries.fetch_add(1, Ordering::SeqCst);
28//! if prev < 2 { Err::<u8, _>(String::from("temporary")) } else { Ok(0u8) }
29//! }
30//! }
31//! },
32//! is_retryable,
33//! ).await;
34//! assert!(res.is_ok());
35//! }
36//! ```
37//!
38//! Macro usage examples (predicate path and inline closure):
39//!
40//! ```no_run
41//! use asyn_retry_policy::retry;
42//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
43//!
44//! fn should_retry(e: &String) -> bool { e == "tmp" }
45//!
46//! #[retry(attempts = 3, predicate = should_retry)]
47//! async fn my_endpoint(tries: Arc<AtomicU8>) -> Result<u8, String> {
48//! let prev = tries.fetch_add(1, Ordering::SeqCst);
49//! if prev < 2 { Err(String::from("tmp")) } else { Ok(7u8) }
50//! }
51//! ```
52//!
53//! Inline closure predicate example:
54//!
55//! ```no_run
56//! use asyn_retry_policy::retry;
57//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
58//!
59//! #[retry(predicate = |e: &String| e == "tmp")]
60//! async fn my_endpoint_closure(tries: Arc<AtomicU8>) -> Result<u8, String> {
61//! let prev = tries.fetch_add(1, Ordering::SeqCst);
62//! if prev < 2 { Err(String::from("tmp")) } else { Ok(8u8) }
63//! }
64//! ```
65
66use rand::Rng;
67use rand::SeedableRng;
68use rand::rngs::SmallRng;
69use std::time::Duration;
70
71// Re-export the proc-macro so users can just write `#[retry]` or `#[retry(3)]` when depending on this crate
72pub use asyn_retry_policy_macro::retry;
73
74/// Retry policy configuration
75#[derive(Clone, Debug)]
76pub struct RetryPolicy {
77 /// Maximum number of attempts (including the first try)
78 pub attempts: usize,
79 /// Base delay to use for backoff
80 pub base_delay: Duration,
81 /// Maximum delay between attempts
82 pub max_delay: Duration,
83 /// Multiplicative backoff factor
84 pub backoff_factor: f64,
85 /// Use random jitter between 0..delay
86 pub jitter: bool,
87 /// Optional RNG seed to allow deterministic jitter for testing
88 pub rng_seed: Option<u64>,
89}
90
91impl Default for RetryPolicy {
92 fn default() -> Self {
93 Self {
94 attempts: 3,
95 base_delay: Duration::from_millis(100),
96 max_delay: Duration::from_secs(5),
97 backoff_factor: 2.0,
98 jitter: true,
99 rng_seed: None,
100 }
101 }
102}
103
104impl RetryPolicy {
105 /// Compute the exponential backoff (without jitter) clamped by `max_delay`.
106 pub fn compute_backoff(&self, attempt: usize) -> Duration {
107 let exp = self.backoff_factor.powi((attempt - 1) as i32);
108 self.base_delay.mul_f64(exp).min(self.max_delay)
109 }
110
111 /// Retry an asynchronous operation described by `f` with this policy.
112 ///
113 /// `f` must return a `Result<T, E>`. The `should_retry` predicate receives a reference to the error
114 /// and returns whether the operation should be retried.
115 pub async fn retry<Fut, T, E, F, P>(&self, mut f: F, mut should_retry: P) -> Result<T, E>
116 where
117 F: FnMut() -> Fut,
118 Fut: std::future::Future<Output = Result<T, E>> + Send,
119 T: Send,
120 E: Send,
121 P: FnMut(&E) -> bool,
122 {
123 for attempt in 1..=self.attempts {
124 match f().await {
125 Ok(v) => return Ok(v),
126 Err(e) if attempt < self.attempts && should_retry(&e) => {
127 // Calculate exponential backoff
128 let mut delay = self.compute_backoff(attempt);
129
130 // Apply jitter
131 if self.jitter {
132 let max_ms = delay.as_millis().max(1) as u64;
133 let jitter_ms = if let Some(seed) = self.rng_seed {
134 // deterministic per-attempt RNG to keep testability
135 let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(attempt as u64));
136 rng.gen_range(0..=max_ms)
137 } else {
138 rand::thread_rng().gen_range(0..=max_ms)
139 };
140 delay = Duration::from_millis(jitter_ms);
141 }
142
143 tokio::time::sleep(delay).await;
144 continue;
145 }
146 Err(e) => return Err(e),
147 }
148 }
149 unreachable!("loop returns or errors")
150 }
151}
152
153#[cfg(test)]
154mod tests {
155 use super::*;
156 use std::sync::Arc;
157 use std::sync::atomic::{AtomicU8, Ordering};
158
159 #[tokio::test]
160 async fn retries_and_succeeds() {
161 let policy = RetryPolicy::default();
162 let tries = Arc::new(AtomicU8::new(0));
163 let res = policy
164 .retry(
165 {
166 let tries = tries.clone();
167 move || {
168 let tries = tries.clone();
169 async move {
170 let prev = tries.fetch_add(1, Ordering::SeqCst);
171 if prev < 2 {
172 Err("temporary")
173 } else {
174 Ok(42u8)
175 }
176 }
177 }
178 },
179 |_| true,
180 )
181 .await;
182 assert_eq!(res.unwrap(), 42u8);
183 assert_eq!(tries.load(Ordering::SeqCst), 3);
184 }
185
186 #[tokio::test]
187 async fn stops_on_non_retryable_error() {
188 let policy = RetryPolicy::default();
189 let tries = Arc::new(AtomicU8::new(0));
190 let res = policy
191 .retry(
192 {
193 let tries = tries.clone();
194 move || {
195 let tries = tries.clone();
196 async move {
197 tries.fetch_add(1, Ordering::SeqCst);
198 Err::<u8, _>("fatal")
199 }
200 }
201 },
202 |_e| false,
203 )
204 .await;
205 assert!(res.is_err());
206 assert_eq!(tries.load(Ordering::SeqCst), 1);
207 }
208}