use crate::types::WorkflowDefinition;
use chrono::{DateTime, Utc};
use distri_types::TaskStatus;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowExecutionState {
pub run_task_id: String,
pub agent_id: String,
pub thread_id: String,
pub user_id: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub workspace_id: Option<String>,
pub definition: WorkflowDefinition,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub entry_point: Option<String>,
#[serde(default)]
pub input: serde_json::Value,
#[serde(default)]
pub context: serde_json::Value,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl WorkflowExecutionState {
pub fn new(
run_task_id: impl Into<String>,
agent_id: impl Into<String>,
thread_id: impl Into<String>,
user_id: impl Into<String>,
definition: WorkflowDefinition,
) -> Self {
let now = Utc::now();
Self {
run_task_id: run_task_id.into(),
agent_id: agent_id.into(),
thread_id: thread_id.into(),
user_id: user_id.into(),
workspace_id: None,
definition,
entry_point: None,
input: serde_json::json!({}),
context: serde_json::json!({}),
created_at: now,
updated_at: now,
}
}
pub fn with_workspace_id(mut self, workspace_id: Option<String>) -> Self {
self.workspace_id = workspace_id;
self
}
pub fn with_entry_point(mut self, entry_point: Option<String>) -> Self {
self.entry_point = entry_point;
self
}
pub fn with_input(mut self, input: serde_json::Value) -> Self {
self.input = input;
self
}
pub fn with_context(mut self, context: serde_json::Value) -> Self {
self.context = context;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct WorkflowStepState {
pub step_id: String,
#[serde(default)]
pub status: TaskStatus,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub result: Option<serde_json::Value>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub started_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub completed_at: Option<DateTime<Utc>>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub wait_task_id: Option<String>,
}
#[async_trait::async_trait]
pub trait WorkflowStore: Send + Sync {
async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()>;
async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>>;
async fn update_context(
&self,
run_task_id: &str,
context: serde_json::Value,
) -> anyhow::Result<()>;
async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()>;
async fn upsert_step(
&self,
run_task_id: &str,
step: WorkflowStepState,
) -> anyhow::Result<()>;
async fn get_step(
&self,
run_task_id: &str,
step_id: &str,
) -> anyhow::Result<Option<WorkflowStepState>>;
async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>>;
}
#[derive(Default)]
pub struct InMemoryWorkflowStore {
runs: std::sync::Mutex<HashMap<String, WorkflowExecutionState>>,
steps: std::sync::Mutex<HashMap<String, Vec<WorkflowStepState>>>,
}
impl InMemoryWorkflowStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait::async_trait]
impl WorkflowStore for InMemoryWorkflowStore {
async fn create_run(&self, state: WorkflowExecutionState) -> anyhow::Result<()> {
let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
runs.insert(state.run_task_id.clone(), state);
Ok(())
}
async fn get_run(&self, run_task_id: &str) -> anyhow::Result<Option<WorkflowExecutionState>> {
let runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
Ok(runs.get(run_task_id).cloned())
}
async fn update_context(
&self,
run_task_id: &str,
context: serde_json::Value,
) -> anyhow::Result<()> {
let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
let row = runs
.get_mut(run_task_id)
.ok_or_else(|| anyhow::anyhow!("workflow run not found: {run_task_id}"))?;
row.context = context;
row.updated_at = Utc::now();
Ok(())
}
async fn delete_run(&self, run_task_id: &str) -> anyhow::Result<()> {
let mut runs = self.runs.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
runs.remove(run_task_id);
steps.remove(run_task_id);
Ok(())
}
async fn upsert_step(
&self,
run_task_id: &str,
step: WorkflowStepState,
) -> anyhow::Result<()> {
let mut steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
let bucket = steps.entry(run_task_id.to_string()).or_default();
if let Some(existing) = bucket.iter_mut().find(|s| s.step_id == step.step_id) {
*existing = step;
} else {
bucket.push(step);
}
Ok(())
}
async fn get_step(
&self,
run_task_id: &str,
step_id: &str,
) -> anyhow::Result<Option<WorkflowStepState>> {
let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
Ok(steps
.get(run_task_id)
.and_then(|bucket| bucket.iter().find(|s| s.step_id == step_id).cloned()))
}
async fn list_steps(&self, run_task_id: &str) -> anyhow::Result<Vec<WorkflowStepState>> {
let steps = self.steps.lock().map_err(|e| anyhow::anyhow!(e.to_string()))?;
Ok(steps.get(run_task_id).cloned().unwrap_or_default())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{WorkflowDefinition, WorkflowStep};
fn sample_def() -> WorkflowDefinition {
WorkflowDefinition::new(vec![WorkflowStep::checkpoint("c", "Checkpoint", "ok")])
}
#[tokio::test]
async fn create_get_roundtrip() {
let store = InMemoryWorkflowStore::new();
let state = WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
.with_entry_point(Some("main".into()))
.with_input(serde_json::json!({"x": 1}));
store.create_run(state).await.unwrap();
let got = store.get_run("run-1").await.unwrap().unwrap();
assert_eq!(got.agent_id, "agent-1");
assert_eq!(got.entry_point.as_deref(), Some("main"));
assert_eq!(got.input, serde_json::json!({"x": 1}));
}
#[tokio::test]
async fn update_context_mutates_only_context() {
let store = InMemoryWorkflowStore::new();
store
.create_run(
WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def())
.with_input(serde_json::json!({"x": 1})),
)
.await
.unwrap();
let new_ctx = serde_json::json!({"steps": {"c": "ok"}});
store
.update_context("run-1", new_ctx.clone())
.await
.unwrap();
let got = store.get_run("run-1").await.unwrap().unwrap();
assert_eq!(got.context, new_ctx);
assert_eq!(got.input, serde_json::json!({"x": 1}));
}
#[tokio::test]
async fn upsert_step_insert_then_update() {
let store = InMemoryWorkflowStore::new();
let s1 = WorkflowStepState {
step_id: "fetch".into(),
status: TaskStatus::Running,
..Default::default()
};
store.upsert_step("run-1", s1).await.unwrap();
let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
assert_eq!(got.status, TaskStatus::Running);
let s2 = WorkflowStepState {
step_id: "fetch".into(),
status: TaskStatus::Completed,
result: Some(serde_json::json!({"docs": []})),
..Default::default()
};
store.upsert_step("run-1", s2).await.unwrap();
let got = store.get_step("run-1", "fetch").await.unwrap().unwrap();
assert_eq!(got.status, TaskStatus::Completed);
assert!(got.result.is_some());
}
#[tokio::test]
async fn list_steps_preserves_insertion_order_and_is_per_run() {
let store = InMemoryWorkflowStore::new();
for id in ["a", "b", "c"] {
store
.upsert_step(
"run-1",
WorkflowStepState {
step_id: id.into(),
status: TaskStatus::Pending,
..Default::default()
},
)
.await
.unwrap();
}
store
.upsert_step(
"run-2",
WorkflowStepState {
step_id: "x".into(),
status: TaskStatus::Pending,
..Default::default()
},
)
.await
.unwrap();
let r1 = store.list_steps("run-1").await.unwrap();
let r2 = store.list_steps("run-2").await.unwrap();
assert_eq!(r1.iter().map(|s| s.step_id.as_str()).collect::<Vec<_>>(), vec!["a", "b", "c"]);
assert_eq!(r2.len(), 1);
}
#[tokio::test]
async fn delete_run_cascades_to_steps() {
let store = InMemoryWorkflowStore::new();
store
.create_run(WorkflowExecutionState::new("run-1", "agent-1", "thread-test", "user-test", sample_def()))
.await
.unwrap();
store
.upsert_step(
"run-1",
WorkflowStepState {
step_id: "s".into(),
..Default::default()
},
)
.await
.unwrap();
store.delete_run("run-1").await.unwrap();
assert!(store.get_run("run-1").await.unwrap().is_none());
assert!(store.list_steps("run-1").await.unwrap().is_empty());
}
}