use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepOutcome {
Continue,
Pause,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CompensationOutcome {
Completed,
Pause,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum RetryPolicy {
#[default]
NoRetry,
Retry {
max_attempts: u8,
backoff_ms: u64,
},
}
impl RetryPolicy {
pub const fn retries(max_attempts: u8) -> Self {
Self::Retry {
max_attempts,
backoff_ms: 0,
}
}
pub const fn retries_with_backoff(max_attempts: u8, backoff_ms: u64) -> Self {
Self::Retry {
max_attempts,
backoff_ms,
}
}
pub fn should_retry(&self, attempt: u8) -> bool {
match self {
Self::NoRetry => false,
Self::Retry { max_attempts, .. } => attempt < *max_attempts,
}
}
pub fn backoff_ms(&self) -> u64 {
match self {
Self::NoRetry => 0,
Self::Retry { backoff_ms, .. } => *backoff_ms,
}
}
}
#[async_trait::async_trait]
pub trait Step<Ctx, Err>: Send + Sync + 'static
where
Ctx: Send + Sync,
Err: Send + Sync,
{
type Input: Serialize + DeserializeOwned + Send + Sync + Clone + 'static;
async fn execute(ctx: &mut Ctx, input: &Self::Input) -> Result<StepOutcome, Err>;
async fn compensate(ctx: &mut Ctx, input: &Self::Input) -> Result<CompensationOutcome, Err>;
fn retry_policy() -> RetryPolicy {
RetryPolicy::NoRetry
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepWrapper<S: Step<Ctx, Err>, Ctx, Err>
where
Ctx: Send + Sync,
Err: Send + Sync,
{
input: S::Input,
#[serde(skip)]
_marker: std::marker::PhantomData<(S, Ctx, Err)>,
}
impl<S, Ctx, Err> StepWrapper<S, Ctx, Err>
where
S: Step<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync,
{
pub fn new(input: S::Input) -> Self {
Self {
input,
_marker: std::marker::PhantomData,
}
}
pub fn input(&self) -> &S::Input {
&self.input
}
pub async fn execute(&self, ctx: &mut Ctx) -> Result<StepOutcome, Err> {
S::execute(ctx, &self.input).await
}
pub async fn compensate(&self, ctx: &mut Ctx) -> Result<CompensationOutcome, Err> {
S::compensate(ctx, &self.input).await
}
pub fn retry_policy(&self) -> RetryPolicy {
S::retry_policy()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn retry_policy_should_retry() {
let no_retry = RetryPolicy::NoRetry;
assert!(!no_retry.should_retry(0));
assert!(!no_retry.should_retry(1));
let retry_3 = RetryPolicy::retries(3);
assert!(retry_3.should_retry(0));
assert!(retry_3.should_retry(1));
assert!(retry_3.should_retry(2));
assert!(!retry_3.should_retry(3));
assert!(!retry_3.should_retry(4));
}
#[test]
fn retry_policy_backoff() {
let no_retry = RetryPolicy::NoRetry;
assert_eq!(no_retry.backoff_ms(), 0);
let retry_with_backoff = RetryPolicy::retries_with_backoff(3, 100);
assert_eq!(retry_with_backoff.backoff_ms(), 100);
}
}