Skip to main content

oxigdal_workflow/engine/
executor.rs

1//! Workflow execution engine.
2
3use crate::dag::{ResourcePool, TaskNode, WorkflowDag, create_execution_plan};
4use crate::engine::state::{
5    StatePersistence, TaskStatus, WorkflowCheckpoint, WorkflowState, WorkflowStatus,
6};
7use crate::error::{Result, WorkflowError};
8use async_trait::async_trait;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use tokio::sync::{RwLock, Semaphore};
13use tokio::time::timeout;
14use tracing::{debug, error, info, warn};
15
16/// Task executor trait - implement this to define custom task execution logic.
17#[async_trait]
18pub trait TaskExecutor: Send + Sync {
19    /// Execute a task.
20    async fn execute(&self, task: &TaskNode, context: &ExecutionContext) -> Result<TaskOutput>;
21}
22
23/// Execution context provided to task executors.
24#[derive(Debug, Clone)]
25pub struct ExecutionContext {
26    /// Workflow execution ID.
27    pub execution_id: String,
28    /// Task ID.
29    pub task_id: String,
30    /// Shared workflow state.
31    pub state: Arc<RwLock<WorkflowState>>,
32    /// Input data from previous tasks.
33    pub inputs: std::collections::HashMap<String, serde_json::Value>,
34}
35
36/// Task execution output.
37#[derive(Debug, Clone)]
38pub struct TaskOutput {
39    /// Output data.
40    pub data: Option<serde_json::Value>,
41    /// Execution logs.
42    pub logs: Vec<String>,
43}
44
45/// Workflow executor configuration.
46#[derive(Debug, Clone)]
47pub struct ExecutorConfig {
48    /// Maximum concurrent tasks.
49    pub max_concurrent_tasks: usize,
50    /// Enable state persistence.
51    pub enable_persistence: bool,
52    /// State directory.
53    pub state_dir: String,
54    /// Resource pool.
55    pub resource_pool: ResourcePool,
56    /// Retry on failure.
57    pub retry_on_failure: bool,
58    /// Stop on first failure.
59    pub stop_on_failure: bool,
60    /// Checkpoint interval (save checkpoint every N tasks).
61    pub checkpoint_interval: usize,
62    /// Enable checkpoint-based recovery.
63    pub enable_checkpointing: bool,
64}
65
66impl Default for ExecutorConfig {
67    fn default() -> Self {
68        Self {
69            max_concurrent_tasks: 10,
70            enable_persistence: true,
71            state_dir: "/tmp/oxigdal-workflow".to_string(),
72            resource_pool: ResourcePool::default(),
73            retry_on_failure: true,
74            stop_on_failure: false,
75            checkpoint_interval: 1, // Save after each task by default
76            enable_checkpointing: true,
77        }
78    }
79}
80
81/// Workflow executor.
82pub struct WorkflowExecutor<E: TaskExecutor> {
83    /// Configuration.
84    config: ExecutorConfig,
85    /// Task executor implementation.
86    task_executor: Arc<E>,
87    /// State persistence.
88    persistence: Option<StatePersistence>,
89    /// Semaphore for limiting concurrent tasks (reserved for future parallel execution).
90    _semaphore: Arc<Semaphore>,
91    /// Checkpoint sequence counter.
92    checkpoint_sequence: AtomicU64,
93    /// Tasks completed since last checkpoint.
94    tasks_since_checkpoint: AtomicU64,
95}
96
97impl<E: TaskExecutor> WorkflowExecutor<E> {
98    /// Create a new workflow executor.
99    pub fn new(config: ExecutorConfig, task_executor: E) -> Self {
100        let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
101        let persistence = if config.enable_persistence {
102            Some(StatePersistence::new(config.state_dir.clone()))
103        } else {
104            None
105        };
106
107        Self {
108            config,
109            task_executor: Arc::new(task_executor),
110            persistence,
111            _semaphore: semaphore,
112            checkpoint_sequence: AtomicU64::new(0),
113            tasks_since_checkpoint: AtomicU64::new(0),
114        }
115    }
116
117    /// Save a checkpoint if conditions are met.
118    async fn maybe_save_checkpoint(&self, state: &WorkflowState, dag: &WorkflowDag) -> Result<()> {
119        if !self.config.enable_checkpointing {
120            return Ok(());
121        }
122
123        let persistence = match &self.persistence {
124            Some(p) => p,
125            None => return Ok(()),
126        };
127
128        let tasks_completed = self.tasks_since_checkpoint.fetch_add(1, Ordering::SeqCst) + 1;
129
130        if tasks_completed >= self.config.checkpoint_interval as u64 {
131            self.tasks_since_checkpoint.store(0, Ordering::SeqCst);
132            let seq = self.checkpoint_sequence.fetch_add(1, Ordering::SeqCst);
133
134            let checkpoint = WorkflowCheckpoint::new(state.clone(), dag.clone(), seq);
135            persistence.save_checkpoint(&checkpoint).await?;
136
137            debug!(
138                "Saved checkpoint {} for execution {}",
139                seq, state.execution_id
140            );
141        }
142
143        Ok(())
144    }
145
146    /// Force save a checkpoint immediately.
147    async fn save_checkpoint_now(&self, state: &WorkflowState, dag: &WorkflowDag) -> Result<()> {
148        if !self.config.enable_checkpointing {
149            return Ok(());
150        }
151
152        let persistence = match &self.persistence {
153            Some(p) => p,
154            None => return Ok(()),
155        };
156
157        self.tasks_since_checkpoint.store(0, Ordering::SeqCst);
158        let seq = self.checkpoint_sequence.fetch_add(1, Ordering::SeqCst);
159
160        let checkpoint = WorkflowCheckpoint::new(state.clone(), dag.clone(), seq);
161        persistence.save_checkpoint(&checkpoint).await?;
162
163        info!(
164            "Saved checkpoint {} for execution {}",
165            seq, state.execution_id
166        );
167        Ok(())
168    }
169
170    /// Execute a workflow.
171    pub async fn execute(
172        &self,
173        workflow_id: String,
174        execution_id: String,
175        dag: WorkflowDag,
176    ) -> Result<WorkflowState> {
177        info!(
178            "Starting workflow execution: workflow_id={}, execution_id={}",
179            workflow_id, execution_id
180        );
181
182        // Validate DAG
183        dag.validate()?;
184
185        // Create initial workflow state
186        let mut state = WorkflowState::new(workflow_id.clone(), execution_id.clone(), workflow_id);
187
188        // Initialize task states
189        for task in dag.tasks() {
190            state.init_task(task.id.clone());
191        }
192
193        state.start();
194
195        // Save initial state
196        if let Some(ref persistence) = self.persistence {
197            persistence.save(&state).await?;
198        }
199
200        // Save initial checkpoint with DAG
201        self.save_checkpoint_now(&state, &dag).await?;
202
203        let state_arc = Arc::new(RwLock::new(state));
204
205        // Create execution plan
206        let execution_plan = create_execution_plan(&dag)?;
207
208        info!(
209            "Execution plan created with {} levels",
210            execution_plan.len()
211        );
212
213        // Execute tasks level by level
214        for (level_idx, level) in execution_plan.iter().enumerate() {
215            info!("Executing level {} with {} tasks", level_idx, level.len());
216
217            let results = self.execute_level(&dag, &state_arc, level).await;
218
219            // Save checkpoint after each level
220            {
221                let state_guard = state_arc.read().await;
222                self.maybe_save_checkpoint(&state_guard, &dag).await?;
223            }
224
225            // Check for failures
226            let failed_tasks: Vec<_> = results
227                .iter()
228                .filter_map(|(task_id, result)| {
229                    if result.is_err() {
230                        Some(task_id.clone())
231                    } else {
232                        None
233                    }
234                })
235                .collect();
236
237            if !failed_tasks.is_empty() {
238                error!("Tasks failed: {:?}", failed_tasks);
239
240                if self.config.stop_on_failure {
241                    warn!("Stopping workflow execution due to failures");
242                    let mut state_guard = state_arc.write().await;
243                    state_guard.fail();
244
245                    if let Some(ref persistence) = self.persistence {
246                        persistence.save(&state_guard).await?;
247                    }
248
249                    // Save final checkpoint on failure
250                    self.save_checkpoint_now(&state_guard, &dag).await?;
251
252                    drop(state_guard);
253
254                    return Ok(Arc::try_unwrap(state_arc)
255                        .map(|rw| rw.into_inner())
256                        .unwrap_or_else(|arc| {
257                            tokio::task::block_in_place(|| arc.blocking_read().clone())
258                        }));
259                }
260            }
261        }
262
263        // Complete workflow
264        let mut state_guard = state_arc.write().await;
265
266        // Check if all tasks completed successfully
267        let all_completed = state_guard
268            .task_states
269            .values()
270            .all(|ts| ts.status == TaskStatus::Completed || ts.status == TaskStatus::Skipped);
271
272        if all_completed {
273            state_guard.complete();
274        } else {
275            state_guard.fail();
276        }
277
278        // Save final state
279        if let Some(ref persistence) = self.persistence {
280            persistence.save(&state_guard).await?;
281        }
282
283        // Save final checkpoint
284        self.save_checkpoint_now(&state_guard, &dag).await?;
285
286        info!(
287            "Workflow execution completed: status={:?}",
288            state_guard.status
289        );
290
291        drop(state_guard);
292
293        Ok(Arc::try_unwrap(state_arc)
294            .map(|rw| rw.into_inner())
295            .unwrap_or_else(|arc| tokio::task::block_in_place(|| arc.blocking_read().clone())))
296    }
297
298    /// Execute a level of tasks in parallel.
299    async fn execute_level(
300        &self,
301        dag: &WorkflowDag,
302        state: &Arc<RwLock<WorkflowState>>,
303        level: &[String],
304    ) -> Vec<(String, Result<()>)> {
305        let mut results = Vec::new();
306
307        for task_id in level {
308            let result = self
309                .execute_task(
310                    task_id,
311                    dag,
312                    state,
313                    &*self.task_executor,
314                    self.config.retry_on_failure,
315                )
316                .await;
317            results.push((task_id.clone(), result));
318        }
319
320        results
321    }
322
323    /// Execute a single task.
324    async fn execute_task(
325        &self,
326        task_id: &str,
327        dag: &WorkflowDag,
328        state: &Arc<RwLock<WorkflowState>>,
329        executor: &E,
330        retry_on_failure: bool,
331    ) -> Result<()> {
332        let task = dag
333            .get_task(task_id)
334            .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
335
336        debug!("Executing task: {}", task_id);
337
338        // Check dependencies
339        if !self.check_dependencies(task_id, dag, state).await? {
340            warn!("Skipping task {} due to failed dependencies", task_id);
341            let mut state_guard = state.write().await;
342            state_guard.skip_task(task_id)?;
343            return Ok(());
344        }
345
346        // Mark task as started
347        {
348            let mut state_guard = state.write().await;
349            state_guard.start_task(task_id)?;
350        }
351
352        // Execute with retry
353        let max_attempts = if retry_on_failure {
354            task.retry.max_attempts
355        } else {
356            1
357        };
358
359        let mut last_error = None;
360
361        for attempt in 0..max_attempts {
362            if attempt > 0 {
363                debug!("Retrying task {} (attempt {})", task_id, attempt + 1);
364
365                let delay_ms =
366                    task.retry.delay_ms as f64 * task.retry.backoff_multiplier.powi(attempt as i32);
367                let delay_ms = delay_ms.min(task.retry.max_delay_ms as f64) as u64;
368
369                tokio::time::sleep(Duration::from_millis(delay_ms)).await;
370            }
371
372            // Gather inputs from dependencies
373            let inputs = self.gather_inputs(task_id, dag, state).await?;
374
375            // Create execution context
376            let ctx = ExecutionContext {
377                execution_id: {
378                    let state_guard = state.read().await;
379                    state_guard.execution_id.clone()
380                },
381                task_id: task_id.to_string(),
382                state: Arc::clone(state),
383                inputs,
384            };
385
386            // Execute task with timeout
387            let task_timeout = Duration::from_secs(task.timeout_secs.unwrap_or(300));
388            let execute_result = timeout(task_timeout, executor.execute(task, &ctx)).await;
389
390            match execute_result {
391                Ok(Ok(output)) => {
392                    // Task succeeded
393                    let mut state_guard = state.write().await;
394                    state_guard.complete_task(task_id, output.data)?;
395
396                    for log in output.logs {
397                        state_guard.add_task_log(task_id, log)?;
398                    }
399
400                    info!("Task {} completed successfully", task_id);
401                    return Ok(());
402                }
403                Ok(Err(e)) => {
404                    warn!("Task {} failed: {}", task_id, e);
405                    last_error = Some(e);
406                }
407                Err(_) => {
408                    let timeout_error =
409                        WorkflowError::task_timeout(task_id, task_timeout.as_secs());
410                    warn!("Task {} timed out", task_id);
411                    last_error = Some(timeout_error);
412                }
413            }
414        }
415
416        // All attempts failed
417        let error = last_error.unwrap_or_else(|| WorkflowError::execution("Unknown error"));
418        let mut state_guard = state.write().await;
419        state_guard.fail_task(task_id, error.to_string())?;
420
421        error!("Task {} failed after {} attempts", task_id, max_attempts);
422        Err(error)
423    }
424
425    /// Check if all task dependencies are met.
426    async fn check_dependencies(
427        &self,
428        task_id: &str,
429        dag: &WorkflowDag,
430        state: &Arc<RwLock<WorkflowState>>,
431    ) -> Result<bool> {
432        let dependencies = dag.get_dependencies(task_id);
433        let state_guard = state.read().await;
434
435        for dep_id in dependencies {
436            if let Some(dep_state) = state_guard.get_task_state(&dep_id) {
437                if dep_state.status != TaskStatus::Completed {
438                    return Ok(false);
439                }
440            } else {
441                return Ok(false);
442            }
443        }
444
445        Ok(true)
446    }
447
448    /// Gather inputs from task dependencies.
449    async fn gather_inputs(
450        &self,
451        task_id: &str,
452        dag: &WorkflowDag,
453        state: &Arc<RwLock<WorkflowState>>,
454    ) -> Result<std::collections::HashMap<String, serde_json::Value>> {
455        let dependencies = dag.get_dependencies(task_id);
456        let state_guard = state.read().await;
457        let mut inputs = std::collections::HashMap::new();
458
459        for dep_id in dependencies {
460            if let Some(dep_state) = state_guard.get_task_state(&dep_id) {
461                if let Some(ref output) = dep_state.output {
462                    inputs.insert(dep_id.clone(), output.clone());
463                }
464            }
465        }
466
467        Ok(inputs)
468    }
469
470    /// Resume a workflow from a saved checkpoint.
471    ///
472    /// This method reconstructs the DAG from the checkpoint and continues
473    /// execution from where it left off, handling:
474    /// - Completed tasks (skipped)
475    /// - Interrupted tasks (reset to pending for retry)
476    /// - Failed tasks (depending on configuration)
477    /// - Pending tasks (executed normally)
478    pub async fn resume(&self, execution_id: String) -> Result<WorkflowState> {
479        let persistence = self
480            .persistence
481            .as_ref()
482            .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
483
484        // Try to load checkpoint first (includes DAG)
485        let mut checkpoint = persistence.load_checkpoint(&execution_id).await.map_err(|e| {
486            WorkflowError::state(format!(
487                "Failed to load checkpoint for recovery: {}. Ensure checkpointing was enabled during execution.",
488                e
489            ))
490        })?;
491
492        if checkpoint.state.is_terminal() {
493            return Err(WorkflowError::state("Cannot resume a terminal workflow"));
494        }
495
496        info!(
497            "Resuming workflow execution: execution_id={}, checkpoint_sequence={}",
498            execution_id, checkpoint.sequence
499        );
500
501        // Prepare checkpoint for resumption (reset interrupted tasks)
502        checkpoint.prepare_for_resume()?;
503
504        // Update checkpoint sequence for this executor
505        self.checkpoint_sequence
506            .store(checkpoint.sequence + 1, Ordering::SeqCst);
507
508        // Resume execution with recovered DAG and state
509        self.resume_from_checkpoint(checkpoint).await
510    }
511
512    /// Resume execution from a checkpoint.
513    async fn resume_from_checkpoint(
514        &self,
515        checkpoint: WorkflowCheckpoint,
516    ) -> Result<WorkflowState> {
517        let dag = checkpoint.dag.clone();
518
519        // Log recovery information (before moving state)
520        let completed = checkpoint.get_completed_tasks();
521        let pending = checkpoint.get_pending_tasks();
522        let interrupted = checkpoint.get_interrupted_tasks();
523        let failed = checkpoint.get_failed_tasks();
524
525        let mut state = checkpoint.state;
526
527        // Ensure workflow is in running state
528        if state.status != WorkflowStatus::Running {
529            state.status = WorkflowStatus::Running;
530        }
531
532        info!(
533            "Recovery state: {} completed, {} pending, {} interrupted, {} failed",
534            completed.len(),
535            pending.len(),
536            interrupted.len(),
537            failed.len()
538        );
539
540        // Save state update
541        if let Some(ref persistence) = self.persistence {
542            persistence.save(&state).await?;
543        }
544
545        let state_arc = Arc::new(RwLock::new(state));
546
547        // Create execution plan from DAG
548        let execution_plan = create_execution_plan(&dag)?;
549
550        info!("Resuming execution with {} levels", execution_plan.len());
551
552        // Execute tasks level by level, skipping completed ones
553        for (level_idx, level) in execution_plan.iter().enumerate() {
554            // Filter out already completed or skipped tasks
555            let tasks_to_execute: Vec<String> = {
556                let state_guard = state_arc.read().await;
557                level
558                    .iter()
559                    .filter(|task_id| {
560                        state_guard
561                            .get_task_state(task_id)
562                            .map(|ts| {
563                                !matches!(ts.status, TaskStatus::Completed | TaskStatus::Skipped)
564                            })
565                            .unwrap_or(true)
566                    })
567                    .cloned()
568                    .collect()
569            };
570
571            if tasks_to_execute.is_empty() {
572                debug!("Level {} has no tasks to execute, skipping", level_idx);
573                continue;
574            }
575
576            info!(
577                "Resuming level {} with {} tasks (skipping {} completed)",
578                level_idx,
579                tasks_to_execute.len(),
580                level.len() - tasks_to_execute.len()
581            );
582
583            let results = self
584                .execute_level(&dag, &state_arc, &tasks_to_execute)
585                .await;
586
587            // Save checkpoint after each level
588            {
589                let state_guard = state_arc.read().await;
590                self.maybe_save_checkpoint(&state_guard, &dag).await?;
591            }
592
593            // Check for failures
594            let failed_tasks: Vec<_> = results
595                .iter()
596                .filter_map(|(task_id, result)| {
597                    if result.is_err() {
598                        Some(task_id.clone())
599                    } else {
600                        None
601                    }
602                })
603                .collect();
604
605            if !failed_tasks.is_empty() {
606                error!("Tasks failed during resume: {:?}", failed_tasks);
607
608                if self.config.stop_on_failure {
609                    warn!("Stopping resumed workflow execution due to failures");
610                    let mut state_guard = state_arc.write().await;
611                    state_guard.fail();
612
613                    if let Some(ref persistence) = self.persistence {
614                        persistence.save(&state_guard).await?;
615                    }
616
617                    // Save final checkpoint on failure
618                    self.save_checkpoint_now(&state_guard, &dag).await?;
619
620                    drop(state_guard);
621
622                    return Ok(Arc::try_unwrap(state_arc)
623                        .map(|rw| rw.into_inner())
624                        .unwrap_or_else(|arc| {
625                            tokio::task::block_in_place(|| arc.blocking_read().clone())
626                        }));
627                }
628            }
629        }
630
631        // Complete workflow
632        let mut state_guard = state_arc.write().await;
633
634        // Check if all tasks completed successfully
635        let all_completed = state_guard
636            .task_states
637            .values()
638            .all(|ts| ts.status == TaskStatus::Completed || ts.status == TaskStatus::Skipped);
639
640        if all_completed {
641            state_guard.complete();
642        } else {
643            state_guard.fail();
644        }
645
646        // Save final state
647        if let Some(ref persistence) = self.persistence {
648            persistence.save(&state_guard).await?;
649        }
650
651        // Save final checkpoint
652        self.save_checkpoint_now(&state_guard, &dag).await?;
653
654        info!(
655            "Resumed workflow execution completed: status={:?}",
656            state_guard.status
657        );
658
659        drop(state_guard);
660
661        Ok(Arc::try_unwrap(state_arc)
662            .map(|rw| rw.into_inner())
663            .unwrap_or_else(|arc| tokio::task::block_in_place(|| arc.blocking_read().clone())))
664    }
665
666    /// Resume a workflow from a specific checkpoint sequence.
667    pub async fn resume_from_sequence(
668        &self,
669        execution_id: String,
670        sequence: u64,
671    ) -> Result<WorkflowState> {
672        let persistence = self
673            .persistence
674            .as_ref()
675            .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
676
677        let mut checkpoint = persistence
678            .load_checkpoint_by_sequence(&execution_id, sequence)
679            .await?;
680
681        if checkpoint.state.is_terminal() {
682            return Err(WorkflowError::state("Cannot resume a terminal workflow"));
683        }
684
685        info!(
686            "Resuming workflow from specific checkpoint: execution_id={}, sequence={}",
687            execution_id, sequence
688        );
689
690        // Prepare checkpoint for resumption
691        checkpoint.prepare_for_resume()?;
692
693        // Update checkpoint sequence
694        self.checkpoint_sequence
695            .store(sequence + 1, Ordering::SeqCst);
696
697        self.resume_from_checkpoint(checkpoint).await
698    }
699
700    /// Get recovery information for an execution.
701    pub async fn get_recovery_info(&self, execution_id: &str) -> Result<RecoveryInfo> {
702        let persistence = self
703            .persistence
704            .as_ref()
705            .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
706
707        let checkpoint = persistence.load_checkpoint(execution_id).await?;
708
709        Ok(RecoveryInfo {
710            execution_id: execution_id.to_string(),
711            checkpoint_sequence: checkpoint.sequence,
712            checkpoint_created_at: checkpoint.created_at,
713            workflow_status: checkpoint.state.status,
714            completed_tasks: checkpoint.get_completed_tasks(),
715            pending_tasks: checkpoint.get_pending_tasks(),
716            interrupted_tasks: checkpoint.get_interrupted_tasks(),
717            failed_tasks: checkpoint.get_failed_tasks(),
718            skipped_tasks: checkpoint.get_skipped_tasks(),
719            can_resume: !checkpoint.state.is_terminal(),
720        })
721    }
722
723    /// List available checkpoints for an execution.
724    pub async fn list_checkpoints(&self, execution_id: &str) -> Result<Vec<u64>> {
725        let persistence = self
726            .persistence
727            .as_ref()
728            .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
729
730        persistence.list_checkpoints(execution_id).await
731    }
732
733    /// Clean up old checkpoints, keeping only the latest N.
734    pub async fn cleanup_checkpoints(
735        &self,
736        execution_id: &str,
737        keep_count: usize,
738    ) -> Result<usize> {
739        let persistence = self
740            .persistence
741            .as_ref()
742            .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
743
744        let checkpoints = persistence.list_checkpoints(execution_id).await?;
745
746        if checkpoints.len() <= keep_count {
747            return Ok(0);
748        }
749
750        let to_delete = checkpoints.len() - keep_count;
751        let mut deleted = 0;
752
753        for seq in checkpoints.iter().take(to_delete) {
754            if persistence
755                .delete_checkpoint(execution_id, *seq)
756                .await
757                .is_ok()
758            {
759                deleted += 1;
760            }
761        }
762
763        Ok(deleted)
764    }
765}
766
767/// Information about workflow recovery state.
768#[derive(Debug, Clone)]
769pub struct RecoveryInfo {
770    /// Execution ID.
771    pub execution_id: String,
772    /// Latest checkpoint sequence number.
773    pub checkpoint_sequence: u64,
774    /// When the checkpoint was created.
775    pub checkpoint_created_at: chrono::DateTime<chrono::Utc>,
776    /// Current workflow status.
777    pub workflow_status: WorkflowStatus,
778    /// Tasks that completed successfully.
779    pub completed_tasks: Vec<String>,
780    /// Tasks that are pending execution.
781    pub pending_tasks: Vec<String>,
782    /// Tasks that were interrupted (running when checkpoint saved).
783    pub interrupted_tasks: Vec<String>,
784    /// Tasks that failed.
785    pub failed_tasks: Vec<String>,
786    /// Tasks that were skipped.
787    pub skipped_tasks: Vec<String>,
788    /// Whether the workflow can be resumed.
789    pub can_resume: bool,
790}
791
792#[cfg(test)]
793mod tests {
794    use super::*;
795    use crate::dag::graph::{ResourceRequirements, RetryPolicy};
796    use crate::engine::state::WorkflowStatus;
797    use std::collections::HashMap;
798
799    struct DummyExecutor;
800
801    #[async_trait]
802    impl TaskExecutor for DummyExecutor {
803        async fn execute(
804            &self,
805            _task: &TaskNode,
806            _context: &ExecutionContext,
807        ) -> Result<TaskOutput> {
808            Ok(TaskOutput {
809                data: Some(serde_json::json!({"result": "success"})),
810                logs: vec!["Task executed".to_string()],
811            })
812        }
813    }
814
815    fn create_test_task(id: &str) -> TaskNode {
816        TaskNode {
817            id: id.to_string(),
818            name: id.to_string(),
819            description: None,
820            config: serde_json::json!({}),
821            retry: RetryPolicy::default(),
822            timeout_secs: Some(60),
823            resources: ResourceRequirements::default(),
824            metadata: HashMap::new(),
825        }
826    }
827
828    #[tokio::test]
829    async fn test_simple_workflow() {
830        let mut dag = WorkflowDag::new();
831        dag.add_task(create_test_task("task1")).ok();
832
833        let executor = WorkflowExecutor::new(ExecutorConfig::default(), DummyExecutor);
834
835        let result = executor
836            .execute("wf1".to_string(), "exec1".to_string(), dag)
837            .await;
838
839        assert!(result.is_ok());
840        let state = result.expect("Expected workflow state");
841        assert_eq!(state.status, WorkflowStatus::Completed);
842    }
843}