use crate::context::{Context, WorkflowCtx};
use crate::error::FloxideError;
use crate::node::Node;
use crate::transition::Transition;
use async_trait::async_trait;
use std::time::Duration;
#[async_trait]
pub trait RetryDelay {
async fn wait(&self, dur: Duration) -> Result<(), FloxideError>;
}
#[async_trait]
impl<S: Context> RetryDelay for WorkflowCtx<S> {
async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
self.run_future(async {
tokio::time::sleep(dur).await;
Ok(())
})
.await
}
}
#[async_trait]
impl<S: Context> RetryDelay for S {
async fn wait(&self, dur: Duration) -> Result<(), FloxideError> {
tokio::time::sleep(dur).await;
Ok(())
}
}
#[derive(Clone, Copy, Debug)]
pub enum BackoffStrategy {
Linear,
Exponential,
}
#[derive(Clone, Copy, Debug)]
pub enum RetryError {
All,
Cancelled,
Timeout,
Generic,
}
#[derive(Clone, Debug)]
pub struct RetryPolicy {
pub max_attempts: usize,
pub initial_backoff: Duration,
pub max_backoff: Duration,
pub strategy: BackoffStrategy,
pub jitter: Option<Duration>,
pub retry_error: RetryError,
}
impl RetryPolicy {
pub fn new(
max_attempts: usize,
initial_backoff: Duration,
max_backoff: Duration,
strategy: BackoffStrategy,
retry_error: RetryError,
) -> Self {
RetryPolicy {
max_attempts,
initial_backoff,
max_backoff,
strategy,
jitter: None,
retry_error,
}
}
pub fn with_jitter(mut self, jitter: Duration) -> Self {
self.jitter = Some(jitter);
self
}
pub fn should_retry(&self, error: &FloxideError, attempt: usize) -> bool {
if attempt >= self.max_attempts {
return false;
}
match self.retry_error {
RetryError::All => true,
RetryError::Cancelled => matches!(error, FloxideError::Cancelled),
RetryError::Timeout => matches!(error, FloxideError::Timeout(_)),
RetryError::Generic => matches!(error, FloxideError::Generic(_)),
}
}
pub fn backoff_duration(&self, attempt: usize) -> Duration {
let base = match self.strategy {
BackoffStrategy::Linear => self.initial_backoff.saturating_mul(attempt as u32),
BackoffStrategy::Exponential => {
let exp = attempt.saturating_sub(1);
let factor = if exp < 32 { 1_u32 << exp } else { u32::MAX };
self.initial_backoff.saturating_mul(factor)
}
};
let capped = if base > self.max_backoff {
self.max_backoff
} else {
base
};
if let Some(j) = self.jitter {
capped.saturating_add(j)
} else {
capped
}
}
}
pub fn with_retry<N>(node: N, policy: RetryPolicy) -> RetryNode<N> {
RetryNode::new(node, policy)
}
#[derive(Clone, Debug)]
pub struct RetryNode<N> {
pub inner: N,
pub policy: RetryPolicy,
}
impl<N> RetryNode<N> {
pub fn new(inner: N, policy: RetryPolicy) -> Self {
RetryNode { inner, policy }
}
}
#[async_trait]
impl<C, N> Node<C> for RetryNode<N>
where
C: Context + RetryDelay,
N: Node<C> + Clone + Send + Sync + 'static,
N::Input: Clone + Send + 'static,
N::Output: Send + 'static,
{
type Input = N::Input;
type Output = N::Output;
async fn process(
&self,
ctx: &C,
input: Self::Input,
) -> Result<Transition<Self::Output>, FloxideError> {
let mut attempt = 1;
loop {
match self.inner.process(ctx, input.clone()).await {
Ok(Transition::NextAll(vs)) => return Ok(Transition::NextAll(vs)),
Ok(Transition::Next(out)) => return Ok(Transition::Next(out)),
Ok(Transition::Hold) => return Ok(Transition::Hold),
Ok(Transition::Abort(e)) | Err(e) => {
tracing::debug!(attempt, error=%e, "RetryNode: caught error, evaluating retry policy");
if self.policy.should_retry(&e, attempt) {
let backoff = self.policy.backoff_duration(attempt);
tracing::debug!(attempt, backoff=?backoff, "RetryNode: retrying after backoff");
ctx.wait(backoff).await?;
attempt += 1;
continue;
} else {
tracing::warn!(attempt, error=%e, "RetryNode: aborting after reaching retry limit or non-retryable error");
return Err(e);
}
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn test_linear_backoff() {
let policy = RetryPolicy::new(
5,
Duration::from_millis(100),
Duration::from_millis(1000),
BackoffStrategy::Linear,
RetryError::All,
);
assert_eq!(policy.backoff_duration(1), Duration::from_millis(100));
assert_eq!(policy.backoff_duration(3), Duration::from_millis(300));
assert_eq!(policy.backoff_duration(20), Duration::from_millis(1000));
}
#[test]
fn test_exponential_backoff() {
let policy = RetryPolicy::new(
5,
Duration::from_millis(50),
Duration::from_millis(400),
BackoffStrategy::Exponential,
RetryError::All,
);
assert_eq!(policy.backoff_duration(1), Duration::from_millis(50));
assert_eq!(policy.backoff_duration(2), Duration::from_millis(100));
assert_eq!(policy.backoff_duration(3), Duration::from_millis(200));
assert_eq!(policy.backoff_duration(4), Duration::from_millis(400));
assert_eq!(policy.backoff_duration(5), Duration::from_millis(400));
}
#[test]
fn test_jitter_addition() {
let mut policy = RetryPolicy::new(
3,
Duration::from_millis(100),
Duration::from_millis(1000),
BackoffStrategy::Linear,
RetryError::All,
);
policy = policy.with_jitter(Duration::from_millis(25));
assert_eq!(
policy.backoff_duration(2),
Duration::from_millis(100 * 2 + 25)
);
}
#[test]
fn test_retry_predicates() {
let mut policy = RetryPolicy::new(
3,
Duration::from_millis(10),
Duration::from_millis(100),
BackoffStrategy::Linear,
RetryError::Generic,
);
let gen_err = FloxideError::Generic("oops".into());
let cancel_err = FloxideError::Cancelled;
let timeout_err = FloxideError::Timeout(Duration::from_secs(1));
assert!(policy.should_retry(&gen_err, 1));
assert!(!policy.should_retry(&cancel_err, 1));
assert!(!policy.should_retry(&timeout_err, 1));
assert!(!policy.should_retry(&gen_err, 3));
policy.retry_error = RetryError::All;
assert!(policy.should_retry(&cancel_err, 2));
assert!(policy.should_retry(&timeout_err, 2));
}
}