1use async_trait::async_trait;
2use std::fmt::{self, Debug, Formatter};
3use std::marker::PhantomData;
4use std::time::Duration;
5
6use crate::action::ActionType;
7use crate::error::FloxideError;
8use crate::node::{Node, NodeId, NodeOutcome};
9
10#[derive(Clone)]
12pub enum BackoffStrategy {
13 Constant(Duration),
15 Linear { base: Duration, increment: Duration },
17 Exponential { base: Duration, max: Duration },
19 Custom(CustomBackoff),
21}
22
23pub struct CustomBackoff {
25 func: Box<dyn Fn(usize) -> Duration + Send + Sync>,
26}
27
28impl Clone for CustomBackoff {
29 fn clone(&self) -> Self {
30 Self {
33 func: Box::new(|attempt| Duration::from_millis(100 * attempt as u64)),
34 }
35 }
36}
37
38impl Debug for CustomBackoff {
39 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
40 f.debug_struct("CustomBackoff")
41 .field("func", &"<function>")
42 .finish()
43 }
44}
45
46impl BackoffStrategy {
47 pub fn calculate_delay(&self, attempt: usize) -> Duration {
49 match self {
50 Self::Constant(duration) => *duration,
51 Self::Linear { base, increment } => *base + (*increment * attempt as u32),
52 Self::Exponential { base, max } => {
53 let calculated = *base * u32::pow(2, attempt as u32);
54 std::cmp::min(calculated, *max)
55 }
56 Self::Custom(custom) => (custom.func)(attempt),
57 }
58 }
59}
60
61impl Debug for BackoffStrategy {
62 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63 match self {
64 Self::Constant(duration) => f.debug_tuple("Constant").field(duration).finish(),
65 Self::Linear { base, increment } => f
66 .debug_struct("Linear")
67 .field("base", base)
68 .field("increment", increment)
69 .finish(),
70 Self::Exponential { base, max } => f
71 .debug_struct("Exponential")
72 .field("base", base)
73 .field("max", max)
74 .finish(),
75 Self::Custom(custom) => f.debug_tuple("Custom").field(custom).finish(),
76 }
77 }
78}
79
80pub struct RetryNode<N, Context, A = crate::action::DefaultAction>
82where
83 N: Node<Context, A>,
84 Context: Send + Sync + 'static,
85 A: ActionType + Send + Sync + 'static,
86{
87 inner_node: N,
89 max_retries: usize,
91 backoff_strategy: BackoffStrategy,
93 _context: PhantomData<Context>,
95 _action: PhantomData<A>,
96}
97
98impl<N, Context, A> RetryNode<N, Context, A>
99where
100 N: Node<Context, A>,
101 Context: Send + Sync + 'static,
102 A: ActionType + Send + Sync + 'static,
103{
104 pub fn with_constant_backoff(inner_node: N, max_retries: usize, delay: Duration) -> Self {
106 Self {
107 inner_node,
108 max_retries,
109 backoff_strategy: BackoffStrategy::Constant(delay),
110 _context: PhantomData,
111 _action: PhantomData,
112 }
113 }
114
115 pub fn with_linear_backoff(
117 inner_node: N,
118 max_retries: usize,
119 base: Duration,
120 increment: Duration,
121 ) -> Self {
122 Self {
123 inner_node,
124 max_retries,
125 backoff_strategy: BackoffStrategy::Linear { base, increment },
126 _context: PhantomData,
127 _action: PhantomData,
128 }
129 }
130
131 pub fn with_exponential_backoff(
133 inner_node: N,
134 max_retries: usize,
135 base: Duration,
136 max: Duration,
137 ) -> Self {
138 Self {
139 inner_node,
140 max_retries,
141 backoff_strategy: BackoffStrategy::Exponential { base, max },
142 _context: PhantomData,
143 _action: PhantomData,
144 }
145 }
146
147 pub fn with_custom_backoff<F>(inner_node: N, max_retries: usize, f: F) -> Self
149 where
150 F: Fn(usize) -> Duration + Send + Sync + 'static,
151 {
152 Self {
153 inner_node,
154 max_retries,
155 backoff_strategy: BackoffStrategy::Custom(CustomBackoff { func: Box::new(f) }),
156 _context: PhantomData,
157 _action: PhantomData,
158 }
159 }
160}
161
162#[async_trait]
163impl<N, Context, A> Node<Context, A> for RetryNode<N, Context, A>
164where
165 N: Node<Context, A> + std::fmt::Debug + Send + Sync,
166 Context: std::fmt::Debug + Send + Sync + 'static,
167 A: crate::action::ActionType + Default + std::fmt::Debug + Send + Sync + 'static,
168 N::Output: Clone + Send + Sync + 'static,
169{
170 type Output = N::Output;
171
172 fn id(&self) -> NodeId {
173 self.inner_node.id()
174 }
175
176 async fn process(
177 &self,
178 ctx: &mut Context,
179 ) -> Result<NodeOutcome<Self::Output, A>, FloxideError> {
180 let mut attempt = 0;
181 loop {
182 attempt += 1;
183 match self.inner_node.process(ctx).await {
184 Ok(outcome) => {
185 tracing::debug!(
186 attempt = attempt,
187 node_id = %self.id(),
188 "Node completed successfully after {} attempts",
189 attempt
190 );
191 return Ok(outcome);
192 }
193 Err(err) => {
194 if attempt >= self.max_retries {
195 tracing::error!(
196 attempt = attempt,
197 max_retries = self.max_retries,
198 node_id = %self.id(),
199 error = %err,
200 "Maximum retry attempts reached, failing"
201 );
202 return Err(err);
203 }
204
205 let delay = self.backoff_strategy.calculate_delay(attempt);
206 tracing::warn!(
207 attempt = attempt,
208 node_id = %self.id(),
209 error = %err,
210 delay_ms = delay.as_millis(),
211 "Node execution failed, retrying after {:?}",
212 delay
213 );
214
215 #[cfg(feature = "async")]
216 {
217 tokio::time::sleep(delay).await;
218 }
219
220 #[cfg(not(feature = "async"))]
221 {
222 std::thread::sleep(delay);
224 }
225 }
226 }
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234
235 use crate::DefaultAction;
236
237 #[derive(Debug, Clone)]
238 struct TestContext {
239 counter: usize,
240 should_fail_until: usize,
241 }
242
243 #[tokio::test]
244 async fn test_retry_success_after_failures() {
245 #[derive(Debug)]
247 struct TestNodeImpl {
248 id: NodeId,
249 }
250
251 #[async_trait]
252 impl Node<TestContext, DefaultAction> for TestNodeImpl {
253 type Output = String;
254
255 fn id(&self) -> NodeId {
256 self.id.clone()
257 }
258
259 async fn process(
260 &self,
261 ctx: &mut TestContext,
262 ) -> Result<NodeOutcome<Self::Output, DefaultAction>, FloxideError> {
263 ctx.counter += 1;
264 if ctx.counter <= ctx.should_fail_until {
265 Err(FloxideError::node_execution("test", "Simulated failure"))
266 } else {
267 Ok(NodeOutcome::<String, DefaultAction>::Success(
268 "success".to_string(),
269 ))
270 }
271 }
272 }
273
274 let test_node = TestNodeImpl {
275 id: "test-node".to_string(),
276 };
277
278 let retry_node = RetryNode::with_constant_backoff(
279 test_node,
280 5,
281 Duration::from_millis(10), );
283
284 let mut ctx = TestContext {
285 counter: 0,
286 should_fail_until: 2, };
288
289 let result = retry_node.process(&mut ctx).await;
290 assert!(result.is_ok());
291 assert_eq!(ctx.counter, 3); }
293
294 #[tokio::test]
295 async fn test_retry_exhausts_attempts() {
296 #[derive(Debug)]
298 struct AlwaysFailNode {
299 id: NodeId,
300 }
301
302 #[async_trait]
303 impl Node<TestContext, DefaultAction> for AlwaysFailNode {
304 type Output = String;
305
306 fn id(&self) -> NodeId {
307 self.id.clone()
308 }
309
310 async fn process(
311 &self,
312 _ctx: &mut TestContext,
313 ) -> Result<NodeOutcome<Self::Output, DefaultAction>, FloxideError> {
314 Err(FloxideError::node_execution("test", "Always failing"))
315 }
316 }
317
318 let test_node = AlwaysFailNode {
319 id: "always-fail".to_string(),
320 };
321
322 let retry_node = RetryNode::with_constant_backoff(test_node, 3, Duration::from_millis(10));
323
324 let mut ctx = TestContext {
325 counter: 0,
326 should_fail_until: 999, };
328
329 let result = retry_node.process(&mut ctx).await;
330 assert!(result.is_err());
331 }
333
334 #[tokio::test]
335 async fn test_backoff_strategies() {
336 let constant = BackoffStrategy::Constant(Duration::from_millis(100));
338 assert_eq!(constant.calculate_delay(1), Duration::from_millis(100));
339 assert_eq!(constant.calculate_delay(2), Duration::from_millis(100));
340
341 let linear = BackoffStrategy::Linear {
343 base: Duration::from_millis(100),
344 increment: Duration::from_millis(50),
345 };
346 assert_eq!(linear.calculate_delay(0), Duration::from_millis(100));
347 assert_eq!(linear.calculate_delay(1), Duration::from_millis(150));
348 assert_eq!(linear.calculate_delay(2), Duration::from_millis(200));
349
350 let exponential = BackoffStrategy::Exponential {
352 base: Duration::from_millis(100),
353 max: Duration::from_millis(1000),
354 };
355 assert_eq!(exponential.calculate_delay(0), Duration::from_millis(100));
356 assert_eq!(exponential.calculate_delay(1), Duration::from_millis(200));
357 assert_eq!(exponential.calculate_delay(2), Duration::from_millis(400));
358 assert_eq!(exponential.calculate_delay(3), Duration::from_millis(800));
359 assert_eq!(exponential.calculate_delay(4), Duration::from_millis(1000)); let custom = BackoffStrategy::Custom(CustomBackoff {
363 func: Box::new(|attempt| Duration::from_millis(attempt as u64 * 25)),
364 });
365 assert_eq!(custom.calculate_delay(1), Duration::from_millis(25));
366 assert_eq!(custom.calculate_delay(2), Duration::from_millis(50));
367 assert_eq!(custom.calculate_delay(10), Duration::from_millis(250));
368 }
369}