1use std::collections::HashMap;
29use std::future::Future;
30use std::pin::Pin;
31use std::sync::Arc;
32use std::time::{Duration, SystemTime, UNIX_EPOCH};
33
34use async_trait::async_trait;
35use parking_lot::RwLock;
36use serde::{Deserialize, Serialize};
37use tracing::{debug, error, info, warn};
38
39pub type ExecutionId = String;
41
42pub type StepId = String;
44
45#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
47pub enum StepStatus {
48 Pending,
50 Running,
52 Completed,
54 Failed,
56 Skipped,
58 TimedOut,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct StepResult {
65 pub step_id: StepId,
67 pub status: StepStatus,
69 pub output: Option<String>,
71 pub error: Option<String>,
73 pub started_at: Option<u64>,
75 pub completed_at: Option<u64>,
77 pub duration_ms: Option<u64>,
79 pub retry_count: u32,
81}
82
83impl StepResult {
84 pub fn pending(step_id: impl Into<String>) -> Self {
85 Self {
86 step_id: step_id.into(),
87 status: StepStatus::Pending,
88 output: None,
89 error: None,
90 started_at: None,
91 completed_at: None,
92 duration_ms: None,
93 retry_count: 0,
94 }
95 }
96
97 pub fn is_complete(&self) -> bool {
98 matches!(self.status, StepStatus::Completed | StepStatus::Skipped)
99 }
100
101 pub fn is_failed(&self) -> bool {
102 matches!(self.status, StepStatus::Failed | StepStatus::TimedOut)
103 }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
108pub struct ExecutionState {
109 pub execution_id: ExecutionId,
111 pub current_step: usize,
113 pub total_steps: usize,
115 pub step_results: Vec<StepResult>,
117 pub started_at: u64,
119 pub updated_at: u64,
121 pub completed_at: Option<u64>,
123 pub status: ExecutionStatus,
125 pub metadata: HashMap<String, String>,
127}
128
129impl ExecutionState {
130 pub fn new(execution_id: impl Into<String>, step_ids: &[String]) -> Self {
131 let now = SystemTime::now()
132 .duration_since(UNIX_EPOCH)
133 .unwrap_or_default()
134 .as_secs();
135
136 Self {
137 execution_id: execution_id.into(),
138 current_step: 0,
139 total_steps: step_ids.len(),
140 step_results: step_ids.iter().map(StepResult::pending).collect(),
141 started_at: now,
142 updated_at: now,
143 completed_at: None,
144 status: ExecutionStatus::Pending,
145 metadata: HashMap::new(),
146 }
147 }
148
149 pub fn is_complete(&self) -> bool {
150 matches!(self.status, ExecutionStatus::Completed)
152 }
153
154 pub fn is_failed(&self) -> bool {
155 matches!(self.status, ExecutionStatus::Failed)
156 }
157
158 pub fn progress(&self) -> f32 {
159 if self.total_steps == 0 {
160 return 1.0;
161 }
162 let completed = self.step_results.iter().filter(|r| r.is_complete()).count();
163 completed as f32 / self.total_steps as f32
164 }
165
166 fn touch(&mut self) {
167 self.updated_at = SystemTime::now()
168 .duration_since(UNIX_EPOCH)
169 .unwrap_or_default()
170 .as_secs();
171 }
172}
173
174#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
176pub enum ExecutionStatus {
177 Pending,
179 Running,
181 Completed,
183 Failed,
185 Paused,
187}
188
189#[derive(Debug, thiserror::Error)]
191pub enum DurableError {
192 #[error("Step '{0}' failed: {1}")]
193 StepFailed(StepId, String),
194
195 #[error("Step '{0}' timed out after {1:?}")]
196 StepTimeout(StepId, Duration),
197
198 #[error("Execution '{0}' not found")]
199 ExecutionNotFound(ExecutionId),
200
201 #[error("Step '{0}' not found")]
202 StepNotFound(StepId),
203
204 #[error("Execution already completed")]
205 AlreadyCompleted,
206
207 #[error("Storage error: {0}")]
208 StorageError(String),
209
210 #[error("Serialization error: {0}")]
211 SerializationError(String),
212
213 #[error("Max retries ({0}) exceeded for step '{1}'")]
214 MaxRetriesExceeded(u32, StepId),
215}
216
217#[derive(Debug, Clone)]
219pub struct DurableConfig {
220 pub default_timeout: Duration,
222 pub max_retries: u32,
224 pub retry_delay: Duration,
226 pub continue_on_failure: bool,
228 pub persist_per_step: bool,
230}
231
232impl Default for DurableConfig {
233 fn default() -> Self {
234 Self {
235 default_timeout: Duration::from_secs(300), max_retries: 3,
237 retry_delay: Duration::from_secs(1),
238 continue_on_failure: false,
239 persist_per_step: true,
240 }
241 }
242}
243
244#[async_trait]
246pub trait ExecutionStore: Send + Sync {
247 async fn save(&self, state: &ExecutionState) -> Result<(), DurableError>;
249
250 async fn load(&self, execution_id: &str) -> Result<Option<ExecutionState>, DurableError>;
252
253 async fn delete(&self, execution_id: &str) -> Result<(), DurableError>;
255
256 async fn list(&self) -> Result<Vec<ExecutionId>, DurableError>;
258}
259
260#[derive(Default)]
262pub struct MemoryExecutionStore {
263 states: RwLock<HashMap<ExecutionId, ExecutionState>>,
264}
265
266impl MemoryExecutionStore {
267 pub fn new() -> Self {
268 Self::default()
269 }
270}
271
272#[async_trait]
273impl ExecutionStore for MemoryExecutionStore {
274 async fn save(&self, state: &ExecutionState) -> Result<(), DurableError> {
275 self.states
276 .write()
277 .insert(state.execution_id.clone(), state.clone());
278 Ok(())
279 }
280
281 async fn load(&self, execution_id: &str) -> Result<Option<ExecutionState>, DurableError> {
282 Ok(self.states.read().get(execution_id).cloned())
283 }
284
285 async fn delete(&self, execution_id: &str) -> Result<(), DurableError> {
286 self.states.write().remove(execution_id);
287 Ok(())
288 }
289
290 async fn list(&self) -> Result<Vec<ExecutionId>, DurableError> {
291 Ok(self.states.read().keys().cloned().collect())
292 }
293}
294
295pub type StepFn =
297 Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> + Send + Sync>;
298
299pub struct ExecutionStep {
301 pub id: StepId,
302 pub name: String,
303 pub timeout: Option<Duration>,
304 pub max_retries: Option<u32>,
305 pub handler: StepFn,
306}
307
308impl ExecutionStep {
309 pub fn new<F, Fut>(id: impl Into<String>, handler: F) -> Self
310 where
311 F: Fn() -> Fut + Send + Sync + 'static,
312 Fut: Future<Output = Result<String, String>> + Send + 'static,
313 {
314 let id = id.into();
315 Self {
316 name: id.clone(),
317 id,
318 timeout: None,
319 max_retries: None,
320 handler: Box::new(move || Box::pin(handler())),
321 }
322 }
323
324 pub fn with_name(mut self, name: impl Into<String>) -> Self {
325 self.name = name.into();
326 self
327 }
328
329 pub fn with_timeout(mut self, timeout: Duration) -> Self {
330 self.timeout = Some(timeout);
331 self
332 }
333
334 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
335 self.max_retries = Some(max_retries);
336 self
337 }
338}
339
340#[derive(Debug, Clone)]
342pub struct ExecutionResult {
343 pub state: ExecutionState,
345 pub success: bool,
347 pub outputs: HashMap<StepId, String>,
349 pub total_duration_ms: u64,
351}
352
353impl ExecutionResult {
354 pub fn get_output(&self, step_id: &str) -> Option<&String> {
355 self.outputs.get(step_id)
356 }
357}
358
359pub struct DurableExecution<S: ExecutionStore> {
361 execution_id: ExecutionId,
362 steps: Vec<ExecutionStep>,
363 config: DurableConfig,
364 store: Arc<S>,
365}
366
367impl DurableExecution<MemoryExecutionStore> {
368 pub fn in_memory(execution_id: impl Into<String>) -> Self {
370 Self::new(execution_id, Arc::new(MemoryExecutionStore::new()))
371 }
372}
373
374impl<S: ExecutionStore> DurableExecution<S> {
375 pub fn new(execution_id: impl Into<String>, store: Arc<S>) -> Self {
376 Self {
377 execution_id: execution_id.into(),
378 steps: Vec::new(),
379 config: DurableConfig::default(),
380 store,
381 }
382 }
383
384 pub fn with_config(mut self, config: DurableConfig) -> Self {
385 self.config = config;
386 self
387 }
388
389 pub fn add_step<F, Fut>(mut self, id: impl Into<String>, handler: F) -> Self
391 where
392 F: Fn() -> Fut + Send + Sync + 'static,
393 Fut: Future<Output = Result<String, String>> + Send + 'static,
394 {
395 self.steps.push(ExecutionStep::new(id, handler));
396 self
397 }
398
399 pub fn add_step_config(mut self, step: ExecutionStep) -> Self {
401 self.steps.push(step);
402 self
403 }
404
405 pub async fn run(&self) -> Result<ExecutionResult, DurableError> {
407 let mut state = match self.store.load(&self.execution_id).await? {
409 Some(existing) => {
410 if existing.is_complete() {
411 return Err(DurableError::AlreadyCompleted);
412 }
413 info!(
414 execution_id = %self.execution_id,
415 current_step = existing.current_step,
416 "Resuming execution"
417 );
418 existing
419 }
420 None => {
421 let step_ids: Vec<String> = self.steps.iter().map(|s| s.id.clone()).collect();
422 let state = ExecutionState::new(&self.execution_id, &step_ids);
423 self.store.save(&state).await?;
424 info!(execution_id = %self.execution_id, "Starting new execution");
425 state
426 }
427 };
428
429 state.status = ExecutionStatus::Running;
430 state.touch();
431 self.store.save(&state).await?;
432
433 let start = std::time::Instant::now();
434 let mut outputs = HashMap::new();
435
436 for step_idx in state.current_step..self.steps.len() {
438 let step = &self.steps[step_idx];
439
440 if state.step_results[step_idx].is_complete() {
442 if let Some(output) = &state.step_results[step_idx].output {
443 outputs.insert(step.id.clone(), output.clone());
444 }
445 continue;
446 }
447
448 debug!(
449 execution_id = %self.execution_id,
450 step_id = %step.id,
451 step_idx,
452 "Executing step"
453 );
454
455 let result = self
457 .execute_step(step, &mut state.step_results[step_idx])
458 .await;
459
460 state.current_step = step_idx;
461 state.touch();
462
463 match result {
464 Ok(output) => {
465 outputs.insert(step.id.clone(), output);
466 if self.config.persist_per_step {
467 self.store.save(&state).await?;
468 }
469 }
470 Err(e) => {
471 error!(
472 execution_id = %self.execution_id,
473 step_id = %step.id,
474 error = %e,
475 "Step failed"
476 );
477
478 if !self.config.continue_on_failure {
479 state.status = ExecutionStatus::Failed;
480 state.touch();
481 self.store.save(&state).await?;
482 return Err(e);
483 }
484 }
485 }
486 }
487
488 state.status = ExecutionStatus::Completed;
490 state.completed_at = Some(
491 SystemTime::now()
492 .duration_since(UNIX_EPOCH)
493 .unwrap_or_default()
494 .as_secs(),
495 );
496 state.touch();
497 self.store.save(&state).await?;
498
499 info!(
500 execution_id = %self.execution_id,
501 duration_ms = start.elapsed().as_millis(),
502 "Execution completed"
503 );
504
505 Ok(ExecutionResult {
506 success: state.step_results.iter().all(|r| r.is_complete()),
507 state,
508 outputs,
509 total_duration_ms: start.elapsed().as_millis() as u64,
510 })
511 }
512
513 async fn execute_step(
514 &self,
515 step: &ExecutionStep,
516 result: &mut StepResult,
517 ) -> Result<String, DurableError> {
518 let timeout = step.timeout.unwrap_or(self.config.default_timeout);
519 let max_retries = step.max_retries.unwrap_or(self.config.max_retries);
520
521 let start = std::time::Instant::now();
522 result.started_at = Some(
523 SystemTime::now()
524 .duration_since(UNIX_EPOCH)
525 .unwrap_or_default()
526 .as_secs(),
527 );
528 result.status = StepStatus::Running;
529
530 for attempt in 0..=max_retries {
531 result.retry_count = attempt;
532
533 match tokio::time::timeout(timeout, (step.handler)()).await {
534 Ok(Ok(output)) => {
535 result.status = StepStatus::Completed;
536 result.output = Some(output.clone());
537 result.completed_at = Some(
538 SystemTime::now()
539 .duration_since(UNIX_EPOCH)
540 .unwrap_or_default()
541 .as_secs(),
542 );
543 result.duration_ms = Some(start.elapsed().as_millis() as u64);
544
545 debug!(
546 step_id = %step.id,
547 attempt,
548 duration_ms = result.duration_ms,
549 "Step completed"
550 );
551
552 return Ok(output);
553 }
554 Ok(Err(e)) => {
555 warn!(
556 step_id = %step.id,
557 attempt,
558 max_retries,
559 error = %e,
560 "Step attempt failed"
561 );
562
563 if attempt < max_retries {
564 tokio::time::sleep(self.config.retry_delay).await;
565 continue;
566 }
567
568 result.status = StepStatus::Failed;
569 result.error = Some(e.clone());
570 result.completed_at = Some(
571 SystemTime::now()
572 .duration_since(UNIX_EPOCH)
573 .unwrap_or_default()
574 .as_secs(),
575 );
576 result.duration_ms = Some(start.elapsed().as_millis() as u64);
577
578 return Err(DurableError::StepFailed(step.id.clone(), e));
579 }
580 Err(_) => {
581 warn!(
582 step_id = %step.id,
583 attempt,
584 timeout_secs = timeout.as_secs(),
585 "Step timed out"
586 );
587
588 if attempt < max_retries {
589 tokio::time::sleep(self.config.retry_delay).await;
590 continue;
591 }
592
593 result.status = StepStatus::TimedOut;
594 result.error = Some(format!("Timed out after {:?}", timeout));
595 result.completed_at = Some(
596 SystemTime::now()
597 .duration_since(UNIX_EPOCH)
598 .unwrap_or_default()
599 .as_secs(),
600 );
601 result.duration_ms = Some(start.elapsed().as_millis() as u64);
602
603 return Err(DurableError::StepTimeout(step.id.clone(), timeout));
604 }
605 }
606 }
607
608 Err(DurableError::MaxRetriesExceeded(
609 max_retries,
610 step.id.clone(),
611 ))
612 }
613
614 pub async fn pause(&self) -> Result<(), DurableError> {
616 if let Some(mut state) = self.store.load(&self.execution_id).await? {
617 state.status = ExecutionStatus::Paused;
618 state.touch();
619 self.store.save(&state).await?;
620 info!(execution_id = %self.execution_id, "Execution paused");
621 }
622 Ok(())
623 }
624
625 pub async fn state(&self) -> Result<Option<ExecutionState>, DurableError> {
627 self.store.load(&self.execution_id).await
628 }
629
630 pub async fn reset(&self) -> Result<(), DurableError> {
632 self.store.delete(&self.execution_id).await
633 }
634}
635
636#[derive(Debug, Clone, Default, Serialize, Deserialize)]
638pub struct DurableStats {
639 pub total_executions: u64,
640 pub completed_executions: u64,
641 pub failed_executions: u64,
642 pub total_steps_executed: u64,
643 pub total_retries: u64,
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use std::sync::atomic::{AtomicU32, Ordering};
650
651 #[tokio::test]
652 async fn test_simple_execution() {
653 let execution = DurableExecution::in_memory("test_1")
654 .add_step("step1", || async { Ok("result1".to_string()) })
655 .add_step("step2", || async { Ok("result2".to_string()) });
656
657 let result = execution.run().await.unwrap();
658
659 assert!(result.success);
660 assert_eq!(result.outputs.get("step1").unwrap(), "result1");
661 assert_eq!(result.outputs.get("step2").unwrap(), "result2");
662 assert_eq!(result.state.status, ExecutionStatus::Completed);
663 }
664
665 #[tokio::test]
666 async fn test_execution_with_failure() {
667 let config = DurableConfig {
668 max_retries: 0,
669 ..Default::default()
670 };
671
672 let execution = DurableExecution::in_memory("test_fail")
673 .with_config(config)
674 .add_step("step1", || async { Ok("ok".to_string()) })
675 .add_step("step2", || async { Err("failed".to_string()) });
676
677 let result = execution.run().await;
678
679 assert!(result.is_err());
680 match result {
681 Err(DurableError::StepFailed(id, _)) => assert_eq!(id, "step2"),
682 _ => panic!("Expected StepFailed error"),
683 }
684 }
685
686 #[tokio::test]
687 async fn test_execution_resume() {
688 let store = Arc::new(MemoryExecutionStore::new());
689 let attempt = Arc::new(AtomicU32::new(0));
690
691 {
693 let attempt_clone = attempt.clone();
694 let config = DurableConfig {
695 max_retries: 0,
696 ..Default::default()
697 };
698
699 let execution = DurableExecution::new("test_resume", store.clone())
700 .with_config(config)
701 .add_step("step1", || async { Ok("done".to_string()) })
702 .add_step("step2", move || {
703 let current = attempt_clone.fetch_add(1, Ordering::SeqCst);
704 async move {
705 if current == 0 {
706 Err("first attempt fails".to_string())
707 } else {
708 Ok("success".to_string())
709 }
710 }
711 });
712
713 let _ = execution.run().await; }
715
716 {
718 let attempt_clone = attempt.clone();
719 let config = DurableConfig {
720 max_retries: 0,
721 ..Default::default()
722 };
723
724 let execution = DurableExecution::new("test_resume", store.clone())
726 .with_config(config)
727 .add_step("step1", || async { Ok("done".to_string()) })
728 .add_step("step2", move || {
729 let current = attempt_clone.fetch_add(1, Ordering::SeqCst);
730 async move {
731 if current == 0 {
732 Err("first attempt fails".to_string())
733 } else {
734 Ok("success".to_string())
735 }
736 }
737 });
738
739 let result = execution.run().await.unwrap();
740 assert!(result.success);
741 }
742 }
743
744 #[tokio::test]
745 async fn test_step_with_retries() {
746 let attempt = AtomicU32::new(0);
747
748 let config = DurableConfig {
749 max_retries: 2,
750 retry_delay: Duration::from_millis(10),
751 ..Default::default()
752 };
753
754 let execution = DurableExecution::in_memory("test_retry")
755 .with_config(config)
756 .add_step("flaky", move || {
757 let current = attempt.fetch_add(1, Ordering::SeqCst);
758 async move {
759 if current < 2 {
760 Err("temporary failure".to_string())
761 } else {
762 Ok("finally worked".to_string())
763 }
764 }
765 });
766
767 let result = execution.run().await.unwrap();
768 assert!(result.success);
769 assert_eq!(result.outputs.get("flaky").unwrap(), "finally worked");
770 }
771
772 #[tokio::test]
773 async fn test_execution_state() {
774 let store = Arc::new(MemoryExecutionStore::new());
775
776 let execution = DurableExecution::new("test_state", store.clone())
777 .add_step("step1", || async { Ok("done".to_string()) });
778
779 execution.run().await.unwrap();
780
781 let state = execution.state().await.unwrap().unwrap();
782 assert_eq!(state.status, ExecutionStatus::Completed);
783 assert!(state.completed_at.is_some());
784 assert_eq!(state.step_results[0].status, StepStatus::Completed);
785 }
786
787 #[tokio::test]
788 async fn test_execution_reset() {
789 let store = Arc::new(MemoryExecutionStore::new());
790
791 let execution = DurableExecution::new("test_reset", store.clone())
792 .add_step("step1", || async { Ok("done".to_string()) });
793
794 execution.run().await.unwrap();
795 assert!(execution.state().await.unwrap().is_some());
796
797 execution.reset().await.unwrap();
798 assert!(execution.state().await.unwrap().is_none());
799 }
800
801 #[tokio::test]
802 async fn test_progress() {
803 let state = ExecutionState::new(
804 "test",
805 &["s1".to_string(), "s2".to_string(), "s3".to_string()],
806 );
807 assert_eq!(state.progress(), 0.0);
808
809 let mut state = state;
810 state.step_results[0].status = StepStatus::Completed;
811 assert!((state.progress() - 0.333).abs() < 0.01);
812
813 state.step_results[1].status = StepStatus::Completed;
814 state.step_results[2].status = StepStatus::Completed;
815 assert_eq!(state.progress(), 1.0);
816 }
817
818 #[tokio::test]
819 async fn test_continue_on_failure() {
820 let config = DurableConfig {
821 max_retries: 0,
822 continue_on_failure: true,
823 ..Default::default()
824 };
825
826 let execution = DurableExecution::in_memory("test_continue")
827 .with_config(config)
828 .add_step("step1", || async { Ok("ok".to_string()) })
829 .add_step("step2", || async { Err("fails".to_string()) })
830 .add_step("step3", || async { Ok("also ok".to_string()) });
831
832 let result = execution.run().await.unwrap();
833
834 assert_eq!(result.state.status, ExecutionStatus::Completed);
836 assert!(!result.success); assert!(result.outputs.contains_key("step1"));
838 assert!(result.outputs.contains_key("step3"));
839 assert!(!result.outputs.contains_key("step2"));
840 }
841
842 #[tokio::test]
843 async fn test_step_result_states() {
844 let result = StepResult::pending("test");
845 assert!(!result.is_complete());
846 assert!(!result.is_failed());
847
848 let mut result = result;
849 result.status = StepStatus::Completed;
850 assert!(result.is_complete());
851
852 result.status = StepStatus::Failed;
853 assert!(result.is_failed());
854
855 result.status = StepStatus::Skipped;
856 assert!(result.is_complete());
857 }
858
859 #[tokio::test]
860 async fn test_memory_store() {
861 let store = MemoryExecutionStore::new();
862
863 let state = ExecutionState::new("exec1", &["s1".to_string()]);
864 store.save(&state).await.unwrap();
865
866 let loaded = store.load("exec1").await.unwrap();
867 assert!(loaded.is_some());
868 assert_eq!(loaded.unwrap().execution_id, "exec1");
869
870 let ids = store.list().await.unwrap();
871 assert_eq!(ids.len(), 1);
872
873 store.delete("exec1").await.unwrap();
874 assert!(store.load("exec1").await.unwrap().is_none());
875 }
876}