1use crate::dag::{ResourcePool, TaskNode, WorkflowDag, create_execution_plan};
4use crate::engine::state::{
5 StatePersistence, TaskStatus, WorkflowCheckpoint, WorkflowState, WorkflowStatus,
6};
7use crate::error::{Result, WorkflowError};
8use async_trait::async_trait;
9use std::sync::Arc;
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::time::Duration;
12use tokio::sync::{RwLock, Semaphore};
13use tokio::time::timeout;
14use tracing::{debug, error, info, warn};
15
16#[async_trait]
18pub trait TaskExecutor: Send + Sync {
19 async fn execute(&self, task: &TaskNode, context: &ExecutionContext) -> Result<TaskOutput>;
21}
22
23#[derive(Debug, Clone)]
25pub struct ExecutionContext {
26 pub execution_id: String,
28 pub task_id: String,
30 pub state: Arc<RwLock<WorkflowState>>,
32 pub inputs: std::collections::HashMap<String, serde_json::Value>,
34}
35
36#[derive(Debug, Clone)]
38pub struct TaskOutput {
39 pub data: Option<serde_json::Value>,
41 pub logs: Vec<String>,
43}
44
45#[derive(Debug, Clone)]
47pub struct ExecutorConfig {
48 pub max_concurrent_tasks: usize,
50 pub enable_persistence: bool,
52 pub state_dir: String,
54 pub resource_pool: ResourcePool,
56 pub retry_on_failure: bool,
58 pub stop_on_failure: bool,
60 pub checkpoint_interval: usize,
62 pub enable_checkpointing: bool,
64}
65
66impl Default for ExecutorConfig {
67 fn default() -> Self {
68 Self {
69 max_concurrent_tasks: 10,
70 enable_persistence: true,
71 state_dir: "/tmp/oxigdal-workflow".to_string(),
72 resource_pool: ResourcePool::default(),
73 retry_on_failure: true,
74 stop_on_failure: false,
75 checkpoint_interval: 1, enable_checkpointing: true,
77 }
78 }
79}
80
81pub struct WorkflowExecutor<E: TaskExecutor> {
83 config: ExecutorConfig,
85 task_executor: Arc<E>,
87 persistence: Option<StatePersistence>,
89 _semaphore: Arc<Semaphore>,
91 checkpoint_sequence: AtomicU64,
93 tasks_since_checkpoint: AtomicU64,
95}
96
97impl<E: TaskExecutor> WorkflowExecutor<E> {
98 pub fn new(config: ExecutorConfig, task_executor: E) -> Self {
100 let semaphore = Arc::new(Semaphore::new(config.max_concurrent_tasks));
101 let persistence = if config.enable_persistence {
102 Some(StatePersistence::new(config.state_dir.clone()))
103 } else {
104 None
105 };
106
107 Self {
108 config,
109 task_executor: Arc::new(task_executor),
110 persistence,
111 _semaphore: semaphore,
112 checkpoint_sequence: AtomicU64::new(0),
113 tasks_since_checkpoint: AtomicU64::new(0),
114 }
115 }
116
117 async fn maybe_save_checkpoint(&self, state: &WorkflowState, dag: &WorkflowDag) -> Result<()> {
119 if !self.config.enable_checkpointing {
120 return Ok(());
121 }
122
123 let persistence = match &self.persistence {
124 Some(p) => p,
125 None => return Ok(()),
126 };
127
128 let tasks_completed = self.tasks_since_checkpoint.fetch_add(1, Ordering::SeqCst) + 1;
129
130 if tasks_completed >= self.config.checkpoint_interval as u64 {
131 self.tasks_since_checkpoint.store(0, Ordering::SeqCst);
132 let seq = self.checkpoint_sequence.fetch_add(1, Ordering::SeqCst);
133
134 let checkpoint = WorkflowCheckpoint::new(state.clone(), dag.clone(), seq);
135 persistence.save_checkpoint(&checkpoint).await?;
136
137 debug!(
138 "Saved checkpoint {} for execution {}",
139 seq, state.execution_id
140 );
141 }
142
143 Ok(())
144 }
145
146 async fn save_checkpoint_now(&self, state: &WorkflowState, dag: &WorkflowDag) -> Result<()> {
148 if !self.config.enable_checkpointing {
149 return Ok(());
150 }
151
152 let persistence = match &self.persistence {
153 Some(p) => p,
154 None => return Ok(()),
155 };
156
157 self.tasks_since_checkpoint.store(0, Ordering::SeqCst);
158 let seq = self.checkpoint_sequence.fetch_add(1, Ordering::SeqCst);
159
160 let checkpoint = WorkflowCheckpoint::new(state.clone(), dag.clone(), seq);
161 persistence.save_checkpoint(&checkpoint).await?;
162
163 info!(
164 "Saved checkpoint {} for execution {}",
165 seq, state.execution_id
166 );
167 Ok(())
168 }
169
170 pub async fn execute(
172 &self,
173 workflow_id: String,
174 execution_id: String,
175 dag: WorkflowDag,
176 ) -> Result<WorkflowState> {
177 info!(
178 "Starting workflow execution: workflow_id={}, execution_id={}",
179 workflow_id, execution_id
180 );
181
182 dag.validate()?;
184
185 let mut state = WorkflowState::new(workflow_id.clone(), execution_id.clone(), workflow_id);
187
188 for task in dag.tasks() {
190 state.init_task(task.id.clone());
191 }
192
193 state.start();
194
195 if let Some(ref persistence) = self.persistence {
197 persistence.save(&state).await?;
198 }
199
200 self.save_checkpoint_now(&state, &dag).await?;
202
203 let state_arc = Arc::new(RwLock::new(state));
204
205 let execution_plan = create_execution_plan(&dag)?;
207
208 info!(
209 "Execution plan created with {} levels",
210 execution_plan.len()
211 );
212
213 for (level_idx, level) in execution_plan.iter().enumerate() {
215 info!("Executing level {} with {} tasks", level_idx, level.len());
216
217 let results = self.execute_level(&dag, &state_arc, level).await;
218
219 {
221 let state_guard = state_arc.read().await;
222 self.maybe_save_checkpoint(&state_guard, &dag).await?;
223 }
224
225 let failed_tasks: Vec<_> = results
227 .iter()
228 .filter_map(|(task_id, result)| {
229 if result.is_err() {
230 Some(task_id.clone())
231 } else {
232 None
233 }
234 })
235 .collect();
236
237 if !failed_tasks.is_empty() {
238 error!("Tasks failed: {:?}", failed_tasks);
239
240 if self.config.stop_on_failure {
241 warn!("Stopping workflow execution due to failures");
242 let mut state_guard = state_arc.write().await;
243 state_guard.fail();
244
245 if let Some(ref persistence) = self.persistence {
246 persistence.save(&state_guard).await?;
247 }
248
249 self.save_checkpoint_now(&state_guard, &dag).await?;
251
252 drop(state_guard);
253
254 return Ok(Arc::try_unwrap(state_arc)
255 .map(|rw| rw.into_inner())
256 .unwrap_or_else(|arc| {
257 tokio::task::block_in_place(|| arc.blocking_read().clone())
258 }));
259 }
260 }
261 }
262
263 let mut state_guard = state_arc.write().await;
265
266 let all_completed = state_guard
268 .task_states
269 .values()
270 .all(|ts| ts.status == TaskStatus::Completed || ts.status == TaskStatus::Skipped);
271
272 if all_completed {
273 state_guard.complete();
274 } else {
275 state_guard.fail();
276 }
277
278 if let Some(ref persistence) = self.persistence {
280 persistence.save(&state_guard).await?;
281 }
282
283 self.save_checkpoint_now(&state_guard, &dag).await?;
285
286 info!(
287 "Workflow execution completed: status={:?}",
288 state_guard.status
289 );
290
291 drop(state_guard);
292
293 Ok(Arc::try_unwrap(state_arc)
294 .map(|rw| rw.into_inner())
295 .unwrap_or_else(|arc| tokio::task::block_in_place(|| arc.blocking_read().clone())))
296 }
297
298 async fn execute_level(
300 &self,
301 dag: &WorkflowDag,
302 state: &Arc<RwLock<WorkflowState>>,
303 level: &[String],
304 ) -> Vec<(String, Result<()>)> {
305 let mut results = Vec::new();
306
307 for task_id in level {
308 let result = self
309 .execute_task(
310 task_id,
311 dag,
312 state,
313 &*self.task_executor,
314 self.config.retry_on_failure,
315 )
316 .await;
317 results.push((task_id.clone(), result));
318 }
319
320 results
321 }
322
323 async fn execute_task(
325 &self,
326 task_id: &str,
327 dag: &WorkflowDag,
328 state: &Arc<RwLock<WorkflowState>>,
329 executor: &E,
330 retry_on_failure: bool,
331 ) -> Result<()> {
332 let task = dag
333 .get_task(task_id)
334 .ok_or_else(|| WorkflowError::not_found(format!("Task '{}'", task_id)))?;
335
336 debug!("Executing task: {}", task_id);
337
338 if !self.check_dependencies(task_id, dag, state).await? {
340 warn!("Skipping task {} due to failed dependencies", task_id);
341 let mut state_guard = state.write().await;
342 state_guard.skip_task(task_id)?;
343 return Ok(());
344 }
345
346 {
348 let mut state_guard = state.write().await;
349 state_guard.start_task(task_id)?;
350 }
351
352 let max_attempts = if retry_on_failure {
354 task.retry.max_attempts
355 } else {
356 1
357 };
358
359 let mut last_error = None;
360
361 for attempt in 0..max_attempts {
362 if attempt > 0 {
363 debug!("Retrying task {} (attempt {})", task_id, attempt + 1);
364
365 let delay_ms =
366 task.retry.delay_ms as f64 * task.retry.backoff_multiplier.powi(attempt as i32);
367 let delay_ms = delay_ms.min(task.retry.max_delay_ms as f64) as u64;
368
369 tokio::time::sleep(Duration::from_millis(delay_ms)).await;
370 }
371
372 let inputs = self.gather_inputs(task_id, dag, state).await?;
374
375 let ctx = ExecutionContext {
377 execution_id: {
378 let state_guard = state.read().await;
379 state_guard.execution_id.clone()
380 },
381 task_id: task_id.to_string(),
382 state: Arc::clone(state),
383 inputs,
384 };
385
386 let task_timeout = Duration::from_secs(task.timeout_secs.unwrap_or(300));
388 let execute_result = timeout(task_timeout, executor.execute(task, &ctx)).await;
389
390 match execute_result {
391 Ok(Ok(output)) => {
392 let mut state_guard = state.write().await;
394 state_guard.complete_task(task_id, output.data)?;
395
396 for log in output.logs {
397 state_guard.add_task_log(task_id, log)?;
398 }
399
400 info!("Task {} completed successfully", task_id);
401 return Ok(());
402 }
403 Ok(Err(e)) => {
404 warn!("Task {} failed: {}", task_id, e);
405 last_error = Some(e);
406 }
407 Err(_) => {
408 let timeout_error =
409 WorkflowError::task_timeout(task_id, task_timeout.as_secs());
410 warn!("Task {} timed out", task_id);
411 last_error = Some(timeout_error);
412 }
413 }
414 }
415
416 let error = last_error.unwrap_or_else(|| WorkflowError::execution("Unknown error"));
418 let mut state_guard = state.write().await;
419 state_guard.fail_task(task_id, error.to_string())?;
420
421 error!("Task {} failed after {} attempts", task_id, max_attempts);
422 Err(error)
423 }
424
425 async fn check_dependencies(
427 &self,
428 task_id: &str,
429 dag: &WorkflowDag,
430 state: &Arc<RwLock<WorkflowState>>,
431 ) -> Result<bool> {
432 let dependencies = dag.get_dependencies(task_id);
433 let state_guard = state.read().await;
434
435 for dep_id in dependencies {
436 if let Some(dep_state) = state_guard.get_task_state(&dep_id) {
437 if dep_state.status != TaskStatus::Completed {
438 return Ok(false);
439 }
440 } else {
441 return Ok(false);
442 }
443 }
444
445 Ok(true)
446 }
447
448 async fn gather_inputs(
450 &self,
451 task_id: &str,
452 dag: &WorkflowDag,
453 state: &Arc<RwLock<WorkflowState>>,
454 ) -> Result<std::collections::HashMap<String, serde_json::Value>> {
455 let dependencies = dag.get_dependencies(task_id);
456 let state_guard = state.read().await;
457 let mut inputs = std::collections::HashMap::new();
458
459 for dep_id in dependencies {
460 if let Some(dep_state) = state_guard.get_task_state(&dep_id) {
461 if let Some(ref output) = dep_state.output {
462 inputs.insert(dep_id.clone(), output.clone());
463 }
464 }
465 }
466
467 Ok(inputs)
468 }
469
470 pub async fn resume(&self, execution_id: String) -> Result<WorkflowState> {
479 let persistence = self
480 .persistence
481 .as_ref()
482 .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
483
484 let mut checkpoint = persistence.load_checkpoint(&execution_id).await.map_err(|e| {
486 WorkflowError::state(format!(
487 "Failed to load checkpoint for recovery: {}. Ensure checkpointing was enabled during execution.",
488 e
489 ))
490 })?;
491
492 if checkpoint.state.is_terminal() {
493 return Err(WorkflowError::state("Cannot resume a terminal workflow"));
494 }
495
496 info!(
497 "Resuming workflow execution: execution_id={}, checkpoint_sequence={}",
498 execution_id, checkpoint.sequence
499 );
500
501 checkpoint.prepare_for_resume()?;
503
504 self.checkpoint_sequence
506 .store(checkpoint.sequence + 1, Ordering::SeqCst);
507
508 self.resume_from_checkpoint(checkpoint).await
510 }
511
512 async fn resume_from_checkpoint(
514 &self,
515 checkpoint: WorkflowCheckpoint,
516 ) -> Result<WorkflowState> {
517 let dag = checkpoint.dag.clone();
518
519 let completed = checkpoint.get_completed_tasks();
521 let pending = checkpoint.get_pending_tasks();
522 let interrupted = checkpoint.get_interrupted_tasks();
523 let failed = checkpoint.get_failed_tasks();
524
525 let mut state = checkpoint.state;
526
527 if state.status != WorkflowStatus::Running {
529 state.status = WorkflowStatus::Running;
530 }
531
532 info!(
533 "Recovery state: {} completed, {} pending, {} interrupted, {} failed",
534 completed.len(),
535 pending.len(),
536 interrupted.len(),
537 failed.len()
538 );
539
540 if let Some(ref persistence) = self.persistence {
542 persistence.save(&state).await?;
543 }
544
545 let state_arc = Arc::new(RwLock::new(state));
546
547 let execution_plan = create_execution_plan(&dag)?;
549
550 info!("Resuming execution with {} levels", execution_plan.len());
551
552 for (level_idx, level) in execution_plan.iter().enumerate() {
554 let tasks_to_execute: Vec<String> = {
556 let state_guard = state_arc.read().await;
557 level
558 .iter()
559 .filter(|task_id| {
560 state_guard
561 .get_task_state(task_id)
562 .map(|ts| {
563 !matches!(ts.status, TaskStatus::Completed | TaskStatus::Skipped)
564 })
565 .unwrap_or(true)
566 })
567 .cloned()
568 .collect()
569 };
570
571 if tasks_to_execute.is_empty() {
572 debug!("Level {} has no tasks to execute, skipping", level_idx);
573 continue;
574 }
575
576 info!(
577 "Resuming level {} with {} tasks (skipping {} completed)",
578 level_idx,
579 tasks_to_execute.len(),
580 level.len() - tasks_to_execute.len()
581 );
582
583 let results = self
584 .execute_level(&dag, &state_arc, &tasks_to_execute)
585 .await;
586
587 {
589 let state_guard = state_arc.read().await;
590 self.maybe_save_checkpoint(&state_guard, &dag).await?;
591 }
592
593 let failed_tasks: Vec<_> = results
595 .iter()
596 .filter_map(|(task_id, result)| {
597 if result.is_err() {
598 Some(task_id.clone())
599 } else {
600 None
601 }
602 })
603 .collect();
604
605 if !failed_tasks.is_empty() {
606 error!("Tasks failed during resume: {:?}", failed_tasks);
607
608 if self.config.stop_on_failure {
609 warn!("Stopping resumed workflow execution due to failures");
610 let mut state_guard = state_arc.write().await;
611 state_guard.fail();
612
613 if let Some(ref persistence) = self.persistence {
614 persistence.save(&state_guard).await?;
615 }
616
617 self.save_checkpoint_now(&state_guard, &dag).await?;
619
620 drop(state_guard);
621
622 return Ok(Arc::try_unwrap(state_arc)
623 .map(|rw| rw.into_inner())
624 .unwrap_or_else(|arc| {
625 tokio::task::block_in_place(|| arc.blocking_read().clone())
626 }));
627 }
628 }
629 }
630
631 let mut state_guard = state_arc.write().await;
633
634 let all_completed = state_guard
636 .task_states
637 .values()
638 .all(|ts| ts.status == TaskStatus::Completed || ts.status == TaskStatus::Skipped);
639
640 if all_completed {
641 state_guard.complete();
642 } else {
643 state_guard.fail();
644 }
645
646 if let Some(ref persistence) = self.persistence {
648 persistence.save(&state_guard).await?;
649 }
650
651 self.save_checkpoint_now(&state_guard, &dag).await?;
653
654 info!(
655 "Resumed workflow execution completed: status={:?}",
656 state_guard.status
657 );
658
659 drop(state_guard);
660
661 Ok(Arc::try_unwrap(state_arc)
662 .map(|rw| rw.into_inner())
663 .unwrap_or_else(|arc| tokio::task::block_in_place(|| arc.blocking_read().clone())))
664 }
665
666 pub async fn resume_from_sequence(
668 &self,
669 execution_id: String,
670 sequence: u64,
671 ) -> Result<WorkflowState> {
672 let persistence = self
673 .persistence
674 .as_ref()
675 .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
676
677 let mut checkpoint = persistence
678 .load_checkpoint_by_sequence(&execution_id, sequence)
679 .await?;
680
681 if checkpoint.state.is_terminal() {
682 return Err(WorkflowError::state("Cannot resume a terminal workflow"));
683 }
684
685 info!(
686 "Resuming workflow from specific checkpoint: execution_id={}, sequence={}",
687 execution_id, sequence
688 );
689
690 checkpoint.prepare_for_resume()?;
692
693 self.checkpoint_sequence
695 .store(sequence + 1, Ordering::SeqCst);
696
697 self.resume_from_checkpoint(checkpoint).await
698 }
699
700 pub async fn get_recovery_info(&self, execution_id: &str) -> Result<RecoveryInfo> {
702 let persistence = self
703 .persistence
704 .as_ref()
705 .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
706
707 let checkpoint = persistence.load_checkpoint(execution_id).await?;
708
709 Ok(RecoveryInfo {
710 execution_id: execution_id.to_string(),
711 checkpoint_sequence: checkpoint.sequence,
712 checkpoint_created_at: checkpoint.created_at,
713 workflow_status: checkpoint.state.status,
714 completed_tasks: checkpoint.get_completed_tasks(),
715 pending_tasks: checkpoint.get_pending_tasks(),
716 interrupted_tasks: checkpoint.get_interrupted_tasks(),
717 failed_tasks: checkpoint.get_failed_tasks(),
718 skipped_tasks: checkpoint.get_skipped_tasks(),
719 can_resume: !checkpoint.state.is_terminal(),
720 })
721 }
722
723 pub async fn list_checkpoints(&self, execution_id: &str) -> Result<Vec<u64>> {
725 let persistence = self
726 .persistence
727 .as_ref()
728 .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
729
730 persistence.list_checkpoints(execution_id).await
731 }
732
733 pub async fn cleanup_checkpoints(
735 &self,
736 execution_id: &str,
737 keep_count: usize,
738 ) -> Result<usize> {
739 let persistence = self
740 .persistence
741 .as_ref()
742 .ok_or_else(|| WorkflowError::state("Persistence is not enabled"))?;
743
744 let checkpoints = persistence.list_checkpoints(execution_id).await?;
745
746 if checkpoints.len() <= keep_count {
747 return Ok(0);
748 }
749
750 let to_delete = checkpoints.len() - keep_count;
751 let mut deleted = 0;
752
753 for seq in checkpoints.iter().take(to_delete) {
754 if persistence
755 .delete_checkpoint(execution_id, *seq)
756 .await
757 .is_ok()
758 {
759 deleted += 1;
760 }
761 }
762
763 Ok(deleted)
764 }
765}
766
767#[derive(Debug, Clone)]
769pub struct RecoveryInfo {
770 pub execution_id: String,
772 pub checkpoint_sequence: u64,
774 pub checkpoint_created_at: chrono::DateTime<chrono::Utc>,
776 pub workflow_status: WorkflowStatus,
778 pub completed_tasks: Vec<String>,
780 pub pending_tasks: Vec<String>,
782 pub interrupted_tasks: Vec<String>,
784 pub failed_tasks: Vec<String>,
786 pub skipped_tasks: Vec<String>,
788 pub can_resume: bool,
790}
791
792#[cfg(test)]
793mod tests {
794 use super::*;
795 use crate::dag::graph::{ResourceRequirements, RetryPolicy};
796 use crate::engine::state::WorkflowStatus;
797 use std::collections::HashMap;
798
799 struct DummyExecutor;
800
801 #[async_trait]
802 impl TaskExecutor for DummyExecutor {
803 async fn execute(
804 &self,
805 _task: &TaskNode,
806 _context: &ExecutionContext,
807 ) -> Result<TaskOutput> {
808 Ok(TaskOutput {
809 data: Some(serde_json::json!({"result": "success"})),
810 logs: vec!["Task executed".to_string()],
811 })
812 }
813 }
814
815 fn create_test_task(id: &str) -> TaskNode {
816 TaskNode {
817 id: id.to_string(),
818 name: id.to_string(),
819 description: None,
820 config: serde_json::json!({}),
821 retry: RetryPolicy::default(),
822 timeout_secs: Some(60),
823 resources: ResourceRequirements::default(),
824 metadata: HashMap::new(),
825 }
826 }
827
828 #[tokio::test]
829 async fn test_simple_workflow() {
830 let mut dag = WorkflowDag::new();
831 dag.add_task(create_test_task("task1")).ok();
832
833 let executor = WorkflowExecutor::new(ExecutorConfig::default(), DummyExecutor);
834
835 let result = executor
836 .execute("wf1".to_string(), "exec1".to_string(), dag)
837 .await;
838
839 assert!(result.is_ok());
840 let state = result.expect("Expected workflow state");
841 assert_eq!(state.status, WorkflowStatus::Completed);
842 }
843}