1use crate::context::{Context, WorkflowCtx};
6use crate::error::FloxideError;
7use crate::node::Node;
8use crate::transition::Transition;
9use async_trait::async_trait;
10use std::time::Duration;
11
12#[async_trait]
14pub trait RetryDelay {
15 async fn wait(&self, dur: Duration) -> Result<(), FloxideError>;
17}
18
19#[async_trait]
20impl<S: Context> RetryDelay for WorkflowCtx<S> {
21 async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
22 self.run_future(async {
23 tokio::time::sleep(dur).await;
24 Ok(())
25 })
26 .await
27 }
28}
29
30#[async_trait]
31impl<S: Context> RetryDelay for S {
32 async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
33 tokio::time::sleep(dur).await;
34 Ok(())
35 }
36}
37
38#[derive(Clone, Copy, Debug)]
40pub enum BackoffStrategy {
41 Linear,
43 Exponential,
45}
46
47#[derive(Clone, Copy, Debug)]
49pub enum RetryError {
50 All,
52 Cancelled,
54 Timeout,
56 Generic,
58}
59
60#[derive(Clone, Debug)]
62pub struct RetryPolicy {
63 pub max_attempts: usize,
65 pub initial_backoff: Duration,
67 pub max_backoff: Duration,
69 pub strategy: BackoffStrategy,
71 pub jitter: Option<Duration>,
73 pub retry_error: RetryError,
75}
76
77impl RetryPolicy {
78 pub fn new(
80 max_attempts: usize,
81 initial_backoff: Duration,
82 max_backoff: Duration,
83 strategy: BackoffStrategy,
84 retry_error: RetryError,
85 ) -> Self {
86 RetryPolicy {
87 max_attempts,
88 initial_backoff,
89 max_backoff,
90 strategy,
91 jitter: None,
92 retry_error,
93 }
94 }
95
96 pub fn with_jitter(mut self, jitter: Duration) -> Self {
98 self.jitter = Some(jitter);
99 self
100 }
101
102 pub fn should_retry(&self, error: &FloxideError, attempt: usize) -> bool {
104 if attempt >= self.max_attempts {
105 return false;
106 }
107 match self.retry_error {
108 RetryError::All => true,
109 RetryError::Cancelled => matches!(error, FloxideError::Cancelled),
110 RetryError::Timeout => matches!(error, FloxideError::Timeout(_)),
111 RetryError::Generic => matches!(error, FloxideError::Generic(_)),
112 }
113 }
114
115 pub fn backoff_duration(&self, attempt: usize) -> Duration {
117 let base = match self.strategy {
118 BackoffStrategy::Linear => self.initial_backoff.saturating_mul(attempt as u32),
119 BackoffStrategy::Exponential => {
120 let exp = attempt.saturating_sub(1);
122 let factor = if exp < 32 { 1_u32 << exp } else { u32::MAX };
123 self.initial_backoff.saturating_mul(factor)
124 }
125 };
126 let capped = if base > self.max_backoff {
127 self.max_backoff
128 } else {
129 base
130 };
131 if let Some(j) = self.jitter {
132 capped.saturating_add(j)
133 } else {
134 capped
135 }
136 }
137}
138
139pub fn with_retry<N>(node: N, policy: RetryPolicy) -> RetryNode<N> {
166 RetryNode::new(node, policy)
167}
168
169#[derive(Clone, Debug)]
174pub struct RetryNode<N> {
175 pub inner: N,
177 pub policy: RetryPolicy,
179}
180
181impl<N> RetryNode<N> {
182 pub fn new(inner: N, policy: RetryPolicy) -> Self {
184 RetryNode { inner, policy }
185 }
186}
187#[async_trait]
190impl<C, N> Node<C> for RetryNode<N>
191where
192 C: Context + RetryDelay,
193 N: Node<C> + Clone + Send + Sync + 'static,
194 N::Input: Clone + Send + 'static,
195 N::Output: Send + 'static,
196{
197 type Input = N::Input;
198 type Output = N::Output;
199
200 async fn process(
201 &self,
202 ctx: &C,
203 input: Self::Input,
204 ) -> Result<Transition<Self::Output>, FloxideError> {
205 let mut attempt = 1;
206 loop {
207 match self.inner.process(ctx, input.clone()).await {
208 Ok(Transition::NextAll(vs)) => return Ok(Transition::NextAll(vs)),
209 Ok(Transition::Next(out)) => return Ok(Transition::Next(out)),
210 Ok(Transition::Hold) => return Ok(Transition::Hold),
211 Ok(Transition::Abort(e)) | Err(e) => {
212 tracing::debug!(attempt, error=%e, "RetryNode: caught error, evaluating retry policy");
214 if self.policy.should_retry(&e, attempt) {
215 let backoff = self.policy.backoff_duration(attempt);
216 tracing::debug!(attempt, backoff=?backoff, "RetryNode: retrying after backoff");
217 ctx.wait(backoff).await?;
218 attempt += 1;
219 continue;
220 } else {
221 tracing::warn!(attempt, error=%e, "RetryNode: aborting after reaching retry limit or non-retryable error");
222 return Err(e);
223 }
224 }
225 }
226 }
227 }
228}
229
230#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::time::Duration;
235
236 #[test]
237 fn test_linear_backoff() {
238 let policy = RetryPolicy::new(
239 5,
240 Duration::from_millis(100),
241 Duration::from_millis(1000),
242 BackoffStrategy::Linear,
243 RetryError::All,
244 );
245 assert_eq!(policy.backoff_duration(1), Duration::from_millis(100));
246 assert_eq!(policy.backoff_duration(3), Duration::from_millis(300));
247 assert_eq!(policy.backoff_duration(20), Duration::from_millis(1000));
249 }
250
251 #[test]
252 fn test_exponential_backoff() {
253 let policy = RetryPolicy::new(
254 5,
255 Duration::from_millis(50),
256 Duration::from_millis(400),
257 BackoffStrategy::Exponential,
258 RetryError::All,
259 );
260 assert_eq!(policy.backoff_duration(1), Duration::from_millis(50));
262 assert_eq!(policy.backoff_duration(2), Duration::from_millis(100));
263 assert_eq!(policy.backoff_duration(3), Duration::from_millis(200));
264 assert_eq!(policy.backoff_duration(4), Duration::from_millis(400));
265 assert_eq!(policy.backoff_duration(5), Duration::from_millis(400));
266 }
267
268 #[test]
269 fn test_jitter_addition() {
270 let mut policy = RetryPolicy::new(
271 3,
272 Duration::from_millis(100),
273 Duration::from_millis(1000),
274 BackoffStrategy::Linear,
275 RetryError::All,
276 );
277 policy = policy.with_jitter(Duration::from_millis(25));
278 assert_eq!(
279 policy.backoff_duration(2),
280 Duration::from_millis(100 * 2 + 25)
281 );
282 }
283
284 #[test]
285 fn test_retry_predicates() {
286 let mut policy = RetryPolicy::new(
287 3,
288 Duration::from_millis(10),
289 Duration::from_millis(100),
290 BackoffStrategy::Linear,
291 RetryError::Generic,
292 );
293 let gen_err = FloxideError::Generic("oops".into());
294 let cancel_err = FloxideError::Cancelled;
295 let timeout_err = FloxideError::Timeout(Duration::from_secs(1));
296 assert!(policy.should_retry(&gen_err, 1));
298 assert!(!policy.should_retry(&cancel_err, 1));
299 assert!(!policy.should_retry(&timeout_err, 1));
300 assert!(!policy.should_retry(&gen_err, 3));
302 policy.retry_error = RetryError::All;
304 assert!(policy.should_retry(&cancel_err, 2));
305 assert!(policy.should_retry(&timeout_err, 2));
306 }
307}