use std::collections::HashSet;
use std::path::PathBuf;
use std::sync::Arc;
use anyhow::{Context, Result};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub task_id: String,
pub agent_id: String,
pub step_index: u32,
pub completed_tool_ids: HashSet<String>,
pub side_effects_log: Vec<SideEffectRecord>,
pub updated_at: i64,
}
impl WorkflowCheckpoint {
pub fn new(task_id: impl Into<String>, agent_id: impl Into<String>) -> Self {
Self {
task_id: task_id.into(),
agent_id: agent_id.into(),
step_index: 0,
completed_tool_ids: HashSet::new(),
side_effects_log: Vec::new(),
updated_at: chrono::Utc::now().timestamp(),
}
}
pub fn is_completed(&self, tool_use_id: &str) -> bool {
self.completed_tool_ids.contains(tool_use_id)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SideEffectRecord {
pub tool_use_id: String,
pub tool_name: String,
pub target: Option<String>,
pub completed_at: i64,
pub reversible: bool,
}
impl SideEffectRecord {
pub fn new(
tool_use_id: impl Into<String>,
tool_name: impl Into<String>,
target: Option<String>,
reversible: bool,
) -> Self {
Self {
tool_use_id: tool_use_id.into(),
tool_name: tool_name.into(),
target,
completed_at: chrono::Utc::now().timestamp(),
reversible,
}
}
}
#[async_trait]
pub trait WorkflowStateStore: Send + Sync {
async fn save_checkpoint(&self, cp: &WorkflowCheckpoint) -> Result<()>;
async fn load_checkpoint(&self, task_id: &str) -> Result<Option<WorkflowCheckpoint>>;
async fn mark_step_complete(
&self,
task_id: &str,
tool_use_id: &str,
effect: SideEffectRecord,
) -> Result<()>;
async fn delete_checkpoint(&self, task_id: &str) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct FsWorkflowStateStore {
dir: PathBuf,
}
impl FsWorkflowStateStore {
pub fn default_path() -> Result<PathBuf> {
let home = dirs::home_dir().context("cannot determine home directory")?;
Ok(home.join(".brainwires").join("workflow"))
}
pub fn new(dir: PathBuf) -> Result<Self> {
std::fs::create_dir_all(&dir)
.with_context(|| format!("cannot create workflow state dir: {}", dir.display()))?;
Ok(Self { dir })
}
pub fn with_default_path() -> Result<Self> {
Self::new(Self::default_path()?)
}
fn checkpoint_path(&self, task_id: &str) -> PathBuf {
let safe_id: String = task_id
.chars()
.map(|c| {
if c.is_alphanumeric() || c == '-' || c == '_' {
c
} else {
'_'
}
})
.collect();
self.dir.join(format!("{safe_id}.json"))
}
}
#[async_trait]
impl WorkflowStateStore for FsWorkflowStateStore {
async fn save_checkpoint(&self, cp: &WorkflowCheckpoint) -> Result<()> {
let path = self.checkpoint_path(&cp.task_id);
let tmp = path.with_extension("json.tmp");
let json = serde_json::to_string_pretty(cp).context("serialize checkpoint")?;
tokio::fs::write(&tmp, &json)
.await
.with_context(|| format!("write checkpoint tmp: {}", tmp.display()))?;
tokio::fs::rename(&tmp, &path)
.await
.with_context(|| format!("rename checkpoint: {}", path.display()))?;
Ok(())
}
async fn load_checkpoint(&self, task_id: &str) -> Result<Option<WorkflowCheckpoint>> {
let path = self.checkpoint_path(task_id);
match tokio::fs::read_to_string(&path).await {
Ok(json) => {
let cp: WorkflowCheckpoint =
serde_json::from_str(&json).context("deserialize checkpoint")?;
Ok(Some(cp))
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(None),
Err(e) => Err(e).with_context(|| format!("read checkpoint: {}", path.display())),
}
}
async fn mark_step_complete(
&self,
task_id: &str,
tool_use_id: &str,
effect: SideEffectRecord,
) -> Result<()> {
let mut cp = self
.load_checkpoint(task_id)
.await?
.unwrap_or_else(|| WorkflowCheckpoint::new(task_id, "unknown"));
cp.completed_tool_ids.insert(tool_use_id.to_string());
cp.side_effects_log.push(effect);
cp.step_index += 1;
cp.updated_at = chrono::Utc::now().timestamp();
self.save_checkpoint(&cp).await
}
async fn delete_checkpoint(&self, task_id: &str) -> Result<()> {
let path = self.checkpoint_path(task_id);
match tokio::fs::remove_file(&path).await {
Ok(()) => Ok(()),
Err(e) if e.kind() == std::io::ErrorKind::NotFound => Ok(()),
Err(e) => Err(e).with_context(|| format!("delete checkpoint: {}", path.display())),
}
}
}
#[derive(Debug, Default)]
pub struct InMemoryWorkflowStateStore {
checkpoints: Arc<Mutex<std::collections::HashMap<String, WorkflowCheckpoint>>>,
}
impl InMemoryWorkflowStateStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl WorkflowStateStore for InMemoryWorkflowStateStore {
async fn save_checkpoint(&self, cp: &WorkflowCheckpoint) -> Result<()> {
self.checkpoints
.lock()
.await
.insert(cp.task_id.clone(), cp.clone());
Ok(())
}
async fn load_checkpoint(&self, task_id: &str) -> Result<Option<WorkflowCheckpoint>> {
Ok(self.checkpoints.lock().await.get(task_id).cloned())
}
async fn mark_step_complete(
&self,
task_id: &str,
tool_use_id: &str,
effect: SideEffectRecord,
) -> Result<()> {
let mut map = self.checkpoints.lock().await;
let cp = map
.entry(task_id.to_string())
.or_insert_with(|| WorkflowCheckpoint::new(task_id, "unknown"));
cp.completed_tool_ids.insert(tool_use_id.to_string());
cp.side_effects_log.push(effect);
cp.step_index += 1;
cp.updated_at = chrono::Utc::now().timestamp();
Ok(())
}
async fn delete_checkpoint(&self, task_id: &str) -> Result<()> {
self.checkpoints.lock().await.remove(task_id);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn in_memory_roundtrip() {
let store = InMemoryWorkflowStateStore::new();
assert!(store.load_checkpoint("t1").await.unwrap().is_none());
let cp = WorkflowCheckpoint::new("t1", "agent-1");
store.save_checkpoint(&cp).await.unwrap();
let loaded = store.load_checkpoint("t1").await.unwrap().unwrap();
assert_eq!(loaded.task_id, "t1");
assert_eq!(loaded.agent_id, "agent-1");
}
#[tokio::test]
async fn mark_step_and_skip() {
let store = InMemoryWorkflowStateStore::new();
let effect = SideEffectRecord::new("use-1", "write_file", Some("src/main.rs".into()), true);
store
.mark_step_complete("t2", "use-1", effect)
.await
.unwrap();
let cp = store.load_checkpoint("t2").await.unwrap().unwrap();
assert!(cp.is_completed("use-1"));
assert!(!cp.is_completed("use-2"));
assert_eq!(cp.step_index, 1);
}
#[tokio::test]
async fn delete_removes_checkpoint() {
let store = InMemoryWorkflowStateStore::new();
let cp = WorkflowCheckpoint::new("t3", "a");
store.save_checkpoint(&cp).await.unwrap();
store.delete_checkpoint("t3").await.unwrap();
assert!(store.load_checkpoint("t3").await.unwrap().is_none());
}
#[tokio::test]
async fn fs_save_and_load_roundtrip() {
let dir = tempfile::tempdir().unwrap();
let store = FsWorkflowStateStore::new(dir.path().to_path_buf()).unwrap();
assert!(store.load_checkpoint("task-a").await.unwrap().is_none());
let cp = WorkflowCheckpoint::new("task-a", "agent-x");
store.save_checkpoint(&cp).await.unwrap();
let loaded = store.load_checkpoint("task-a").await.unwrap().unwrap();
assert_eq!(loaded.task_id, "task-a");
assert_eq!(loaded.agent_id, "agent-x");
}
#[tokio::test]
async fn fs_atomic_write_produces_no_tmp_file_after_save() {
let dir = tempfile::tempdir().unwrap();
let store = FsWorkflowStateStore::new(dir.path().to_path_buf()).unwrap();
let cp = WorkflowCheckpoint::new("atomic-task", "a");
store.save_checkpoint(&cp).await.unwrap();
let tmp = dir.path().join("atomic-task.json.tmp");
assert!(!tmp.exists(), ".tmp file should be gone after rename");
let real = dir.path().join("atomic-task.json");
assert!(real.exists());
}
#[tokio::test]
async fn fs_mark_step_creates_checkpoint_implicitly() {
let dir = tempfile::tempdir().unwrap();
let store = FsWorkflowStateStore::new(dir.path().to_path_buf()).unwrap();
let effect = SideEffectRecord::new("use-99", "write_file", Some("foo.rs".into()), true);
store
.mark_step_complete("fresh-task", "use-99", effect)
.await
.unwrap();
let cp = store.load_checkpoint("fresh-task").await.unwrap().unwrap();
assert!(cp.is_completed("use-99"));
assert_eq!(cp.step_index, 1);
assert_eq!(cp.side_effects_log.len(), 1);
assert_eq!(cp.side_effects_log[0].tool_name, "write_file");
}
#[tokio::test]
async fn fs_delete_is_idempotent() {
let dir = tempfile::tempdir().unwrap();
let store = FsWorkflowStateStore::new(dir.path().to_path_buf()).unwrap();
let cp = WorkflowCheckpoint::new("del-task", "a");
store.save_checkpoint(&cp).await.unwrap();
store.delete_checkpoint("del-task").await.unwrap();
store.delete_checkpoint("del-task").await.unwrap();
store.delete_checkpoint("never-existed").await.unwrap();
}
#[tokio::test]
async fn fs_checkpoint_path_sanitizes_special_chars() {
let dir = tempfile::tempdir().unwrap();
let store = FsWorkflowStateStore::new(dir.path().to_path_buf()).unwrap();
let cp = WorkflowCheckpoint::new("proj/task.1 final", "a");
store.save_checkpoint(&cp).await.unwrap();
let loaded = store
.load_checkpoint("proj/task.1 final")
.await
.unwrap()
.unwrap();
assert_eq!(loaded.task_id, "proj/task.1 final");
let entries: Vec<_> = std::fs::read_dir(dir.path())
.unwrap()
.filter_map(|e| e.ok())
.collect();
assert_eq!(entries.len(), 1, "should be exactly one file, no subdirs");
}
}