use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::Duration;
#[async_trait]
pub trait StateStoreJsonExt: StateStore {
async fn set_json<T: Serialize + Send + Sync>(
&self,
key: &str,
value: &T,
ttl: Option<Duration>,
) -> anyhow::Result<()> {
let bytes = serde_json::to_vec(value)?;
self.set(key, &bytes, ttl).await
}
async fn get_json<T: for<'de> Deserialize<'de>>(&self, key: &str) -> anyhow::Result<Option<T>> {
match self.get(key).await? {
Some(bytes) => Ok(Some(serde_json::from_slice(&bytes)?)),
None => Ok(None),
}
}
}
impl<S: StateStore + ?Sized> StateStoreJsonExt for S {}
use crate::kernel::{ExecutionId, ExecutionState, StepId, TenantId, UserId};
use super::StorageBackend;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionSnapshot {
pub execution_id: ExecutionId,
pub tenant_id: TenantId,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub user_id: Option<UserId>,
pub state: ExecutionState,
pub current_step_id: Option<StepId>,
pub step_outputs: HashMap<StepId, serde_json::Value>,
pub variables: HashMap<String, serde_json::Value>,
pub timestamp: DateTime<Utc>,
pub sequence_number: u64,
}
impl ExecutionSnapshot {
pub fn new(
execution_id: ExecutionId,
tenant_id: TenantId,
state: ExecutionState,
sequence_number: u64,
) -> Self {
Self {
execution_id,
tenant_id,
user_id: None,
state,
current_step_id: None,
step_outputs: HashMap::new(),
variables: HashMap::new(),
timestamp: Utc::now(),
sequence_number,
}
}
pub fn with_user(
execution_id: ExecutionId,
tenant_id: TenantId,
user_id: Option<UserId>,
state: ExecutionState,
sequence_number: u64,
) -> Self {
Self {
execution_id,
tenant_id,
user_id,
state,
current_step_id: None,
step_outputs: HashMap::new(),
variables: HashMap::new(),
timestamp: Utc::now(),
sequence_number,
}
}
pub fn is_fresh(&self, event_store_sequence: u64) -> bool {
self.sequence_number >= event_store_sequence
}
}
#[async_trait]
pub trait StateStore: StorageBackend {
async fn save_snapshot(&self, snapshot: ExecutionSnapshot) -> anyhow::Result<()>;
async fn load_snapshot(
&self,
execution_id: &ExecutionId,
) -> anyhow::Result<Option<ExecutionSnapshot>>;
async fn delete_snapshot(&self, execution_id: &ExecutionId) -> anyhow::Result<()>;
async fn set(&self, key: &str, value: &[u8], ttl: Option<Duration>) -> anyhow::Result<()>;
async fn get(&self, key: &str) -> anyhow::Result<Option<Vec<u8>>>;
async fn delete(&self, key: &str) -> anyhow::Result<()>;
async fn exists(&self, key: &str) -> anyhow::Result<bool> {
Ok(self.get(key).await?.is_some())
}
async fn delete_execution_state(&self, execution_id: &ExecutionId) -> anyhow::Result<()> {
self.delete_snapshot(execution_id).await
}
async fn list_snapshots(
&self,
tenant_id: &TenantId,
limit: usize,
) -> anyhow::Result<Vec<ExecutionId>>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snapshot_freshness() {
let snapshot = ExecutionSnapshot::new(
ExecutionId::new(),
TenantId::from("test"),
ExecutionState::Running,
10,
);
assert!(snapshot.is_fresh(10));
assert!(snapshot.is_fresh(5));
assert!(!snapshot.is_fresh(15));
}
#[test]
fn test_snapshot_serialization() {
let mut snapshot = ExecutionSnapshot::new(
ExecutionId::new(),
TenantId::from("test"),
ExecutionState::Running,
5,
);
snapshot
.variables
.insert("foo".to_string(), serde_json::json!("bar"));
let json = serde_json::to_string(&snapshot).unwrap();
let parsed: ExecutionSnapshot = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.state, ExecutionState::Running);
assert_eq!(parsed.variables.get("foo").unwrap(), "bar");
}
}