use super::Checkpoint;
use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save(&self, checkpoint: Checkpoint) -> anyhow::Result<()>;
async fn load(&self, id: &str) -> anyhow::Result<Option<Checkpoint>>;
async fn load_latest(&self, run_id: &str) -> anyhow::Result<Option<Checkpoint>>;
async fn list(&self, run_id: &str) -> anyhow::Result<Vec<Checkpoint>>;
async fn delete(&self, id: &str) -> anyhow::Result<()>;
}
#[derive(Default)]
pub struct InMemoryCheckpointStore {
checkpoints: Arc<RwLock<HashMap<String, Checkpoint>>>,
}
impl InMemoryCheckpointStore {
pub fn new() -> Self {
Self::default()
}
}
#[async_trait]
impl CheckpointStore for InMemoryCheckpointStore {
async fn save(&self, checkpoint: Checkpoint) -> anyhow::Result<()> {
let mut store = self.checkpoints.write().unwrap();
store.insert(checkpoint.id.clone(), checkpoint);
Ok(())
}
async fn load(&self, id: &str) -> anyhow::Result<Option<Checkpoint>> {
let store = self.checkpoints.read().unwrap();
Ok(store.get(id).cloned())
}
async fn load_latest(&self, run_id: &str) -> anyhow::Result<Option<Checkpoint>> {
let store = self.checkpoints.read().unwrap();
let latest = store
.values()
.filter(|c| c.run_id.as_str() == run_id)
.max_by_key(|c| c.created_at);
Ok(latest.cloned())
}
async fn list(&self, run_id: &str) -> anyhow::Result<Vec<Checkpoint>> {
let store = self.checkpoints.read().unwrap();
let checkpoints: Vec<_> = store
.values()
.filter(|c| c.run_id.as_str() == run_id)
.cloned()
.collect();
Ok(checkpoints)
}
async fn delete(&self, id: &str) -> anyhow::Result<()> {
let mut store = self.checkpoints.write().unwrap();
store.remove(id);
Ok(())
}
}