1use crate::types::WorkflowDefinition;
30use chrono::{DateTime, Utc};
31use distri_types::TaskStatus;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct WorkflowExecutionState {
44 pub run_task_id: String,
45 pub agent_id: String,
46 pub thread_id: String,
47 pub user_id: String,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub workspace_id: Option<String>,
50 pub definition: WorkflowDefinition,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub entry_point: Option<String>,
55 #[serde(default)]
56 pub input: serde_json::Value,
57 #[serde(default)]
60 pub context: serde_json::Value,
61 pub created_at: DateTime<Utc>,
62 pub updated_at: DateTime<Utc>,
63}
64
65impl WorkflowExecutionState {
66 pub fn new(
67 run_task_id: impl Into<String>,
68 agent_id: impl Into<String>,
69 thread_id: impl Into<String>,
70 user_id: impl Into<String>,
71 definition: WorkflowDefinition,
72 ) -> Self {
73 let now = Utc::now();
74 Self {
75 run_task_id: run_task_id.into(),
76 agent_id: agent_id.into(),
77 thread_id: thread_id.into(),
78 user_id: user_id.into(),
79 workspace_id: None,
80 definition,
81 entry_point: None,
82 input: serde_json::json!({}),
83 context: serde_json::json!({}),
84 created_at: now,
85 updated_at: now,
86 }
87 }
88
89 pub fn with_workspace_id(mut self, workspace_id: Option<String>) -> Self {
90 self.workspace_id = workspace_id;
91 self
92 }
93
94 pub fn with_entry_point(mut self, entry_point: Option<String>) -> Self {
95 self.entry_point = entry_point;
96 self
97 }
98
99 pub fn with_input(mut self, input: serde_json::Value) -> Self {
100 self.input = input;
101 self
102 }
103
104 pub fn with_context(mut self, context: serde_json::Value) -> Self {
105 self.context = context;
106 self
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, Default)]
119pub struct WorkflowStepState {
120 pub step_id: String,
121 #[serde(default)]
122 pub status: TaskStatus,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub result: Option<serde_json::Value>,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub error: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub started_at: Option<DateTime<Utc>>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub completed_at: Option<DateTime<Utc>>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub wait_task_id: Option<String>,
133}
134
135#[async_trait::async_trait]
140pub trait WorkflowStore: Send + Sync {
141 async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()>;
146
147 async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>>;
149
150 async fn update_context(
154 &self,
155 run_task_id: &str,
156 context: serde_json::Value,
157 ) -> anyhow::Result<()>;
158
159 async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()>;
161
162 async fn upsert_step(
166 &self,
167 run_task_id: &str,
168 step: WorkflowStepState,
169 ) -> anyhow::Result<()>;
170
171 async fn get_step(
173 &self,
174 run_task_id: &str,
175 step_id: &str,
176 ) -> anyhow::Result<Option<WorkflowStepState>>;
177
178 async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>>;
180}
181
182#[derive(Default)]
186pub struct InMemoryWorkflowStore {
187 runs: std::sync::Mutex<HashMap<String, WorkflowExecutionState>>,
188 steps: std::sync::Mutex<HashMap<String, Vec<WorkflowStepState>>>,
191}
192
193impl InMemoryWorkflowStore {
194 pub fn new() -> Self {
195 Self::default()
196 }
197}
198
199#[async_trait::async_trait]
200impl WorkflowStore for InMemoryWorkflowStore {
201 async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()> {
202 let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
203 runs.insert(state.run_task_id.clone(), state);
204 Ok(())
205 }
206
207 async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>> {
208 let runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
209 Ok(runs.get(run_task_id).cloned())
210 }
211
212 async fn update_context(
213 &self,
214 run_task_id: &str,
215 context: serde_json::Value,
216 ) -> anyhow::Result<()> {
217 let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
218 let row = runs
219 .get_mut(run_task_id)
220 .ok_or_else(|| anyhow::anyhow!("workflow run not found: {run_task_id}"))?;
221 row.context = context;
222 row.updated_at = Utc::now();
223 Ok(())
224 }
225
226 async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()> {
227 let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
228 let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
229 runs.remove(run_task_id);
230 steps.remove(run_task_id);
231 Ok(())
232 }
233
234 async fn upsert_step(
235 &self,
236 run_task_id: &str,
237 step: WorkflowStepState,
238 ) -> anyhow::Result<()> {
239 let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
240 let bucket = steps.entry(run_task_id.to_string()).or_default();
241 if let Some(existing) = bucket.iter_mut().find(|s| s.step_id == step.step_id) {
242 *existing = step;
243 } else {
244 bucket.push(step);
245 }
246 Ok(())
247 }
248
249 async fn get_step(
250 &self,
251 run_task_id: &str,
252 step_id: &str,
253 ) -> anyhow::Result<Option<WorkflowStepState>> {
254 let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
255 Ok(steps
256 .get(run_task_id)
257 .and_then(|bucket| bucket.iter().find(|s| s.step_id == step_id).cloned()))
258 }
259
260 async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>> {
261 let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
262 Ok(steps.get(run_task_id).cloned().unwrap_or_default())
263 }
264}
265
266#[cfg(test)]
267mod tests {
268 use super::*;
269 use crate::types::{WorkflowDefinition, WorkflowStep};
270
271 fn sample_def() -> WorkflowDefinition {
272 WorkflowDefinition::new(vec![WorkflowStep::checkpoint("c", "Checkpoint", "ok")])
273 }
274
275 #[tokio::test]
276 async fn create_get_roundtrip() {
277 let store = InMemoryWorkflowStore::new();
278 let state = WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
279 .with_entry_point(Some("main".into()))
280 .with_input(serde_json::json!({"x": 1}));
281 store.create_run(state).await.unwrap();
282 let got = store.get_run("run-1").await.unwrap().unwrap();
283 assert_eq!(got.agent_id, "agent-1");
284 assert_eq!(got.entry_point.as_deref(), Some("main"));
285 assert_eq!(got.input, serde_json::json!({"x": 1}));
286 }
287
288 #[tokio::test]
289 async fn update_context_mutates_only_context() {
290 let store = InMemoryWorkflowStore::new();
291 store
292 .create_run(
293 WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
294 .with_input(serde_json::json!({"x": 1})),
295 )
296 .await
297 .unwrap();
298 let new_ctx = serde_json::json!({"steps": {"c": "ok"}});
299 store
300 .update_context("run-1", new_ctx.clone())
301 .await
302 .unwrap();
303 let got = store.get_run("run-1").await.unwrap().unwrap();
304 assert_eq!(got.context, new_ctx);
305 assert_eq!(got.input, serde_json::json!({"x": 1}));
306 }
307
308 #[tokio::test]
309 async fn upsert_step_insert_then_update() {
310 let store = InMemoryWorkflowStore::new();
311 let s1 = WorkflowStepState {
312 step_id: "fetch".into(),
313 status: TaskStatus::Running,
314 ..Default::default()
315 };
316 store.upsert_step("run-1", s1).await.unwrap();
317 let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
318 assert_eq!(got.status, TaskStatus::Running);
319
320 let s2 = WorkflowStepState {
321 step_id: "fetch".into(),
322 status: TaskStatus::Completed,
323 result: Some(serde_json::json!({"docs": []})),
324 ..Default::default()
325 };
326 store.upsert_step("run-1", s2).await.unwrap();
327 let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
328 assert_eq!(got.status, TaskStatus::Completed);
329 assert!(got.result.is_some());
330 }
331
332 #[tokio::test]
333 async fn list_steps_preserves_insertion_order_and_is_per_run() {
334 let store = InMemoryWorkflowStore::new();
335 for id in ["a", "b", "c"] {
336 store
337 .upsert_step(
338 "run-1",
339 WorkflowStepState {
340 step_id: id.into(),
341 status: TaskStatus::Pending,
342 ..Default::default()
343 },
344 )
345 .await
346 .unwrap();
347 }
348 store
349 .upsert_step(
350 "run-2",
351 WorkflowStepState {
352 step_id: "x".into(),
353 status: TaskStatus::Pending,
354 ..Default::default()
355 },
356 )
357 .await
358 .unwrap();
359 let r1 = store.list_steps("run-1").await.unwrap();
360 let r2 = store.list_steps("run-2").await.unwrap();
361 assert_eq!(r1.iter().map(|s| s.step_id.as_str()).collect::<Vec<_>>(), vec!["a", "b", "c"]);
362 assert_eq!(r2.len(), 1);
363 }
364
365 #[tokio::test]
366 async fn delete_run_cascades_to_steps() {
367 let store = InMemoryWorkflowStore::new();
368 store
369 .create_run(WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def()))
370 .await
371 .unwrap();
372 store
373 .upsert_step(
374 "run-1",
375 WorkflowStepState {
376 step_id: "s".into(),
377 ..Default::default()
378 },
379 )
380 .await
381 .unwrap();
382 store.delete_run("run-1").await.unwrap();
383 assert!(store.get_run("run-1").await.unwrap().is_none());
384 assert!(store.list_steps("run-1").await.unwrap().is_empty());
385 }
386}