use async_trait::async_trait;
use redis::AsyncCommands;
use uuid::Uuid;
use crate::checkpoint::{CheckpointStore, WorkflowCheckpoint};
use crate::error::PersistError;
const KEY_PREFIX: &str = "blazen:checkpoint:";
#[derive(Clone)]
pub struct ValkeyCheckpointStore {
conn: redis::aio::ConnectionManager,
ttl_seconds: Option<u64>,
}
impl ValkeyCheckpointStore {
pub async fn new(url: &str) -> Result<Self, PersistError> {
let client = redis::Client::open(url)
.map_err(|e| PersistError::Redis(format!("failed to parse URL: {e}")))?;
let conn = redis::aio::ConnectionManager::new(client).await?;
Ok(Self {
conn,
ttl_seconds: None,
})
}
pub async fn with_ttl(url: &str, ttl_seconds: u64) -> Result<Self, PersistError> {
let client = redis::Client::open(url)
.map_err(|e| PersistError::Redis(format!("failed to parse URL: {e}")))?;
let conn = redis::aio::ConnectionManager::new(client).await?;
Ok(Self {
conn,
ttl_seconds: Some(ttl_seconds),
})
}
fn key(run_id: &Uuid) -> String {
format!("{KEY_PREFIX}{run_id}")
}
}
impl std::fmt::Debug for ValkeyCheckpointStore {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ValkeyCheckpointStore")
.field("ttl_seconds", &self.ttl_seconds)
.finish_non_exhaustive()
}
}
#[async_trait]
impl CheckpointStore for ValkeyCheckpointStore {
async fn save(&self, checkpoint: &WorkflowCheckpoint) -> Result<(), PersistError> {
let key = Self::key(&checkpoint.run_id);
let value: Vec<u8> = rmp_serde::to_vec_named(checkpoint)?;
let mut conn = self.conn.clone();
if let Some(ttl) = self.ttl_seconds {
let () = conn.set_ex(&key, &value, ttl).await?;
} else {
let () = conn.set(&key, &value).await?;
}
Ok(())
}
async fn load(&self, run_id: &Uuid) -> Result<Option<WorkflowCheckpoint>, PersistError> {
let key = Self::key(run_id);
let mut conn = self.conn.clone();
let value: Option<Vec<u8>> = conn.get(&key).await?;
match value {
Some(bytes) => {
let checkpoint: WorkflowCheckpoint = rmp_serde::from_slice(&bytes)
.or_else(|_| serde_json::from_slice(&bytes).map_err(PersistError::from))?;
Ok(Some(checkpoint))
}
None => Ok(None),
}
}
async fn list(&self) -> Result<Vec<WorkflowCheckpoint>, PersistError> {
let mut conn = self.conn.clone();
let pattern = format!("{KEY_PREFIX}*");
let mut cursor: u64 = 0;
let mut all_keys: Vec<String> = Vec::new();
loop {
let result: (u64, Vec<String>) = redis::cmd("SCAN")
.arg(cursor)
.arg("MATCH")
.arg(&pattern)
.arg("COUNT")
.arg(100)
.query_async(&mut conn)
.await?;
let (next_cursor, keys) = result;
all_keys.extend(keys);
cursor = next_cursor;
if cursor == 0 {
break;
}
}
let mut checkpoints = Vec::with_capacity(all_keys.len());
for key in &all_keys {
let value: Option<Vec<u8>> = conn.get(key).await?;
if let Some(bytes) = value {
match rmp_serde::from_slice::<WorkflowCheckpoint>(&bytes)
.or_else(|_| serde_json::from_slice(&bytes))
{
Ok(cp) => checkpoints.push(cp),
Err(e) => {
tracing::warn!(
key = %key,
error = %e,
"skipping malformed checkpoint entry"
);
}
}
}
}
checkpoints.sort_by_key(|c| std::cmp::Reverse(c.timestamp));
Ok(checkpoints)
}
async fn delete(&self, run_id: &Uuid) -> Result<(), PersistError> {
let key = Self::key(run_id);
let mut conn = self.conn.clone();
let () = conn.del(&key).await?;
Ok(())
}
}