use crate::dag::WorkflowDag;
use crate::error::{Result, WorkflowError};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
use tokio::fs;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowState {
pub workflow_id: String,
pub execution_id: String,
pub status: WorkflowStatus,
pub task_states: HashMap<String, TaskState>,
pub metadata: WorkflowMetadata,
pub context: ExecutionContext,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum WorkflowStatus {
Pending,
Running,
Completed,
Failed,
Cancelled,
Paused,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskState {
pub task_id: String,
pub status: TaskStatus,
pub attempts: u32,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub duration_ms: Option<u64>,
pub output: Option<serde_json::Value>,
pub error: Option<String>,
pub logs: Vec<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskStatus {
Pending,
Running,
Completed,
Failed,
Skipped,
Cancelled,
WaitingRetry,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowMetadata {
pub name: String,
pub version: String,
pub created_at: DateTime<Utc>,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub duration_ms: Option<u64>,
pub owner: Option<String>,
pub tags: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionContext {
pub variables: HashMap<String, serde_json::Value>,
pub parameters: HashMap<String, serde_json::Value>,
pub env: HashMap<String, String>,
}
impl WorkflowState {
pub fn new(workflow_id: String, execution_id: String, name: String) -> Self {
Self {
workflow_id,
execution_id,
status: WorkflowStatus::Pending,
task_states: HashMap::new(),
metadata: WorkflowMetadata {
name,
version: "1.0.0".to_string(),
created_at: Utc::now(),
started_at: None,
completed_at: None,
duration_ms: None,
owner: None,
tags: HashMap::new(),
},
context: ExecutionContext {
variables: HashMap::new(),
parameters: HashMap::new(),
env: HashMap::new(),
},
}
}
pub fn init_task(&mut self, task_id: String) {
self.task_states.insert(
task_id.clone(),
TaskState {
task_id,
status: TaskStatus::Pending,
attempts: 0,
started_at: None,
completed_at: None,
duration_ms: None,
output: None,
error: None,
logs: Vec::new(),
},
);
}
pub fn start_task(&mut self, task_id: &str) -> Result<()> {
let task_state = self
.task_states
.get_mut(task_id)
.ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
task_state.status = TaskStatus::Running;
task_state.started_at = Some(Utc::now());
task_state.attempts += 1;
Ok(())
}
pub fn complete_task(
&mut self,
task_id: &str,
output: Option<serde_json::Value>,
) -> Result<()> {
let task_state = self
.task_states
.get_mut(task_id)
.ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
task_state.status = TaskStatus::Completed;
task_state.completed_at = Some(Utc::now());
task_state.output = output;
if let Some(started) = task_state.started_at {
task_state.duration_ms = Some(
(Utc::now() - started)
.num_milliseconds()
.try_into()
.unwrap_or(0),
);
}
Ok(())
}
pub fn fail_task(&mut self, task_id: &str, error: String) -> Result<()> {
let task_state = self
.task_states
.get_mut(task_id)
.ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
task_state.status = TaskStatus::Failed;
task_state.completed_at = Some(Utc::now());
task_state.error = Some(error);
if let Some(started) = task_state.started_at {
task_state.duration_ms = Some(
(Utc::now() - started)
.num_milliseconds()
.try_into()
.unwrap_or(0),
);
}
Ok(())
}
pub fn skip_task(&mut self, task_id: &str) -> Result<()> {
let task_state = self
.task_states
.get_mut(task_id)
.ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
task_state.status = TaskStatus::Skipped;
task_state.completed_at = Some(Utc::now());
Ok(())
}
pub fn add_task_log(&mut self, task_id: &str, log: String) -> Result<()> {
let task_state = self
.task_states
.get_mut(task_id)
.ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
task_state.logs.push(log);
Ok(())
}
pub fn start(&mut self) {
self.status = WorkflowStatus::Running;
self.metadata.started_at = Some(Utc::now());
}
pub fn complete(&mut self) {
self.status = WorkflowStatus::Completed;
self.metadata.completed_at = Some(Utc::now());
if let Some(started) = self.metadata.started_at {
self.metadata.duration_ms = Some(
(Utc::now() - started)
.num_milliseconds()
.try_into()
.unwrap_or(0),
);
}
}
pub fn fail(&mut self) {
self.status = WorkflowStatus::Failed;
self.metadata.completed_at = Some(Utc::now());
if let Some(started) = self.metadata.started_at {
self.metadata.duration_ms = Some(
(Utc::now() - started)
.num_milliseconds()
.try_into()
.unwrap_or(0),
);
}
}
pub fn cancel(&mut self) {
self.status = WorkflowStatus::Cancelled;
self.metadata.completed_at = Some(Utc::now());
if let Some(started) = self.metadata.started_at {
self.metadata.duration_ms = Some(
(Utc::now() - started)
.num_milliseconds()
.try_into()
.unwrap_or(0),
);
}
}
pub fn get_task_state(&self, task_id: &str) -> Option<&TaskState> {
self.task_states.get(task_id)
}
pub fn set_variable(&mut self, key: String, value: serde_json::Value) {
self.context.variables.insert(key, value);
}
pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
self.context.variables.get(key)
}
pub fn is_terminal(&self) -> bool {
matches!(
self.status,
WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub version: u32,
pub created_at: DateTime<Utc>,
pub sequence: u64,
pub state: WorkflowState,
pub dag: WorkflowDag,
}
impl WorkflowCheckpoint {
pub const CURRENT_VERSION: u32 = 1;
pub fn new(state: WorkflowState, dag: WorkflowDag, sequence: u64) -> Self {
Self {
version: Self::CURRENT_VERSION,
created_at: Utc::now(),
sequence,
state,
dag,
}
}
pub fn get_pending_tasks(&self) -> Vec<String> {
self.state
.task_states
.iter()
.filter(|(_, ts)| matches!(ts.status, TaskStatus::Pending | TaskStatus::WaitingRetry))
.map(|(id, _)| id.clone())
.collect()
}
pub fn get_interrupted_tasks(&self) -> Vec<String> {
self.state
.task_states
.iter()
.filter(|(_, ts)| ts.status == TaskStatus::Running)
.map(|(id, _)| id.clone())
.collect()
}
pub fn get_completed_tasks(&self) -> Vec<String> {
self.state
.task_states
.iter()
.filter(|(_, ts)| ts.status == TaskStatus::Completed)
.map(|(id, _)| id.clone())
.collect()
}
pub fn get_failed_tasks(&self) -> Vec<String> {
self.state
.task_states
.iter()
.filter(|(_, ts)| ts.status == TaskStatus::Failed)
.map(|(id, _)| id.clone())
.collect()
}
pub fn get_skipped_tasks(&self) -> Vec<String> {
self.state
.task_states
.iter()
.filter(|(_, ts)| ts.status == TaskStatus::Skipped)
.map(|(id, _)| id.clone())
.collect()
}
pub fn are_dependencies_satisfied(&self, task_id: &str) -> bool {
let dependencies = self.dag.get_dependencies(task_id);
dependencies.iter().all(|dep_id| {
self.state
.task_states
.get(dep_id)
.map(|ts| ts.status == TaskStatus::Completed)
.unwrap_or(false)
})
}
pub fn get_ready_tasks(&self) -> Vec<String> {
self.get_pending_tasks()
.into_iter()
.filter(|task_id| self.are_dependencies_satisfied(task_id))
.collect()
}
pub fn prepare_for_resume(&mut self) -> Result<()> {
let interrupted = self.get_interrupted_tasks();
for task_id in interrupted {
if let Some(task_state) = self.state.task_states.get_mut(&task_id) {
task_state.status = TaskStatus::Pending;
}
}
if self.state.status == WorkflowStatus::Paused {
self.state.status = WorkflowStatus::Running;
}
Ok(())
}
}
pub struct StatePersistence {
state_dir: String,
}
impl StatePersistence {
pub fn new(state_dir: String) -> Self {
Self { state_dir }
}
pub async fn save(&self, state: &WorkflowState) -> Result<()> {
let dir_path = Path::new(&self.state_dir);
fs::create_dir_all(dir_path).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to create state dir: {}", e))
})?;
let file_path = dir_path.join(format!("{}.json", state.execution_id));
let json = serde_json::to_string_pretty(state)?;
fs::write(&file_path, json)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to write state: {}", e)))?;
Ok(())
}
pub async fn load(&self, execution_id: &str) -> Result<WorkflowState> {
let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
let json = fs::read_to_string(&file_path)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read state: {}", e)))?;
let state = serde_json::from_str(&json)?;
Ok(state)
}
pub async fn delete(&self, execution_id: &str) -> Result<()> {
let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
fs::remove_file(&file_path)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to delete state: {}", e)))?;
Ok(())
}
pub async fn list(&self) -> Result<Vec<String>> {
let dir_path = Path::new(&self.state_dir);
if !dir_path.exists() {
return Ok(Vec::new());
}
let mut entries = fs::read_dir(dir_path)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read state dir: {}", e)))?;
let mut execution_ids = Vec::new();
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
{
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
execution_ids.push(stem.to_string());
}
}
}
Ok(execution_ids)
}
pub async fn save_checkpoint(&self, checkpoint: &WorkflowCheckpoint) -> Result<()> {
let dir_path = Path::new(&self.state_dir).join("checkpoints");
fs::create_dir_all(&dir_path).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to create checkpoint dir: {}", e))
})?;
let file_path = dir_path.join(format!(
"{}_checkpoint_{}.json",
checkpoint.state.execution_id, checkpoint.sequence
));
let json = serde_json::to_string_pretty(checkpoint)?;
fs::write(&file_path, json).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to write checkpoint: {}", e))
})?;
let latest_path = dir_path.join(format!("{}_latest.json", checkpoint.state.execution_id));
let json_latest = serde_json::to_string_pretty(checkpoint)?;
fs::write(&latest_path, json_latest).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to write latest checkpoint: {}", e))
})?;
Ok(())
}
pub async fn load_checkpoint(&self, execution_id: &str) -> Result<WorkflowCheckpoint> {
let latest_path = Path::new(&self.state_dir)
.join("checkpoints")
.join(format!("{}_latest.json", execution_id));
let json = fs::read_to_string(&latest_path)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
if checkpoint.version > WorkflowCheckpoint::CURRENT_VERSION {
return Err(WorkflowError::persistence(format!(
"Checkpoint version {} is newer than supported version {}",
checkpoint.version,
WorkflowCheckpoint::CURRENT_VERSION
)));
}
Ok(checkpoint)
}
pub async fn load_checkpoint_by_sequence(
&self,
execution_id: &str,
sequence: u64,
) -> Result<WorkflowCheckpoint> {
let file_path = Path::new(&self.state_dir)
.join("checkpoints")
.join(format!("{}_checkpoint_{}.json", execution_id, sequence));
let json = fs::read_to_string(&file_path)
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
Ok(checkpoint)
}
pub async fn delete_checkpoint(&self, execution_id: &str, sequence: u64) -> Result<()> {
let file_path = Path::new(&self.state_dir)
.join("checkpoints")
.join(format!("{}_checkpoint_{}.json", execution_id, sequence));
fs::remove_file(&file_path).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
})?;
Ok(())
}
pub async fn delete_all_checkpoints(&self, execution_id: &str) -> Result<()> {
let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
if !checkpoints_dir.exists() {
return Ok(());
}
let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
})?;
let prefix = format!("{}_", execution_id);
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
{
let path = entry.path();
if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
if name.starts_with(&prefix) {
fs::remove_file(&path).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
})?;
}
}
}
Ok(())
}
pub async fn list_checkpoints(&self, execution_id: &str) -> Result<Vec<u64>> {
let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
if !checkpoints_dir.exists() {
return Ok(Vec::new());
}
let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
})?;
let mut sequences = Vec::new();
let prefix = format!("{}_checkpoint_", execution_id);
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
{
let path = entry.path();
if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
if name.starts_with(&prefix) {
if let Some(seq_str) = name.strip_prefix(&prefix) {
if let Ok(seq) = seq_str.parse::<u64>() {
sequences.push(seq);
}
}
}
}
}
sequences.sort();
Ok(sequences)
}
pub async fn checkpoint_exists(&self, execution_id: &str) -> bool {
let latest_path = Path::new(&self.state_dir)
.join("checkpoints")
.join(format!("{}_latest.json", execution_id));
latest_path.exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_workflow_state_lifecycle() {
let mut state = WorkflowState::new(
"wf1".to_string(),
"exec1".to_string(),
"Test Workflow".to_string(),
);
assert_eq!(state.status, WorkflowStatus::Pending);
state.start();
assert_eq!(state.status, WorkflowStatus::Running);
assert!(state.metadata.started_at.is_some());
state.complete();
assert_eq!(state.status, WorkflowStatus::Completed);
assert!(state.metadata.completed_at.is_some());
assert!(state.metadata.duration_ms.is_some());
}
#[test]
fn test_task_state_lifecycle() {
let mut state = WorkflowState::new(
"wf1".to_string(),
"exec1".to_string(),
"Test Workflow".to_string(),
);
state.init_task("task1".to_string());
assert_eq!(
state.get_task_state("task1").map(|t| t.status),
Some(TaskStatus::Pending)
);
state.start_task("task1").ok();
assert_eq!(
state.get_task_state("task1").map(|t| t.status),
Some(TaskStatus::Running)
);
assert_eq!(state.get_task_state("task1").map(|t| t.attempts), Some(1));
state
.complete_task("task1", Some(serde_json::json!({"result": "success"})))
.ok();
assert_eq!(
state.get_task_state("task1").map(|t| t.status),
Some(TaskStatus::Completed)
);
}
#[test]
fn test_context_variables() {
let mut state = WorkflowState::new(
"wf1".to_string(),
"exec1".to_string(),
"Test Workflow".to_string(),
);
state.set_variable("key1".to_string(), serde_json::json!("value1"));
assert_eq!(
state.get_variable("key1"),
Some(&serde_json::json!("value1"))
);
}
}