1use crate::types::{CheckpointMeta, StepResult, WorkflowRun};
9
10#[async_trait::async_trait]
13pub trait WorkflowStateStore: Send + Sync {
14 async fn load(&self, workflow_id: &str) -> Result<Option<WorkflowRun>, String>;
16
17 async fn save(&self, run: &WorkflowRun) -> Result<(), String>;
19
20 async fn commit_step(
22 &self,
23 workflow_id: &str,
24 step_index: usize,
25 result: StepResult,
26 ) -> Result<(), String>;
27
28 async fn save_checkpoint(
30 &self,
31 _workflow_id: &str,
32 _step_id: &str,
33 ) -> Result<CheckpointMeta, String> {
34 Err("Checkpoints not supported by this store".into())
35 }
36
37 async fn load_checkpoint(
39 &self,
40 _workflow_id: &str,
41 _checkpoint_id: &str,
42 ) -> Result<Option<WorkflowRun>, String> {
43 Ok(None)
44 }
45
46 async fn list_checkpoints(&self, _workflow_id: &str) -> Result<Vec<CheckpointMeta>, String> {
48 Ok(vec![])
49 }
50}
51
52pub struct InMemoryStore {
54 runs: std::sync::Mutex<std::collections::HashMap<String, WorkflowRun>>,
55}
56
57impl Default for InMemoryStore {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl InMemoryStore {
64 pub fn new() -> Self {
65 Self {
66 runs: std::sync::Mutex::new(std::collections::HashMap::new()),
67 }
68 }
69}
70
71#[async_trait::async_trait]
72impl WorkflowStateStore for InMemoryStore {
73 async fn load(&self, workflow_id: &str) -> Result<Option<WorkflowRun>, String> {
74 let map = self.runs.lock().map_err(|e| e.to_string())?;
75 Ok(map.get(workflow_id).cloned())
76 }
77
78 async fn save(&self, run: &WorkflowRun) -> Result<(), String> {
79 let mut map = self.runs.lock().map_err(|e| e.to_string())?;
80 map.insert(run.id().to_string(), run.clone());
81 Ok(())
82 }
83
84 async fn commit_step(
85 &self,
86 workflow_id: &str,
87 step_index: usize,
88 result: StepResult,
89 ) -> Result<(), String> {
90 let mut map = self.runs.lock().map_err(|e| e.to_string())?;
91 let run = map.get_mut(workflow_id).ok_or("Workflow not found")?;
92
93 if let (Some(step), Some(step_run)) = (
94 run.definition.steps.get(step_index),
95 run.step_runs.get_mut(step_index),
96 ) {
97 let step_id = step.id.clone();
98 step_run.status = result.status;
99 step_run.result = result.result.clone();
100 step_run.error = result.error;
101 step_run.completed_at = Some(chrono::Utc::now());
102
103 if let Some(ref result_val) = result.result {
105 let ctx = run
106 .context
107 .as_object_mut()
108 .expect("workflow context must be an object");
109 let steps = ctx
110 .entry("steps")
111 .or_insert(serde_json::json!({}))
112 .as_object_mut()
113 .expect("steps must be an object");
114 steps.insert(step_id, result_val.clone());
115 }
116
117 if let Some(updates) = result.context_updates {
119 if let (Some(ctx), Some(upd)) = (run.context.as_object_mut(), updates.as_object()) {
120 for (k, v) in upd {
121 ctx.insert(k.clone(), v.clone());
122 }
123 }
124 }
125 }
126
127 run.updated_at = chrono::Utc::now();
128 Ok(())
129 }
130}