asyn_retry_policy/
lib.rs

1//! A small crate providing an async retry policy with exponential backoff and jitter.
2//!
3//! Example: programmatic and macro usage (predicate-aware)
4//!
5//! ```no_run
6//! use asyn_retry_policy::RetryPolicy;
7//! use std::time::Duration;
8//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
9//!
10//! // Programmatic usage with a predicate that inspects the error type (`String` here)
11//! #[tokio::main]
12//! async fn main() {
13//!     // predicate gets an `&E`, so when `E = String` it's `&String`.
14//!     fn is_retryable(e: &String) -> bool { e == "temporary" }
15//!
16//!     let mut policy = RetryPolicy::default();
17//!     policy.attempts = 5;
18//!     policy.jitter = false;
19//!
20//!     let tries = Arc::new(AtomicU8::new(0));
21//!     let res = policy.retry(
22//!         {
23//!             let tries = tries.clone();
24//!             move || {
25//!                 let tries = tries.clone();
26//!                 async move {
27//!                     let prev = tries.fetch_add(1, Ordering::SeqCst);
28//!                     if prev < 2 { Err::<u8, _>(String::from("temporary")) } else { Ok(0u8) }
29//!                 }
30//!             }
31//!         },
32//!         is_retryable,
33//!     ).await;
34//!     assert!(res.is_ok());
35//! }
36//! ```
37//!
38//! Macro usage examples (predicate path and inline closure):
39//!
40//! ```no_run
41//! use asyn_retry_policy::retry;
42//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
43//!
44//! fn should_retry(e: &String) -> bool { e == "tmp" }
45//!
46//! #[retry(attempts = 3, predicate = should_retry)]
47//! async fn my_endpoint(tries: Arc<AtomicU8>) -> Result<u8, String> {
48//!     let prev = tries.fetch_add(1, Ordering::SeqCst);
49//!     if prev < 2 { Err(String::from("tmp")) } else { Ok(7u8) }
50//! }
51//! ```
52//!
53//! Inline closure predicate example:
54//!
55//! ```no_run
56//! use asyn_retry_policy::retry;
57//! use std::sync::{Arc, atomic::{AtomicU8, Ordering}};
58//!
59//! #[retry(predicate = |e: &String| e == "tmp")]
60//! async fn my_endpoint_closure(tries: Arc<AtomicU8>) -> Result<u8, String> {
61//!     let prev = tries.fetch_add(1, Ordering::SeqCst);
62//!     if prev < 2 { Err(String::from("tmp")) } else { Ok(8u8) }
63//! }
64//! ```
65
66use rand::Rng;
67use rand::SeedableRng;
68use rand::rngs::SmallRng;
69use std::time::Duration;
70
71// Re-export the proc-macro so users can just write `#[retry]` or `#[retry(3)]` when depending on this crate
72pub use asyn_retry_policy_macro::retry;
73
74/// Retry policy configuration
75#[derive(Clone, Debug)]
76pub struct RetryPolicy {
77    /// Maximum number of attempts (including the first try)
78    pub attempts: usize,
79    /// Base delay to use for backoff
80    pub base_delay: Duration,
81    /// Maximum delay between attempts
82    pub max_delay: Duration,
83    /// Multiplicative backoff factor
84    pub backoff_factor: f64,
85    /// Use random jitter between 0..delay
86    pub jitter: bool,
87    /// Optional RNG seed to allow deterministic jitter for testing
88    pub rng_seed: Option<u64>,
89}
90
91impl Default for RetryPolicy {
92    fn default() -> Self {
93        Self {
94            attempts: 3,
95            base_delay: Duration::from_millis(100),
96            max_delay: Duration::from_secs(5),
97            backoff_factor: 2.0,
98            jitter: true,
99            rng_seed: None,
100        }
101    }
102}
103
104impl RetryPolicy {
105    /// Compute the exponential backoff (without jitter) clamped by `max_delay`.
106    pub fn compute_backoff(&self, attempt: usize) -> Duration {
107        let exp = self.backoff_factor.powi((attempt - 1) as i32);
108        self.base_delay.mul_f64(exp).min(self.max_delay)
109    }
110
111    /// Retry an asynchronous operation described by `f` with this policy.
112    ///
113    /// `f` must return a `Result<T, E>`. The `should_retry` predicate receives a reference to the error
114    /// and returns whether the operation should be retried.
115    pub async fn retry<Fut, T, E, F, P>(&self, mut f: F, mut should_retry: P) -> Result<T, E>
116    where
117        F: FnMut() -> Fut,
118        Fut: std::future::Future<Output = Result<T, E>> + Send,
119        T: Send,
120        E: Send,
121        P: FnMut(&E) -> bool,
122    {
123        for attempt in 1..=self.attempts {
124            match f().await {
125                Ok(v) => return Ok(v),
126                Err(e) if attempt < self.attempts && should_retry(&e) => {
127                    // Calculate exponential backoff
128                    let mut delay = self.compute_backoff(attempt);
129
130                    // Apply jitter
131                    if self.jitter {
132                        let max_ms = delay.as_millis().max(1) as u64;
133                        let jitter_ms = if let Some(seed) = self.rng_seed {
134                            // deterministic per-attempt RNG to keep testability
135                            let mut rng = SmallRng::seed_from_u64(seed.wrapping_add(attempt as u64));
136                            rng.gen_range(0..=max_ms)
137                        } else {
138                            rand::thread_rng().gen_range(0..=max_ms)
139                        };
140                        delay = Duration::from_millis(jitter_ms);
141                    }
142
143                    tokio::time::sleep(delay).await;
144                    continue;
145                }
146                Err(e) => return Err(e),
147            }
148        }
149        unreachable!("loop returns or errors")
150    }
151}
152
153#[cfg(test)]
154mod tests {
155    use super::*;
156    use std::sync::Arc;
157    use std::sync::atomic::{AtomicU8, Ordering};
158
159    #[tokio::test]
160    async fn retries_and_succeeds() {
161        let policy = RetryPolicy::default();
162        let tries = Arc::new(AtomicU8::new(0));
163        let res = policy
164            .retry(
165                {
166                    let tries = tries.clone();
167                    move || {
168                        let tries = tries.clone();
169                        async move {
170                            let prev = tries.fetch_add(1, Ordering::SeqCst);
171                            if prev < 2 {
172                                Err("temporary")
173                            } else {
174                                Ok(42u8)
175                            }
176                        }
177                    }
178                },
179                |_| true,
180            )
181            .await;
182        assert_eq!(res.unwrap(), 42u8);
183        assert_eq!(tries.load(Ordering::SeqCst), 3);
184    }
185
186    #[tokio::test]
187    async fn stops_on_non_retryable_error() {
188        let policy = RetryPolicy::default();
189        let tries = Arc::new(AtomicU8::new(0));
190        let res = policy
191            .retry(
192                {
193                    let tries = tries.clone();
194                    move || {
195                        let tries = tries.clone();
196                        async move {
197                            tries.fetch_add(1, Ordering::SeqCst);
198                            Err::<u8, _>("fatal")
199                        }
200                    }
201                },
202                |_e| false,
203            )
204            .await;
205        assert!(res.is_err());
206        assert_eq!(tries.load(Ordering::SeqCst), 1);
207    }
208}