Skip to main content

distri_workflow/
workflow_store.rs

1//! The single store for workflow execution data.
2//!
3//! Replaces the earlier two-trait split (`WorkflowRunStore` over
4//! `workflow_runs`, `WorkflowStepExecutionStore` over
5//! `workflow_step_executions`) — every workflow needs both run-level
6//! and step-level state, and surfacing them as two trait objects on
7//! the orchestrator was needless complication.
8//!
9//! What lives here:
10//!
11//!   - [`WorkflowExecutionState`] — the run-level extras (definition
12//!     snapshot, entry point, input, shared context) that a bare
13//!     `Task` row can't carry.
14//!   - [`WorkflowStepState`] — per-step extras (status, result, error,
15//!     timestamps, and the optional `wait_task_id` for wait-style steps
16//!     that need to be A2A-addressable).
17//!   - [`WorkflowStore`] trait — a single CRUD surface over both.
18//!
19//! What is **not** here: the run's status, the tree shape, the
20//! tasks/messages/events history. Those live on the canonical `Task`
21//! tree (`TaskStore`). This store is the workflow-specific sidecar —
22//! a workflow run = a `Task` (status + tree) + a `WorkflowStore` entry
23//! (definition + context + step results).
24//!
25//! Implementations are free to use one or two collections internally
26//! (a Redis impl typically uses `wf:run:{id}` JSON + `wf:steps:{id}`
27//! HASH for cheap per-step updates); the trait keeps that invisible.
28
29use crate::types::WorkflowDefinition;
30use chrono::{DateTime, Utc};
31use distri_types::TaskStatus;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35/// Run-level state for one workflow execution. Keyed by `run_task_id`
36/// (the run's root `Task` id).
37///
38/// `thread_id` / `user_id` / `workspace_id` are snapshotted at run
39/// start so resume re-builds an `ExecutorContext` without any task
40/// store lookup (which would itself need tenant context — a chicken
41/// and egg).
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct WorkflowExecutionState {
44    pub run_task_id: String,
45    pub agent_id: String,
46    pub thread_id: String,
47    pub user_id: String,
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub workspace_id: Option<String>,
50    /// Workflow definition snapshotted at run start. Later edits to
51    /// the agent config cannot corrupt an in-flight run.
52    pub definition: WorkflowDefinition,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub entry_point: Option<String>,
55    #[serde(default)]
56    pub input: serde_json::Value,
57    /// Shared bag steps accumulate results into — read at template
58    /// resolution time as `{steps.X}`, `{input.Y}`, `{env.Z}`.
59    #[serde(default)]
60    pub context: serde_json::Value,
61    pub created_at: DateTime<Utc>,
62    pub updated_at: DateTime<Utc>,
63}
64
65impl WorkflowExecutionState {
66    pub fn new(
67        run_task_id: impl Into<String>,
68        agent_id: impl Into<String>,
69        thread_id: impl Into<String>,
70        user_id: impl Into<String>,
71        definition: WorkflowDefinition,
72    ) -> Self {
73        let now = Utc::now();
74        Self {
75            run_task_id: run_task_id.into(),
76            agent_id: agent_id.into(),
77            thread_id: thread_id.into(),
78            user_id: user_id.into(),
79            workspace_id: None,
80            definition,
81            entry_point: None,
82            input: serde_json::json!({}),
83            context: serde_json::json!({}),
84            created_at: now,
85            updated_at: now,
86        }
87    }
88
89    pub fn with_workspace_id(mut self, workspace_id: Option<String>) -> Self {
90        self.workspace_id = workspace_id;
91        self
92    }
93
94    pub fn with_entry_point(mut self, entry_point: Option<String>) -> Self {
95        self.entry_point = entry_point;
96        self
97    }
98
99    pub fn with_input(mut self, input: serde_json::Value) -> Self {
100        self.input = input;
101        self
102    }
103
104    pub fn with_context(mut self, context: serde_json::Value) -> Self {
105        self.context = context;
106        self
107    }
108}
109
110/// Per-step state. Keyed by `(run_task_id, step_id)`.
111///
112/// `step_id` is the definition-level identifier ("fetch", "summarize");
113/// `wait_task_id` is `Some(task_id)` only for wait-style steps
114/// (`ExternalToolCall`, `WaitForInput`, `WaitForEvent`) that create a
115/// child `Task` in `InputRequired` so external parties can resume them
116/// via `/complete-tool` or A2A `message/send` with `taskId`. Regular
117/// steps execute in-process and have `wait_task_id = None`.
118#[derive(Debug, Clone, Serialize, Deserialize, Default)]
119pub struct WorkflowStepState {
120    pub step_id: String,
121    #[serde(default)]
122    pub status: TaskStatus,
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    pub result: Option<serde_json::Value>,
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub error: Option<String>,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    pub started_at: Option<DateTime<Utc>>,
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub completed_at: Option<DateTime<Utc>>,
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub wait_task_id: Option<String>,
133}
134
135/// Persist and load workflow execution state. One trait, both
136/// run-level and step-level CRUD — keeps orchestrator wiring to a
137/// single field. Implementations: in-memory (tests + OSS server-cli),
138/// Redis (cloud).
139#[async_trait::async_trait]
140pub trait WorkflowStore: Send + Sync {
141    // ── run-level ──────────────────────────────────────────────────
142
143    /// Insert a new run record. Existing record under the same
144    /// `run_task_id` is overwritten (treat the call as create-or-resume).
145    async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()>;
146
147    /// Load the run-level state for a workflow.
148    async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>>;
149
150    /// Update the shared `context` bag (called after step results
151    /// merge in). Other run-level fields are immutable for the life
152    /// of the run.
153    async fn update_context(
154        &self,
155        run_task_id: &str,
156        context: serde_json::Value,
157    ) -> anyhow::Result<()>;
158
159    /// Drop a run and all its step rows.
160    async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()>;
161
162    // ── step-level ─────────────────────────────────────────────────
163
164    /// Insert or update one step's state under a run.
165    async fn upsert_step(&self, run_task_id: &str, step: WorkflowStepState) -> anyhow::Result<()>;
166
167    /// Load one step's state.
168    async fn get_step(
169        &self,
170        run_task_id: &str,
171        step_id: &str,
172    ) -> anyhow::Result<Option<WorkflowStepState>>;
173
174    /// List all step states for a run, in insertion order.
175    async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>>;
176}
177
178/// In-memory [`WorkflowStore`] for tests and the standalone OSS
179/// runner. Two HashMaps wrapped in one struct — the trait surface
180/// stays singular regardless of the internal layout.
181#[derive(Default)]
182pub struct InMemoryWorkflowStore {
183    runs: std::sync::Mutex<HashMap<String, WorkflowExecutionState>>,
184    /// `run_task_id -> (step_id -> WorkflowStepState)` — preserves
185    /// step insertion order via `Vec`-backed map.
186    steps: std::sync::Mutex<HashMap<String, Vec<WorkflowStepState>>>,
187}
188
189impl InMemoryWorkflowStore {
190    pub fn new() -> Self {
191        Self::default()
192    }
193}
194
195#[async_trait::async_trait]
196impl WorkflowStore for InMemoryWorkflowStore {
197    async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()> {
198        let mut runs = self
199            .runs
200            .lock()
201            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
202        runs.insert(state.run_task_id.clone(), state);
203        Ok(())
204    }
205
206    async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>> {
207        let runs = self
208            .runs
209            .lock()
210            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
211        Ok(runs.get(run_task_id).cloned())
212    }
213
214    async fn update_context(
215        &self,
216        run_task_id: &str,
217        context: serde_json::Value,
218    ) -> anyhow::Result<()> {
219        let mut runs = self
220            .runs
221            .lock()
222            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
223        let row = runs
224            .get_mut(run_task_id)
225            .ok_or_else(|| anyhow::anyhow!("workflow run not found: {run_task_id}"))?;
226        row.context = context;
227        row.updated_at = Utc::now();
228        Ok(())
229    }
230
231    async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()> {
232        let mut runs = self
233            .runs
234            .lock()
235            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
236        let mut steps = self
237            .steps
238            .lock()
239            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
240        runs.remove(run_task_id);
241        steps.remove(run_task_id);
242        Ok(())
243    }
244
245    async fn upsert_step(&self, run_task_id: &str, step: WorkflowStepState) -> anyhow::Result<()> {
246        let mut steps = self
247            .steps
248            .lock()
249            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
250        let bucket = steps.entry(run_task_id.to_string()).or_default();
251        if let Some(existing) = bucket.iter_mut().find(|s| s.step_id == step.step_id) {
252            *existing = step;
253        } else {
254            bucket.push(step);
255        }
256        Ok(())
257    }
258
259    async fn get_step(
260        &self,
261        run_task_id: &str,
262        step_id: &str,
263    ) -> anyhow::Result<Option<WorkflowStepState>> {
264        let steps = self
265            .steps
266            .lock()
267            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
268        Ok(steps
269            .get(run_task_id)
270            .and_then(|bucket| bucket.iter().find(|s| s.step_id == step_id).cloned()))
271    }
272
273    async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>> {
274        let steps = self
275            .steps
276            .lock()
277            .map_err(|e| anyhow::anyhow!(e.to_string()))?;
278        Ok(steps.get(run_task_id).cloned().unwrap_or_default())
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::types::{WorkflowDefinition, WorkflowStep};
286
287    fn sample_def() -> WorkflowDefinition {
288        WorkflowDefinition::new(vec![WorkflowStep::checkpoint("c", "Checkpoint", "ok")])
289    }
290
291    #[tokio::test]
292    async fn create_get_roundtrip() {
293        let store = InMemoryWorkflowStore::new();
294        let state = WorkflowExecutionState::new(
295            "run-1",
296            "agent-1",
297            "thread-test",
298            "user-test",
299            sample_def(),
300        )
301        .with_entry_point(Some("main".into()))
302        .with_input(serde_json::json!({"x": 1}));
303        store.create_run(state).await.unwrap();
304        let got = store.get_run("run-1").await.unwrap().unwrap();
305        assert_eq!(got.agent_id, "agent-1");
306        assert_eq!(got.entry_point.as_deref(), Some("main"));
307        assert_eq!(got.input, serde_json::json!({"x": 1}));
308    }
309
310    #[tokio::test]
311    async fn update_context_mutates_only_context() {
312        let store = InMemoryWorkflowStore::new();
313        store
314            .create_run(
315                WorkflowExecutionState::new(
316                    "run-1",
317                    "agent-1",
318                    "thread-test",
319                    "user-test",
320                    sample_def(),
321                )
322                .with_input(serde_json::json!({"x": 1})),
323            )
324            .await
325            .unwrap();
326        let new_ctx = serde_json::json!({"steps": {"c": "ok"}});
327        store
328            .update_context("run-1", new_ctx.clone())
329            .await
330            .unwrap();
331        let got = store.get_run("run-1").await.unwrap().unwrap();
332        assert_eq!(got.context, new_ctx);
333        assert_eq!(got.input, serde_json::json!({"x": 1}));
334    }
335
336    #[tokio::test]
337    async fn upsert_step_insert_then_update() {
338        let store = InMemoryWorkflowStore::new();
339        let s1 = WorkflowStepState {
340            step_id: "fetch".into(),
341            status: TaskStatus::Running,
342            ..Default::default()
343        };
344        store.upsert_step("run-1", s1).await.unwrap();
345        let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
346        assert_eq!(got.status, TaskStatus::Running);
347
348        let s2 = WorkflowStepState {
349            step_id: "fetch".into(),
350            status: TaskStatus::Completed,
351            result: Some(serde_json::json!({"docs": []})),
352            ..Default::default()
353        };
354        store.upsert_step("run-1", s2).await.unwrap();
355        let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
356        assert_eq!(got.status, TaskStatus::Completed);
357        assert!(got.result.is_some());
358    }
359
360    #[tokio::test]
361    async fn list_steps_preserves_insertion_order_and_is_per_run() {
362        let store = InMemoryWorkflowStore::new();
363        for id in ["a", "b", "c"] {
364            store
365                .upsert_step(
366                    "run-1",
367                    WorkflowStepState {
368                        step_id: id.into(),
369                        status: TaskStatus::Pending,
370                        ..Default::default()
371                    },
372                )
373                .await
374                .unwrap();
375        }
376        store
377            .upsert_step(
378                "run-2",
379                WorkflowStepState {
380                    step_id: "x".into(),
381                    status: TaskStatus::Pending,
382                    ..Default::default()
383                },
384            )
385            .await
386            .unwrap();
387        let r1 = store.list_steps("run-1").await.unwrap();
388        let r2 = store.list_steps("run-2").await.unwrap();
389        assert_eq!(
390            r1.iter().map(|s| s.step_id.as_str()).collect::<Vec<_>>(),
391            vec!["a", "b", "c"]
392        );
393        assert_eq!(r2.len(), 1);
394    }
395
396    #[tokio::test]
397    async fn delete_run_cascades_to_steps() {
398        let store = InMemoryWorkflowStore::new();
399        store
400            .create_run(WorkflowExecutionState::new(
401                "run-1",
402                "agent-1",
403                "thread-test",
404                "user-test",
405                sample_def(),
406            ))
407            .await
408            .unwrap();
409        store
410            .upsert_step(
411                "run-1",
412                WorkflowStepState {
413                    step_id: "s".into(),
414                    ..Default::default()
415                },
416            )
417            .await
418            .unwrap();
419        store.delete_run("run-1").await.unwrap();
420        assert!(store.get_run("run-1").await.unwrap().is_none());
421        assert!(store.list_steps("run-1").await.unwrap().is_empty());
422    }
423}