use crate::state::AgentStateSnapshot;
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub type ThreadId = String;
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct CheckpointerConfig {
pub params: HashMap<String, serde_json::Value>,
}
#[async_trait]
pub trait Checkpointer: Send + Sync {
async fn save_state(
&self,
thread_id: &ThreadId,
state: &AgentStateSnapshot,
) -> anyhow::Result<()>;
async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>>;
async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()>;
async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>>;
}
#[derive(Debug, Default)]
pub struct InMemoryCheckpointer {
states: std::sync::RwLock<HashMap<ThreadId, AgentStateSnapshot>>,
}
impl InMemoryCheckpointer {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl Checkpointer for InMemoryCheckpointer {
async fn save_state(
&self,
thread_id: &ThreadId,
state: &AgentStateSnapshot,
) -> anyhow::Result<()> {
let mut states = self.states.write().map_err(|_| {
anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
})?;
states.insert(thread_id.clone(), state.clone());
tracing::debug!(thread_id = %thread_id, "Saved agent state to memory");
Ok(())
}
async fn load_state(&self, thread_id: &ThreadId) -> anyhow::Result<Option<AgentStateSnapshot>> {
let states = self.states.read().map_err(|_| {
anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
})?;
let state = states.get(thread_id).cloned();
if state.is_some() {
tracing::debug!(thread_id = %thread_id, "Loaded agent state from memory");
}
Ok(state)
}
async fn delete_thread(&self, thread_id: &ThreadId) -> anyhow::Result<()> {
let mut states = self.states.write().map_err(|_| {
anyhow::anyhow!("Failed to acquire write lock on in-memory checkpointer")
})?;
states.remove(thread_id);
tracing::debug!(thread_id = %thread_id, "Deleted thread from memory");
Ok(())
}
async fn list_threads(&self) -> anyhow::Result<Vec<ThreadId>> {
let states = self.states.read().map_err(|_| {
anyhow::anyhow!("Failed to acquire read lock on in-memory checkpointer")
})?;
Ok(states.keys().cloned().collect())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::state::{TodoItem, TodoStatus};
fn sample_state() -> AgentStateSnapshot {
let mut state = AgentStateSnapshot::default();
state.todos.push(TodoItem {
content: "Test todo".to_string(),
status: TodoStatus::Pending,
});
state
.files
.insert("test.txt".to_string(), "content".to_string());
state
.scratchpad
.insert("key".to_string(), serde_json::json!("value"));
state
}
#[tokio::test]
async fn in_memory_checkpointer_save_and_load() {
let checkpointer = InMemoryCheckpointer::new();
let thread_id = "test-thread".to_string();
let state = sample_state();
checkpointer.save_state(&thread_id, &state).await.unwrap();
let loaded = checkpointer.load_state(&thread_id).await.unwrap();
assert!(loaded.is_some());
let loaded_state = loaded.unwrap();
assert_eq!(loaded_state.todos.len(), 1);
assert_eq!(loaded_state.todos[0].content, "Test todo");
assert_eq!(loaded_state.files.get("test.txt").unwrap(), "content");
assert_eq!(
loaded_state.scratchpad.get("key").unwrap(),
&serde_json::json!("value")
);
}
#[tokio::test]
async fn in_memory_checkpointer_nonexistent_thread() {
let checkpointer = InMemoryCheckpointer::new();
let result = checkpointer
.load_state(&"nonexistent".to_string())
.await
.unwrap();
assert!(result.is_none());
}
#[tokio::test]
async fn in_memory_checkpointer_delete_thread() {
let checkpointer = InMemoryCheckpointer::new();
let thread_id = "test-thread".to_string();
let state = sample_state();
checkpointer.save_state(&thread_id, &state).await.unwrap();
assert!(checkpointer.load_state(&thread_id).await.unwrap().is_some());
checkpointer.delete_thread(&thread_id).await.unwrap();
assert!(checkpointer.load_state(&thread_id).await.unwrap().is_none());
}
#[tokio::test]
async fn in_memory_checkpointer_list_threads() {
let checkpointer = InMemoryCheckpointer::new();
let state = sample_state();
checkpointer
.save_state(&"thread1".to_string(), &state)
.await
.unwrap();
checkpointer
.save_state(&"thread2".to_string(), &state)
.await
.unwrap();
let threads = checkpointer.list_threads().await.unwrap();
assert_eq!(threads.len(), 2);
assert!(threads.contains(&"thread1".to_string()));
assert!(threads.contains(&"thread2".to_string()));
}
}