Skip to main content

enact_core/graph/
checkpoint_store.rs

1//! Checkpoint store trait and implementations
2
3use super::Checkpoint;
4use async_trait::async_trait;
5use std::collections::HashMap;
6use std::sync::{Arc, RwLock};
7
8/// Checkpoint store trait
9#[async_trait]
10pub trait CheckpointStore: Send + Sync {
11    /// Save a checkpoint
12    async fn save(&self, checkpoint: Checkpoint) -> anyhow::Result<()>;
13
14    /// Load a checkpoint by ID
15    async fn load(&self, id: &str) -> anyhow::Result<Option<Checkpoint>>;
16
17    /// Load latest checkpoint for a run
18    async fn load_latest(&self, run_id: &str) -> anyhow::Result<Option<Checkpoint>>;
19
20    /// List checkpoints for a run
21    async fn list(&self, run_id: &str) -> anyhow::Result<Vec<Checkpoint>>;
22
23    /// Delete a checkpoint
24    async fn delete(&self, id: &str) -> anyhow::Result<()>;
25}
26
27/// In-memory checkpoint store (for testing/development)
28#[derive(Default)]
29pub struct InMemoryCheckpointStore {
30    checkpoints: Arc<RwLock<HashMap<String, Checkpoint>>>,
31}
32
33impl InMemoryCheckpointStore {
34    pub fn new() -> Self {
35        Self::default()
36    }
37}
38
39#[async_trait]
40impl CheckpointStore for InMemoryCheckpointStore {
41    async fn save(&self, checkpoint: Checkpoint) -> anyhow::Result<()> {
42        let mut store = self.checkpoints.write().unwrap();
43        store.insert(checkpoint.id.clone(), checkpoint);
44        Ok(())
45    }
46
47    async fn load(&self, id: &str) -> anyhow::Result<Option<Checkpoint>> {
48        let store = self.checkpoints.read().unwrap();
49        Ok(store.get(id).cloned())
50    }
51
52    async fn load_latest(&self, run_id: &str) -> anyhow::Result<Option<Checkpoint>> {
53        let store = self.checkpoints.read().unwrap();
54        let latest = store
55            .values()
56            .filter(|c| c.run_id.as_str() == run_id)
57            .max_by_key(|c| c.created_at);
58        Ok(latest.cloned())
59    }
60
61    async fn list(&self, run_id: &str) -> anyhow::Result<Vec<Checkpoint>> {
62        let store = self.checkpoints.read().unwrap();
63        let checkpoints: Vec<_> = store
64            .values()
65            .filter(|c| c.run_id.as_str() == run_id)
66            .cloned()
67            .collect();
68        Ok(checkpoints)
69    }
70
71    async fn delete(&self, id: &str) -> anyhow::Result<()> {
72        let mut store = self.checkpoints.write().unwrap();
73        store.remove(id);
74        Ok(())
75    }
76}