mod file;
#[cfg(feature = "sqlite")]
mod sqlite;
mod types;
pub use file::FileCheckpointer;
#[cfg(feature = "sqlite")]
pub use sqlite::SqliteCheckpointer;
pub use types::{
Checkpoint, CheckpointConfig, CheckpointMetadata, CheckpointSource, CheckpointTuple,
DurabilityMode, PendingWrite,
};
use std::collections::{HashMap, HashSet};
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use crate::harness::ids::CheckpointId;
use crate::{Result, TinyAgentsError};
#[async_trait]
pub trait Checkpointer<State>: Send + Sync
where
State: Send + Sync + 'static,
{
async fn put(&self, checkpoint: Checkpoint<State>) -> Result<CheckpointId>;
async fn get(
&self,
thread_id: &str,
checkpoint_id: Option<&str>,
) -> Result<Option<Checkpoint<State>>>;
async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>>;
async fn get_tuple(&self, config: CheckpointConfig) -> Result<Option<CheckpointTuple<State>>> {
let Some(checkpoint) = self
.get(&config.thread_id, config.checkpoint_id.as_deref())
.await?
else {
return Ok(None);
};
let resolved = CheckpointConfig {
thread_id: checkpoint.thread_id.clone(),
checkpoint_id: Some(checkpoint.checkpoint_id.clone()),
namespace: checkpoint.namespace.clone(),
};
let parent_config =
checkpoint
.parent_checkpoint_id
.as_ref()
.map(|parent| CheckpointConfig {
thread_id: checkpoint.thread_id.clone(),
checkpoint_id: Some(parent.clone()),
namespace: checkpoint.namespace.clone(),
});
let pending_writes = checkpoint.pending_writes.clone();
Ok(Some(CheckpointTuple {
config: resolved,
checkpoint,
parent_config,
pending_writes,
}))
}
async fn list_threads(&self) -> Result<Vec<String>>;
async fn delete_thread(&self, thread_id: &str) -> Result<()>;
async fn delete_checkpoints(&self, thread_id: &str, ids: &[String]) -> Result<usize>;
async fn delete_by_run(&self, thread_id: &str, run_id: &str) -> Result<usize> {
let ids: Vec<String> = self
.list(thread_id)
.await?
.into_iter()
.filter(|m| m.run_id.as_deref() == Some(run_id))
.map(|m| m.checkpoint_id)
.collect();
self.delete_checkpoints(thread_id, &ids).await
}
async fn copy_thread(&self, source_thread: &str, target_thread: &str) -> Result<()> {
let metas = self.list(source_thread).await?;
for meta in metas {
let Some(mut checkpoint) = self.get(source_thread, Some(&meta.checkpoint_id)).await?
else {
continue;
};
checkpoint.thread_id = target_thread.to_string();
self.put(checkpoint).await?;
}
Ok(())
}
async fn prune(&self, thread_id: &str, keep_last: usize) -> Result<usize> {
let metas = self.list(thread_id).await?;
if metas.is_empty() {
return Ok(0);
}
let keep_last = keep_last.max(1).min(metas.len());
let mut parent_of: HashMap<&str, Option<&str>> = HashMap::new();
for m in &metas {
parent_of.insert(m.checkpoint_id.as_str(), m.parent_checkpoint_id.as_deref());
}
let mut protected: HashSet<String> = HashSet::new();
for m in metas.iter().rev().take(keep_last) {
protected.insert(m.checkpoint_id.clone());
}
let window: Vec<String> = protected.iter().cloned().collect();
for start in window {
let mut cursor = parent_of.get(start.as_str()).copied().flatten();
while let Some(parent) = cursor {
if !protected.insert(parent.to_string()) {
break; }
cursor = parent_of.get(parent).copied().flatten();
}
}
let to_delete: Vec<String> = metas
.iter()
.map(|m| m.checkpoint_id.clone())
.filter(|id| !protected.contains(id))
.collect();
self.delete_checkpoints(thread_id, &to_delete).await
}
}
pub struct InMemoryCheckpointer<State> {
inner: Arc<Mutex<HashMap<String, Vec<Checkpoint<State>>>>>,
}
impl<State> InMemoryCheckpointer<State> {
pub fn new() -> Self {
Self {
inner: Arc::new(Mutex::new(HashMap::new())),
}
}
pub fn count(&self, thread_id: &str) -> usize {
self.inner
.lock()
.map(|m| m.get(thread_id).map(|v| v.len()).unwrap_or(0))
.unwrap_or(0)
}
}
impl<State> Default for InMemoryCheckpointer<State> {
fn default() -> Self {
Self::new()
}
}
impl<State> Clone for InMemoryCheckpointer<State> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
fn lock_err() -> TinyAgentsError {
TinyAgentsError::Checkpoint("in-memory checkpointer lock poisoned".to_string())
}
fn metadata_of<State>(c: &Checkpoint<State>) -> CheckpointMetadata {
c.to_metadata()
}
#[async_trait]
impl<State> Checkpointer<State> for InMemoryCheckpointer<State>
where
State: Clone + Send + Sync + 'static,
{
async fn put(&self, checkpoint: Checkpoint<State>) -> Result<CheckpointId> {
let id = CheckpointId::new(checkpoint.checkpoint_id.clone());
let mut map = self.inner.lock().map_err(|_| lock_err())?;
map.entry(checkpoint.thread_id.clone())
.or_default()
.push(checkpoint);
Ok(id)
}
async fn get(
&self,
thread_id: &str,
checkpoint_id: Option<&str>,
) -> Result<Option<Checkpoint<State>>> {
let map = self.inner.lock().map_err(|_| lock_err())?;
let Some(list) = map.get(thread_id) else {
return Ok(None);
};
let found = match checkpoint_id {
Some(id) => list.iter().find(|c| c.checkpoint_id == id),
None => list.last(),
};
Ok(found.cloned())
}
async fn list(&self, thread_id: &str) -> Result<Vec<CheckpointMetadata>> {
let map = self.inner.lock().map_err(|_| lock_err())?;
Ok(map
.get(thread_id)
.map(|list| list.iter().map(metadata_of).collect())
.unwrap_or_default())
}
async fn list_threads(&self) -> Result<Vec<String>> {
let map = self.inner.lock().map_err(|_| lock_err())?;
Ok(map
.iter()
.filter(|(_, list)| !list.is_empty())
.map(|(thread, _)| thread.clone())
.collect())
}
async fn delete_thread(&self, thread_id: &str) -> Result<()> {
let mut map = self.inner.lock().map_err(|_| lock_err())?;
map.remove(thread_id);
Ok(())
}
async fn delete_checkpoints(&self, thread_id: &str, ids: &[String]) -> Result<usize> {
if ids.is_empty() {
return Ok(0);
}
let drop: HashSet<&str> = ids.iter().map(String::as_str).collect();
let mut map = self.inner.lock().map_err(|_| lock_err())?;
let Some(list) = map.get_mut(thread_id) else {
return Ok(0);
};
let before = list.len();
list.retain(|c| !drop.contains(c.checkpoint_id.as_str()));
Ok(before - list.len())
}
}
#[cfg(test)]
mod test;