1use crate::types::WorkflowDefinition;
30use chrono::{DateTime, Utc};
31use distri_types::TaskStatus;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct WorkflowExecutionState {
44 pub run_task_id: String,
45 pub agent_id: String,
46 pub thread_id: String,
47 pub user_id: String,
48 #[serde(default, skip_serializing_if = "Option::is_none")]
49 pub workspace_id: Option<String>,
50 pub definition: WorkflowDefinition,
53 #[serde(default, skip_serializing_if = "Option::is_none")]
54 pub entry_point: Option<String>,
55 #[serde(default)]
56 pub input: serde_json::Value,
57 #[serde(default)]
60 pub context: serde_json::Value,
61 pub created_at: DateTime<Utc>,
62 pub updated_at: DateTime<Utc>,
63}
64
65impl WorkflowExecutionState {
66 pub fn new(
67 run_task_id: impl Into<String>,
68 agent_id: impl Into<String>,
69 thread_id: impl Into<String>,
70 user_id: impl Into<String>,
71 definition: WorkflowDefinition,
72 ) -> Self {
73 let now = Utc::now();
74 Self {
75 run_task_id: run_task_id.into(),
76 agent_id: agent_id.into(),
77 thread_id: thread_id.into(),
78 user_id: user_id.into(),
79 workspace_id: None,
80 definition,
81 entry_point: None,
82 input: serde_json::json!({}),
83 context: serde_json::json!({}),
84 created_at: now,
85 updated_at: now,
86 }
87 }
88
89 pub fn with_workspace_id(mut self, workspace_id: Option<String>) -> Self {
90 self.workspace_id = workspace_id;
91 self
92 }
93
94 pub fn with_entry_point(mut self, entry_point: Option<String>) -> Self {
95 self.entry_point = entry_point;
96 self
97 }
98
99 pub fn with_input(mut self, input: serde_json::Value) -> Self {
100 self.input = input;
101 self
102 }
103
104 pub fn with_context(mut self, context: serde_json::Value) -> Self {
105 self.context = context;
106 self
107 }
108}
109
110#[derive(Debug, Clone, Serialize, Deserialize, Default)]
119pub struct WorkflowStepState {
120 pub step_id: String,
121 #[serde(default)]
122 pub status: TaskStatus,
123 #[serde(default, skip_serializing_if = "Option::is_none")]
124 pub result: Option<serde_json::Value>,
125 #[serde(default, skip_serializing_if = "Option::is_none")]
126 pub error: Option<String>,
127 #[serde(default, skip_serializing_if = "Option::is_none")]
128 pub started_at: Option<DateTime<Utc>>,
129 #[serde(default, skip_serializing_if = "Option::is_none")]
130 pub completed_at: Option<DateTime<Utc>>,
131 #[serde(default, skip_serializing_if = "Option::is_none")]
132 pub wait_task_id: Option<String>,
133}
134
135#[async_trait::async_trait]
140pub trait WorkflowStore: Send + Sync {
141 async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()>;
146
147 async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>>;
149
150 async fn update_context(
154 &self,
155 run_task_id: &str,
156 context: serde_json::Value,
157 ) -> anyhow::Result<()>;
158
159 async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()>;
161
162 async fn upsert_step(&self, run_task_id: &str, step: WorkflowStepState) -> anyhow::Result<()>;
166
167 async fn get_step(
169 &self,
170 run_task_id: &str,
171 step_id: &str,
172 ) -> anyhow::Result<Option<WorkflowStepState>>;
173
174 async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>>;
176}
177
178#[derive(Default)]
182pub struct InMemoryWorkflowStore {
183 runs: std::sync::Mutex<HashMap<String, WorkflowExecutionState>>,
184 steps: std::sync::Mutex<HashMap<String, Vec<WorkflowStepState>>>,
187}
188
189impl InMemoryWorkflowStore {
190 pub fn new() -> Self {
191 Self::default()
192 }
193}
194
195#[async_trait::async_trait]
196impl WorkflowStore for InMemoryWorkflowStore {
197 async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()> {
198 let mut runs = self
199 .runs
200 .lock()
201 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
202 runs.insert(state.run_task_id.clone(), state);
203 Ok(())
204 }
205
206 async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>> {
207 let runs = self
208 .runs
209 .lock()
210 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
211 Ok(runs.get(run_task_id).cloned())
212 }
213
214 async fn update_context(
215 &self,
216 run_task_id: &str,
217 context: serde_json::Value,
218 ) -> anyhow::Result<()> {
219 let mut runs = self
220 .runs
221 .lock()
222 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
223 let row = runs
224 .get_mut(run_task_id)
225 .ok_or_else(|| anyhow::anyhow!("workflow run not found: {run_task_id}"))?;
226 row.context = context;
227 row.updated_at = Utc::now();
228 Ok(())
229 }
230
231 async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()> {
232 let mut runs = self
233 .runs
234 .lock()
235 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
236 let mut steps = self
237 .steps
238 .lock()
239 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
240 runs.remove(run_task_id);
241 steps.remove(run_task_id);
242 Ok(())
243 }
244
245 async fn upsert_step(&self, run_task_id: &str, step: WorkflowStepState) -> anyhow::Result<()> {
246 let mut steps = self
247 .steps
248 .lock()
249 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
250 let bucket = steps.entry(run_task_id.to_string()).or_default();
251 if let Some(existing) = bucket.iter_mut().find(|s| s.step_id == step.step_id) {
252 *existing = step;
253 } else {
254 bucket.push(step);
255 }
256 Ok(())
257 }
258
259 async fn get_step(
260 &self,
261 run_task_id: &str,
262 step_id: &str,
263 ) -> anyhow::Result<Option<WorkflowStepState>> {
264 let steps = self
265 .steps
266 .lock()
267 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
268 Ok(steps
269 .get(run_task_id)
270 .and_then(|bucket| bucket.iter().find(|s| s.step_id == step_id).cloned()))
271 }
272
273 async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>> {
274 let steps = self
275 .steps
276 .lock()
277 .map_err(|e| anyhow::anyhow!(e.to_string()))?;
278 Ok(steps.get(run_task_id).cloned().unwrap_or_default())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::types::{WorkflowDefinition, WorkflowStep};
286
287 fn sample_def() -> WorkflowDefinition {
288 WorkflowDefinition::new(vec![WorkflowStep::checkpoint("c", "Checkpoint", "ok")])
289 }
290
291 #[tokio::test]
292 async fn create_get_roundtrip() {
293 let store = InMemoryWorkflowStore::new();
294 let state = WorkflowExecutionState::new(
295 "run-1",
296 "agent-1",
297 "thread-test",
298 "user-test",
299 sample_def(),
300 )
301 .with_entry_point(Some("main".into()))
302 .with_input(serde_json::json!({"x": 1}));
303 store.create_run(state).await.unwrap();
304 let got = store.get_run("run-1").await.unwrap().unwrap();
305 assert_eq!(got.agent_id, "agent-1");
306 assert_eq!(got.entry_point.as_deref(), Some("main"));
307 assert_eq!(got.input, serde_json::json!({"x": 1}));
308 }
309
310 #[tokio::test]
311 async fn update_context_mutates_only_context() {
312 let store = InMemoryWorkflowStore::new();
313 store
314 .create_run(
315 WorkflowExecutionState::new(
316 "run-1",
317 "agent-1",
318 "thread-test",
319 "user-test",
320 sample_def(),
321 )
322 .with_input(serde_json::json!({"x": 1})),
323 )
324 .await
325 .unwrap();
326 let new_ctx = serde_json::json!({"steps": {"c": "ok"}});
327 store
328 .update_context("run-1", new_ctx.clone())
329 .await
330 .unwrap();
331 let got = store.get_run("run-1").await.unwrap().unwrap();
332 assert_eq!(got.context, new_ctx);
333 assert_eq!(got.input, serde_json::json!({"x": 1}));
334 }
335
336 #[tokio::test]
337 async fn upsert_step_insert_then_update() {
338 let store = InMemoryWorkflowStore::new();
339 let s1 = WorkflowStepState {
340 step_id: "fetch".into(),
341 status: TaskStatus::Running,
342 ..Default::default()
343 };
344 store.upsert_step("run-1", s1).await.unwrap();
345 let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
346 assert_eq!(got.status, TaskStatus::Running);
347
348 let s2 = WorkflowStepState {
349 step_id: "fetch".into(),
350 status: TaskStatus::Completed,
351 result: Some(serde_json::json!({"docs": []})),
352 ..Default::default()
353 };
354 store.upsert_step("run-1", s2).await.unwrap();
355 let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
356 assert_eq!(got.status, TaskStatus::Completed);
357 assert!(got.result.is_some());
358 }
359
360 #[tokio::test]
361 async fn list_steps_preserves_insertion_order_and_is_per_run() {
362 let store = InMemoryWorkflowStore::new();
363 for id in ["a", "b", "c"] {
364 store
365 .upsert_step(
366 "run-1",
367 WorkflowStepState {
368 step_id: id.into(),
369 status: TaskStatus::Pending,
370 ..Default::default()
371 },
372 )
373 .await
374 .unwrap();
375 }
376 store
377 .upsert_step(
378 "run-2",
379 WorkflowStepState {
380 step_id: "x".into(),
381 status: TaskStatus::Pending,
382 ..Default::default()
383 },
384 )
385 .await
386 .unwrap();
387 let r1 = store.list_steps("run-1").await.unwrap();
388 let r2 = store.list_steps("run-2").await.unwrap();
389 assert_eq!(
390 r1.iter().map(|s| s.step_id.as_str()).collect::<Vec<_>>(),
391 vec!["a", "b", "c"]
392 );
393 assert_eq!(r2.len(), 1);
394 }
395
396 #[tokio::test]
397 async fn delete_run_cascades_to_steps() {
398 let store = InMemoryWorkflowStore::new();
399 store
400 .create_run(WorkflowExecutionState::new(
401 "run-1",
402 "agent-1",
403 "thread-test",
404 "user-test",
405 sample_def(),
406 ))
407 .await
408 .unwrap();
409 store
410 .upsert_step(
411 "run-1",
412 WorkflowStepState {
413 step_id: "s".into(),
414 ..Default::default()
415 },
416 )
417 .await
418 .unwrap();
419 store.delete_run("run-1").await.unwrap();
420 assert!(store.get_run("run-1").await.unwrap().is_none());
421 assert!(store.list_steps("run-1").await.unwrap().is_empty());
422 }
423}