Skip to main content

distri_workflow/
workflow_store.rs

1//! The single store for workflow execution data.
2//!
3//! Replaces the earlier two-trait split (`WorkflowRunStore` over
4//! `workflow_runs`, `WorkflowStepExecutionStore` over
5//! `workflow_step_executions`) — every workflow needs both run-level
6//! and step-level state, and surfacing them as two trait objects on
7//! the orchestrator was needless complication.
8//!
9//! What lives here:
10//!
11//!   - [`WorkflowExecutionState`] — the run-level extras (definition
12//!     snapshot, entry point, input, shared context) that a bare
13//!     `Task` row can't carry.
14//!   - [`WorkflowStepState`] — per-step extras (status, result, error,
15//!     timestamps, and the optional `wait_task_id` for wait-style steps
16//!     that need to be A2A-addressable).
17//!   - [`WorkflowStore`] trait — a single CRUD surface over both.
18//!
19//! What is **not** here: the run's status, the tree shape, the
20//! tasks/messages/events history. Those live on the canonical `Task`
21//! tree (`TaskStore`). This store is the workflow-specific sidecar —
22//! a workflow run = a `Task` (status + tree) + a `WorkflowStore` entry
23//! (definition + context + step results).
24//!
25//! Implementations are free to use one or two collections internally
26//! (a Redis impl typically uses `wf:run:{id}` JSON + `wf:steps:{id}`
27//! HASH for cheap per-step updates); the trait keeps that invisible.
28
29use crate::types::WorkflowDefinition;
30use chrono::{DateTime, Utc};
31use distri_types::TaskStatus;
32use serde::{Deserialize, Serialize};
33use std::collections::HashMap;
34
35/// Run-level state for one workflow execution. Keyed by `run_task_id`
36/// (the run's root `Task` id).
37///
38/// `thread_id` / `user_id` / `workspace_id` are snapshotted at run
39/// start so resume re-builds an `ExecutorContext` without any task
40/// store lookup (which would itself need tenant context — a chicken
41/// and egg).
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct WorkflowExecutionState {
44    pub run_task_id: String,
45    pub agent_id: String,
46    pub thread_id: String,
47    pub user_id: String,
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub workspace_id: Option<String>,
50    /// Workflow definition snapshotted at run start. Later edits to
51    /// the agent config cannot corrupt an in-flight run.
52    pub definition: WorkflowDefinition,
53    #[serde(default, skip_serializing_if = "Option::is_none")]
54    pub entry_point: Option<String>,
55    #[serde(default)]
56    pub input: serde_json::Value,
57    /// Shared bag steps accumulate results into — read at template
58    /// resolution time as `{steps.X}`, `{input.Y}`, `{env.Z}`.
59    #[serde(default)]
60    pub context: serde_json::Value,
61    pub created_at: DateTime<Utc>,
62    pub updated_at: DateTime<Utc>,
63}
64
65impl WorkflowExecutionState {
66    pub fn new(
67        run_task_id: impl Into<String>,
68        agent_id: impl Into<String>,
69        thread_id: impl Into<String>,
70        user_id: impl Into<String>,
71        definition: WorkflowDefinition,
72    ) -> Self {
73        let now = Utc::now();
74        Self {
75            run_task_id: run_task_id.into(),
76            agent_id: agent_id.into(),
77            thread_id: thread_id.into(),
78            user_id: user_id.into(),
79            workspace_id: None,
80            definition,
81            entry_point: None,
82            input: serde_json::json!({}),
83            context: serde_json::json!({}),
84            created_at: now,
85            updated_at: now,
86        }
87    }
88
89    pub fn with_workspace_id(mut self, workspace_id: Option<String>) -> Self {
90        self.workspace_id = workspace_id;
91        self
92    }
93
94    pub fn with_entry_point(mut self, entry_point: Option<String>) -> Self {
95        self.entry_point = entry_point;
96        self
97    }
98
99    pub fn with_input(mut self, input: serde_json::Value) -> Self {
100        self.input = input;
101        self
102    }
103
104    pub fn with_context(mut self, context: serde_json::Value) -> Self {
105        self.context = context;
106        self
107    }
108}
109
110/// Per-step state. Keyed by `(run_task_id, step_id)`.
111///
112/// `step_id` is the definition-level identifier ("fetch", "summarize");
113/// `wait_task_id` is `Some(task_id)` only for wait-style steps
114/// (`ExternalToolCall`, `WaitForInput`, `WaitForEvent`) that create a
115/// child `Task` in `InputRequired` so external parties can resume them
116/// via `/complete-tool` or A2A `message/send` with `taskId`. Regular
117/// steps execute in-process and have `wait_task_id = None`.
118#[derive(Debug, Clone, Serialize, Deserialize, Default)]
119pub struct WorkflowStepState {
120    pub step_id: String,
121    #[serde(default)]
122    pub status: TaskStatus,
123    #[serde(default, skip_serializing_if = "Option::is_none")]
124    pub result: Option<serde_json::Value>,
125    #[serde(default, skip_serializing_if = "Option::is_none")]
126    pub error: Option<String>,
127    #[serde(default, skip_serializing_if = "Option::is_none")]
128    pub started_at: Option<DateTime<Utc>>,
129    #[serde(default, skip_serializing_if = "Option::is_none")]
130    pub completed_at: Option<DateTime<Utc>>,
131    #[serde(default, skip_serializing_if = "Option::is_none")]
132    pub wait_task_id: Option<String>,
133}
134
135/// Persist and load workflow execution state. One trait, both
136/// run-level and step-level CRUD — keeps orchestrator wiring to a
137/// single field. Implementations: in-memory (tests + OSS server-cli),
138/// Redis (cloud).
139#[async_trait::async_trait]
140pub trait WorkflowStore: Send + Sync {
141    // ── run-level ──────────────────────────────────────────────────
142
143    /// Insert a new run record. Existing record under the same
144    /// `run_task_id` is overwritten (treat the call as create-or-resume).
145    async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()>;
146
147    /// Load the run-level state for a workflow.
148    async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>>;
149
150    /// Update the shared `context` bag (called after step results
151    /// merge in). Other run-level fields are immutable for the life
152    /// of the run.
153    async fn update_context(
154        &self,
155        run_task_id: &str,
156        context: serde_json::Value,
157    ) -> anyhow::Result<()>;
158
159    /// Drop a run and all its step rows.
160    async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()>;
161
162    // ── step-level ─────────────────────────────────────────────────
163
164    /// Insert or update one step's state under a run.
165    async fn upsert_step(
166        &self,
167        run_task_id: &str,
168        step: WorkflowStepState,
169    ) -> anyhow::Result<()>;
170
171    /// Load one step's state.
172    async fn get_step(
173        &self,
174        run_task_id: &str,
175        step_id: &str,
176    ) -> anyhow::Result<Option<WorkflowStepState>>;
177
178    /// List all step states for a run, in insertion order.
179    async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>>;
180}
181
182/// In-memory [`WorkflowStore`] for tests and the standalone OSS
183/// runner. Two HashMaps wrapped in one struct — the trait surface
184/// stays singular regardless of the internal layout.
185#[derive(Default)]
186pub struct InMemoryWorkflowStore {
187    runs: std::sync::Mutex<HashMap<String, WorkflowExecutionState>>,
188    /// `run_task_id -> (step_id -> WorkflowStepState)` — preserves
189    /// step insertion order via `Vec`-backed map.
190    steps: std::sync::Mutex<HashMap<String, Vec<WorkflowStepState>>>,
191}
192
193impl InMemoryWorkflowStore {
194    pub fn new() -> Self {
195        Self::default()
196    }
197}
198
199#[async_trait::async_trait]
200impl WorkflowStore for InMemoryWorkflowStore {
201    async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()> {
202        let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
203        runs.insert(state.run_task_id.clone(), state);
204        Ok(())
205    }
206
207    async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>> {
208        let runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
209        Ok(runs.get(run_task_id).cloned())
210    }
211
212    async fn update_context(
213        &self,
214        run_task_id: &str,
215        context: serde_json::Value,
216    ) -> anyhow::Result<()> {
217        let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
218        let row = runs
219            .get_mut(run_task_id)
220            .ok_or_else(|| anyhow::anyhow!("workflow run not found: {run_task_id}"))?;
221        row.context = context;
222        row.updated_at = Utc::now();
223        Ok(())
224    }
225
226    async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()> {
227        let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
228        let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
229        runs.remove(run_task_id);
230        steps.remove(run_task_id);
231        Ok(())
232    }
233
234    async fn upsert_step(
235        &self,
236        run_task_id: &str,
237        step: WorkflowStepState,
238    ) -> anyhow::Result<()> {
239        let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
240        let bucket = steps.entry(run_task_id.to_string()).or_default();
241        if let Some(existing) = bucket.iter_mut().find(|s| s.step_id == step.step_id) {
242            *existing = step;
243        } else {
244            bucket.push(step);
245        }
246        Ok(())
247    }
248
249    async fn get_step(
250        &self,
251        run_task_id: &str,
252        step_id: &str,
253    ) -> anyhow::Result<Option<WorkflowStepState>> {
254        let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
255        Ok(steps
256            .get(run_task_id)
257            .and_then(|bucket| bucket.iter().find(|s| s.step_id == step_id).cloned()))
258    }
259
260    async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>> {
261        let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
262        Ok(steps.get(run_task_id).cloned().unwrap_or_default())
263    }
264}
265
266#[cfg(test)]
267mod tests {
268    use super::*;
269    use crate::types::{WorkflowDefinition, WorkflowStep};
270
271    fn sample_def() -> WorkflowDefinition {
272        WorkflowDefinition::new(vec![WorkflowStep::checkpoint("c", "Checkpoint", "ok")])
273    }
274
275    #[tokio::test]
276    async fn create_get_roundtrip() {
277        let store = InMemoryWorkflowStore::new();
278        let state = WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
279            .with_entry_point(Some("main".into()))
280            .with_input(serde_json::json!({"x": 1}));
281        store.create_run(state).await.unwrap();
282        let got = store.get_run("run-1").await.unwrap().unwrap();
283        assert_eq!(got.agent_id, "agent-1");
284        assert_eq!(got.entry_point.as_deref(), Some("main"));
285        assert_eq!(got.input, serde_json::json!({"x": 1}));
286    }
287
288    #[tokio::test]
289    async fn update_context_mutates_only_context() {
290        let store = InMemoryWorkflowStore::new();
291        store
292            .create_run(
293                WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
294                    .with_input(serde_json::json!({"x": 1})),
295            )
296            .await
297            .unwrap();
298        let new_ctx = serde_json::json!({"steps": {"c": "ok"}});
299        store
300            .update_context("run-1", new_ctx.clone())
301            .await
302            .unwrap();
303        let got = store.get_run("run-1").await.unwrap().unwrap();
304        assert_eq!(got.context, new_ctx);
305        assert_eq!(got.input, serde_json::json!({"x": 1}));
306    }
307
308    #[tokio::test]
309    async fn upsert_step_insert_then_update() {
310        let store = InMemoryWorkflowStore::new();
311        let s1 = WorkflowStepState {
312            step_id: "fetch".into(),
313            status: TaskStatus::Running,
314            ..Default::default()
315        };
316        store.upsert_step("run-1", s1).await.unwrap();
317        let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
318        assert_eq!(got.status, TaskStatus::Running);
319
320        let s2 = WorkflowStepState {
321            step_id: "fetch".into(),
322            status: TaskStatus::Completed,
323            result: Some(serde_json::json!({"docs": []})),
324            ..Default::default()
325        };
326        store.upsert_step("run-1", s2).await.unwrap();
327        let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
328        assert_eq!(got.status, TaskStatus::Completed);
329        assert!(got.result.is_some());
330    }
331
332    #[tokio::test]
333    async fn list_steps_preserves_insertion_order_and_is_per_run() {
334        let store = InMemoryWorkflowStore::new();
335        for id in ["a", "b", "c"] {
336            store
337                .upsert_step(
338                    "run-1",
339                    WorkflowStepState {
340                        step_id: id.into(),
341                        status: TaskStatus::Pending,
342                        ..Default::default()
343                    },
344                )
345                .await
346                .unwrap();
347        }
348        store
349            .upsert_step(
350                "run-2",
351                WorkflowStepState {
352                    step_id: "x".into(),
353                    status: TaskStatus::Pending,
354                    ..Default::default()
355                },
356            )
357            .await
358            .unwrap();
359        let r1 = store.list_steps("run-1").await.unwrap();
360        let r2 = store.list_steps("run-2").await.unwrap();
361        assert_eq!(r1.iter().map(|s| s.step_id.as_str()).collect::<Vec<_>>(), vec!["a", "b", "c"]);
362        assert_eq!(r2.len(), 1);
363    }
364
365    #[tokio::test]
366    async fn delete_run_cascades_to_steps() {
367        let store = InMemoryWorkflowStore::new();
368        store
369            .create_run(WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def()))
370            .await
371            .unwrap();
372        store
373            .upsert_step(
374                "run-1",
375                WorkflowStepState {
376                    step_id: "s".into(),
377                    ..Default::default()
378                },
379            )
380            .await
381            .unwrap();
382        store.delete_run("run-1").await.unwrap();
383        assert!(store.get_run("run-1").await.unwrap().is_none());
384        assert!(store.list_steps("run-1").await.unwrap().is_empty());
385    }
386}