Skip to main content

oxigdal_workflow/engine/
state.rs

1//! Workflow state management and persistence.
2
3use crate::dag::WorkflowDag;
4use crate::error::{Result, WorkflowError};
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::path::Path;
9use tokio::fs;
10
11/// Workflow execution state.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct WorkflowState {
14    /// Workflow ID.
15    pub workflow_id: String,
16    /// Workflow execution ID (unique per run).
17    pub execution_id: String,
18    /// Current workflow status.
19    pub status: WorkflowStatus,
20    /// Task states.
21    pub task_states: HashMap<String, TaskState>,
22    /// Workflow metadata.
23    pub metadata: WorkflowMetadata,
24    /// Execution context.
25    pub context: ExecutionContext,
26}
27
28/// Workflow execution status.
29#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
30pub enum WorkflowStatus {
31    /// Workflow is pending execution.
32    Pending,
33    /// Workflow is currently running.
34    Running,
35    /// Workflow completed successfully.
36    Completed,
37    /// Workflow failed.
38    Failed,
39    /// Workflow was cancelled.
40    Cancelled,
41    /// Workflow is paused.
42    Paused,
43}
44
45/// Individual task state.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct TaskState {
48    /// Task ID.
49    pub task_id: String,
50    /// Current task status.
51    pub status: TaskStatus,
52    /// Number of attempts.
53    pub attempts: u32,
54    /// Task start time.
55    pub started_at: Option<DateTime<Utc>>,
56    /// Task completion time.
57    pub completed_at: Option<DateTime<Utc>>,
58    /// Task duration in milliseconds.
59    pub duration_ms: Option<u64>,
60    /// Task output (if any).
61    pub output: Option<serde_json::Value>,
62    /// Task error (if any).
63    pub error: Option<String>,
64    /// Task logs.
65    pub logs: Vec<String>,
66}
67
68/// Task execution status.
69#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
70pub enum TaskStatus {
71    /// Task is pending execution.
72    Pending,
73    /// Task is currently running.
74    Running,
75    /// Task completed successfully.
76    Completed,
77    /// Task failed.
78    Failed,
79    /// Task was skipped (due to conditionals).
80    Skipped,
81    /// Task was cancelled.
82    Cancelled,
83    /// Task is waiting for retry.
84    WaitingRetry,
85}
86
87/// Workflow metadata.
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct WorkflowMetadata {
90    /// Workflow name.
91    pub name: String,
92    /// Workflow version.
93    pub version: String,
94    /// Workflow creation time.
95    pub created_at: DateTime<Utc>,
96    /// Workflow start time.
97    pub started_at: Option<DateTime<Utc>>,
98    /// Workflow completion time.
99    pub completed_at: Option<DateTime<Utc>>,
100    /// Total duration in milliseconds.
101    pub duration_ms: Option<u64>,
102    /// Workflow creator/owner.
103    pub owner: Option<String>,
104    /// Custom tags.
105    pub tags: HashMap<String, String>,
106}
107
108/// Execution context shared across tasks.
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct ExecutionContext {
111    /// Shared variables.
112    pub variables: HashMap<String, serde_json::Value>,
113    /// Workflow parameters.
114    pub parameters: HashMap<String, serde_json::Value>,
115    /// Environment variables.
116    pub env: HashMap<String, String>,
117}
118
119impl WorkflowState {
120    /// Create a new workflow state.
121    pub fn new(workflow_id: String, execution_id: String, name: String) -> Self {
122        Self {
123            workflow_id,
124            execution_id,
125            status: WorkflowStatus::Pending,
126            task_states: HashMap::new(),
127            metadata: WorkflowMetadata {
128                name,
129                version: "1.0.0".to_string(),
130                created_at: Utc::now(),
131                started_at: None,
132                completed_at: None,
133                duration_ms: None,
134                owner: None,
135                tags: HashMap::new(),
136            },
137            context: ExecutionContext {
138                variables: HashMap::new(),
139                parameters: HashMap::new(),
140                env: HashMap::new(),
141            },
142        }
143    }
144
145    /// Initialize a task state.
146    pub fn init_task(&mut self, task_id: String) {
147        self.task_states.insert(
148            task_id.clone(),
149            TaskState {
150                task_id,
151                status: TaskStatus::Pending,
152                attempts: 0,
153                started_at: None,
154                completed_at: None,
155                duration_ms: None,
156                output: None,
157                error: None,
158                logs: Vec::new(),
159            },
160        );
161    }
162
163    /// Mark a task as running.
164    pub fn start_task(&mut self, task_id: &str) -> Result<()> {
165        let task_state = self
166            .task_states
167            .get_mut(task_id)
168            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
169
170        task_state.status = TaskStatus::Running;
171        task_state.started_at = Some(Utc::now());
172        task_state.attempts += 1;
173
174        Ok(())
175    }
176
177    /// Mark a task as completed.
178    pub fn complete_task(
179        &mut self,
180        task_id: &str,
181        output: Option<serde_json::Value>,
182    ) -> Result<()> {
183        let task_state = self
184            .task_states
185            .get_mut(task_id)
186            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
187
188        task_state.status = TaskStatus::Completed;
189        task_state.completed_at = Some(Utc::now());
190        task_state.output = output;
191
192        if let Some(started) = task_state.started_at {
193            task_state.duration_ms = Some(
194                (Utc::now() - started)
195                    .num_milliseconds()
196                    .try_into()
197                    .unwrap_or(0),
198            );
199        }
200
201        Ok(())
202    }
203
204    /// Mark a task as failed.
205    pub fn fail_task(&mut self, task_id: &str, error: String) -> Result<()> {
206        let task_state = self
207            .task_states
208            .get_mut(task_id)
209            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
210
211        task_state.status = TaskStatus::Failed;
212        task_state.completed_at = Some(Utc::now());
213        task_state.error = Some(error);
214
215        if let Some(started) = task_state.started_at {
216            task_state.duration_ms = Some(
217                (Utc::now() - started)
218                    .num_milliseconds()
219                    .try_into()
220                    .unwrap_or(0),
221            );
222        }
223
224        Ok(())
225    }
226
227    /// Mark a task as skipped.
228    pub fn skip_task(&mut self, task_id: &str) -> Result<()> {
229        let task_state = self
230            .task_states
231            .get_mut(task_id)
232            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
233
234        task_state.status = TaskStatus::Skipped;
235        task_state.completed_at = Some(Utc::now());
236
237        Ok(())
238    }
239
240    /// Add a log entry for a task.
241    pub fn add_task_log(&mut self, task_id: &str, log: String) -> Result<()> {
242        let task_state = self
243            .task_states
244            .get_mut(task_id)
245            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
246
247        task_state.logs.push(log);
248
249        Ok(())
250    }
251
252    /// Start the workflow execution.
253    pub fn start(&mut self) {
254        self.status = WorkflowStatus::Running;
255        self.metadata.started_at = Some(Utc::now());
256    }
257
258    /// Mark the workflow as completed.
259    pub fn complete(&mut self) {
260        self.status = WorkflowStatus::Completed;
261        self.metadata.completed_at = Some(Utc::now());
262
263        if let Some(started) = self.metadata.started_at {
264            self.metadata.duration_ms = Some(
265                (Utc::now() - started)
266                    .num_milliseconds()
267                    .try_into()
268                    .unwrap_or(0),
269            );
270        }
271    }
272
273    /// Mark the workflow as failed.
274    pub fn fail(&mut self) {
275        self.status = WorkflowStatus::Failed;
276        self.metadata.completed_at = Some(Utc::now());
277
278        if let Some(started) = self.metadata.started_at {
279            self.metadata.duration_ms = Some(
280                (Utc::now() - started)
281                    .num_milliseconds()
282                    .try_into()
283                    .unwrap_or(0),
284            );
285        }
286    }
287
288    /// Mark the workflow as cancelled.
289    pub fn cancel(&mut self) {
290        self.status = WorkflowStatus::Cancelled;
291        self.metadata.completed_at = Some(Utc::now());
292
293        if let Some(started) = self.metadata.started_at {
294            self.metadata.duration_ms = Some(
295                (Utc::now() - started)
296                    .num_milliseconds()
297                    .try_into()
298                    .unwrap_or(0),
299            );
300        }
301    }
302
303    /// Get task state.
304    pub fn get_task_state(&self, task_id: &str) -> Option<&TaskState> {
305        self.task_states.get(task_id)
306    }
307
308    /// Set a context variable.
309    pub fn set_variable(&mut self, key: String, value: serde_json::Value) {
310        self.context.variables.insert(key, value);
311    }
312
313    /// Get a context variable.
314    pub fn get_variable(&self, key: &str) -> Option<&serde_json::Value> {
315        self.context.variables.get(key)
316    }
317
318    /// Check if the workflow is terminal (completed, failed, or cancelled).
319    pub fn is_terminal(&self) -> bool {
320        matches!(
321            self.status,
322            WorkflowStatus::Completed | WorkflowStatus::Failed | WorkflowStatus::Cancelled
323        )
324    }
325}
326
327/// Workflow checkpoint containing both state and DAG for recovery.
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct WorkflowCheckpoint {
330    /// Checkpoint version for compatibility.
331    pub version: u32,
332    /// Timestamp when checkpoint was created.
333    pub created_at: DateTime<Utc>,
334    /// Checkpoint sequence number (increments with each save).
335    pub sequence: u64,
336    /// The workflow state.
337    pub state: WorkflowState,
338    /// The workflow DAG definition.
339    pub dag: WorkflowDag,
340}
341
342impl WorkflowCheckpoint {
343    /// Current checkpoint format version.
344    pub const CURRENT_VERSION: u32 = 1;
345
346    /// Create a new checkpoint from state and DAG.
347    pub fn new(state: WorkflowState, dag: WorkflowDag, sequence: u64) -> Self {
348        Self {
349            version: Self::CURRENT_VERSION,
350            created_at: Utc::now(),
351            sequence,
352            state,
353            dag,
354        }
355    }
356
357    /// Get tasks that need to be executed (pending or failed but retriable).
358    pub fn get_pending_tasks(&self) -> Vec<String> {
359        self.state
360            .task_states
361            .iter()
362            .filter(|(_, ts)| matches!(ts.status, TaskStatus::Pending | TaskStatus::WaitingRetry))
363            .map(|(id, _)| id.clone())
364            .collect()
365    }
366
367    /// Get tasks that were running when checkpoint was saved (need retry).
368    pub fn get_interrupted_tasks(&self) -> Vec<String> {
369        self.state
370            .task_states
371            .iter()
372            .filter(|(_, ts)| ts.status == TaskStatus::Running)
373            .map(|(id, _)| id.clone())
374            .collect()
375    }
376
377    /// Get tasks that completed successfully.
378    pub fn get_completed_tasks(&self) -> Vec<String> {
379        self.state
380            .task_states
381            .iter()
382            .filter(|(_, ts)| ts.status == TaskStatus::Completed)
383            .map(|(id, _)| id.clone())
384            .collect()
385    }
386
387    /// Get tasks that failed (not retriable).
388    pub fn get_failed_tasks(&self) -> Vec<String> {
389        self.state
390            .task_states
391            .iter()
392            .filter(|(_, ts)| ts.status == TaskStatus::Failed)
393            .map(|(id, _)| id.clone())
394            .collect()
395    }
396
397    /// Get tasks that were skipped.
398    pub fn get_skipped_tasks(&self) -> Vec<String> {
399        self.state
400            .task_states
401            .iter()
402            .filter(|(_, ts)| ts.status == TaskStatus::Skipped)
403            .map(|(id, _)| id.clone())
404            .collect()
405    }
406
407    /// Check if all dependencies for a task are satisfied.
408    pub fn are_dependencies_satisfied(&self, task_id: &str) -> bool {
409        let dependencies = self.dag.get_dependencies(task_id);
410        dependencies.iter().all(|dep_id| {
411            self.state
412                .task_states
413                .get(dep_id)
414                .map(|ts| ts.status == TaskStatus::Completed)
415                .unwrap_or(false)
416        })
417    }
418
419    /// Get tasks ready to execute (pending with satisfied dependencies).
420    pub fn get_ready_tasks(&self) -> Vec<String> {
421        self.get_pending_tasks()
422            .into_iter()
423            .filter(|task_id| self.are_dependencies_satisfied(task_id))
424            .collect()
425    }
426
427    /// Prepare state for resumption by resetting interrupted tasks.
428    pub fn prepare_for_resume(&mut self) -> Result<()> {
429        // Reset interrupted (running) tasks to pending for retry
430        let interrupted = self.get_interrupted_tasks();
431        for task_id in interrupted {
432            if let Some(task_state) = self.state.task_states.get_mut(&task_id) {
433                task_state.status = TaskStatus::Pending;
434                // Keep attempt count for proper retry tracking
435            }
436        }
437
438        // Reset workflow status to running
439        if self.state.status == WorkflowStatus::Paused {
440            self.state.status = WorkflowStatus::Running;
441        }
442
443        Ok(())
444    }
445}
446
447/// State persistence manager.
448pub struct StatePersistence {
449    /// Directory for state storage.
450    state_dir: String,
451}
452
453impl StatePersistence {
454    /// Create a new state persistence manager.
455    pub fn new(state_dir: String) -> Self {
456        Self { state_dir }
457    }
458
459    /// Save workflow state to disk.
460    pub async fn save(&self, state: &WorkflowState) -> Result<()> {
461        let dir_path = Path::new(&self.state_dir);
462        fs::create_dir_all(dir_path).await.map_err(|e| {
463            WorkflowError::persistence(format!("Failed to create state dir: {}", e))
464        })?;
465
466        let file_path = dir_path.join(format!("{}.json", state.execution_id));
467        let json = serde_json::to_string_pretty(state)?;
468
469        fs::write(&file_path, json)
470            .await
471            .map_err(|e| WorkflowError::persistence(format!("Failed to write state: {}", e)))?;
472
473        Ok(())
474    }
475
476    /// Load workflow state from disk.
477    pub async fn load(&self, execution_id: &str) -> Result<WorkflowState> {
478        let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
479
480        let json = fs::read_to_string(&file_path)
481            .await
482            .map_err(|e| WorkflowError::persistence(format!("Failed to read state: {}", e)))?;
483
484        let state = serde_json::from_str(&json)?;
485        Ok(state)
486    }
487
488    /// Delete workflow state from disk.
489    pub async fn delete(&self, execution_id: &str) -> Result<()> {
490        let file_path = Path::new(&self.state_dir).join(format!("{}.json", execution_id));
491
492        fs::remove_file(&file_path)
493            .await
494            .map_err(|e| WorkflowError::persistence(format!("Failed to delete state: {}", e)))?;
495
496        Ok(())
497    }
498
499    /// List all workflow states.
500    pub async fn list(&self) -> Result<Vec<String>> {
501        let dir_path = Path::new(&self.state_dir);
502
503        if !dir_path.exists() {
504            return Ok(Vec::new());
505        }
506
507        let mut entries = fs::read_dir(dir_path)
508            .await
509            .map_err(|e| WorkflowError::persistence(format!("Failed to read state dir: {}", e)))?;
510
511        let mut execution_ids = Vec::new();
512
513        while let Some(entry) = entries
514            .next_entry()
515            .await
516            .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
517        {
518            let path = entry.path();
519            if path.extension().and_then(|s| s.to_str()) == Some("json") {
520                if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
521                    execution_ids.push(stem.to_string());
522                }
523            }
524        }
525
526        Ok(execution_ids)
527    }
528
529    /// Save a workflow checkpoint (state + DAG) to disk.
530    pub async fn save_checkpoint(&self, checkpoint: &WorkflowCheckpoint) -> Result<()> {
531        let dir_path = Path::new(&self.state_dir).join("checkpoints");
532        fs::create_dir_all(&dir_path).await.map_err(|e| {
533            WorkflowError::persistence(format!("Failed to create checkpoint dir: {}", e))
534        })?;
535
536        let file_path = dir_path.join(format!(
537            "{}_checkpoint_{}.json",
538            checkpoint.state.execution_id, checkpoint.sequence
539        ));
540        let json = serde_json::to_string_pretty(checkpoint)?;
541
542        fs::write(&file_path, json).await.map_err(|e| {
543            WorkflowError::persistence(format!("Failed to write checkpoint: {}", e))
544        })?;
545
546        // Also save a "latest" symlink/copy for easy access
547        let latest_path = dir_path.join(format!("{}_latest.json", checkpoint.state.execution_id));
548        let json_latest = serde_json::to_string_pretty(checkpoint)?;
549        fs::write(&latest_path, json_latest).await.map_err(|e| {
550            WorkflowError::persistence(format!("Failed to write latest checkpoint: {}", e))
551        })?;
552
553        Ok(())
554    }
555
556    /// Load the latest checkpoint for an execution.
557    pub async fn load_checkpoint(&self, execution_id: &str) -> Result<WorkflowCheckpoint> {
558        let latest_path = Path::new(&self.state_dir)
559            .join("checkpoints")
560            .join(format!("{}_latest.json", execution_id));
561
562        let json = fs::read_to_string(&latest_path)
563            .await
564            .map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
565
566        let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
567
568        // Validate checkpoint version
569        if checkpoint.version > WorkflowCheckpoint::CURRENT_VERSION {
570            return Err(WorkflowError::persistence(format!(
571                "Checkpoint version {} is newer than supported version {}",
572                checkpoint.version,
573                WorkflowCheckpoint::CURRENT_VERSION
574            )));
575        }
576
577        Ok(checkpoint)
578    }
579
580    /// Load a specific checkpoint by sequence number.
581    pub async fn load_checkpoint_by_sequence(
582        &self,
583        execution_id: &str,
584        sequence: u64,
585    ) -> Result<WorkflowCheckpoint> {
586        let file_path = Path::new(&self.state_dir)
587            .join("checkpoints")
588            .join(format!("{}_checkpoint_{}.json", execution_id, sequence));
589
590        let json = fs::read_to_string(&file_path)
591            .await
592            .map_err(|e| WorkflowError::persistence(format!("Failed to read checkpoint: {}", e)))?;
593
594        let checkpoint: WorkflowCheckpoint = serde_json::from_str(&json)?;
595        Ok(checkpoint)
596    }
597
598    /// Delete a checkpoint.
599    pub async fn delete_checkpoint(&self, execution_id: &str, sequence: u64) -> Result<()> {
600        let file_path = Path::new(&self.state_dir)
601            .join("checkpoints")
602            .join(format!("{}_checkpoint_{}.json", execution_id, sequence));
603
604        fs::remove_file(&file_path).await.map_err(|e| {
605            WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
606        })?;
607
608        Ok(())
609    }
610
611    /// Delete all checkpoints for an execution.
612    pub async fn delete_all_checkpoints(&self, execution_id: &str) -> Result<()> {
613        let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
614
615        if !checkpoints_dir.exists() {
616            return Ok(());
617        }
618
619        let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
620            WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
621        })?;
622
623        let prefix = format!("{}_", execution_id);
624
625        while let Some(entry) = entries
626            .next_entry()
627            .await
628            .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
629        {
630            let path = entry.path();
631            if let Some(name) = path.file_name().and_then(|s| s.to_str()) {
632                if name.starts_with(&prefix) {
633                    fs::remove_file(&path).await.map_err(|e| {
634                        WorkflowError::persistence(format!("Failed to delete checkpoint: {}", e))
635                    })?;
636                }
637            }
638        }
639
640        Ok(())
641    }
642
643    /// List all checkpoints for an execution (sorted by sequence).
644    pub async fn list_checkpoints(&self, execution_id: &str) -> Result<Vec<u64>> {
645        let checkpoints_dir = Path::new(&self.state_dir).join("checkpoints");
646
647        if !checkpoints_dir.exists() {
648            return Ok(Vec::new());
649        }
650
651        let mut entries = fs::read_dir(&checkpoints_dir).await.map_err(|e| {
652            WorkflowError::persistence(format!("Failed to read checkpoints dir: {}", e))
653        })?;
654
655        let mut sequences = Vec::new();
656        let prefix = format!("{}_checkpoint_", execution_id);
657
658        while let Some(entry) = entries
659            .next_entry()
660            .await
661            .map_err(|e| WorkflowError::persistence(format!("Failed to read entry: {}", e)))?
662        {
663            let path = entry.path();
664            if let Some(name) = path.file_stem().and_then(|s| s.to_str()) {
665                if name.starts_with(&prefix) {
666                    if let Some(seq_str) = name.strip_prefix(&prefix) {
667                        if let Ok(seq) = seq_str.parse::<u64>() {
668                            sequences.push(seq);
669                        }
670                    }
671                }
672            }
673        }
674
675        sequences.sort();
676        Ok(sequences)
677    }
678
679    /// Check if a checkpoint exists for an execution.
680    pub async fn checkpoint_exists(&self, execution_id: &str) -> bool {
681        let latest_path = Path::new(&self.state_dir)
682            .join("checkpoints")
683            .join(format!("{}_latest.json", execution_id));
684        latest_path.exists()
685    }
686}
687
688#[cfg(test)]
689mod tests {
690    use super::*;
691
692    #[test]
693    fn test_workflow_state_lifecycle() {
694        let mut state = WorkflowState::new(
695            "wf1".to_string(),
696            "exec1".to_string(),
697            "Test Workflow".to_string(),
698        );
699
700        assert_eq!(state.status, WorkflowStatus::Pending);
701
702        state.start();
703        assert_eq!(state.status, WorkflowStatus::Running);
704        assert!(state.metadata.started_at.is_some());
705
706        state.complete();
707        assert_eq!(state.status, WorkflowStatus::Completed);
708        assert!(state.metadata.completed_at.is_some());
709        assert!(state.metadata.duration_ms.is_some());
710    }
711
712    #[test]
713    fn test_task_state_lifecycle() {
714        let mut state = WorkflowState::new(
715            "wf1".to_string(),
716            "exec1".to_string(),
717            "Test Workflow".to_string(),
718        );
719
720        state.init_task("task1".to_string());
721        assert_eq!(
722            state.get_task_state("task1").map(|t| t.status),
723            Some(TaskStatus::Pending)
724        );
725
726        state.start_task("task1").ok();
727        assert_eq!(
728            state.get_task_state("task1").map(|t| t.status),
729            Some(TaskStatus::Running)
730        );
731        assert_eq!(state.get_task_state("task1").map(|t| t.attempts), Some(1));
732
733        state
734            .complete_task("task1", Some(serde_json::json!({"result": "success"})))
735            .ok();
736        assert_eq!(
737            state.get_task_state("task1").map(|t| t.status),
738            Some(TaskStatus::Completed)
739        );
740    }
741
742    #[test]
743    fn test_context_variables() {
744        let mut state = WorkflowState::new(
745            "wf1".to_string(),
746            "exec1".to_string(),
747            "Test Workflow".to_string(),
748        );
749
750        state.set_variable("key1".to_string(), serde_json::json!("value1"));
751        assert_eq!(
752            state.get_variable("key1"),
753            Some(&serde_json::json!("value1"))
754        );
755    }
756}