Skip to main content

bob_adapters/
store_memory.rs

1//! # In-Memory Session Store
2//!
3//! In-memory session store — implements [`SessionStore`] via `scc::HashMap`.
4//!
5//! ## Overview
6//!
7//! This adapter provides a thread-safe, in-memory session store backed by
8//! [`scc::HashMap`](https://docs.rs/scc/latest/scc/struct.HashMap.html).
9//!
10//! Suitable for:
11//! - Development and testing
12//! - Single-process CLI applications
13//! - Scenarios where persistence across restarts is not required
14//!
15//! Not suitable for:
16//! - Multi-process deployments
17//! - Production environments requiring persistence
18//! - Horizontal scaling
19//!
20//! ## CAS Support
21//!
22//! The `save_if_version` method performs an atomic compare-and-swap: the
23//! session is only persisted when the stored version matches the expected
24//! version. On success the version is incremented atomically.
25
26use bob_core::{
27    error::StoreError,
28    ports::SessionStore,
29    types::{SessionId, SessionState},
30};
31
32/// Thread-safe, in-memory session store backed by [`scc::HashMap`].
33///
34/// Suitable for single-process / CLI usage where persistence across
35/// restarts is not required.
36#[derive(Debug)]
37pub struct InMemorySessionStore {
38    inner: scc::HashMap<SessionId, SessionState>,
39}
40
41impl Default for InMemorySessionStore {
42    fn default() -> Self {
43        Self::new()
44    }
45}
46
47impl InMemorySessionStore {
48    /// Create an empty store.
49    #[must_use]
50    pub fn new() -> Self {
51        Self { inner: scc::HashMap::new() }
52    }
53}
54
55#[async_trait::async_trait]
56impl SessionStore for InMemorySessionStore {
57    async fn load(&self, id: &SessionId) -> Result<Option<SessionState>, StoreError> {
58        let state = self.inner.read_async(id, |_k, v| v.clone()).await;
59        Ok(state)
60    }
61
62    async fn save(&self, id: &SessionId, state: &SessionState) -> Result<(), StoreError> {
63        let entry = self.inner.entry_async(id.clone()).await;
64        match entry {
65            scc::hash_map::Entry::Occupied(mut occ) => {
66                let new_version = occ.get().version.saturating_add(1);
67                let mut updated = state.clone();
68                updated.version = new_version;
69                occ.get_mut().clone_from(&updated);
70            }
71            scc::hash_map::Entry::Vacant(vac) => {
72                let mut initial = state.clone();
73                initial.version = initial.version.max(1);
74                let _ = vac.insert_entry(initial);
75            }
76        }
77        Ok(())
78    }
79
80    async fn save_if_version(
81        &self,
82        id: &SessionId,
83        state: &SessionState,
84        expected_version: u64,
85    ) -> Result<u64, StoreError> {
86        let entry = self.inner.entry_async(id.clone()).await;
87        match entry {
88            scc::hash_map::Entry::Occupied(mut occ) => {
89                if occ.get().version != expected_version {
90                    return Err(StoreError::VersionConflict {
91                        expected: expected_version,
92                        actual: occ.get().version,
93                    });
94                }
95                let new_version = expected_version.saturating_add(1);
96                let mut updated = state.clone();
97                updated.version = new_version;
98                occ.get_mut().clone_from(&updated);
99                Ok(new_version)
100            }
101            scc::hash_map::Entry::Vacant(vac) => {
102                if expected_version != 0 {
103                    return Err(StoreError::VersionConflict {
104                        expected: expected_version,
105                        actual: 0,
106                    });
107                }
108                let mut initial = state.clone();
109                initial.version = 1;
110                let _ = vac.insert_entry(initial);
111                Ok(1)
112            }
113        }
114    }
115}
116
117// ── Tests ────────────────────────────────────────────────────────────
118
119#[cfg(test)]
120mod tests {
121    use std::sync::Arc;
122
123    use bob_core::types::Message;
124
125    use super::*;
126
127    #[tokio::test]
128    async fn load_missing_returns_none() {
129        let store = InMemorySessionStore::new();
130        let result = store.load(&"nonexistent".to_string()).await;
131        assert!(result.is_ok());
132        assert!(result.ok().flatten().is_none());
133    }
134
135    #[tokio::test]
136    async fn roundtrip_save_load() {
137        let store = InMemorySessionStore::new();
138        let id = "sess-1".to_string();
139        let state = SessionState {
140            messages: vec![Message::text(bob_core::types::Role::User, "hello")],
141            ..SessionState::default()
142        };
143
144        store.save(&id, &state).await.ok();
145        let loaded = store.load(&id).await.ok().flatten();
146        assert!(loaded.is_some());
147        assert_eq!(loaded.as_ref().map(|s| s.messages.len()), Some(1));
148        assert_eq!(loaded.as_ref().map(|s| s.version), Some(1));
149    }
150
151    #[tokio::test]
152    async fn save_increments_version() {
153        let store = InMemorySessionStore::new();
154        let id = "sess-v".to_string();
155        let state = SessionState::default();
156
157        store.save(&id, &state).await.ok();
158        let v1 = store.load(&id).await.ok().flatten().unwrap_or_default().version;
159        assert_eq!(v1, 1);
160
161        store.save(&id, &state).await.ok();
162        let v2 = store.load(&id).await.ok().flatten().unwrap_or_default().version;
163        assert_eq!(v2, 2);
164    }
165
166    #[tokio::test]
167    async fn save_if_version_succeeds_on_match() {
168        let store = InMemorySessionStore::new();
169        let id = "sess-cas".to_string();
170        let state = SessionState::default();
171
172        // First save (version starts at 0 -> becomes 1)
173        store.save(&id, &state).await.ok();
174        let loaded = store.load(&id).await.ok().flatten().unwrap_or_default();
175        assert_eq!(loaded.version, 1);
176
177        // CAS with correct version
178        let new_version = store.save_if_version(&id, &state, 1).await;
179        assert!(new_version.is_ok());
180        assert_eq!(new_version.unwrap_or_default(), 2);
181    }
182
183    #[tokio::test]
184    async fn save_if_version_fails_on_mismatch() {
185        let store = InMemorySessionStore::new();
186        let id = "sess-cas-fail".to_string();
187        let state = SessionState::default();
188
189        store.save(&id, &state).await.ok();
190
191        // CAS with stale version
192        let result = store.save_if_version(&id, &state, 0).await;
193        assert!(result.is_err());
194        if let Err(StoreError::VersionConflict { expected, actual }) = result {
195            assert_eq!(expected, 0);
196            assert_eq!(actual, 1);
197        } else {
198            panic!("expected VersionConflict");
199        }
200    }
201
202    #[tokio::test]
203    async fn overwrite_existing_session() {
204        let store = InMemorySessionStore::new();
205        let id = "sess-2".to_string();
206
207        let state1 = SessionState {
208            messages: vec![Message::text(bob_core::types::Role::User, "first")],
209            ..SessionState::default()
210        };
211        store.save(&id, &state1).await.ok();
212
213        let state2 = SessionState {
214            messages: vec![
215                Message::text(bob_core::types::Role::User, "first"),
216                Message::text(bob_core::types::Role::Assistant, "second"),
217            ],
218            ..SessionState::default()
219        };
220        store.save(&id, &state2).await.ok();
221
222        let loaded = store.load(&id).await.ok().flatten();
223        assert_eq!(loaded.as_ref().map(|s| s.messages.len()), Some(2));
224        assert_eq!(loaded.as_ref().map(|s| s.version), Some(2));
225    }
226
227    #[tokio::test]
228    async fn arc_dyn_session_store_works() {
229        let store: Arc<dyn SessionStore> = Arc::new(InMemorySessionStore::new());
230        let id = "sess-arc".to_string();
231        let state = SessionState::default();
232        store.save(&id, &state).await.ok();
233        let loaded = store.load(&id).await.ok().flatten();
234        assert!(loaded.is_some());
235    }
236}