use agents_core::persistence::{Checkpointer, ThreadId};
use agents_core::state::AgentStateSnapshot;
use anyhow::Context;
use async_trait::async_trait;
use redis::{aio::ConnectionManager, AsyncCommands};
use std::time::Duration;
#[derive(Clone)]
pub struct RedisCheckpointer {
connection: ConnectionManager,
namespace: String,
ttl: Option<Duration>,
}
impl RedisCheckpointer {
pub async fn new(url: &str) -> anyhow::Result<Self> {
Self::builder().url(url).build().await
}
pub fn builder() -> RedisCheckpointerBuilder {
RedisCheckpointerBuilder::default()
}
fn key_for_thread(&self, thread_id: &ThreadId) -> String {
format!("{}:thread:{}", self.namespace, thread_id)
}
fn threads_index_key(&self) -> String {
format!("{}:threads", self.namespace)
}
}
#[async_trait]
impl Checkpointer for RedisCheckpointer {
async fn save_state(
&self,
thread_id: &ThreadId,
state: &AgentStateSnapshot,
) -> anyhow::Result<()> {
let key = self.key_for_thread(thread_id);
let index_key = self.threads_index_key();
let json =
serde_json::to_string(state).context("Failed to serialize agent state to JSON")?;
let mut conn = self.connection.clone();
if let Some(ttl) = self.ttl {
conn.set_ex::<_, _, ()>(&key, json, ttl.as_secs())
.await
.context("Failed to save state to Redis with TTL")?;
} else {
conn.set::<_, _, ()>(&key, json)
.await
.context("Failed to save state to Redis")?;
}
conn.sadd::<_, _, ()>(&index_key, thread_id)
.await
.context("Failed to update thread index")?;
tracing::debug!(
thread_id = %thread_id,
namespace = %self.namespace,
"Saved agent state to Redis"
);
Ok(())
}
async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
let key = self.key_for_thread(thread_id);
let mut conn = self.connection.clone();
let json: Option<String> = conn
.get(&key)
.await
.context("Failed to load state from Redis")?;
match json {
Some(data) => {
let state: AgentStateSnapshot = serde_json::from_str(&data)
.context("Failed to deserialize agent state from JSON")?;
tracing::debug!(
thread_id = %thread_id,
namespace = %self.namespace,
"Loaded agent state from Redis"
);
Ok(Some(state))
}
None => {
tracing::debug!(
thread_id = %thread_id,
namespace = %self.namespace,
"No saved state found in Redis"
);
Ok(None)
}
}
}
async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
let key = self.key_for_thread(thread_id);
let index_key = self.threads_index_key();
let mut conn = self.connection.clone();
conn.del::<_, ()>(&key)
.await
.context("Failed to delete state from Redis")?;
conn.srem::<_, _, ()>(&index_key, thread_id)
.await
.context("Failed to update thread index")?;
tracing::debug!(
thread_id = %thread_id,
namespace = %self.namespace,
"Deleted thread from Redis"
);
Ok(())
}
async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
let index_key = self.threads_index_key();
let mut conn = self.connection.clone();
let threads: Vec<String> = conn
.smembers(&index_key)
.await
.context("Failed to list threads from Redis")?;
Ok(threads)
}
}
#[derive(Default)]
pub struct RedisCheckpointerBuilder {
url: Option<String>,
namespace: Option<String>,
ttl: Option<Duration>,
}
impl RedisCheckpointerBuilder {
pub fn url(mut self, url: impl Into<String>) -> Self {
self.url = Some(url.into());
self
}
pub fn namespace(mut self, namespace: impl Into<String>) -> Self {
self.namespace = Some(namespace.into());
self
}
pub fn ttl(mut self, ttl: Duration) -> Self {
self.ttl = Some(ttl);
self
}
pub async fn build(self) -> anyhow::Result<RedisCheckpointer> {
let url = self
.url
.ok_or_else(|| anyhow::anyhow!("Redis URL is required"))?;
let client = redis::Client::open(url.as_str()).context("Failed to create Redis client")?;
let connection = ConnectionManager::new(client)
.await
.context("Failed to establish Redis connection")?;
Ok(RedisCheckpointer {
connection,
namespace: self.namespace.unwrap_or_else(|| "agents".to_string()),
ttl: self.ttl,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use agents_core::state::TodoItem;
fn sample_state() -> AgentStateSnapshot {
let mut state = AgentStateSnapshot::default();
state.todos.push(TodoItem::pending("Test todo"));
state
.files
.insert("test.txt".to_string(), "content".to_string());
state
.scratchpad
.insert("key".to_string(), serde_json::json!("value"));
state
}
#[tokio::test]
#[ignore] async fn test_redis_save_and_load() {
let checkpointer = RedisCheckpointer::new("redis://127.0.0.1:6379")
.await
.expect("Failed to connect to Redis");
let thread_id = "test-thread".to_string();
let state = sample_state();
checkpointer
.save_state(&thread_id, &state)
.await
.expect("Failed to save state");
let loaded = checkpointer
.load_state(&thread_id)
.await
.expect("Failed to load state");
assert!(loaded.is_some());
let loaded_state = loaded.unwrap();
assert_eq!(loaded_state.todos.len(), 1);
assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
checkpointer
.delete_thread(&thread_id)
.await
.expect("Failed to delete thread");
}
#[tokio::test]
#[ignore] async fn test_redis_list_threads() {
let checkpointer = RedisCheckpointer::builder()
.url("redis://127.0.0.1:6379")
.namespace("test-namespace")
.build()
.await
.expect("Failed to connect to Redis");
let state = sample_state();
checkpointer
.save_state(&"thread1".to_string(), &state)
.await
.unwrap();
checkpointer
.save_state(&"thread2".to_string(), &state)
.await
.unwrap();
let threads = checkpointer.list_threads().await.unwrap();
assert!(threads.contains(&"thread1".to_string()));
assert!(threads.contains(&"thread2".to_string()));
checkpointer
.delete_thread(&"thread1".to_string())
.await
.unwrap();
checkpointer
.delete_thread(&"thread2".to_string())
.await
.unwrap();
}
}