use serde::{Deserialize, Serialize};
use std::marker::PhantomData;
use crate::hlist::{BlockResult, InstructionList};
use crate::store::now_millis;
#[derive(Debug, Clone, Copy, Default)]
pub struct New;
#[derive(Debug, Clone, Copy)]
pub struct Paused;
#[derive(Debug, Clone, Copy)]
pub struct Completed;
#[derive(Debug, Clone, Copy)]
pub struct Failed;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
pub enum ExecutionPhase {
#[default]
Forward,
Compensating,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StepStatus {
Continue,
Pause,
Failed,
Compensated,
CompensationFailed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepTiming {
pub step_index: usize,
pub is_compensation: bool,
pub started_at: u64,
pub completed_at: Option<u64>,
pub outcome: Option<StepStatus>,
}
impl StepTiming {
pub fn new(step_index: usize, is_compensation: bool) -> Self {
Self {
step_index,
is_compensation,
started_at: now_millis(),
completed_at: None,
outcome: None,
}
}
pub fn complete(&mut self, outcome: StepStatus) {
self.completed_at = Some(now_millis());
self.outcome = Some(outcome);
}
pub fn duration_ms(&self) -> Option<u64> {
self.completed_at
.map(|end| end.saturating_sub(self.started_at))
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ExecutionState {
index: usize,
executed: usize,
retry_count: u8,
phase: ExecutionPhase,
started_at: Option<u64>,
step_timings: Vec<StepTiming>,
}
impl ExecutionState {
pub fn new() -> Self {
Self::default()
}
pub fn current_index(&self) -> usize {
self.index
}
pub fn executed_count(&self) -> usize {
self.executed
}
pub fn retry_count(&self) -> u8 {
self.retry_count
}
pub fn phase(&self) -> ExecutionPhase {
self.phase
}
pub fn advance(&mut self) {
self.index += 1;
self.retry_count = 0;
}
pub fn mark_executed(&mut self) {
self.executed = self.executed.max(self.index);
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
pub fn begin_compensation(&mut self) {
self.phase = ExecutionPhase::Compensating;
}
pub fn is_compensating(&self) -> bool {
self.phase == ExecutionPhase::Compensating
}
pub fn mark_started(&mut self) {
if self.started_at.is_none() {
self.started_at = Some(now_millis());
}
}
pub fn started_at(&self) -> Option<u64> {
self.started_at
}
pub fn record_step_start(&mut self, step_index: usize, is_compensation: bool) {
self.step_timings
.push(StepTiming::new(step_index, is_compensation));
}
pub fn record_step_end(&mut self, outcome: StepStatus) {
if let Some(timing) = self.step_timings.last_mut() {
timing.complete(outcome);
}
}
pub fn step_timings(&self) -> &[StepTiming] {
&self.step_timings
}
pub fn timing_for_step(&self, step_index: usize, is_compensation: bool) -> Option<&StepTiming> {
self.step_timings
.iter()
.find(|t| t.step_index == step_index && t.is_compensation == is_compensation)
}
}
pub enum ExecutionResult<Ctx, Err, Steps>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
Paused(Execution<Ctx, Err, Steps, Paused>),
Completed(Execution<Ctx, Err, Steps, Completed>),
Failed(Execution<Ctx, Err, Steps, Failed>, Err),
CompensationFailed {
execution: Execution<Ctx, Err, Steps, Failed>,
original_error: Err,
compensation_error: Err,
failed_at: usize,
},
}
impl<Ctx, Err, Steps> ExecutionResult<Ctx, Err, Steps>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
pub fn is_paused(&self) -> bool {
matches!(self, Self::Paused(_))
}
pub fn is_completed(&self) -> bool {
matches!(self, Self::Completed(_))
}
pub fn is_failed(&self) -> bool {
matches!(self, Self::Failed(_, _) | Self::CompensationFailed { .. })
}
}
#[derive(Serialize, Deserialize)]
#[serde(
bound = "Steps: Serialize + serde::de::DeserializeOwned, Ctx: Serialize + serde::de::DeserializeOwned, Err: Serialize + serde::de::DeserializeOwned"
)]
pub struct Execution<Ctx, Err, Steps, State = New>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
steps: Steps,
ctx: Ctx,
state: ExecutionState,
#[serde(skip)]
_marker: PhantomData<(Err, State)>,
}
impl<Ctx, Err, Steps> Execution<Ctx, Err, Steps, New>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
pub fn new(steps: Steps, ctx: Ctx) -> Self {
Self {
steps,
ctx,
state: ExecutionState::new(),
_marker: PhantomData,
}
}
pub async fn start(self) -> ExecutionResult<Ctx, Err, Steps> {
self.run_internal().await
}
}
impl<Ctx, Err, Steps> Execution<Ctx, Err, Steps, Paused>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
pub async fn resume(self) -> ExecutionResult<Ctx, Err, Steps> {
self.run_internal().await
}
pub fn context(&self) -> &Ctx {
&self.ctx
}
}
impl<Ctx, Err, Steps> Execution<Ctx, Err, Steps, Completed>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
pub fn context(&self) -> &Ctx {
&self.ctx
}
pub fn into_context(self) -> Ctx {
self.ctx
}
}
impl<Ctx, Err, Steps> Execution<Ctx, Err, Steps, Failed>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
pub fn context(&self) -> &Ctx {
&self.ctx
}
pub fn into_context(self) -> Ctx {
self.ctx
}
}
impl<Ctx, Err, Steps, State> Execution<Ctx, Err, Steps, State>
where
Steps: InstructionList<Ctx, Err>,
Ctx: Send + Sync,
Err: Send + Sync + Clone,
{
async fn run_internal(mut self) -> ExecutionResult<Ctx, Err, Steps> {
self.state.mark_started();
match self.steps.execute_all(&mut self.ctx, &mut self.state).await {
BlockResult::Completed => ExecutionResult::Completed(Execution {
steps: self.steps,
ctx: self.ctx,
state: self.state,
_marker: PhantomData,
}),
BlockResult::Paused => ExecutionResult::Paused(Execution {
steps: self.steps,
ctx: self.ctx,
state: self.state,
_marker: PhantomData,
}),
BlockResult::Failed(e) => ExecutionResult::Failed(
Execution {
steps: self.steps,
ctx: self.ctx,
state: self.state,
_marker: PhantomData,
},
e,
),
BlockResult::CompensationFailed {
original_error,
compensation_error,
failed_at,
} => ExecutionResult::CompensationFailed {
execution: Execution {
steps: self.steps,
ctx: self.ctx,
state: self.state,
_marker: PhantomData,
},
original_error,
compensation_error,
failed_at,
},
}
}
}