Skip to main content

cognis_core/wrappers/
retry.rs

1//! Retry wrapper with exponential backoff.
2
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use async_trait::async_trait;
7
8use crate::runnable::{Runnable, RunnableConfig};
9use crate::{CognisError, Result};
10
11/// Policy controlling retry timing.
12#[derive(Debug, Clone)]
13pub struct RetryPolicy {
14    /// Maximum total attempts (including the first). `1` means no retries.
15    pub max_attempts: u32,
16    /// Initial delay before the first retry.
17    pub initial_delay: Duration,
18    /// Multiplier applied to the delay after every failed attempt.
19    pub backoff_multiplier: f64,
20    /// Cap for the per-attempt delay.
21    pub max_delay: Duration,
22}
23
24impl Default for RetryPolicy {
25    fn default() -> Self {
26        Self {
27            max_attempts: 3,
28            initial_delay: Duration::from_millis(100),
29            backoff_multiplier: 2.0,
30            max_delay: Duration::from_secs(30),
31        }
32    }
33}
34
35impl RetryPolicy {
36    /// Build a policy with the given max attempts and otherwise default settings.
37    pub fn new(max_attempts: u32) -> Self {
38        Self {
39            max_attempts,
40            ..Default::default()
41        }
42    }
43    /// Override the initial delay.
44    pub fn with_initial_delay(mut self, d: Duration) -> Self {
45        self.initial_delay = d;
46        self
47    }
48    /// Override the backoff multiplier.
49    pub fn with_backoff(mut self, factor: f64) -> Self {
50        self.backoff_multiplier = factor;
51        self
52    }
53    /// Override the per-attempt delay cap.
54    pub fn with_max_delay(mut self, d: Duration) -> Self {
55        self.max_delay = d;
56        self
57    }
58}
59
60/// Retries `inner.invoke` on retryable errors with exponential backoff.
61///
62/// Retryability is decided by [`CognisError::is_retryable`]. An error's
63/// own `retry_delay()` (e.g. a `RateLimited`'s `Retry-After`) takes
64/// precedence over the policy's computed delay for that attempt.
65pub struct Retry<R, I, O> {
66    inner: R,
67    policy: RetryPolicy,
68    _phantom: PhantomData<fn(I) -> O>,
69}
70
71impl<R, I, O> Retry<R, I, O>
72where
73    R: Runnable<I, O>,
74    I: Clone + Send + 'static,
75    O: Send + 'static,
76{
77    /// Wrap a runnable with the given retry policy.
78    pub fn new(inner: R, policy: RetryPolicy) -> Self {
79        Self {
80            inner,
81            policy,
82            _phantom: PhantomData,
83        }
84    }
85}
86
87#[async_trait]
88impl<R, I, O> Runnable<I, O> for Retry<R, I, O>
89where
90    R: Runnable<I, O>,
91    I: Clone + Send + 'static,
92    O: Send + 'static,
93{
94    async fn invoke(&self, input: I, config: RunnableConfig) -> Result<O> {
95        let mut delay = self.policy.initial_delay;
96        let mut last_err: Option<CognisError> = None;
97        for attempt in 0..self.policy.max_attempts {
98            match self.inner.invoke(input.clone(), config.clone()).await {
99                Ok(v) => return Ok(v),
100                Err(e) if !e.is_retryable() => return Err(e),
101                Err(e) => {
102                    let suggested = e.retry_delay().unwrap_or(delay);
103                    last_err = Some(e);
104                    if attempt + 1 >= self.policy.max_attempts {
105                        break;
106                    }
107                    let sleep_for = suggested.min(self.policy.max_delay);
108                    tokio::time::sleep(sleep_for).await;
109                    delay = Duration::from_secs_f64(
110                        (delay.as_secs_f64() * self.policy.backoff_multiplier)
111                            .min(self.policy.max_delay.as_secs_f64()),
112                    );
113                }
114            }
115        }
116        Err(last_err.unwrap_or_else(|| {
117            CognisError::Internal("retry exhausted with no error captured".into())
118        }))
119    }
120    fn name(&self) -> &str {
121        "Retry"
122    }
123}
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128    use std::sync::atomic::{AtomicU32, Ordering};
129    use std::sync::Arc;
130
131    struct FlakyTwice {
132        attempts: Arc<AtomicU32>,
133    }
134
135    #[async_trait]
136    impl Runnable<u32, u32> for FlakyTwice {
137        async fn invoke(&self, input: u32, _: RunnableConfig) -> Result<u32> {
138            let n = self.attempts.fetch_add(1, Ordering::SeqCst);
139            if n < 2 {
140                Err(CognisError::Network {
141                    status_code: Some(503),
142                    message: "boom".into(),
143                })
144            } else {
145                Ok(input)
146            }
147        }
148    }
149
150    struct AlwaysAuth;
151
152    #[async_trait]
153    impl Runnable<u32, u32> for AlwaysAuth {
154        async fn invoke(&self, _: u32, _: RunnableConfig) -> Result<u32> {
155            Err(CognisError::AuthenticationFailed("bad key".into()))
156        }
157    }
158
159    #[tokio::test]
160    async fn retries_until_success() {
161        let attempts = Arc::new(AtomicU32::new(0));
162        let r = Retry::new(
163            FlakyTwice {
164                attempts: attempts.clone(),
165            },
166            RetryPolicy::new(5).with_initial_delay(Duration::from_millis(1)),
167        );
168        let out = r.invoke(7, RunnableConfig::default()).await.unwrap();
169        assert_eq!(out, 7);
170        assert_eq!(attempts.load(Ordering::SeqCst), 3);
171    }
172
173    #[tokio::test]
174    async fn non_retryable_short_circuits() {
175        let r = Retry::new(
176            AlwaysAuth,
177            RetryPolicy::new(5).with_initial_delay(Duration::from_millis(1)),
178        );
179        let err = r.invoke(0, RunnableConfig::default()).await.unwrap_err();
180        assert!(matches!(err, CognisError::AuthenticationFailed(_)));
181    }
182}