floxide_core/
retry.rs

1//! Retry integration is handled via a wrapper node (`RetryNode`) around any `Node`.
2//! This preserves the existing `Transition` API without adding a retry variant.
3//!
4//! Users can opt in by wrapping nodes in `RetryNode`, which applies the `RetryPolicy`.
5use crate::context::{Context, WorkflowCtx};
6use crate::error::FloxideError;
7use crate::node::Node;
8use crate::transition::Transition;
9use async_trait::async_trait;
10use std::time::Duration;
11
12/// Helper trait: run a delay respecting cancellation and timeouts if available.
13#[async_trait]
14pub trait RetryDelay {
15    /// Wait the given duration, returning an error if cancelled or timed out.
16    async fn wait(&self, dur: Duration) -> Result<(), FloxideError>;
17}
18
19#[async_trait]
20impl<S: Context> RetryDelay for WorkflowCtx<S> {
21    async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
22        self.run_future(async {
23            tokio::time::sleep(dur).await;
24            Ok(())
25        })
26        .await
27    }
28}
29
30#[async_trait]
31impl<S: Context> RetryDelay for S {
32    async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
33        tokio::time::sleep(dur).await;
34        Ok(())
35    }
36}
37
38/// Strategy for computing backoff durations.
39#[derive(Clone, Copy, Debug)]
40pub enum BackoffStrategy {
41    /// Delay = initial_backoff * attempt_count
42    Linear,
43    /// Delay = initial_backoff * 2^(attempt_count - 1)
44    Exponential,
45}
46
47/// Which errors should be retried.
48#[derive(Clone, Copy, Debug)]
49pub enum RetryError {
50    /// Retry on any error.
51    All,
52    /// Retry only on cancellation.
53    Cancelled,
54    /// Retry only on timeout.
55    Timeout,
56    /// Retry only on generic errors.
57    Generic,
58}
59
60/// Policy controlling retry behavior for nodes.
61#[derive(Clone, Debug)]
62pub struct RetryPolicy {
63    /// Maximum number of attempts (including the first).
64    pub max_attempts: usize,
65    /// Initial backoff duration between retries.
66    pub initial_backoff: Duration,
67    /// Maximum backoff duration allowed.
68    pub max_backoff: Duration,
69    /// Strategy to compute backoff durations.
70    pub strategy: BackoffStrategy,
71    /// Optional fixed jitter to add to each backoff.
72    pub jitter: Option<Duration>,
73    /// Error predicate controlling which errors to retry.
74    pub retry_error: RetryError,
75}
76
77impl RetryPolicy {
78    /// Construct a new RetryPolicy.
79    pub fn new(
80        max_attempts: usize,
81        initial_backoff: Duration,
82        max_backoff: Duration,
83        strategy: BackoffStrategy,
84        retry_error: RetryError,
85    ) -> Self {
86        RetryPolicy {
87            max_attempts,
88            initial_backoff,
89            max_backoff,
90            strategy,
91            jitter: None,
92            retry_error,
93        }
94    }
95
96    /// Specify a fixed jitter offset to add to each backoff.
97    pub fn with_jitter(mut self, jitter: Duration) -> Self {
98        self.jitter = Some(jitter);
99        self
100    }
101
102    /// Determine whether an error should be retried for the given attempt (1-based).
103    pub fn should_retry(&self, error: &FloxideError, attempt: usize) -> bool {
104        if attempt >= self.max_attempts {
105            return false;
106        }
107        match self.retry_error {
108            RetryError::All => true,
109            RetryError::Cancelled => matches!(error, FloxideError::Cancelled),
110            RetryError::Timeout => matches!(error, FloxideError::Timeout(_)),
111            RetryError::Generic => matches!(error, FloxideError::Generic(_)),
112        }
113    }
114
115    /// Compute the backoff duration before the next retry given the attempt count (1-based).
116    pub fn backoff_duration(&self, attempt: usize) -> Duration {
117        let base = match self.strategy {
118            BackoffStrategy::Linear => self.initial_backoff.saturating_mul(attempt as u32),
119            BackoffStrategy::Exponential => {
120                // Compute 2^(attempt-1) with shift, saturating at 32 bits
121                let exp = attempt.saturating_sub(1);
122                let factor = if exp < 32 { 1_u32 << exp } else { u32::MAX };
123                self.initial_backoff.saturating_mul(factor)
124            }
125        };
126        let capped = if base > self.max_backoff {
127            self.max_backoff
128        } else {
129            base
130        };
131        if let Some(j) = self.jitter {
132            capped.saturating_add(j)
133        } else {
134            capped
135        }
136    }
137}
138
139// Ergonomic retry helpers
140/// Wrap an existing node with retry behavior according to the given policy.
141///
142/// # Example
143///
144/// ```rust
145/// use floxide_core::*;
146/// use std::time::Duration;
147/// // Define a policy: up to 3 attempts, exponential backoff 100ms→200ms→400ms
148/// let policy = RetryPolicy::new(
149///     3,
150///     Duration::from_millis(100),
151///     Duration::from_secs(1),
152///     BackoffStrategy::Exponential,
153///     RetryError::All,
154/// );
155/// let my_node = FooNode::new();
156/// let retry_node = with_retry(my_node, policy.clone());
157/// ```
158///
159/// In future macro syntax you could write:
160///
161/// ```ignore
162/// #[node(retry = "my_policy")]
163/// foo: FooNode;
164/// ```
165pub fn with_retry<N>(node: N, policy: RetryPolicy) -> RetryNode<N> {
166    RetryNode::new(node, policy)
167}
168
169/// Wrapper node that applies a `RetryPolicy` on inner node failures.
170///
171/// Internally it will re-run the inner node up to `policy.max_attempts`,
172/// using backoff delays between attempts.
173#[derive(Clone, Debug)]
174pub struct RetryNode<N> {
175    /// Inner node to invoke.
176    pub inner: N,
177    /// Policy controlling retry attempts and backoff.
178    pub policy: RetryPolicy,
179}
180
181impl<N> RetryNode<N> {
182    /// Create a new retry wrapper around `inner` with the given `policy`.
183    pub fn new(inner: N, policy: RetryPolicy) -> Self {
184        RetryNode { inner, policy }
185    }
186}
187// RetryNode implements Node by looping on errors according to its policy.
188
189#[async_trait]
190impl<C, N> Node<C> for RetryNode<N>
191where
192    C: Context + RetryDelay,
193    N: Node<C> + Clone + Send + Sync + 'static,
194    N::Input: Clone + Send + 'static,
195    N::Output: Send + 'static,
196{
197    type Input = N::Input;
198    type Output = N::Output;
199
200    async fn process(
201        &self,
202        ctx: &C,
203        input: Self::Input,
204    ) -> Result<Transition<Self::Output>, FloxideError> {
205        let mut attempt = 1;
206        loop {
207            match self.inner.process(ctx, input.clone()).await {
208                Ok(Transition::NextAll(vs)) => return Ok(Transition::NextAll(vs)),
209                Ok(Transition::Next(out)) => return Ok(Transition::Next(out)),
210                Ok(Transition::Hold) => return Ok(Transition::Hold),
211                Ok(Transition::Abort(e)) | Err(e) => {
212                    // emit tracing event for retry evaluation
213                    tracing::debug!(attempt, error=%e, "RetryNode: caught error, evaluating retry policy");
214                    if self.policy.should_retry(&e, attempt) {
215                        let backoff = self.policy.backoff_duration(attempt);
216                        tracing::debug!(attempt, backoff=?backoff, "RetryNode: retrying after backoff");
217                        ctx.wait(backoff).await?;
218                        attempt += 1;
219                        continue;
220                    } else {
221                        tracing::warn!(attempt, error=%e, "RetryNode: aborting after reaching retry limit or non-retryable error");
222                        return Err(e);
223                    }
224                }
225            }
226        }
227    }
228}
229
230// Unit tests for RetryPolicy backoff and retry logic
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::time::Duration;
235
236    #[test]
237    fn test_linear_backoff() {
238        let policy = RetryPolicy::new(
239            5,
240            Duration::from_millis(100),
241            Duration::from_millis(1000),
242            BackoffStrategy::Linear,
243            RetryError::All,
244        );
245        assert_eq!(policy.backoff_duration(1), Duration::from_millis(100));
246        assert_eq!(policy.backoff_duration(3), Duration::from_millis(300));
247        // capped at max_backoff
248        assert_eq!(policy.backoff_duration(20), Duration::from_millis(1000));
249    }
250
251    #[test]
252    fn test_exponential_backoff() {
253        let policy = RetryPolicy::new(
254            5,
255            Duration::from_millis(50),
256            Duration::from_millis(400),
257            BackoffStrategy::Exponential,
258            RetryError::All,
259        );
260        // 1 -> 50ms, 2 -> 100ms, 3 -> 200ms, 4 -> 400ms, capped thereafter
261        assert_eq!(policy.backoff_duration(1), Duration::from_millis(50));
262        assert_eq!(policy.backoff_duration(2), Duration::from_millis(100));
263        assert_eq!(policy.backoff_duration(3), Duration::from_millis(200));
264        assert_eq!(policy.backoff_duration(4), Duration::from_millis(400));
265        assert_eq!(policy.backoff_duration(5), Duration::from_millis(400));
266    }
267
268    #[test]
269    fn test_jitter_addition() {
270        let mut policy = RetryPolicy::new(
271            3,
272            Duration::from_millis(100),
273            Duration::from_millis(1000),
274            BackoffStrategy::Linear,
275            RetryError::All,
276        );
277        policy = policy.with_jitter(Duration::from_millis(25));
278        assert_eq!(
279            policy.backoff_duration(2),
280            Duration::from_millis(100 * 2 + 25)
281        );
282    }
283
284    #[test]
285    fn test_retry_predicates() {
286        let mut policy = RetryPolicy::new(
287            3,
288            Duration::from_millis(10),
289            Duration::from_millis(100),
290            BackoffStrategy::Linear,
291            RetryError::Generic,
292        );
293        let gen_err = FloxideError::Generic("oops".into());
294        let cancel_err = FloxideError::Cancelled;
295        let timeout_err = FloxideError::Timeout(Duration::from_secs(1));
296        // Generic only
297        assert!(policy.should_retry(&gen_err, 1));
298        assert!(!policy.should_retry(&cancel_err, 1));
299        assert!(!policy.should_retry(&timeout_err, 1));
300        // Exhausted attempts
301        assert!(!policy.should_retry(&gen_err, 3));
302        // RetryError::All
303        policy.retry_error = RetryError::All;
304        assert!(policy.should_retry(&cancel_err, 2));
305        assert!(policy.should_retry(&timeout_err, 2));
306    }
307}