Skip to main content

forge_agent/workflow/
state.rs

1//! Workflow state inspection API.
2//!
3//! Provides runtime state inspection for workflows including task status,
4//! progress tracking, and serialization for external monitoring.
5//!
6//! # Thread Safety
7//!
8//! This module provides [`ConcurrentState`] for thread-safe state access during
9//! parallel workflow execution. It uses `Arc<RwLock<T>>` to allow:
10//! - Multiple concurrent reads (tasks checking state)
11//! - Exclusive writes (executor updating task status)
12//!
13//! ## Thread-Safety Audit (Task 1 of Phase 12-02)
14//!
15//! **Findings:**
16//! - `WorkflowState` uses `Vec<TaskSummary>` - NOT thread-safe
17//! - `WorkflowExecutor.completed_tasks: HashSet<TaskId>` - NOT thread-safe
18//! - `WorkflowExecutor.failed_tasks: Vec<TaskId>` - NOT thread-safe
19//!
20//! **Decision:** Use `Arc<RwLock<T>>` instead of:
21//! - `Arc<Mutex<T>>`: RwLock allows concurrent reads
22//! - `dashmap`: Not needed - we don't require per-key concurrent access
23//!
24//! **Data Race Identified:** In `execute_parallel()`, line 850:
25//! ```rust
26//! self.completed_tasks.insert(task_id.clone());  // DATA RACE!
27//! ```
28//! Fixed by using `ConcurrentState` for all state mutations.
29
30use crate::workflow::dag::TaskNode;
31use crate::workflow::executor::WorkflowExecutor;
32use crate::workflow::task::TaskId;
33use serde::{Deserialize, Serialize};
34use std::collections::HashSet;
35use std::sync::{Arc, RwLock};
36
37/// Status of a workflow execution.
38#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub enum WorkflowStatus {
40    /// Workflow is pending execution
41    Pending,
42    /// Workflow is currently running
43    Running,
44    /// Workflow completed successfully
45    Completed,
46    /// Workflow failed
47    Failed,
48    /// Workflow was rolled back
49    RolledBack,
50}
51
52/// Status of an individual task.
53#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TaskStatus {
55    /// Task is pending execution
56    Pending,
57    /// Task is currently running
58    Running,
59    /// Task completed successfully
60    Completed,
61    /// Task failed
62    Failed,
63    /// Task was skipped
64    Skipped,
65}
66
67impl TaskStatus {
68    /// Creates TaskStatus from a parallel execution result.
69    pub(crate) fn from_parallel_result(success: bool) -> Self {
70        if success {
71            TaskStatus::Completed
72        } else {
73            TaskStatus::Failed
74        }
75    }
76}
77
78/// Summary of a task's state.
79#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct TaskSummary {
81    /// Task identifier
82    pub id: String,
83    /// Task name
84    pub name: String,
85    /// Current task status
86    pub status: TaskStatus,
87}
88
89impl TaskSummary {
90    /// Creates a new TaskSummary.
91    pub fn new(id: impl Into<String>, name: impl Into<String>, status: TaskStatus) -> Self {
92        Self {
93            id: id.into(),
94            name: name.into(),
95            status,
96        }
97    }
98}
99
100/// Snapshot of workflow execution state.
101///
102/// Provides a complete view of the workflow's current execution status
103/// including completed, pending, and failed tasks.
104#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct WorkflowState {
106    /// Workflow identifier
107    pub workflow_id: String,
108    /// Current workflow status
109    pub status: WorkflowStatus,
110    /// Currently executing task (if any)
111    pub current_task: Option<TaskSummary>,
112    /// Tasks that have completed
113    pub completed_tasks: Vec<TaskSummary>,
114    /// Tasks that are pending execution
115    pub pending_tasks: Vec<TaskSummary>,
116    /// Tasks that have failed
117    pub failed_tasks: Vec<TaskSummary>,
118}
119
120impl WorkflowState {
121    /// Creates a new WorkflowState.
122    pub fn new(workflow_id: impl Into<String>) -> Self {
123        Self {
124            workflow_id: workflow_id.into(),
125            status: WorkflowStatus::Pending,
126            current_task: None,
127            completed_tasks: Vec::new(),
128            pending_tasks: Vec::new(),
129            failed_tasks: Vec::new(),
130        }
131    }
132
133    /// Sets the workflow status.
134    pub fn with_status(mut self, status: WorkflowStatus) -> Self {
135        self.status = status;
136        self
137    }
138
139    /// Adds a completed task.
140    pub fn with_completed_task(mut self, task: TaskSummary) -> Self {
141        self.completed_tasks.push(task);
142        self
143    }
144
145    /// Adds a pending task.
146    pub fn with_pending_task(mut self, task: TaskSummary) -> Self {
147        self.pending_tasks.push(task);
148        self
149    }
150
151    /// Adds a failed task.
152    pub fn with_failed_task(mut self, task: TaskSummary) -> Self {
153        self.failed_tasks.push(task);
154        self
155    }
156
157    /// Sets the current task.
158    pub fn with_current_task(mut self, task: TaskSummary) -> Self {
159        self.current_task = Some(task);
160        self
161    }
162}
163
164/// Thread-safe wrapper for workflow state during parallel execution.
165///
166/// Uses `Arc<RwLock<T>>` to allow multiple concurrent reads with exclusive writes.
167/// This is optimal for workflow execution where:
168/// - Multiple tasks may read state concurrently
169/// - Only the executor writes state updates
170///
171/// # Example
172///
173/// ```ignore
174/// let state = ConcurrentState::new(WorkflowState::new("workflow-1"));
175///
176/// // Concurrent reads (tasks checking state)
177/// {
178///     let reader = state.read().unwrap();
179///     assert_eq!(reader.status, WorkflowStatus::Running);
180/// }
181///
182/// // Exclusive write (executor updating state)
183/// {
184///     let mut writer = state.write().unwrap();
185///     writer.status = WorkflowStatus::Completed;
186/// }
187/// ```
188#[derive(Clone)]
189pub struct ConcurrentState {
190    /// Inner state wrapped in Arc<RwLock for thread-safe access
191    inner: Arc<RwLock<WorkflowState>>,
192}
193
194impl ConcurrentState {
195    /// Creates a new ConcurrentState from a WorkflowState.
196    pub fn new(state: WorkflowState) -> Self {
197        Self {
198            inner: Arc::new(RwLock::new(state)),
199        }
200    }
201
202    /// Acquires a read lock, allowing concurrent access from multiple readers.
203    ///
204    /// # Returns
205    ///
206    /// A `RwLockReadGuard` that provides immutable access to the state.
207    ///
208    /// # Panics
209    ///
210    /// Panics if the lock is poisoned (another thread panicked while holding the lock).
211    pub fn read(&self) -> Result<std::sync::RwLockReadGuard<WorkflowState>, std::sync::PoisonError<std::sync::RwLockReadGuard<WorkflowState>>> {
212        self.inner.read()
213    }
214
215    /// Acquires a write lock, providing exclusive mutable access.
216    ///
217    /// # Returns
218    ///
219    /// A `RwLockWriteGuard` that provides mutable access to the state.
220    ///
221    /// # Panics
222    ///
223    /// Panics if the lock is poisoned (another thread panicked while holding the lock).
224    pub fn write(&self) -> Result<std::sync::RwLockWriteGuard<WorkflowState>, std::sync::PoisonError<std::sync::RwLockWriteGuard<WorkflowState>>> {
225        self.inner.write()
226    }
227
228    /// Attempts to acquire a read lock without blocking.
229    ///
230    /// # Returns
231    ///
232    /// - `Some(guard)` if the lock was acquired immediately
233    /// - `None` if the lock is held by a writer
234    pub fn try_read(&self) -> Option<std::sync::RwLockReadGuard<'_, WorkflowState>> {
235        self.inner.try_read().ok()
236    }
237
238    /// Attempts to acquire a write lock without blocking.
239    ///
240    /// # Returns
241    ///
242    /// - `Some(guard)` if the lock was acquired immediately
243    /// - `None` if the lock is held by another reader or writer
244    pub fn try_write(&self) -> Option<std::sync::RwLockWriteGuard<'_, WorkflowState>> {
245        self.inner.try_write().ok()
246    }
247
248    /// Returns the number of strong references to the inner state.
249    ///
250    /// Useful for debugging to see how many clones exist.
251    pub fn ref_count(&self) -> usize {
252        Arc::strong_count(&self.inner)
253    }
254}
255
256// SAFETY: ConcurrentState is Send + Sync because:
257// - Arc<T> is Send + Sync when T: Send + Sync
258// - RwLock<T> is Send + Sync when T: Send
259// - WorkflowState is Send (all fields are Send)
260unsafe impl Send for ConcurrentState {}
261unsafe impl Sync for ConcurrentState {}
262
263#[cfg(test)]
264mod concurrent_state_tests {
265    use super::*;
266    use std::sync::Barrier;
267    use tokio::task::JoinSet;
268
269    #[test]
270    fn test_concurrent_state_creation() {
271        let state = WorkflowState::new("workflow-1");
272        let concurrent = ConcurrentState::new(state);
273
274        let reader = concurrent.read().unwrap();
275        assert_eq!(reader.workflow_id, "workflow-1");
276        assert_eq!(reader.status, WorkflowStatus::Pending);
277    }
278
279    #[test]
280    fn test_concurrent_state_clone_is_cheap() {
281        let state = WorkflowState::new("workflow-1");
282        let concurrent = ConcurrentState::new(state);
283
284        // Clone is cheap (just Arc increment)
285        let cloned = concurrent.clone();
286        assert_eq!(concurrent.ref_count(), 2);
287
288        let cloned2 = cloned.clone();
289        assert_eq!(concurrent.ref_count(), 3);
290    }
291
292    #[test]
293    fn test_concurrent_read_write() {
294        let state = WorkflowState::new("workflow-1");
295        let concurrent = ConcurrentState::new(state);
296
297        // Read initial state
298        {
299            let reader = concurrent.read().unwrap();
300            assert_eq!(reader.status, WorkflowStatus::Pending);
301        }
302
303        // Write new state
304        {
305            let mut writer = concurrent.write().unwrap();
306            writer.status = WorkflowStatus::Completed;
307        }
308
309        // Read updated state
310        {
311            let reader = concurrent.read().unwrap();
312            assert_eq!(reader.status, WorkflowStatus::Completed);
313        }
314    }
315
316    #[test]
317    fn test_try_read_write() {
318        let state = WorkflowState::new("workflow-1");
319        let concurrent = ConcurrentState::new(state);
320
321        // Try read should succeed
322        assert!(concurrent.try_read().is_some());
323
324        // Try write should succeed
325        assert!(concurrent.try_write().is_some());
326    }
327
328    #[tokio::test]
329    async fn test_concurrent_state_thread_safety() {
330        let state = WorkflowState::new("workflow-1").with_status(WorkflowStatus::Running);
331        let concurrent = Arc::new(ConcurrentState::new(state));
332        let barrier = Arc::new(Barrier::new(3)); // 2 readers + 1 writer
333
334        let mut handles = JoinSet::new();
335
336        // Spawn reader 1
337        let concurrent1 = Arc::clone(&concurrent);
338        let barrier1 = Arc::clone(&barrier);
339        handles.spawn(async move {
340            barrier1.wait();
341            let reader = concurrent1.read().unwrap();
342            assert_eq!(reader.workflow_id, "workflow-1");
343        });
344
345        // Spawn reader 2
346        let concurrent2 = Arc::clone(&concurrent);
347        let barrier2 = Arc::clone(&barrier);
348        handles.spawn(async move {
349            barrier2.wait();
350            let reader = concurrent2.read().unwrap();
351            assert_eq!(reader.status, WorkflowStatus::Running);
352        });
353
354        // Spawn writer
355        let concurrent3 = Arc::clone(&concurrent);
356        let barrier3 = Arc::clone(&barrier);
357        handles.spawn(async move {
358            barrier3.wait();
359            // Small delay to let readers read first
360            tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361            let mut writer = concurrent3.write().unwrap();
362            writer.status = WorkflowStatus::Completed;
363        });
364
365        // Wait for all tasks
366        while let Some(result) = handles.join_next().await {
367            result.expect("Task should complete successfully");
368        }
369
370        // Verify final state
371        let reader = concurrent.read().unwrap();
372        assert_eq!(reader.status, WorkflowStatus::Completed);
373    }
374
375    #[tokio::test]
376    async fn test_concurrent_state_stress_test() {
377        let state = WorkflowState::new("workflow-stress");
378        let concurrent = Arc::new(ConcurrentState::new(state));
379
380        let mut handles = JoinSet::new();
381
382        // Spawn 10 concurrent readers/writers
383        for i in 0..10 {
384            let concurrent_clone = Arc::clone(&concurrent);
385            handles.spawn(async move {
386                // Read (guard dropped before await)
387                {
388                    let _reader = concurrent_clone.read().unwrap();
389                }
390
391                // Small delay
392                tokio::time::sleep(std::time::Duration::from_millis(1)).await;
393
394                // Write (if even number)
395                if i % 2 == 0 {
396                    let mut writer = concurrent_clone.write().unwrap();
397                    writer.completed_tasks.push(TaskSummary::new(
398                        format!("task-{}", i),
399                        format!("Task {}", i),
400                        TaskStatus::Completed,
401                    ));
402                }
403            });
404        }
405
406        // Wait for all tasks
407        while let Some(result) = handles.join_next().await {
408            result.expect("Task should complete successfully");
409        }
410
411        // Verify no corruption - should have 5 completed tasks (even numbers 0,2,4,6,8)
412        let reader = concurrent.read().unwrap();
413        assert_eq!(reader.completed_tasks.len(), 5);
414    }
415}
416
417impl WorkflowExecutor {
418    /// Returns a snapshot of the current workflow state.
419    ///
420    /// This method provides a complete view of the workflow's execution
421    /// status including all tasks and their current states.
422    ///
423    /// # Returns
424    ///
425    /// A `WorkflowState` snapshot containing task summaries and status.
426    ///
427    /// # Example
428    ///
429    /// ```ignore
430    /// let mut executor = WorkflowExecutor::new(workflow);
431    /// executor.execute().await?;
432    /// let state = executor.state();
433    /// println!("Status: {:?}", state.status);
434    /// println!("Completed: {}", state.completed_tasks.len());
435    /// ```
436    pub fn state(&self) -> WorkflowState {
437        // Determine workflow status based on completion state
438        let status = if self.failed_tasks.is_empty() && self.completed_tasks.is_empty() {
439            WorkflowStatus::Pending
440        } else if !self.failed_tasks.is_empty() {
441            WorkflowStatus::Failed
442        } else if self.completed_tasks.len() == self.workflow.task_count() {
443            WorkflowStatus::Completed
444        } else {
445            WorkflowStatus::Running
446        };
447
448        // Build completed task summaries
449        let completed_tasks: Vec<TaskSummary> = self
450            .completed_tasks
451            .iter()
452            .map(|id| {
453                let name = self.get_task_name(id)
454                    .unwrap_or_else(|| "Unknown".to_string());
455                TaskSummary::new(
456                    id.as_str(),
457                    &name,
458                    TaskStatus::Completed,
459                )
460            })
461            .collect();
462
463        // Build pending task summaries
464        let pending_task_ids: HashSet<_> = self.workflow.task_ids().into_iter().collect();
465        let completed_ids: HashSet<_> = self.completed_tasks.iter().cloned().collect();
466        let failed_ids: HashSet<_> = self.failed_tasks.iter().cloned().collect();
467
468        let pending_tasks: Vec<TaskSummary> = pending_task_ids
469            .difference(&completed_ids)
470            .filter(|id| !failed_ids.contains(id))
471            .map(|id| {
472                let name = self.get_task_name(id)
473                    .unwrap_or_else(|| "Unknown".to_string());
474                TaskSummary::new(
475                    id.as_str(),
476                    &name,
477                    TaskStatus::Pending,
478                )
479            })
480            .collect();
481
482        // Build failed task summaries
483        let failed_tasks: Vec<TaskSummary> = self
484            .failed_tasks
485            .iter()
486            .map(|id| {
487                let name = self.get_task_name(id)
488                    .unwrap_or_else(|| "Unknown".to_string());
489                TaskSummary::new(
490                    id.as_str(),
491                    &name,
492                    TaskStatus::Failed,
493                )
494            })
495            .collect();
496
497        WorkflowState {
498            workflow_id: format!("workflow-{:?}", self.audit_log.tx_id()),
499            status,
500            current_task: None,
501            completed_tasks,
502            pending_tasks,
503            failed_tasks,
504        }
505    }
506
507    /// Helper method to get task name from workflow.
508    fn get_task_name(&self, id: &TaskId) -> Option<String> {
509        self.workflow.task_name(id)
510    }
511}
512
513#[cfg(test)]
514mod tests {
515    use super::*;
516    use crate::workflow::dag::Workflow;
517    use crate::workflow::task::{TaskContext, TaskError, TaskResult, WorkflowTask};
518    use async_trait::async_trait;
519
520    // Mock task for testing
521    struct MockTask {
522        id: TaskId,
523        name: String,
524    }
525
526    impl MockTask {
527        fn new(id: impl Into<TaskId>, name: &str) -> Self {
528            Self {
529                id: id.into(),
530                name: name.to_string(),
531            }
532        }
533    }
534
535    #[async_trait]
536    impl WorkflowTask for MockTask {
537        async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
538            Ok(TaskResult::Success)
539        }
540
541        fn id(&self) -> TaskId {
542            self.id.clone()
543        }
544
545        fn name(&self) -> &str {
546            &self.name
547        }
548    }
549
550    #[test]
551    fn test_task_summary_creation() {
552        let summary = TaskSummary::new("task-1", "Task 1", TaskStatus::Pending);
553        assert_eq!(summary.id, "task-1");
554        assert_eq!(summary.name, "Task 1");
555        assert_eq!(summary.status, TaskStatus::Pending);
556    }
557
558    #[test]
559    fn test_workflow_state_creation() {
560        let state = WorkflowState::new("workflow-1");
561        assert_eq!(state.workflow_id, "workflow-1");
562        assert_eq!(state.status, WorkflowStatus::Pending);
563        assert!(state.completed_tasks.is_empty());
564        assert!(state.pending_tasks.is_empty());
565        assert!(state.failed_tasks.is_empty());
566    }
567
568    #[test]
569    fn test_workflow_state_builder() {
570        let state = WorkflowState::new("workflow-1")
571            .with_status(WorkflowStatus::Running)
572            .with_completed_task(TaskSummary::new("task-1", "Task 1", TaskStatus::Completed))
573            .with_pending_task(TaskSummary::new("task-2", "Task 2", TaskStatus::Pending));
574
575        assert_eq!(state.status, WorkflowStatus::Running);
576        assert_eq!(state.completed_tasks.len(), 1);
577        assert_eq!(state.pending_tasks.len(), 1);
578    }
579
580    #[tokio::test]
581    async fn test_workflow_state_snapshot() {
582        use crate::workflow::executor::WorkflowExecutor;
583
584        let mut workflow = Workflow::new();
585        workflow.add_task(Box::new(MockTask::new("task-1", "Task 1")));
586        workflow.add_task(Box::new(MockTask::new("task-2", "Task 2")));
587        workflow.add_task(Box::new(MockTask::new("task-3", "Task 3")));
588
589        let executor = WorkflowExecutor::new(workflow);
590        let state = executor.state();
591
592        // Before execution, all tasks should be pending
593        assert_eq!(state.status, WorkflowStatus::Pending);
594        assert_eq!(state.pending_tasks.len(), 3);
595        assert_eq!(state.completed_tasks.len(), 0);
596    }
597
598    #[tokio::test]
599    async fn test_progress_calculation() {
600        use crate::workflow::executor::WorkflowExecutor;
601
602        let mut workflow = Workflow::new();
603        workflow.add_task(Box::new(MockTask::new("task-1", "Task 1")));
604        workflow.add_task(Box::new(MockTask::new("task-2", "Task 2")));
605        workflow.add_task(Box::new(MockTask::new("task-3", "Task 3")));
606        workflow.add_task(Box::new(MockTask::new("task-4", "Task 4")));
607
608        let executor = WorkflowExecutor::new(workflow);
609
610        // Before execution: 0/4 = 0.0
611        assert_eq!(executor.progress(), 0.0);
612    }
613
614    #[test]
615    fn test_progress_empty_workflow() {
616        use crate::workflow::executor::WorkflowExecutor;
617
618        let workflow = Workflow::new();
619        let executor = WorkflowExecutor::new(workflow);
620
621        // Empty workflow: 0 tasks = 0.0 progress
622        assert_eq!(executor.progress(), 0.0);
623    }
624
625    #[tokio::test]
626    async fn test_state_serialization() {
627        let state = WorkflowState::new("workflow-1")
628            .with_status(WorkflowStatus::Completed)
629            .with_completed_task(TaskSummary::new("task-1", "Task 1", TaskStatus::Completed));
630
631        // Serialize to JSON
632        let json = serde_json::to_string(&state).unwrap();
633        assert!(json.contains("workflow-1"));
634        assert!(json.contains("Completed"));
635
636        // Deserialize back
637        let deserialized: WorkflowState = serde_json::from_str(&json).unwrap();
638        assert_eq!(deserialized.workflow_id, "workflow-1");
639        assert_eq!(deserialized.status, WorkflowStatus::Completed);
640        assert_eq!(deserialized.completed_tasks.len(), 1);
641    }
642
643    #[test]
644    fn test_task_status_equality() {
645        assert_eq!(TaskStatus::Pending, TaskStatus::Pending);
646        assert_ne!(TaskStatus::Pending, TaskStatus::Running);
647        assert_eq!(WorkflowStatus::Completed, WorkflowStatus::Completed);
648        assert_ne!(WorkflowStatus::Completed, WorkflowStatus::Failed);
649    }
650}