bob_adapters/
store_memory.rs1use bob_core::{
27 error::StoreError,
28 ports::SessionStore,
29 types::{SessionId, SessionState},
30};
31
32#[derive(Debug)]
37pub struct InMemorySessionStore {
38 inner: scc::HashMap<SessionId, SessionState>,
39}
40
41impl Default for InMemorySessionStore {
42 fn default() -> Self {
43 Self::new()
44 }
45}
46
47impl InMemorySessionStore {
48 #[must_use]
50 pub fn new() -> Self {
51 Self { inner: scc::HashMap::new() }
52 }
53}
54
55#[async_trait::async_trait]
56impl SessionStore for InMemorySessionStore {
57 async fn load(&self, id: &SessionId) -> Result<Option<SessionState>, StoreError> {
58 let state = self.inner.read_async(id, |_k, v| v.clone()).await;
59 Ok(state)
60 }
61
62 async fn save(&self, id: &SessionId, state: &SessionState) -> Result<(), StoreError> {
63 let entry = self.inner.entry_async(id.clone()).await;
64 match entry {
65 scc::hash_map::Entry::Occupied(mut occ) => {
66 let new_version = occ.get().version.saturating_add(1);
67 let mut updated = state.clone();
68 updated.version = new_version;
69 occ.get_mut().clone_from(&updated);
70 }
71 scc::hash_map::Entry::Vacant(vac) => {
72 let mut initial = state.clone();
73 initial.version = initial.version.max(1);
74 let _ = vac.insert_entry(initial);
75 }
76 }
77 Ok(())
78 }
79
80 async fn save_if_version(
81 &self,
82 id: &SessionId,
83 state: &SessionState,
84 expected_version: u64,
85 ) -> Result<u64, StoreError> {
86 let entry = self.inner.entry_async(id.clone()).await;
87 match entry {
88 scc::hash_map::Entry::Occupied(mut occ) => {
89 if occ.get().version != expected_version {
90 return Err(StoreError::VersionConflict {
91 expected: expected_version,
92 actual: occ.get().version,
93 });
94 }
95 let new_version = expected_version.saturating_add(1);
96 let mut updated = state.clone();
97 updated.version = new_version;
98 occ.get_mut().clone_from(&updated);
99 Ok(new_version)
100 }
101 scc::hash_map::Entry::Vacant(vac) => {
102 if expected_version != 0 {
103 return Err(StoreError::VersionConflict {
104 expected: expected_version,
105 actual: 0,
106 });
107 }
108 let mut initial = state.clone();
109 initial.version = 1;
110 let _ = vac.insert_entry(initial);
111 Ok(1)
112 }
113 }
114 }
115}
116
117#[cfg(test)]
120mod tests {
121 use std::sync::Arc;
122
123 use bob_core::types::Message;
124
125 use super::*;
126
127 #[tokio::test]
128 async fn load_missing_returns_none() {
129 let store = InMemorySessionStore::new();
130 let result = store.load(&"nonexistent".to_string()).await;
131 assert!(result.is_ok());
132 assert!(result.ok().flatten().is_none());
133 }
134
135 #[tokio::test]
136 async fn roundtrip_save_load() {
137 let store = InMemorySessionStore::new();
138 let id = "sess-1".to_string();
139 let state = SessionState {
140 messages: vec![Message::text(bob_core::types::Role::User, "hello")],
141 ..SessionState::default()
142 };
143
144 store.save(&id, &state).await.ok();
145 let loaded = store.load(&id).await.ok().flatten();
146 assert!(loaded.is_some());
147 assert_eq!(loaded.as_ref().map(|s| s.messages.len()), Some(1));
148 assert_eq!(loaded.as_ref().map(|s| s.version), Some(1));
149 }
150
151 #[tokio::test]
152 async fn save_increments_version() {
153 let store = InMemorySessionStore::new();
154 let id = "sess-v".to_string();
155 let state = SessionState::default();
156
157 store.save(&id, &state).await.ok();
158 let v1 = store.load(&id).await.ok().flatten().unwrap_or_default().version;
159 assert_eq!(v1, 1);
160
161 store.save(&id, &state).await.ok();
162 let v2 = store.load(&id).await.ok().flatten().unwrap_or_default().version;
163 assert_eq!(v2, 2);
164 }
165
166 #[tokio::test]
167 async fn save_if_version_succeeds_on_match() {
168 let store = InMemorySessionStore::new();
169 let id = "sess-cas".to_string();
170 let state = SessionState::default();
171
172 store.save(&id, &state).await.ok();
174 let loaded = store.load(&id).await.ok().flatten().unwrap_or_default();
175 assert_eq!(loaded.version, 1);
176
177 let new_version = store.save_if_version(&id, &state, 1).await;
179 assert!(new_version.is_ok());
180 assert_eq!(new_version.unwrap_or_default(), 2);
181 }
182
183 #[tokio::test]
184 async fn save_if_version_fails_on_mismatch() {
185 let store = InMemorySessionStore::new();
186 let id = "sess-cas-fail".to_string();
187 let state = SessionState::default();
188
189 store.save(&id, &state).await.ok();
190
191 let result = store.save_if_version(&id, &state, 0).await;
193 assert!(result.is_err());
194 if let Err(StoreError::VersionConflict { expected, actual }) = result {
195 assert_eq!(expected, 0);
196 assert_eq!(actual, 1);
197 } else {
198 panic!("expected VersionConflict");
199 }
200 }
201
202 #[tokio::test]
203 async fn overwrite_existing_session() {
204 let store = InMemorySessionStore::new();
205 let id = "sess-2".to_string();
206
207 let state1 = SessionState {
208 messages: vec![Message::text(bob_core::types::Role::User, "first")],
209 ..SessionState::default()
210 };
211 store.save(&id, &state1).await.ok();
212
213 let state2 = SessionState {
214 messages: vec![
215 Message::text(bob_core::types::Role::User, "first"),
216 Message::text(bob_core::types::Role::Assistant, "second"),
217 ],
218 ..SessionState::default()
219 };
220 store.save(&id, &state2).await.ok();
221
222 let loaded = store.load(&id).await.ok().flatten();
223 assert_eq!(loaded.as_ref().map(|s| s.messages.len()), Some(2));
224 assert_eq!(loaded.as_ref().map(|s| s.version), Some(2));
225 }
226
227 #[tokio::test]
228 async fn arc_dyn_session_store_works() {
229 let store: Arc<dyn SessionStore> = Arc::new(InMemorySessionStore::new());
230 let id = "sess-arc".to_string();
231 let state = SessionState::default();
232 store.save(&id, &state).await.ok();
233 let loaded = store.load(&id).await.ok().flatten();
234 assert!(loaded.is_some());
235 }
236}