use std::{collections::HashMap, sync::Arc};
use serde::{Deserialize, Serialize};
use crate::{legend, CompensationOutcome, RetryPolicy, Step, StepOutcome};
#[derive(thiserror::Error, Clone, Debug, Serialize, Deserialize, PartialEq)]
pub enum MathError {
#[error("Overflow")]
Overflow,
#[error("Underflow")]
Underflow,
#[error("Transient")]
Transient,
}
#[derive(Default, Debug, Serialize, Deserialize, Clone)]
pub struct MathContext {
pub r: HashMap<Arc<str>, u8>,
pub to_rollback: HashMap<Arc<str>, Option<u8>>,
pub pause_after: Option<usize>,
pub fail_count: usize,
}
pub struct Add;
#[async_trait::async_trait]
impl Step<MathContext, MathError> for Add {
type Input = (Arc<str>, u8, u8);
async fn execute(ctx: &mut MathContext, input: &Self::Input) -> Result<StepOutcome, MathError> {
let (name, a, b) = input;
let x = (*a).checked_add(*b).ok_or(MathError::Overflow)?;
let previous = ctx.r.insert(name.clone(), x);
if !ctx.to_rollback.contains_key(name) {
ctx.to_rollback.insert(name.clone(), previous);
}
if let Some(ref mut count) = ctx.pause_after {
if *count == 0 {
return Ok(StepOutcome::Pause);
}
*count -= 1;
}
Ok(StepOutcome::Continue)
}
async fn compensate(
ctx: &mut MathContext,
input: &Self::Input,
) -> Result<CompensationOutcome, MathError> {
let (name, _, _) = input;
match ctx.to_rollback.remove(name) {
Some(Some(prev)) => {
ctx.r.insert(name.clone(), prev);
}
Some(None) => {
ctx.r.remove(name);
}
None => {}
}
Ok(CompensationOutcome::Completed)
}
}
pub struct Sub;
#[async_trait::async_trait]
impl Step<MathContext, MathError> for Sub {
type Input = (Arc<str>, u8, u8);
async fn execute(ctx: &mut MathContext, input: &Self::Input) -> Result<StepOutcome, MathError> {
let (name, a, b) = input;
let x = (*a).checked_sub(*b).ok_or(MathError::Underflow)?;
let previous = ctx.r.insert(name.clone(), x);
if !ctx.to_rollback.contains_key(name) {
ctx.to_rollback.insert(name.clone(), previous);
}
Ok(StepOutcome::Continue)
}
async fn compensate(
ctx: &mut MathContext,
input: &Self::Input,
) -> Result<CompensationOutcome, MathError> {
let (name, _, _) = input;
match ctx.to_rollback.remove(name) {
Some(Some(prev)) => {
ctx.r.insert(name.clone(), prev);
}
Some(None) => {
ctx.r.remove(name);
}
None => {}
}
Ok(CompensationOutcome::Completed)
}
}
pub struct Halt;
#[async_trait::async_trait]
impl Step<MathContext, MathError> for Halt {
type Input = ();
async fn execute(
ctx: &mut MathContext,
_input: &Self::Input,
) -> Result<StepOutcome, MathError> {
ctx.to_rollback.clear();
Ok(StepOutcome::Continue)
}
async fn compensate(
_ctx: &mut MathContext,
_input: &Self::Input,
) -> Result<CompensationOutcome, MathError> {
Ok(CompensationOutcome::Completed)
}
}
pub struct Flaky;
#[async_trait::async_trait]
impl Step<MathContext, MathError> for Flaky {
type Input = Arc<str>;
async fn execute(ctx: &mut MathContext, input: &Self::Input) -> Result<StepOutcome, MathError> {
if ctx.fail_count > 0 {
ctx.fail_count -= 1;
return Err(MathError::Transient);
}
ctx.r.insert(input.clone(), 1);
Ok(StepOutcome::Continue)
}
async fn compensate(
ctx: &mut MathContext,
input: &Self::Input,
) -> Result<CompensationOutcome, MathError> {
ctx.r.remove(input);
Ok(CompensationOutcome::Completed)
}
fn retry_policy() -> RetryPolicy {
RetryPolicy::retries(3)
}
}
legend! {
Math<MathContext, MathError> {
add: Add,
sub: Sub,
halt: Halt,
}
}
legend! {
MathWithFlaky<MathContext, MathError> {
add: Add,
flaky: Flaky,
halt: Halt,
}
}
legend! {
SingleAdd<MathContext, MathError> {
add: Add,
}
}
legend! {
SingleHalt<MathContext, MathError> {
halt: Halt,
}
}