use std::sync::Arc;
use async_trait::async_trait;
use synaptic_core::{Store, SynapticError};
use crate::checkpoint::{Checkpoint, CheckpointConfig, Checkpointer};
pub struct StoreCheckpointer {
store: Arc<dyn Store>,
}
impl StoreCheckpointer {
pub fn new(store: Arc<dyn Store>) -> Self {
Self { store }
}
}
#[async_trait]
impl Checkpointer for StoreCheckpointer {
async fn put(
&self,
config: &CheckpointConfig,
checkpoint: &Checkpoint,
) -> Result<(), SynapticError> {
let value = serde_json::to_value(checkpoint)
.map_err(|e| SynapticError::Graph(format!("failed to serialize checkpoint: {e}")))?;
self.store
.put(&["checkpoints", &config.thread_id], &checkpoint.id, value)
.await
}
async fn get(&self, config: &CheckpointConfig) -> Result<Option<Checkpoint>, SynapticError> {
if let Some(ref target_id) = config.checkpoint_id {
let item = self
.store
.get(&["checkpoints", &config.thread_id], target_id)
.await?;
return match item {
Some(item) => {
let checkpoint: Checkpoint =
serde_json::from_value(item.value).map_err(|e| {
SynapticError::Graph(format!("failed to deserialize checkpoint: {e}"))
})?;
Ok(Some(checkpoint))
}
None => Ok(None),
};
}
let items = self
.store
.search(&["checkpoints", &config.thread_id], None, 10_000)
.await?;
if items.is_empty() {
return Ok(None);
}
let latest = items.into_iter().max_by(|a, b| a.key.cmp(&b.key)).unwrap();
let checkpoint: Checkpoint = serde_json::from_value(latest.value)
.map_err(|e| SynapticError::Graph(format!("failed to deserialize checkpoint: {e}")))?;
Ok(Some(checkpoint))
}
async fn list(&self, config: &CheckpointConfig) -> Result<Vec<Checkpoint>, SynapticError> {
let items = self
.store
.search(&["checkpoints", &config.thread_id], None, 10_000)
.await?;
let mut checkpoints: Vec<Checkpoint> = items
.into_iter()
.map(|item| {
serde_json::from_value(item.value).map_err(|e| {
SynapticError::Graph(format!("failed to deserialize checkpoint: {e}"))
})
})
.collect::<Result<Vec<_>, _>>()?;
checkpoints.sort_by(|a, b| a.id.cmp(&b.id));
Ok(checkpoints)
}
}