1use crate::workflow::dag::TaskNode;
31use crate::workflow::executor::WorkflowExecutor;
32use crate::workflow::task::TaskId;
33use serde::{Deserialize, Serialize};
34use std::collections::HashSet;
35use std::sync::{Arc, RwLock};
36
37#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
39pub enum WorkflowStatus {
40 Pending,
42 Running,
44 Completed,
46 Failed,
48 RolledBack,
50}
51
52#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
54pub enum TaskStatus {
55 Pending,
57 Running,
59 Completed,
61 Failed,
63 Skipped,
65}
66
67impl TaskStatus {
68 pub(crate) fn from_parallel_result(success: bool) -> Self {
70 if success {
71 TaskStatus::Completed
72 } else {
73 TaskStatus::Failed
74 }
75 }
76}
77
78#[derive(Clone, Debug, Serialize, Deserialize)]
80pub struct TaskSummary {
81 pub id: String,
83 pub name: String,
85 pub status: TaskStatus,
87}
88
89impl TaskSummary {
90 pub fn new(id: impl Into<String>, name: impl Into<String>, status: TaskStatus) -> Self {
92 Self {
93 id: id.into(),
94 name: name.into(),
95 status,
96 }
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
105pub struct WorkflowState {
106 pub workflow_id: String,
108 pub status: WorkflowStatus,
110 pub current_task: Option<TaskSummary>,
112 pub completed_tasks: Vec<TaskSummary>,
114 pub pending_tasks: Vec<TaskSummary>,
116 pub failed_tasks: Vec<TaskSummary>,
118}
119
120impl WorkflowState {
121 pub fn new(workflow_id: impl Into<String>) -> Self {
123 Self {
124 workflow_id: workflow_id.into(),
125 status: WorkflowStatus::Pending,
126 current_task: None,
127 completed_tasks: Vec::new(),
128 pending_tasks: Vec::new(),
129 failed_tasks: Vec::new(),
130 }
131 }
132
133 pub fn with_status(mut self, status: WorkflowStatus) -> Self {
135 self.status = status;
136 self
137 }
138
139 pub fn with_completed_task(mut self, task: TaskSummary) -> Self {
141 self.completed_tasks.push(task);
142 self
143 }
144
145 pub fn with_pending_task(mut self, task: TaskSummary) -> Self {
147 self.pending_tasks.push(task);
148 self
149 }
150
151 pub fn with_failed_task(mut self, task: TaskSummary) -> Self {
153 self.failed_tasks.push(task);
154 self
155 }
156
157 pub fn with_current_task(mut self, task: TaskSummary) -> Self {
159 self.current_task = Some(task);
160 self
161 }
162}
163
164#[derive(Clone)]
189pub struct ConcurrentState {
190 inner: Arc<RwLock<WorkflowState>>,
192}
193
194impl ConcurrentState {
195 pub fn new(state: WorkflowState) -> Self {
197 Self {
198 inner: Arc::new(RwLock::new(state)),
199 }
200 }
201
202 pub fn read(&self) -> Result<std::sync::RwLockReadGuard<WorkflowState>, std::sync::PoisonError<std::sync::RwLockReadGuard<WorkflowState>>> {
212 self.inner.read()
213 }
214
215 pub fn write(&self) -> Result<std::sync::RwLockWriteGuard<WorkflowState>, std::sync::PoisonError<std::sync::RwLockWriteGuard<WorkflowState>>> {
225 self.inner.write()
226 }
227
228 pub fn try_read(&self) -> Option<std::sync::RwLockReadGuard<'_, WorkflowState>> {
235 self.inner.try_read().ok()
236 }
237
238 pub fn try_write(&self) -> Option<std::sync::RwLockWriteGuard<'_, WorkflowState>> {
245 self.inner.try_write().ok()
246 }
247
248 pub fn ref_count(&self) -> usize {
252 Arc::strong_count(&self.inner)
253 }
254}
255
256unsafe impl Send for ConcurrentState {}
261unsafe impl Sync for ConcurrentState {}
262
263#[cfg(test)]
264mod concurrent_state_tests {
265 use super::*;
266 use std::sync::Barrier;
267 use tokio::task::JoinSet;
268
269 #[test]
270 fn test_concurrent_state_creation() {
271 let state = WorkflowState::new("workflow-1");
272 let concurrent = ConcurrentState::new(state);
273
274 let reader = concurrent.read().unwrap();
275 assert_eq!(reader.workflow_id, "workflow-1");
276 assert_eq!(reader.status, WorkflowStatus::Pending);
277 }
278
279 #[test]
280 fn test_concurrent_state_clone_is_cheap() {
281 let state = WorkflowState::new("workflow-1");
282 let concurrent = ConcurrentState::new(state);
283
284 let cloned = concurrent.clone();
286 assert_eq!(concurrent.ref_count(), 2);
287
288 let cloned2 = cloned.clone();
289 assert_eq!(concurrent.ref_count(), 3);
290 }
291
292 #[test]
293 fn test_concurrent_read_write() {
294 let state = WorkflowState::new("workflow-1");
295 let concurrent = ConcurrentState::new(state);
296
297 {
299 let reader = concurrent.read().unwrap();
300 assert_eq!(reader.status, WorkflowStatus::Pending);
301 }
302
303 {
305 let mut writer = concurrent.write().unwrap();
306 writer.status = WorkflowStatus::Completed;
307 }
308
309 {
311 let reader = concurrent.read().unwrap();
312 assert_eq!(reader.status, WorkflowStatus::Completed);
313 }
314 }
315
316 #[test]
317 fn test_try_read_write() {
318 let state = WorkflowState::new("workflow-1");
319 let concurrent = ConcurrentState::new(state);
320
321 assert!(concurrent.try_read().is_some());
323
324 assert!(concurrent.try_write().is_some());
326 }
327
328 #[tokio::test]
329 async fn test_concurrent_state_thread_safety() {
330 let state = WorkflowState::new("workflow-1").with_status(WorkflowStatus::Running);
331 let concurrent = Arc::new(ConcurrentState::new(state));
332 let barrier = Arc::new(Barrier::new(3)); let mut handles = JoinSet::new();
335
336 let concurrent1 = Arc::clone(&concurrent);
338 let barrier1 = Arc::clone(&barrier);
339 handles.spawn(async move {
340 barrier1.wait();
341 let reader = concurrent1.read().unwrap();
342 assert_eq!(reader.workflow_id, "workflow-1");
343 });
344
345 let concurrent2 = Arc::clone(&concurrent);
347 let barrier2 = Arc::clone(&barrier);
348 handles.spawn(async move {
349 barrier2.wait();
350 let reader = concurrent2.read().unwrap();
351 assert_eq!(reader.status, WorkflowStatus::Running);
352 });
353
354 let concurrent3 = Arc::clone(&concurrent);
356 let barrier3 = Arc::clone(&barrier);
357 handles.spawn(async move {
358 barrier3.wait();
359 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
361 let mut writer = concurrent3.write().unwrap();
362 writer.status = WorkflowStatus::Completed;
363 });
364
365 while let Some(result) = handles.join_next().await {
367 result.expect("Task should complete successfully");
368 }
369
370 let reader = concurrent.read().unwrap();
372 assert_eq!(reader.status, WorkflowStatus::Completed);
373 }
374
375 #[tokio::test]
376 async fn test_concurrent_state_stress_test() {
377 let state = WorkflowState::new("workflow-stress");
378 let concurrent = Arc::new(ConcurrentState::new(state));
379
380 let mut handles = JoinSet::new();
381
382 for i in 0..10 {
384 let concurrent_clone = Arc::clone(&concurrent);
385 handles.spawn(async move {
386 {
388 let _reader = concurrent_clone.read().unwrap();
389 }
390
391 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
393
394 if i % 2 == 0 {
396 let mut writer = concurrent_clone.write().unwrap();
397 writer.completed_tasks.push(TaskSummary::new(
398 format!("task-{}", i),
399 format!("Task {}", i),
400 TaskStatus::Completed,
401 ));
402 }
403 });
404 }
405
406 while let Some(result) = handles.join_next().await {
408 result.expect("Task should complete successfully");
409 }
410
411 let reader = concurrent.read().unwrap();
413 assert_eq!(reader.completed_tasks.len(), 5);
414 }
415}
416
417impl WorkflowExecutor {
418 pub fn state(&self) -> WorkflowState {
437 let status = if self.failed_tasks.is_empty() && self.completed_tasks.is_empty() {
439 WorkflowStatus::Pending
440 } else if !self.failed_tasks.is_empty() {
441 WorkflowStatus::Failed
442 } else if self.completed_tasks.len() == self.workflow.task_count() {
443 WorkflowStatus::Completed
444 } else {
445 WorkflowStatus::Running
446 };
447
448 let completed_tasks: Vec<TaskSummary> = self
450 .completed_tasks
451 .iter()
452 .map(|id| {
453 let name = self.get_task_name(id)
454 .unwrap_or_else(|| "Unknown".to_string());
455 TaskSummary::new(
456 id.as_str(),
457 &name,
458 TaskStatus::Completed,
459 )
460 })
461 .collect();
462
463 let pending_task_ids: HashSet<_> = self.workflow.task_ids().into_iter().collect();
465 let completed_ids: HashSet<_> = self.completed_tasks.iter().cloned().collect();
466 let failed_ids: HashSet<_> = self.failed_tasks.iter().cloned().collect();
467
468 let pending_tasks: Vec<TaskSummary> = pending_task_ids
469 .difference(&completed_ids)
470 .filter(|id| !failed_ids.contains(id))
471 .map(|id| {
472 let name = self.get_task_name(id)
473 .unwrap_or_else(|| "Unknown".to_string());
474 TaskSummary::new(
475 id.as_str(),
476 &name,
477 TaskStatus::Pending,
478 )
479 })
480 .collect();
481
482 let failed_tasks: Vec<TaskSummary> = self
484 .failed_tasks
485 .iter()
486 .map(|id| {
487 let name = self.get_task_name(id)
488 .unwrap_or_else(|| "Unknown".to_string());
489 TaskSummary::new(
490 id.as_str(),
491 &name,
492 TaskStatus::Failed,
493 )
494 })
495 .collect();
496
497 WorkflowState {
498 workflow_id: format!("workflow-{:?}", self.audit_log.tx_id()),
499 status,
500 current_task: None,
501 completed_tasks,
502 pending_tasks,
503 failed_tasks,
504 }
505 }
506
507 fn get_task_name(&self, id: &TaskId) -> Option<String> {
509 self.workflow.task_name(id)
510 }
511}
512
513#[cfg(test)]
514mod tests {
515 use super::*;
516 use crate::workflow::dag::Workflow;
517 use crate::workflow::task::{TaskContext, TaskError, TaskResult, WorkflowTask};
518 use async_trait::async_trait;
519
520 struct MockTask {
522 id: TaskId,
523 name: String,
524 }
525
526 impl MockTask {
527 fn new(id: impl Into<TaskId>, name: &str) -> Self {
528 Self {
529 id: id.into(),
530 name: name.to_string(),
531 }
532 }
533 }
534
535 #[async_trait]
536 impl WorkflowTask for MockTask {
537 async fn execute(&self, _context: &TaskContext) -> Result<TaskResult, TaskError> {
538 Ok(TaskResult::Success)
539 }
540
541 fn id(&self) -> TaskId {
542 self.id.clone()
543 }
544
545 fn name(&self) -> &str {
546 &self.name
547 }
548 }
549
550 #[test]
551 fn test_task_summary_creation() {
552 let summary = TaskSummary::new("task-1", "Task 1", TaskStatus::Pending);
553 assert_eq!(summary.id, "task-1");
554 assert_eq!(summary.name, "Task 1");
555 assert_eq!(summary.status, TaskStatus::Pending);
556 }
557
558 #[test]
559 fn test_workflow_state_creation() {
560 let state = WorkflowState::new("workflow-1");
561 assert_eq!(state.workflow_id, "workflow-1");
562 assert_eq!(state.status, WorkflowStatus::Pending);
563 assert!(state.completed_tasks.is_empty());
564 assert!(state.pending_tasks.is_empty());
565 assert!(state.failed_tasks.is_empty());
566 }
567
568 #[test]
569 fn test_workflow_state_builder() {
570 let state = WorkflowState::new("workflow-1")
571 .with_status(WorkflowStatus::Running)
572 .with_completed_task(TaskSummary::new("task-1", "Task 1", TaskStatus::Completed))
573 .with_pending_task(TaskSummary::new("task-2", "Task 2", TaskStatus::Pending));
574
575 assert_eq!(state.status, WorkflowStatus::Running);
576 assert_eq!(state.completed_tasks.len(), 1);
577 assert_eq!(state.pending_tasks.len(), 1);
578 }
579
580 #[tokio::test]
581 async fn test_workflow_state_snapshot() {
582 use crate::workflow::executor::WorkflowExecutor;
583
584 let mut workflow = Workflow::new();
585 workflow.add_task(Box::new(MockTask::new("task-1", "Task 1")));
586 workflow.add_task(Box::new(MockTask::new("task-2", "Task 2")));
587 workflow.add_task(Box::new(MockTask::new("task-3", "Task 3")));
588
589 let executor = WorkflowExecutor::new(workflow);
590 let state = executor.state();
591
592 assert_eq!(state.status, WorkflowStatus::Pending);
594 assert_eq!(state.pending_tasks.len(), 3);
595 assert_eq!(state.completed_tasks.len(), 0);
596 }
597
598 #[tokio::test]
599 async fn test_progress_calculation() {
600 use crate::workflow::executor::WorkflowExecutor;
601
602 let mut workflow = Workflow::new();
603 workflow.add_task(Box::new(MockTask::new("task-1", "Task 1")));
604 workflow.add_task(Box::new(MockTask::new("task-2", "Task 2")));
605 workflow.add_task(Box::new(MockTask::new("task-3", "Task 3")));
606 workflow.add_task(Box::new(MockTask::new("task-4", "Task 4")));
607
608 let executor = WorkflowExecutor::new(workflow);
609
610 assert_eq!(executor.progress(), 0.0);
612 }
613
614 #[test]
615 fn test_progress_empty_workflow() {
616 use crate::workflow::executor::WorkflowExecutor;
617
618 let workflow = Workflow::new();
619 let executor = WorkflowExecutor::new(workflow);
620
621 assert_eq!(executor.progress(), 0.0);
623 }
624
625 #[tokio::test]
626 async fn test_state_serialization() {
627 let state = WorkflowState::new("workflow-1")
628 .with_status(WorkflowStatus::Completed)
629 .with_completed_task(TaskSummary::new("task-1", "Task 1", TaskStatus::Completed));
630
631 let json = serde_json::to_string(&state).unwrap();
633 assert!(json.contains("workflow-1"));
634 assert!(json.contains("Completed"));
635
636 let deserialized: WorkflowState = serde_json::from_str(&json).unwrap();
638 assert_eq!(deserialized.workflow_id, "workflow-1");
639 assert_eq!(deserialized.status, WorkflowStatus::Completed);
640 assert_eq!(deserialized.completed_tasks.len(), 1);
641 }
642
643 #[test]
644 fn test_task_status_equality() {
645 assert_eq!(TaskStatus::Pending, TaskStatus::Pending);
646 assert_ne!(TaskStatus::Pending, TaskStatus::Running);
647 assert_eq!(WorkflowStatus::Completed, WorkflowStatus::Completed);
648 assert_ne!(WorkflowStatus::Completed, WorkflowStatus::Failed);
649 }
650}