Skip to main content

throttle_net/
retry.rs

1//! Retry policy: drive a fallible async operation with a [`Backoff`].
2
3use crate::backoff::Backoff;
4
5/// The default number of attempts a [`Retry`] makes before giving up.
6const DEFAULT_MAX_ATTEMPTS: u32 = 5;
7
8/// What to do with an error a retried operation returned.
9///
10/// A classifier (see [`Retry::run`]) inspects each error and returns one of
11/// these. `#[non_exhaustive]` so future actions do not break callers.
12#[non_exhaustive]
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum RetryAction {
15    /// Retry, waiting the policy's computed backoff delay.
16    Retry,
17    /// Retry, but wait at least this long — a server-supplied `Retry-After`
18    /// override, honored when [`Retry::respect_retry_after`] is set.
19    RetryAfter(core::time::Duration),
20    /// Stop and return the error to the caller.
21    GiveUp,
22}
23
24/// A retry policy: a [`Backoff`], an attempt ceiling, and whether to honor a
25/// server's `Retry-After`.
26///
27/// The policy is independent of the limiters — retry any fallible async
28/// operation, or wrap a limiter's `acquire` call. The error is classified per
29/// attempt by a closure you supply, so retry works with any error type; for
30/// errors that implement [`error_forge::ForgeError`], the
31/// [`retry_if_retryable`] helper classifies by the error's own retryability.
32///
33/// # Examples
34///
35/// ```
36/// # async fn run() {
37/// use throttle_net::{Backoff, Retry, RetryAction};
38///
39/// let retry = Retry::new(Backoff::default()).max_attempts(4);
40///
41/// let result: Result<u32, &str> = retry
42///     .run(
43///         || async { Err("transient") },
44///         |_err| RetryAction::Retry,
45///     )
46///     .await;
47/// assert_eq!(result, Err("transient")); // gave up after 4 attempts
48/// # }
49/// ```
50#[derive(Debug, Clone, Copy)]
51pub struct Retry {
52    backoff: Backoff,
53    max_attempts: u32,
54    respect_retry_after: bool,
55}
56
57impl Retry {
58    /// Creates a retry policy with the given backoff, a default of five
59    /// attempts, and `Retry-After` honored.
60    #[must_use]
61    pub fn new(backoff: Backoff) -> Self {
62        Self {
63            backoff,
64            max_attempts: DEFAULT_MAX_ATTEMPTS,
65            respect_retry_after: true,
66        }
67    }
68
69    /// Sets the maximum number of attempts (including the first). A value of `1`
70    /// disables retrying; `0` is treated as `1`.
71    #[must_use]
72    pub fn max_attempts(mut self, attempts: u32) -> Self {
73        self.max_attempts = attempts.max(1);
74        self
75    }
76
77    /// Sets whether a [`RetryAction::RetryAfter`] delay overrides the computed
78    /// backoff. On by default.
79    #[must_use]
80    pub fn respect_retry_after(mut self, yes: bool) -> Self {
81        self.respect_retry_after = yes;
82        self
83    }
84
85    /// The configured attempt ceiling.
86    #[must_use]
87    pub const fn attempts(&self) -> u32 {
88        self.max_attempts
89    }
90
91    /// The configured backoff policy.
92    #[must_use]
93    pub const fn backoff(&self) -> &Backoff {
94        &self.backoff
95    }
96}
97
98#[cfg(feature = "runtime")]
99#[cfg_attr(docsrs, doc(cfg(feature = "runtime")))]
100impl Retry {
101    /// Runs `operation`, retrying on failure per the policy until it succeeds,
102    /// the classifier says to stop, or the attempt ceiling is reached.
103    ///
104    /// `operation` is called once per attempt. `classify` inspects each error and
105    /// returns a [`RetryAction`]: retry with the backoff delay, retry honoring a
106    /// `Retry-After`, or give up. The last error is returned when attempts run
107    /// out or the classifier gives up.
108    ///
109    /// # Examples
110    ///
111    /// Retry on a `Retry-After` the server sent, parsed with
112    /// [`parse_retry_after`](crate::parse_retry_after):
113    ///
114    /// ```
115    /// # async fn run() {
116    /// use std::time::Duration;
117    /// use throttle_net::{Backoff, Retry, RetryAction};
118    ///
119    /// struct Rejected { retry_after: Option<Duration> }
120    ///
121    /// let retry = Retry::new(Backoff::default()).respect_retry_after(true);
122    /// let result: Result<(), &str> = retry
123    ///     .run(
124    ///         || async { Err::<(), _>(Rejected { retry_after: Some(Duration::from_millis(10)) }) },
125    ///         |err: &Rejected| match err.retry_after {
126    ///             Some(after) => RetryAction::RetryAfter(after),
127    ///             None => RetryAction::Retry,
128    ///         },
129    ///     )
130    ///     .await
131    ///     .map(|_| ())
132    ///     .map_err(|_| "exhausted");
133    /// assert_eq!(result, Err("exhausted"));
134    /// # }
135    /// ```
136    pub async fn run<F, Fut, T, E, C>(&self, mut operation: F, classify: C) -> Result<T, E>
137    where
138        F: FnMut() -> Fut,
139        Fut: core::future::Future<Output = Result<T, E>>,
140        C: Fn(&E) -> RetryAction,
141    {
142        let mut delays = self.backoff.iter();
143        let mut attempt = 1u32;
144        loop {
145            match operation().await {
146                Ok(value) => return Ok(value),
147                Err(error) => {
148                    if attempt >= self.max_attempts {
149                        return Err(error);
150                    }
151                    let delay = match classify(&error) {
152                        RetryAction::GiveUp => return Err(error),
153                        RetryAction::Retry => delays.next_delay(),
154                        RetryAction::RetryAfter(after) => {
155                            // Always advance the backoff so its state (and
156                            // decorrelated jitter) keeps progressing, even when
157                            // the server's hint overrides the chosen delay.
158                            let computed = delays.next_delay();
159                            if self.respect_retry_after {
160                                after
161                            } else {
162                                computed
163                            }
164                        }
165                    };
166                    crate::rt::sleep(delay).await;
167                    attempt += 1;
168                }
169            }
170        }
171    }
172}
173
174/// Classifies an [`error_forge::ForgeError`] by its own retryability: retry when
175/// [`is_retryable`](error_forge::ForgeError::is_retryable) is `true`, otherwise
176/// give up. A convenient `classify` argument for [`Retry::run`].
177///
178/// # Examples
179///
180/// ```
181/// # async fn run() {
182/// use throttle_net::{Backoff, Retry, retry_if_retryable};
183/// # use error_forge::AppError;
184///
185/// let retry = Retry::new(Backoff::default());
186/// let result = retry
187///     .run(
188///         || async { Err::<(), _>(AppError::network("api.example", None)) },
189///         retry_if_retryable,
190///     )
191///     .await;
192/// assert!(result.is_err());
193/// # }
194/// ```
195#[must_use]
196pub fn retry_if_retryable<E: error_forge::ForgeError>(error: &E) -> RetryAction {
197    if error.is_retryable() {
198        RetryAction::Retry
199    } else {
200        RetryAction::GiveUp
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use super::{Retry, RetryAction, retry_if_retryable};
207    use crate::backoff::Backoff;
208    use core::time::Duration;
209    use std::sync::Arc;
210    use std::sync::atomic::{AtomicU32, Ordering};
211
212    fn fast_policy() -> Retry {
213        // Tiny, deterministic delays so paused-time tests stay exact.
214        Retry::new(Backoff::constant(Duration::from_millis(10)))
215    }
216
217    #[test]
218    fn test_max_attempts_floor_is_one() {
219        assert_eq!(fast_policy().max_attempts(0).attempts(), 1);
220        assert_eq!(fast_policy().max_attempts(7).attempts(), 7);
221    }
222
223    #[tokio::test(start_paused = true)]
224    async fn test_succeeds_after_transient_failures() {
225        let calls = Arc::new(AtomicU32::new(0));
226        let c = calls.clone();
227        let result: Result<u32, &str> = fast_policy()
228            .max_attempts(5)
229            .run(
230                move || {
231                    let c = c.clone();
232                    async move {
233                        let n = c.fetch_add(1, Ordering::Relaxed) + 1;
234                        if n < 3 { Err("transient") } else { Ok(n) }
235                    }
236                },
237                |_| RetryAction::Retry,
238            )
239            .await;
240        assert_eq!(result, Ok(3));
241        assert_eq!(calls.load(Ordering::Relaxed), 3);
242    }
243
244    #[tokio::test(start_paused = true)]
245    async fn test_gives_up_after_max_attempts() {
246        let calls = Arc::new(AtomicU32::new(0));
247        let c = calls.clone();
248        let result: Result<(), &str> = fast_policy()
249            .max_attempts(3)
250            .run(
251                move || {
252                    let c = c.clone();
253                    async move {
254                        let _ = c.fetch_add(1, Ordering::Relaxed);
255                        Err("always")
256                    }
257                },
258                |_| RetryAction::Retry,
259            )
260            .await;
261        assert_eq!(result, Err("always"));
262        assert_eq!(
263            calls.load(Ordering::Relaxed),
264            3,
265            "operation runs exactly max_attempts times"
266        );
267    }
268
269    #[tokio::test(start_paused = true)]
270    async fn test_give_up_classification_stops_immediately() {
271        let calls = Arc::new(AtomicU32::new(0));
272        let c = calls.clone();
273        let result: Result<(), &str> = fast_policy()
274            .max_attempts(10)
275            .run(
276                move || {
277                    let c = c.clone();
278                    async move {
279                        let _ = c.fetch_add(1, Ordering::Relaxed);
280                        Err("fatal")
281                    }
282                },
283                |_| RetryAction::GiveUp,
284            )
285            .await;
286        assert_eq!(result, Err("fatal"));
287        assert_eq!(
288            calls.load(Ordering::Relaxed),
289            1,
290            "GiveUp stops after the first attempt"
291        );
292    }
293
294    // These two assert *exact* elapsed time via tokio's paused virtual clock,
295    // which requires the tokio timer; under a real (smol) timer the wait is real,
296    // so they are tokio-specific. The honor/ignore logic itself is runtime-agnostic.
297    #[cfg(feature = "tokio")]
298    #[tokio::test(start_paused = true)]
299    async fn test_retry_after_is_honored_when_enabled() {
300        let start = tokio::time::Instant::now();
301        let policy = Retry::new(Backoff::constant(Duration::from_secs(1)))
302            .max_attempts(2)
303            .respect_retry_after(true);
304        let _: Result<(), &str> = policy
305            .run(
306                || async { Err("rejected") },
307                |_| RetryAction::RetryAfter(Duration::from_secs(30)),
308            )
309            .await;
310        // One retry, waiting the 30s Retry-After rather than the 1s backoff.
311        assert_eq!(start.elapsed(), Duration::from_secs(30));
312    }
313
314    #[cfg(feature = "tokio")]
315    #[tokio::test(start_paused = true)]
316    async fn test_retry_after_is_ignored_when_disabled() {
317        let start = tokio::time::Instant::now();
318        let policy = Retry::new(Backoff::constant(Duration::from_secs(1)))
319            .max_attempts(2)
320            .respect_retry_after(false);
321        let _: Result<(), &str> = policy
322            .run(
323                || async { Err("rejected") },
324                |_| RetryAction::RetryAfter(Duration::from_secs(30)),
325            )
326            .await;
327        // The 30s hint is ignored; the 1s computed backoff is used instead.
328        assert_eq!(start.elapsed(), Duration::from_secs(1));
329    }
330
331    #[tokio::test(start_paused = true)]
332    async fn test_retry_if_retryable_helper() {
333        use error_forge::AppError;
334
335        // A non-retryable error gives up immediately.
336        let calls = Arc::new(AtomicU32::new(0));
337        let c = calls.clone();
338        let result: Result<(), AppError> = fast_policy()
339            .max_attempts(5)
340            .run(
341                move || {
342                    let c = c.clone();
343                    async move {
344                        let _ = c.fetch_add(1, Ordering::Relaxed);
345                        Err(AppError::config("bad"))
346                    }
347                },
348                retry_if_retryable,
349            )
350            .await;
351        assert!(result.is_err());
352        assert_eq!(
353            calls.load(Ordering::Relaxed),
354            1,
355            "config errors are not retryable"
356        );
357    }
358}