mod types;
pub use types::{Checkpoint, CheckpointMetadata, PendingWrite};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use crate::harness::ids::CheckpointId;
use crate::{Result, RustAgentsError};
#[async_trait]
pub trait Checkpointer<State>: Send + Sync
where
State: Send + Sync + 'static,
{
async fn put(&self, checkpoint: Checkpoint<State>) -> Result<CheckpointId>;
async fn get(
&self,
thread_id: &str,
checkpoint_id: Option<&str>,
) -> Result<Option<Checkpoint<State>>>;
async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>>;
}
pub struct InMemoryCheckpointer<State> {
inner: Arc<Mutex<HashMap<String, Vec<Checkpoint<State>>>>>,
}
impl<State> InMemoryCheckpointer<State> {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn count(&self, thread_id: &str) -> usize {
self.inner
.lock()
.map(|m| m.get(thread_id).map(|v| v.len()).unwrap_or(0))
.unwrap_or(0)
}
}
impl<State> Default for InMemoryCheckpointer<State> {
fn default() -> Self {
Self::new()
}
}
impl<State> Clone for InMemoryCheckpointer<State> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
fn lock_err() -> RustAgentsError {
RustAgentsError::Checkpoint("in-memory checkpointer lock poisoned".to_string())
}
fn metadata_of<State>(c: &Checkpoint<State>) -> CheckpointMetadata {
let source = c
.metadata
.get("source")
.and_then(|v| v.as_str())
.unwrap_or("loop")
.to_string();
let step = c.metadata.get("step").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
CheckpointMetadata {
thread_id: c.thread_id.clone(),
checkpoint_id: c.checkpoint_id.clone(),
parent_checkpoint_id: c.parent_checkpoint_id.clone(),
namespace: c.namespace.clone(),
next_nodes: c.next_nodes.clone(),
has_interrupts: !c.interrupts.is_empty(),
source,
step,
}
}
#[async_trait]
impl<State> Checkpointer<State> for InMemoryCheckpointer<State>
where
State: Clone + Send + Sync + 'static,
{
async fn put(&self, checkpoint: Checkpoint<State>) -> Result<CheckpointId> {
let id = CheckpointId::new(checkpoint.checkpoint_id.clone());
let mut map = self.inner.lock().map_err(|_| lock_err())?;
map.entry(checkpoint.thread_id.clone())
.or_default()
.push(checkpoint);
Ok(id)
}
async fn get(
&self,
thread_id: &str,
checkpoint_id: Option<&str>,
) -> Result<Option<Checkpoint<State>>> {
let map = self.inner.lock().map_err(|_| lock_err())?;
let Some(list) = map.get(thread_id) else {
return Ok(None);
};
let found = match checkpoint_id {
Some(id) => list.iter().find(|c| c.checkpoint_id == id),
None => list.last(),
};
Ok(found.cloned())
}
async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>> {
let map = self.inner.lock().map_err(|_| lock_err())?;
Ok(map
.get(thread_id)
.map(|list| list.iter().map(metadata_of).collect())
.unwrap_or_default())
}
}
#[cfg(test)]
mod test;