use std::collections::HashMap;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use crate::error::PersistError;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowCheckpoint {
pub workflow_name: String,
pub run_id: Uuid,
pub timestamp: DateTime<Utc>,
pub state: HashMap<String, serde_json::Value>,
pub pending_events: Vec<SerializedEvent>,
pub metadata: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SerializedEvent {
pub event_type: String,
pub data: serde_json::Value,
}
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save(&self, checkpoint: &WorkflowCheckpoint) -> Result<(), PersistError>;
async fn load(&self, run_id: &Uuid) -> Result<Option<WorkflowCheckpoint>, PersistError>;
async fn list(&self) -> Result<Vec<WorkflowCheckpoint>, PersistError>;
async fn delete(&self, run_id: &Uuid) -> Result<(), PersistError>;
}
#[cfg(feature = "redb")]
mod redb_backend {
use std::path::Path;
use async_trait::async_trait;
use redb::ReadableDatabase;
use redb::ReadableTable;
use uuid::Uuid;
use super::{CheckpointStore, WorkflowCheckpoint};
use crate::error::PersistError;
const CHECKPOINTS: redb::TableDefinition<&[u8], &[u8]> =
redb::TableDefinition::new("checkpoints");
pub struct RedbCheckpointStore {
db: redb::Database,
}
impl RedbCheckpointStore {
pub fn new(path: impl AsRef<Path>) -> Result<Self, PersistError> {
let db = redb::Database::create(path)?;
let write_txn = db.begin_write()?;
{
let _table = write_txn.open_table(CHECKPOINTS)?;
}
write_txn.commit()?;
Ok(Self { db })
}
pub fn in_memory() -> Result<Self, PersistError> {
let backend = redb::backends::InMemoryBackend::new();
let db = redb::Database::builder()
.create_with_backend(backend)
.map_err(|e| PersistError::Database(e.to_string()))?;
let write_txn = db.begin_write()?;
{
let _table = write_txn.open_table(CHECKPOINTS)?;
}
write_txn.commit()?;
Ok(Self { db })
}
}
impl std::fmt::Debug for RedbCheckpointStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedbCheckpointStore")
.finish_non_exhaustive()
}
}
#[async_trait]
impl CheckpointStore for RedbCheckpointStore {
async fn save(&self, checkpoint: &WorkflowCheckpoint) -> Result<(), PersistError> {
let key = checkpoint.run_id.as_bytes().to_vec();
let value = rmp_serde::to_vec_named(checkpoint)?;
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(CHECKPOINTS)?;
table.insert(key.as_slice(), value.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
async fn load(&self, run_id: &Uuid) -> Result<Option<WorkflowCheckpoint>, PersistError> {
let key = run_id.as_bytes().to_vec();
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(CHECKPOINTS)?;
match table.get(key.as_slice())? {
Some(guard) => {
let bytes: &[u8] = guard.value();
let checkpoint: WorkflowCheckpoint = rmp_serde::from_slice(bytes)
.or_else(|_| serde_json::from_slice(bytes).map_err(PersistError::from))?;
Ok(Some(checkpoint))
}
None => Ok(None),
}
}
async fn list(&self) -> Result<Vec<WorkflowCheckpoint>, PersistError> {
let read_txn = self.db.begin_read()?;
let table = read_txn.open_table(CHECKPOINTS)?;
let mut checkpoints = Vec::new();
let iter = table.iter()?;
for entry in iter {
let (_key_guard, value_guard) = entry?;
let bytes: &[u8] = value_guard.value();
let checkpoint: WorkflowCheckpoint = rmp_serde::from_slice(bytes)
.or_else(|_| serde_json::from_slice(bytes).map_err(PersistError::from))?;
checkpoints.push(checkpoint);
}
checkpoints.sort_by_key(|c| std::cmp::Reverse(c.timestamp));
Ok(checkpoints)
}
async fn delete(&self, run_id: &Uuid) -> Result<(), PersistError> {
let key = run_id.as_bytes().to_vec();
let write_txn = self.db.begin_write()?;
{
let mut table = write_txn.open_table(CHECKPOINTS)?;
table.remove(key.as_slice())?;
}
write_txn.commit()?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use chrono::Utc;
use super::*;
use crate::checkpoint::SerializedEvent;
fn sample_checkpoint(name: &str) -> WorkflowCheckpoint {
WorkflowCheckpoint {
workflow_name: name.to_owned(),
run_id: Uuid::new_v4(),
timestamp: Utc::now(),
state: {
let mut m = std::collections::HashMap::new();
m.insert("counter".to_owned(), serde_json::json!(42));
m
},
pending_events: vec![SerializedEvent {
event_type: "blazen::StartEvent".to_owned(),
data: serde_json::json!({"input": "hello"}),
}],
metadata: std::collections::HashMap::new(),
}
}
#[tokio::test]
async fn save_and_load() {
let store = RedbCheckpointStore::in_memory().unwrap();
let cp = sample_checkpoint("test_workflow");
let run_id = cp.run_id;
store.save(&cp).await.unwrap();
let loaded = store.load(&run_id).await.unwrap().unwrap();
assert_eq!(loaded.workflow_name, "test_workflow");
assert_eq!(loaded.run_id, run_id);
assert_eq!(loaded.state["counter"], serde_json::json!(42));
}
#[tokio::test]
async fn load_missing_returns_none() {
let store = RedbCheckpointStore::in_memory().unwrap();
let result = store.load(&Uuid::new_v4()).await.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn list_returns_all_sorted() {
let store = RedbCheckpointStore::in_memory().unwrap();
let mut cp1 = sample_checkpoint("wf_a");
cp1.timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-01T00:00:00Z")
.unwrap()
.with_timezone(&Utc);
let mut cp2 = sample_checkpoint("wf_b");
cp2.timestamp = chrono::DateTime::parse_from_rfc3339("2025-06-15T12:00:00Z")
.unwrap()
.with_timezone(&Utc);
store.save(&cp1).await.unwrap();
store.save(&cp2).await.unwrap();
let list = store.list().await.unwrap();
assert_eq!(list.len(), 2);
assert_eq!(list[0].workflow_name, "wf_b");
assert_eq!(list[1].workflow_name, "wf_a");
}
#[tokio::test]
async fn delete_removes_checkpoint() {
let store = RedbCheckpointStore::in_memory().unwrap();
let cp = sample_checkpoint("delete_me");
let run_id = cp.run_id;
store.save(&cp).await.unwrap();
assert!(store.load(&run_id).await.unwrap().is_some());
store.delete(&run_id).await.unwrap();
assert!(store.load(&run_id).await.unwrap().is_none());
}
#[tokio::test]
async fn save_overwrites_existing() {
let store = RedbCheckpointStore::in_memory().unwrap();
let mut cp = sample_checkpoint("overwrite");
let run_id = cp.run_id;
store.save(&cp).await.unwrap();
cp.state.insert("counter".to_owned(), serde_json::json!(99));
store.save(&cp).await.unwrap();
let loaded = store.load(&run_id).await.unwrap().unwrap();
assert_eq!(loaded.state["counter"], serde_json::json!(99));
}
}
}
#[cfg(feature = "redb")]
pub use redb_backend::RedbCheckpointStore;