use crate::{WorkflowRetryPolicy, WorkflowState};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub workflow_id: Uuid,
pub timestamp: u64,
pub completed_tasks: Vec<Uuid>,
pub failed_tasks: Vec<(Uuid, String)>,
pub in_progress_tasks: Vec<Uuid>,
pub state: WorkflowState,
pub version: u32,
}
impl WorkflowCheckpoint {
pub fn new(workflow_id: Uuid, state: WorkflowState) -> Self {
Self {
workflow_id,
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs(),
completed_tasks: Vec::new(),
failed_tasks: Vec::new(),
in_progress_tasks: Vec::new(),
state,
version: 1,
}
}
pub fn record_completed(&mut self, task_id: Uuid) {
self.completed_tasks.push(task_id);
self.in_progress_tasks.retain(|&id| id != task_id);
}
pub fn record_failed(&mut self, task_id: Uuid, error: String) {
self.failed_tasks.push((task_id, error));
self.in_progress_tasks.retain(|&id| id != task_id);
}
pub fn record_in_progress(&mut self, task_id: Uuid) {
if !self.in_progress_tasks.contains(&task_id) {
self.in_progress_tasks.push(task_id);
}
}
pub fn is_completed(&self, task_id: &Uuid) -> bool {
self.completed_tasks.contains(task_id)
}
pub fn is_failed(&self, task_id: &Uuid) -> bool {
self.failed_tasks.iter().any(|(id, _)| id == task_id)
}
pub fn tasks_to_retry(&self) -> &[Uuid] {
&self.in_progress_tasks
}
pub fn to_json(&self) -> Result<String, serde_json::Error> {
serde_json::to_string(self)
}
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
}
impl std::fmt::Display for WorkflowCheckpoint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"WorkflowCheckpoint[id={}, completed={}, failed={}, in_progress={}]",
self.workflow_id,
self.completed_tasks.len(),
self.failed_tasks.len(),
self.in_progress_tasks.len()
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowRecoveryPolicy {
pub auto_recovery: bool,
pub resume_from_checkpoint: bool,
pub replay_failed: bool,
pub max_checkpoint_age: Option<u64>,
pub retry_policy: Option<WorkflowRetryPolicy>,
}
impl WorkflowRecoveryPolicy {
pub fn auto_recover() -> Self {
Self {
auto_recovery: true,
resume_from_checkpoint: true,
replay_failed: true,
max_checkpoint_age: Some(3600), retry_policy: None,
}
}
pub fn manual() -> Self {
Self {
auto_recovery: false,
resume_from_checkpoint: true,
replay_failed: false,
max_checkpoint_age: None,
retry_policy: None,
}
}
pub fn with_max_checkpoint_age(mut self, seconds: u64) -> Self {
self.max_checkpoint_age = Some(seconds);
self
}
pub fn with_retry_policy(mut self, policy: WorkflowRetryPolicy) -> Self {
self.retry_policy = Some(policy);
self
}
pub fn is_checkpoint_valid(&self, checkpoint: &WorkflowCheckpoint) -> bool {
if let Some(max_age) = self.max_checkpoint_age {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let age = now.saturating_sub(checkpoint.timestamp);
age <= max_age
} else {
true
}
}
}
impl std::fmt::Display for WorkflowRecoveryPolicy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "WorkflowRecoveryPolicy[")?;
if self.auto_recovery {
write!(f, "auto")?;
} else {
write!(f, "manual")?;
}
if self.resume_from_checkpoint {
write!(f, " resume")?;
}
if self.replay_failed {
write!(f, " replay_failed")?;
}
write!(f, "]")
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct StateVersion {
pub major: u32,
pub minor: u32,
pub patch: u32,
}
impl StateVersion {
pub fn new(major: u32, minor: u32, patch: u32) -> Self {
Self {
major,
minor,
patch,
}
}
pub fn current() -> Self {
Self {
major: 1,
minor: 0,
patch: 0,
}
}
pub fn is_compatible(&self, other: &StateVersion) -> bool {
self.major == other.major && self.minor <= other.minor
}
}
impl std::fmt::Display for StateVersion {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}.{}.{}", self.major, self.minor, self.patch)
}
}
#[derive(Debug, Clone)]
pub enum StateMigrationError {
IncompatibleVersion {
from: StateVersion,
to: StateVersion,
},
MigrationFailed(String),
UnsupportedVersion(StateVersion),
}
impl std::fmt::Display for StateMigrationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::IncompatibleVersion { from, to } => {
write!(f, "Incompatible state version: {} -> {}", from, to)
}
Self::MigrationFailed(msg) => write!(f, "State migration failed: {}", msg),
Self::UnsupportedVersion(version) => {
write!(f, "Unsupported state version: {}", version)
}
}
}
}
impl std::error::Error for StateMigrationError {}
pub trait StateMigration {
fn migrate(&self, from: StateVersion, to: StateVersion) -> Result<(), StateMigrationError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionedWorkflowState {
pub version: StateVersion,
pub state: WorkflowState,
pub migration_history: Vec<(StateVersion, StateVersion, u64)>, }
impl VersionedWorkflowState {
pub fn new(state: WorkflowState) -> Self {
Self {
version: StateVersion::current(),
state,
migration_history: Vec::new(),
}
}
pub fn migrate_to(&mut self, target: StateVersion) -> Result<(), StateMigrationError> {
if self.version == target {
return Ok(());
}
if !self.version.is_compatible(&target) {
return Err(StateMigrationError::IncompatibleVersion {
from: self.version,
to: target,
});
}
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
self.migration_history
.push((self.version, target, timestamp));
self.version = target;
Ok(())
}
pub fn can_migrate_to(&self, target: &StateVersion) -> bool {
self.version.is_compatible(target)
}
pub fn get_migration_history(&self) -> &[(StateVersion, StateVersion, u64)] {
&self.migration_history
}
}