Skip to main content

bob_adapters/
checkpoint_memory.rs

1//! In-memory turn checkpoint store adapter.
2
3use bob_core::{
4    error::StoreError,
5    ports::TurnCheckpointStorePort,
6    types::{SessionId, TurnCheckpoint},
7};
8
9/// In-memory checkpoint store keyed by session id.
10#[derive(Debug, Default)]
11pub struct InMemoryCheckpointStore {
12    inner: scc::HashMap<SessionId, TurnCheckpoint>,
13}
14
15impl InMemoryCheckpointStore {
16    #[must_use]
17    pub fn new() -> Self {
18        Self::default()
19    }
20}
21
22#[async_trait::async_trait]
23impl TurnCheckpointStorePort for InMemoryCheckpointStore {
24    async fn save_checkpoint(&self, checkpoint: &TurnCheckpoint) -> Result<(), StoreError> {
25        let entry = self.inner.entry_async(checkpoint.session_id.clone()).await;
26        match entry {
27            scc::hash_map::Entry::Occupied(mut occ) => {
28                occ.get_mut().clone_from(checkpoint);
29            }
30            scc::hash_map::Entry::Vacant(vac) => {
31                let _ = vac.insert_entry(checkpoint.clone());
32            }
33        }
34        Ok(())
35    }
36
37    async fn load_latest(
38        &self,
39        session_id: &SessionId,
40    ) -> Result<Option<TurnCheckpoint>, StoreError> {
41        Ok(self.inner.read_async(session_id, |_k, v| v.clone()).await)
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use bob_core::types::TokenUsage;
48
49    use super::*;
50
51    #[tokio::test]
52    async fn roundtrip_checkpoint() {
53        let store = InMemoryCheckpointStore::new();
54        let checkpoint = TurnCheckpoint {
55            session_id: "s1".to_string(),
56            step: 2,
57            tool_calls: 1,
58            usage: TokenUsage { prompt_tokens: 10, completion_tokens: 5 },
59        };
60        let saved = store.save_checkpoint(&checkpoint).await;
61        assert!(saved.is_ok());
62
63        let loaded = store.load_latest(&"s1".to_string()).await;
64        assert!(loaded.is_ok());
65        assert_eq!(loaded.ok().flatten().map(|c| c.step), Some(2));
66    }
67}