1use std::{
7 collections::{HashMap, VecDeque},
8 sync::Arc,
9 time::{Duration, Instant},
10};
11
12use tokio::{
13 sync::{mpsc, RwLock, Mutex},
14 time::{sleep, timeout},
15};
16use tracing::{debug, error, info, warn, instrument, span, Level};
17
18use crate::workflow::{
19 RollbackStrategy, StageId, Version,
20 WorkflowContext, WorkflowDefinition, WorkflowError, WorkflowEvent,
21 WorkflowHandle, WorkflowId, WorkflowMetrics, WorkflowResult, WorkflowStage,
22 WorkflowStatus, WorkflowRegistry, StateStore,
23};
24
25#[derive(Debug, Clone)]
27pub struct WorkflowEngineConfig {
28 pub max_concurrent_workflows: usize,
30 pub default_timeout: Duration,
32 pub enable_tracing: bool,
34 pub checkpoint_interval: Duration,
36 pub max_system_retries: u32,
38 pub worker_count: usize,
40}
41
42impl Default for WorkflowEngineConfig {
43 fn default() -> Self {
44 Self {
45 max_concurrent_workflows: 1000,
46 default_timeout: Duration::from_secs(300),
47 enable_tracing: true,
48 checkpoint_interval: Duration::from_secs(10),
49 max_system_retries: 3,
50 worker_count: 4,
51 }
52 }
53}
54
55pub struct WorkflowEngine {
57 config: WorkflowEngineConfig,
59 registry: Arc<WorkflowRegistry>,
61 state_store: Arc<dyn StateStore>,
63 executors: Arc<RwLock<HashMap<WorkflowId, WorkflowExecutor>>>,
65 event_queue: Arc<Mutex<VecDeque<(WorkflowId, WorkflowEvent)>>>,
67 shutdown_tx: mpsc::Sender<()>,
69 shutdown_rx: Arc<Mutex<mpsc::Receiver<()>>>,
70}
71
72impl WorkflowEngine {
73 pub fn new(
75 config: WorkflowEngineConfig,
76 registry: Arc<WorkflowRegistry>,
77 state_store: Arc<dyn StateStore>,
78 ) -> Self {
79 let (shutdown_tx, shutdown_rx) = mpsc::channel(1);
80
81 Self {
82 config,
83 registry,
84 state_store,
85 executors: Arc::new(RwLock::new(HashMap::new())),
86 event_queue: Arc::new(Mutex::new(VecDeque::new())),
87 shutdown_tx,
88 shutdown_rx: Arc::new(Mutex::new(shutdown_rx)),
89 }
90 }
91
92 pub async fn start(&self) -> Result<(), WorkflowError> {
94 info!("Starting workflow engine with {} workers", self.config.worker_count);
95
96 for worker_id in 0..self.config.worker_count {
98 let engine = self.clone();
99 tokio::spawn(async move {
100 engine.worker_loop(worker_id).await;
101 });
102 }
103
104 let engine = self.clone();
106 tokio::spawn(async move {
107 engine.checkpoint_loop().await;
108 });
109
110 Ok(())
111 }
112
113 pub async fn stop(&self) -> Result<(), WorkflowError> {
115 info!("Stopping workflow engine");
116
117 let _ = self.shutdown_tx.send(()).await;
119
120 let timeout_duration = Duration::from_secs(30);
122 let start = Instant::now();
123
124 loop {
125 let executors = self.executors.read().await;
126 if executors.is_empty() {
127 break;
128 }
129
130 if start.elapsed() > timeout_duration {
131 warn!("Timeout waiting for workflows to complete");
132 break;
133 }
134
135 drop(executors);
136 sleep(Duration::from_millis(100)).await;
137 }
138
139 Ok(())
140 }
141
142 #[instrument(skip(self, input))]
144 pub async fn start_workflow(
145 &self,
146 workflow_id: &str,
147 version: &Version,
148 input: HashMap<String, Vec<u8>>,
149 ) -> Result<WorkflowHandle, WorkflowError> {
150 let definition = self.registry.get(workflow_id, version).await
152 .ok_or_else(|| WorkflowError {
153 code: "WORKFLOW_NOT_FOUND".to_string(),
154 message: format!("Workflow {} version {} not found", workflow_id, version),
155 stage: None,
156 trace: None,
157 recovery_hints: vec!["Check workflow ID and version".to_string()],
158 })?;
159
160 let executors = self.executors.read().await;
162 if executors.len() >= self.config.max_concurrent_workflows {
163 return Err(WorkflowError {
164 code: "MAX_WORKFLOWS_REACHED".to_string(),
165 message: "Maximum concurrent workflows reached".to_string(),
166 stage: None,
167 trace: None,
168 recovery_hints: vec!["Wait for existing workflows to complete".to_string()],
169 });
170 }
171 drop(executors);
172
173 let instance_id = WorkflowId::generate();
175
176 let (event_tx, event_rx) = mpsc::channel(100);
178 let handle = WorkflowHandle::new(instance_id, event_tx);
179
180 let executor = WorkflowExecutor::new(
181 instance_id,
182 definition,
183 input,
184 event_rx,
185 handle.clone(),
186 self.state_store.clone(),
187 self.config.clone(),
188 );
189
190 let mut executors = self.executors.write().await;
192 executors.insert(instance_id, executor);
193
194 self.event_queue.lock().await.push_back((instance_id, WorkflowEvent::Start));
196
197 info!("Started workflow {} instance {}", workflow_id, instance_id);
198 Ok(handle)
199 }
200
201 pub async fn resume_workflow(
203 &self,
204 instance_id: WorkflowId,
205 ) -> Result<WorkflowHandle, WorkflowError> {
206 let state = self.state_store.load(&instance_id).await?;
208
209 let definition = self.registry.get(&state.workflow_id, &state.version).await
211 .ok_or_else(|| WorkflowError {
212 code: "WORKFLOW_NOT_FOUND".to_string(),
213 message: format!("Workflow {} version {} not found", state.workflow_id, state.version),
214 stage: None,
215 trace: None,
216 recovery_hints: vec!["Check workflow ID and version".to_string()],
217 })?;
218
219 let (event_tx, event_rx) = mpsc::channel(100);
221 let handle = WorkflowHandle::new(instance_id, event_tx);
222
223 let mut executor = WorkflowExecutor::new(
224 instance_id,
225 definition,
226 state.input.clone(),
227 event_rx,
228 handle.clone(),
229 self.state_store.clone(),
230 self.config.clone(),
231 );
232
233 executor.restore_state(state).await?;
235
236 let mut executors = self.executors.write().await;
238 executors.insert(instance_id, executor);
239
240 info!("Resumed workflow instance {}", instance_id);
241 Ok(handle)
242 }
243
244 async fn worker_loop(&self, worker_id: usize) {
246 let span = span!(Level::DEBUG, "workflow_worker", worker_id = worker_id);
247 let _enter = span.enter();
248
249 debug!("Worker {} started", worker_id);
250
251 loop {
252 if self.shutdown_rx.lock().await.try_recv().is_ok() {
254 debug!("Worker {} shutting down", worker_id);
255 break;
256 }
257
258 let event = {
260 let mut queue = self.event_queue.lock().await;
261 queue.pop_front()
262 };
263
264 if let Some((workflow_id, event)) = event {
265 if let Err(e) = self.process_event(workflow_id, event).await {
267 error!("Error processing event for workflow {}: {:?}", workflow_id, e);
268 }
269 } else {
270 sleep(Duration::from_millis(10)).await;
272 }
273 }
274 }
275
276 async fn process_event(
278 &self,
279 workflow_id: WorkflowId,
280 event: WorkflowEvent,
281 ) -> Result<(), WorkflowError> {
282 let mut executors = self.executors.write().await;
283
284 if let Some(executor) = executors.get_mut(&workflow_id) {
285 executor.process_event(event).await?;
287
288 let status = executor.handle.status().await;
290 match status {
291 WorkflowStatus::Completed { .. } |
292 WorkflowStatus::Failed { .. } |
293 WorkflowStatus::Cancelled => {
294 executors.remove(&workflow_id);
296 info!("Workflow {} completed with status: {:?}", workflow_id, status);
297 }
298 _ => {}
299 }
300 }
301
302 Ok(())
303 }
304
305 async fn checkpoint_loop(&self) {
307 let mut interval = tokio::time::interval(self.config.checkpoint_interval);
308
309 loop {
310 interval.tick().await;
311
312 if self.shutdown_rx.lock().await.try_recv().is_ok() {
314 break;
315 }
316
317 let executors = self.executors.read().await;
319 for (id, executor) in executors.iter() {
320 if let Err(e) = executor.checkpoint().await {
321 error!("Failed to checkpoint workflow {}: {:?}", id, e);
322 }
323 }
324 }
325 }
326}
327
328impl Clone for WorkflowEngine {
329 fn clone(&self) -> Self {
330 Self {
331 config: self.config.clone(),
332 registry: self.registry.clone(),
333 state_store: self.state_store.clone(),
334 executors: self.executors.clone(),
335 event_queue: self.event_queue.clone(),
336 shutdown_tx: self.shutdown_tx.clone(),
337 shutdown_rx: self.shutdown_rx.clone(),
338 }
339 }
340}
341
342struct WorkflowExecutor {
344 id: WorkflowId,
346 definition: WorkflowDefinition,
348 context: WorkflowContext,
350 event_rx: mpsc::Receiver<WorkflowEvent>,
352 handle: WorkflowHandle,
354 state_store: Arc<dyn StateStore>,
356 config: WorkflowEngineConfig,
358 retry_attempts: HashMap<StageId, u32>,
360 start_time: Instant,
362 last_checkpoint: Instant,
364}
365
366impl WorkflowExecutor {
367 fn new(
369 id: WorkflowId,
370 definition: WorkflowDefinition,
371 input: HashMap<String, Vec<u8>>,
372 event_rx: mpsc::Receiver<WorkflowEvent>,
373 handle: WorkflowHandle,
374 state_store: Arc<dyn StateStore>,
375 config: WorkflowEngineConfig,
376 ) -> Self {
377 let context = WorkflowContext {
378 workflow_id: id,
379 current_stage: definition.initial_stage.clone(),
380 state: input,
381 metrics: WorkflowMetrics::default(),
382 stage_start: Instant::now(),
383 };
384
385 Self {
386 id,
387 definition,
388 context,
389 event_rx,
390 handle,
391 state_store,
392 config,
393 retry_attempts: HashMap::new(),
394 start_time: Instant::now(),
395 last_checkpoint: Instant::now(),
396 }
397 }
398
399 async fn process_event(&mut self, event: WorkflowEvent) -> Result<(), WorkflowError> {
401 debug!("Processing event {:?} for workflow {}", event, self.id);
402
403 match event {
404 WorkflowEvent::Start => {
405 self.handle.update_status(WorkflowStatus::Running {
406 current_stage: self.definition.initial_stage.clone(),
407 }).await;
408 self.execute_stage(self.definition.initial_stage.clone()).await?;
409 }
410 WorkflowEvent::StageCompleted { stage_id } => {
411 self.handle_stage_completion(stage_id).await?;
412 }
413 WorkflowEvent::StageFailed { stage_id, error } => {
414 self.handle_stage_failure(stage_id, error).await?;
415 }
416 WorkflowEvent::Timeout { stage_id } => {
417 self.handle_stage_timeout(stage_id).await?;
418 }
419 WorkflowEvent::Cancel => {
420 self.handle_cancellation().await?;
421 }
422 _ => {}
423 }
424
425 Ok(())
426 }
427
428 async fn execute_stage(&mut self, stage_id: StageId) -> Result<(), WorkflowError> {
430 info!("Executing stage {} for workflow {}", stage_id, self.id);
431
432 let stage = self.definition.stages.iter()
434 .find(|s| s.id == stage_id)
435 .ok_or_else(|| WorkflowError {
436 code: "STAGE_NOT_FOUND".to_string(),
437 message: format!("Stage {} not found", stage_id),
438 stage: Some(stage_id.clone()),
439 trace: None,
440 recovery_hints: vec![],
441 })?
442 .clone();
443
444 self.context.current_stage = stage_id.clone();
446 self.context.stage_start = Instant::now();
447
448 if !self.check_preconditions(&stage).await? {
450 return Err(WorkflowError {
451 code: "PRECONDITION_FAILED".to_string(),
452 message: format!("Preconditions not met for stage {}", stage_id),
453 stage: Some(stage_id),
454 trace: None,
455 recovery_hints: vec!["Check stage preconditions".to_string()],
456 });
457 }
458
459 let stage_timeout = stage.max_duration
461 .or_else(|| self.definition.timeouts.get(&stage_id).cloned())
462 .unwrap_or(self.config.default_timeout);
463
464 let result = timeout(stage_timeout, self.execute_stage_actions(&stage)).await;
466
467 match result {
468 Ok(Ok(())) => {
469 if self.check_postconditions(&stage).await? {
471 self.handle.send_event(WorkflowEvent::StageCompleted {
472 stage_id: stage_id.clone(),
473 }).await?;
474 } else {
475 self.handle.send_event(WorkflowEvent::StageFailed {
476 stage_id: stage_id.clone(),
477 error: "Postconditions not met".to_string(),
478 }).await?;
479 }
480 }
481 Ok(Err(e)) => {
482 self.handle.send_event(WorkflowEvent::StageFailed {
483 stage_id: stage_id.clone(),
484 error: e.message.clone(),
485 }).await?;
486 }
487 Err(_) => {
488 self.handle.send_event(WorkflowEvent::Timeout {
489 stage_id: stage_id.clone(),
490 }).await?;
491 }
492 }
493
494 Ok(())
495 }
496
497 async fn execute_stage_actions(&mut self, stage: &WorkflowStage) -> Result<(), WorkflowError> {
499 for (i, action) in stage.actions.iter().enumerate() {
501 debug!("Executing action {} for stage {}", i, stage.id);
502
503 action.execute(&mut self.context).await?;
505
506 self.context.metrics.stages_executed += 1;
508 }
509
510 let duration = self.context.stage_start.elapsed();
512 self.context.metrics.stage_durations.insert(stage.id.clone(), duration);
513
514 Ok(())
515 }
516
517 async fn check_preconditions(&self, stage: &WorkflowStage) -> Result<bool, WorkflowError> {
519 for condition in &stage.preconditions {
520 if !condition.check(&self.context).await {
521 debug!("Precondition {} failed for stage {}", condition.description(), stage.id);
522 return Ok(false);
523 }
524 }
525 Ok(true)
526 }
527
528 async fn check_postconditions(&self, stage: &WorkflowStage) -> Result<bool, WorkflowError> {
530 for condition in &stage.postconditions {
531 if !condition.check(&self.context).await {
532 debug!("Postcondition {} failed for stage {}", condition.description(), stage.id);
533 return Ok(false);
534 }
535 }
536 Ok(true)
537 }
538
539 async fn handle_stage_completion(&mut self, stage_id: StageId) -> Result<(), WorkflowError> {
541 info!("Stage {} completed for workflow {}", stage_id, self.id);
542
543 self.retry_attempts.remove(&stage_id);
545
546 if self.definition.final_stages.contains(&stage_id) {
548 self.complete_workflow().await?;
549 return Ok(());
550 }
551
552 let event = WorkflowEvent::StageCompleted { stage_id: stage_id.clone() };
554 if let Some(next_stage) = self.definition.transitions.get(&(stage_id, event)) {
555 self.execute_stage(next_stage.clone()).await?;
556 } else {
557 self.complete_workflow().await?;
559 }
560
561 Ok(())
562 }
563
564 async fn handle_stage_failure(&mut self, stage_id: StageId, error: String) -> Result<(), WorkflowError> {
566 warn!("Stage {} failed for workflow {}: {}", stage_id, self.id, error);
567
568 self.context.metrics.error_count += 1;
570
571 if let Some(handler) = self.definition.error_handlers.get(&stage_id) {
573 let attempts = self.retry_attempts.entry(stage_id.clone()).or_insert(0);
575 *attempts += 1;
576
577 if *attempts <= handler.max_retries {
578 let delay = handler.backoff.calculate_delay(*attempts - 1);
580 info!("Retrying stage {} after {:?} (attempt {}/{})",
581 stage_id, delay, attempts, handler.max_retries);
582
583 sleep(delay).await;
585
586 self.context.metrics.retry_count += 1;
588
589 self.execute_stage(stage_id).await?;
591 } else if let Some(fallback) = &handler.fallback_stage {
592 info!("Max retries exceeded for stage {}, going to fallback {}", stage_id, fallback);
594 self.execute_stage(fallback.clone()).await?;
595 } else if handler.propagate {
596 self.fail_workflow(WorkflowError {
598 code: "STAGE_FAILED".to_string(),
599 message: error,
600 stage: Some(stage_id),
601 trace: None,
602 recovery_hints: vec![],
603 }).await?;
604 } else {
605 if let Some(stage) = self.definition.stages.iter().find(|s| s.id == stage_id) {
607 if let Some(rollback) = &stage.rollback {
608 self.execute_rollback(rollback.clone(), stage_id).await?;
609 }
610 }
611 }
612 } else {
613 self.fail_workflow(WorkflowError {
615 code: "STAGE_FAILED".to_string(),
616 message: error,
617 stage: Some(stage_id),
618 trace: None,
619 recovery_hints: vec![],
620 }).await?;
621 }
622
623 Ok(())
624 }
625
626 async fn handle_stage_timeout(&mut self, stage_id: StageId) -> Result<(), WorkflowError> {
628 warn!("Stage {} timed out for workflow {}", stage_id, self.id);
629
630 self.handle_stage_failure(stage_id, "Stage execution timed out".to_string()).await
632 }
633
634 async fn execute_rollback(&mut self, strategy: RollbackStrategy, failed_stage: StageId) -> Result<(), WorkflowError> {
636 info!("Executing rollback for stage {} in workflow {}", failed_stage, self.id);
637
638 match strategy {
639 RollbackStrategy::None => Ok(()),
640 RollbackStrategy::Compensate { actions } => {
641 for action_name in actions {
643 debug!("Executing compensating action: {}", action_name);
644 }
646 Ok(())
647 }
648 RollbackStrategy::RestoreCheckpoint { checkpoint_id } => {
649 debug!("Restoring from checkpoint: {}", checkpoint_id);
651 Ok(())
653 }
654 RollbackStrategy::JumpToStage { stage_id } => {
655 self.execute_stage(stage_id).await
657 }
658 }
659 }
660
661 async fn handle_cancellation(&mut self) -> Result<(), WorkflowError> {
663 info!("Workflow {} cancelled", self.id);
664
665 self.handle.update_status(WorkflowStatus::Cancelled).await;
666
667 Ok(())
671 }
672
673 async fn complete_workflow(&mut self) -> Result<(), WorkflowError> {
675 let duration = self.start_time.elapsed();
676
677 info!("Workflow {} completed successfully in {:?}", self.id, duration);
678
679 let result = WorkflowResult {
680 output: self.context.state.clone(),
681 duration,
682 metrics: self.context.metrics.clone(),
683 };
684
685 self.handle.update_status(WorkflowStatus::Completed { result }).await;
686
687 self.checkpoint().await?;
689
690 Ok(())
691 }
692
693 async fn fail_workflow(&mut self, error: WorkflowError) -> Result<(), WorkflowError> {
695 error!("Workflow {} failed: {:?}", self.id, error);
696
697 self.handle.update_status(WorkflowStatus::Failed { error: error.clone() }).await;
698
699 self.checkpoint().await?;
701
702 Ok(())
703 }
704
705 async fn checkpoint(&self) -> Result<(), WorkflowError> {
707 debug!("Checkpointing workflow {}", self.id);
708
709 Ok(())
713 }
714
715 async fn restore_state(&mut self, _state: crate::workflow::WorkflowState) -> Result<(), WorkflowError> {
717 Ok(())
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725 use crate::workflow::InMemoryStateStore;
726
727 #[tokio::test]
728 async fn test_workflow_engine_basic() {
729 let registry = Arc::new(WorkflowRegistry::new());
730 registry.load_defaults().await.unwrap();
731
732 let state_store = Arc::new(InMemoryStateStore::new());
733 let engine = WorkflowEngine::new(
734 WorkflowEngineConfig::default(),
735 registry,
736 state_store,
737 );
738
739 engine.start().await.unwrap();
740
741 let handle = engine.start_workflow(
742 "nat_traversal_basic",
743 &Version { major: 1, minor: 0, patch: 0 },
744 HashMap::new(),
745 ).await.unwrap();
746
747 assert_eq!(handle.status().await, WorkflowStatus::Initializing);
748
749 engine.stop().await.unwrap();
750 }
751}