floxide_core/
retry.rs

1use async_trait::async_trait;
2use std::fmt::{self, Debug, Formatter};
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use crate::action::ActionType;
7use crate::error::FloxideError;
8use crate::node::{Node, NodeId, NodeOutcome};
9
10/// Backoff strategy for retries
11#[derive(Clone)]
12pub enum BackoffStrategy {
13    /// Constant time between retries
14    Constant(Duration),
15    /// Linear increase in time between retries
16    Linear { base: Duration, increment: Duration },
17    /// Exponential increase in time between retries (base * 2^attempt)
18    Exponential { base: Duration, max: Duration },
19    /// Custom backoff strategy implemented as a function
20    Custom(CustomBackoff),
21}
22
23/// A wrapper for custom backoff functions that can be cloned and debugged
24pub struct CustomBackoff {
25    func: Box<dyn Fn(usize) -> Duration + Send + Sync>,
26}
27
28impl Clone for CustomBackoff {
29    fn clone(&self) -> Self {
30        // We can't actually clone the function, so this is a hack
31        // In practice, RetryNode should be created just once and not cloned
32        Self {
33            func: Box::new(|attempt| Duration::from_millis(100 * attempt as u64)),
34        }
35    }
36}
37
38impl Debug for CustomBackoff {
39    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40        f.debug_struct("CustomBackoff")
41            .field("func", &"<function>")
42            .finish()
43    }
44}
45
46impl BackoffStrategy {
47    /// Calculate the delay for a given attempt
48    pub fn calculate_delay(&self, attempt: usize) -> Duration {
49        match self {
50            Self::Constant(duration) => *duration,
51            Self::Linear { base, increment } => *base + (*increment * attempt as u32),
52            Self::Exponential { base, max } => {
53                let calculated = *base * u32::pow(2, attempt as u32);
54                std::cmp::min(calculated, *max)
55            }
56            Self::Custom(custom) => (custom.func)(attempt),
57        }
58    }
59}
60
61impl Debug for BackoffStrategy {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        match self {
64            Self::Constant(duration) => f.debug_tuple("Constant").field(duration).finish(),
65            Self::Linear { base, increment } => f
66                .debug_struct("Linear")
67                .field("base", base)
68                .field("increment", increment)
69                .finish(),
70            Self::Exponential { base, max } => f
71                .debug_struct("Exponential")
72                .field("base", base)
73                .field("max", max)
74                .finish(),
75            Self::Custom(custom) => f.debug_tuple("Custom").field(custom).finish(),
76        }
77    }
78}
79
80/// A node that retries another node if it fails
81pub struct RetryNode<N, Context, A = crate::action::DefaultAction>
82where
83    N: Node<Context, A>,
84    Context: Send + Sync + 'static,
85    A: ActionType + Send + Sync + 'static,
86{
87    /// The node to retry
88    inner_node: N,
89    /// Maximum number of retry attempts
90    max_retries: usize,
91    /// Backoff strategy
92    backoff_strategy: BackoffStrategy,
93    /// Type markers
94    _context: PhantomData<Context>,
95    _action: PhantomData<A>,
96}
97
98impl<N, Context, A> RetryNode<N, Context, A>
99where
100    N: Node<Context, A>,
101    Context: Send + Sync + 'static,
102    A: ActionType + Send + Sync + 'static,
103{
104    /// Create a new retry node with a constant backoff
105    pub fn with_constant_backoff(inner_node: N, max_retries: usize, delay: Duration) -> Self {
106        Self {
107            inner_node,
108            max_retries,
109            backoff_strategy: BackoffStrategy::Constant(delay),
110            _context: PhantomData,
111            _action: PhantomData,
112        }
113    }
114
115    /// Create a new retry node with linear backoff
116    pub fn with_linear_backoff(
117        inner_node: N,
118        max_retries: usize,
119        base: Duration,
120        increment: Duration,
121    ) -> Self {
122        Self {
123            inner_node,
124            max_retries,
125            backoff_strategy: BackoffStrategy::Linear { base, increment },
126            _context: PhantomData,
127            _action: PhantomData,
128        }
129    }
130
131    /// Create a new retry node with exponential backoff
132    pub fn with_exponential_backoff(
133        inner_node: N,
134        max_retries: usize,
135        base: Duration,
136        max: Duration,
137    ) -> Self {
138        Self {
139            inner_node,
140            max_retries,
141            backoff_strategy: BackoffStrategy::Exponential { base, max },
142            _context: PhantomData,
143            _action: PhantomData,
144        }
145    }
146
147    /// Create a new retry node with a custom backoff strategy
148    pub fn with_custom_backoff<F>(inner_node: N, max_retries: usize, f: F) -> Self
149    where
150        F: Fn(usize) -> Duration + Send + Sync + 'static,
151    {
152        Self {
153            inner_node,
154            max_retries,
155            backoff_strategy: BackoffStrategy::Custom(CustomBackoff { func: Box::new(f) }),
156            _context: PhantomData,
157            _action: PhantomData,
158        }
159    }
160}
161
162#[async_trait]
163impl<N, Context, A> Node<Context, A> for RetryNode<N, Context, A>
164where
165    N: Node<Context, A> + std::fmt::Debug + Send + Sync,
166    Context: std::fmt::Debug + Send + Sync + 'static,
167    A: crate::action::ActionType + Default + std::fmt::Debug + Send + Sync + 'static,
168    N::Output: Clone + Send + Sync + 'static,
169{
170    type Output = N::Output;
171
172    fn id(&self) -> NodeId {
173        self.inner_node.id()
174    }
175
176    async fn process(
177        &self,
178        ctx: &mut Context,
179    ) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
180        let mut attempt = 0;
181        loop {
182            attempt += 1;
183            match self.inner_node.process(ctx).await {
184                Ok(outcome) => {
185                    tracing::debug!(
186                        attempt = attempt,
187                        node_id = %self.id(),
188                        "Node completed successfully after {} attempts",
189                        attempt
190                    );
191                    return Ok(outcome);
192                }
193                Err(err) => {
194                    if attempt >= self.max_retries {
195                        tracing::error!(
196                            attempt = attempt,
197                            max_retries = self.max_retries,
198                            node_id = %self.id(),
199                            error = %err,
200                            "Maximum retry attempts reached, failing"
201                        );
202                        return Err(err);
203                    }
204
205                    let delay = self.backoff_strategy.calculate_delay(attempt);
206                    tracing::warn!(
207                        attempt = attempt,
208                        node_id = %self.id(),
209                        error = %err,
210                        delay_ms = delay.as_millis(),
211                        "Node execution failed, retrying after {:?}",
212                        delay
213                    );
214
215                    #[cfg(feature = "async")]
216                    {
217                        tokio::time::sleep(delay).await;
218                    }
219
220                    #[cfg(not(feature = "async"))]
221                    {
222                        // For compatibility with sync execution, we just delay using std
223                        std::thread::sleep(delay);
224                    }
225                }
226            }
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234
235    use crate::DefaultAction;
236
237    #[derive(Debug, Clone)]
238    struct TestContext {
239        counter: usize,
240        should_fail_until: usize,
241    }
242
243    #[tokio::test]
244    async fn test_retry_success_after_failures() {
245        // Create a custom node implementation for testing retry logic
246        #[derive(Debug)]
247        struct TestNodeImpl {
248            id: NodeId,
249        }
250
251        #[async_trait]
252        impl Node<TestContext, DefaultAction> for TestNodeImpl {
253            type Output = String;
254
255            fn id(&self) -> NodeId {
256                self.id.clone()
257            }
258
259            async fn process(
260                &self,
261                ctx: &mut TestContext,
262            ) -> Result<NodeOutcome<Self::Output, DefaultAction>, FloxideError> {
263                ctx.counter += 1;
264                if ctx.counter <= ctx.should_fail_until {
265                    Err(FloxideError::node_execution("test", "Simulated failure"))
266                } else {
267                    Ok(NodeOutcome::<String, DefaultAction>::Success(
268                        "success".to_string(),
269                    ))
270                }
271            }
272        }
273
274        let test_node = TestNodeImpl {
275            id: "test-node".to_string(),
276        };
277
278        let retry_node = RetryNode::with_constant_backoff(
279            test_node,
280            5,
281            Duration::from_millis(10), // short delay for tests
282        );
283
284        let mut ctx = TestContext {
285            counter: 0,
286            should_fail_until: 2, // fail first 2 attempts
287        };
288
289        let result = retry_node.process(&mut ctx).await;
290        assert!(result.is_ok());
291        assert_eq!(ctx.counter, 3); // should have run 3 times
292    }
293
294    #[tokio::test]
295    async fn test_retry_exhausts_attempts() {
296        // Create a custom node implementation that always fails
297        #[derive(Debug)]
298        struct AlwaysFailNode {
299            id: NodeId,
300        }
301
302        #[async_trait]
303        impl Node<TestContext, DefaultAction> for AlwaysFailNode {
304            type Output = String;
305
306            fn id(&self) -> NodeId {
307                self.id.clone()
308            }
309
310            async fn process(
311                &self,
312                _ctx: &mut TestContext,
313            ) -> Result<NodeOutcome<Self::Output, DefaultAction>, FloxideError> {
314                Err(FloxideError::node_execution("test", "Always failing"))
315            }
316        }
317
318        let test_node = AlwaysFailNode {
319            id: "always-fail".to_string(),
320        };
321
322        let retry_node = RetryNode::with_constant_backoff(test_node, 3, Duration::from_millis(10));
323
324        let mut ctx = TestContext {
325            counter: 0,
326            should_fail_until: 999, // always fail
327        };
328
329        let result = retry_node.process(&mut ctx).await;
330        assert!(result.is_err());
331        // Should have attempted 3 times (max_retries)
332    }
333
334    #[tokio::test]
335    async fn test_backoff_strategies() {
336        // Test constant backoff
337        let constant = BackoffStrategy::Constant(Duration::from_millis(100));
338        assert_eq!(constant.calculate_delay(1), Duration::from_millis(100));
339        assert_eq!(constant.calculate_delay(2), Duration::from_millis(100));
340
341        // Test linear backoff
342        let linear = BackoffStrategy::Linear {
343            base: Duration::from_millis(100),
344            increment: Duration::from_millis(50),
345        };
346        assert_eq!(linear.calculate_delay(0), Duration::from_millis(100));
347        assert_eq!(linear.calculate_delay(1), Duration::from_millis(150));
348        assert_eq!(linear.calculate_delay(2), Duration::from_millis(200));
349
350        // Test exponential backoff
351        let exponential = BackoffStrategy::Exponential {
352            base: Duration::from_millis(100),
353            max: Duration::from_millis(1000),
354        };
355        assert_eq!(exponential.calculate_delay(0), Duration::from_millis(100));
356        assert_eq!(exponential.calculate_delay(1), Duration::from_millis(200));
357        assert_eq!(exponential.calculate_delay(2), Duration::from_millis(400));
358        assert_eq!(exponential.calculate_delay(3), Duration::from_millis(800));
359        assert_eq!(exponential.calculate_delay(4), Duration::from_millis(1000)); // Max reached
360
361        // Test custom backoff
362        let custom = BackoffStrategy::Custom(CustomBackoff {
363            func: Box::new(|attempt| Duration::from_millis(attempt as u64 * 25)),
364        });
365        assert_eq!(custom.calculate_delay(1), Duration::from_millis(25));
366        assert_eq!(custom.calculate_delay(2), Duration::from_millis(50));
367        assert_eq!(custom.calculate_delay(10), Duration::from_millis(250));
368    }
369}