bamboo_infrastructure/storage/
session_merge.rs1use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42pub struct LockedSessionStore {
50 storage: Arc<dyn Storage>,
51 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
52}
53
54impl LockedSessionStore {
55 pub fn new(storage: Arc<dyn Storage>) -> Self {
57 Self {
58 storage,
59 locks: Arc::new(DashMap::new()),
60 }
61 }
62
63 pub fn storage(&self) -> &Arc<dyn Storage> {
65 &self.storage
66 }
67
68 pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
73 let lock = self
74 .locks
75 .entry(session_id.to_string())
76 .or_insert_with(|| Arc::new(Mutex::new(())))
77 .clone();
78 lock.lock_owned().await
79 }
80
81 pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
91 let _guard = self.acquire_lock(&session.id).await;
92 self.storage.save_session(session).await
93 }
94
95 pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
105 let _guard = self.acquire_lock(&session.id).await;
106 merge_authoritative_metadata_into_stale(&self.storage, session).await;
107 self.storage.save_session(session).await
108 }
109}
110
111#[async_trait::async_trait]
115impl RuntimeSessionPersistence for LockedSessionStore {
116 async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
117 self.merge_save_runtime(session).await
118 }
119}
120
121async fn merge_authoritative_metadata_into_stale(
130 storage: &Arc<dyn Storage>,
131 session: &mut Session,
132) {
133 if let Ok(Some(latest)) = storage.load_session(&session.id).await {
134 if latest.metadata_version >= session.metadata_version {
135 session.title = latest.title;
136 session.title_version = latest.title_version;
137 session.pinned = latest.pinned;
138 session.metadata_version = latest.metadata_version;
139 }
140 }
141}
142
143pub async fn merge_save_session(
156 storage: &Arc<dyn Storage>,
157 session: &mut Session,
158) -> std::io::Result<()> {
159 merge_authoritative_metadata_into_stale(storage, session).await;
160 storage.save_session(session).await
161}
162
163#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::storage::v2::SessionStoreV2;
169 use bamboo_domain::session::types::Session;
170
171 async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
172 let temp = tempfile::tempdir().unwrap();
173 let storage = SessionStoreV2::new(temp.path().to_path_buf())
174 .await
175 .expect("storage init");
176 (temp, Arc::new(storage) as Arc<dyn Storage>)
177 }
178
179 fn fresh(id: &str) -> Session {
180 Session::new(id.to_string(), "test-model".to_string())
181 }
182
183 #[tokio::test]
186 async fn merge_preserves_disk_title_when_versions_equal() {
187 let (_temp, storage) = make_storage().await;
188 let session_id = "merge-equal";
189
190 let mut on_disk = fresh(session_id);
191 on_disk.title = "User Set This".to_string();
192 on_disk.title_version = 0;
193 on_disk.metadata_version = 0;
194 storage.save_session(&on_disk).await.unwrap();
195
196 let mut runtime_copy = fresh(session_id);
197 runtime_copy.title = "Stale Default".to_string();
198 runtime_copy.title_version = 0;
199 runtime_copy.metadata_version = 0;
200 runtime_copy.messages = vec![];
201
202 merge_save_session(&storage, &mut runtime_copy)
203 .await
204 .unwrap();
205
206 let after = storage.load_session(session_id).await.unwrap().unwrap();
207 assert_eq!(after.title, "User Set This");
208 assert_eq!(after.title_version, 0);
209 assert_eq!(runtime_copy.title, "User Set This");
210 }
211
212 #[tokio::test]
213 async fn merge_preserves_disk_when_disk_version_higher() {
214 let (_temp, storage) = make_storage().await;
215 let session_id = "merge-higher";
216
217 let mut on_disk = fresh(session_id);
218 on_disk.title = "User Title v3".to_string();
219 on_disk.title_version = 3;
220 on_disk.metadata_version = 5;
221 storage.save_session(&on_disk).await.unwrap();
222
223 let mut runtime_copy = fresh(session_id);
224 runtime_copy.title = "Stale".to_string();
225 runtime_copy.title_version = 1;
226 runtime_copy.metadata_version = 0;
227
228 merge_save_session(&storage, &mut runtime_copy)
229 .await
230 .unwrap();
231
232 let after = storage.load_session(session_id).await.unwrap().unwrap();
233 assert_eq!(after.title, "User Title v3");
234 assert_eq!(after.title_version, 3);
235 assert_eq!(after.metadata_version, 5);
236 }
237
238 #[tokio::test]
239 async fn merge_now_preserves_disk_pinned_in_metadata_group() {
240 let (_temp, storage) = make_storage().await;
241 let session_id = "pinned-merge";
242
243 let mut on_disk = fresh(session_id);
244 on_disk.pinned = true;
245 on_disk.metadata_version = 2;
246 storage.save_session(&on_disk).await.unwrap();
247
248 let mut runtime_copy = fresh(session_id);
249 runtime_copy.pinned = false;
250 runtime_copy.metadata_version = 0;
251
252 merge_save_session(&storage, &mut runtime_copy)
253 .await
254 .unwrap();
255
256 let after = storage.load_session(session_id).await.unwrap().unwrap();
257 assert!(
258 after.pinned,
259 "disk pinned=true should win over runtime false"
260 );
261 assert_eq!(after.metadata_version, 2);
262 }
263
264 #[tokio::test]
265 async fn merge_keeps_in_memory_when_session_version_higher() {
266 let (_temp, storage) = make_storage().await;
267 let session_id = "merge-bumped";
268
269 let mut on_disk = fresh(session_id);
270 on_disk.title = "Old".to_string();
271 on_disk.title_version = 1;
272 on_disk.metadata_version = 3;
273 storage.save_session(&on_disk).await.unwrap();
274
275 let mut authoritative_copy = fresh(session_id);
276 authoritative_copy.title = "New Authoritative".to_string();
277 authoritative_copy.title_version = 2;
278 authoritative_copy.metadata_version = 4;
279 authoritative_copy.pinned = true;
280
281 merge_save_session(&storage, &mut authoritative_copy)
282 .await
283 .unwrap();
284
285 let after = storage.load_session(session_id).await.unwrap().unwrap();
286 assert_eq!(after.title, "New Authoritative");
287 assert_eq!(after.title_version, 2);
288 assert_eq!(after.metadata_version, 4);
289 assert!(after.pinned);
290 }
291
292 #[tokio::test]
293 async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
294 let (_temp, storage) = make_storage().await;
295 let session_id = "merge-messages";
296
297 let mut on_disk = fresh(session_id);
298 on_disk.title = "Fresh Title".to_string();
299 on_disk.title_version = 2;
300 on_disk.metadata_version = 5;
301 storage.save_session(&on_disk).await.unwrap();
302
303 let mut runtime_copy = fresh(session_id);
304 runtime_copy.title = "Stale".to_string();
305 runtime_copy.metadata_version = 0;
306 runtime_copy.messages = vec![bamboo_domain::session::types::Message {
307 role: bamboo_domain::session::types::Role::User,
308 content: "keep me".to_string(),
309 id: "msg-1".to_string(),
310 created_at: chrono::Utc::now(),
311 reasoning: None,
312 content_parts: None,
313 image_ocr: None,
314 phase: None,
315 tool_calls: None,
316 tool_call_id: None,
317 tool_success: None,
318 compressed: false,
319 compressed_by_event_id: None,
320 never_compress: false,
321 compression_level: 0,
322 metadata: None,
323 }];
324
325 merge_save_session(&storage, &mut runtime_copy)
326 .await
327 .unwrap();
328
329 let after = storage.load_session(session_id).await.unwrap().unwrap();
330 assert_eq!(after.title, "Fresh Title");
331 assert_eq!(after.metadata_version, 5);
332 assert_eq!(after.messages.len(), 1);
333 assert_eq!(after.messages[0].content, "keep me");
334 }
335
336 #[tokio::test]
339 async fn locked_merge_save_runtime_serialises_concurrent_writes() {
340 let (_temp, storage) = make_storage().await;
341 let store = Arc::new(LockedSessionStore::new(storage));
342 let session_id = "lock-serial".to_string();
343
344 let base = fresh(&session_id);
346 store.storage().save_session(&base).await.unwrap();
347
348 let store_a = store.clone();
351 let store_b = store.clone();
352 let sid_a = session_id.clone();
353 let sid_b = session_id.clone();
354
355 let a = tokio::spawn(async move {
356 let _guard = store_a.acquire_lock(&sid_a).await;
357 let mut s = store_a
358 .storage()
359 .load_session(&sid_a)
360 .await
361 .unwrap()
362 .unwrap();
363 s.title = "Writer A".to_string();
364 s.title_version = s.title_version.saturating_add(1);
365 s.metadata_version = s.metadata_version.saturating_add(1);
366 s.updated_at = chrono::Utc::now();
367 store_a.storage().save_session(&s).await.unwrap();
368 s.title_version
369 });
370
371 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
373
374 let b = tokio::spawn(async move {
375 let _guard = store_b.acquire_lock(&sid_b).await;
376 let mut s = store_b
377 .storage()
378 .load_session(&sid_b)
379 .await
380 .unwrap()
381 .unwrap();
382 s.title = "Writer B".to_string();
383 s.title_version = s.title_version.saturating_add(1);
384 s.metadata_version = s.metadata_version.saturating_add(1);
385 s.updated_at = chrono::Utc::now();
386 store_b.storage().save_session(&s).await.unwrap();
387 s.title_version
388 });
389
390 let (ver_a, ver_b) = tokio::join!(a, b);
391 let final_s = store
392 .storage()
393 .load_session(&session_id)
394 .await
395 .unwrap()
396 .unwrap();
397 assert!(
398 ver_a.unwrap() != ver_b.unwrap(),
399 "concurrent writers must produce distinct versions"
400 );
401 assert_eq!(final_s.metadata_version, 2);
402 }
403
404 #[tokio::test]
405 async fn commit_metadata_is_plain_save_inside_lock() {
406 let (_temp, storage) = make_storage().await;
407 let store = LockedSessionStore::new(storage);
408 let session_id = "commit-plain";
409
410 let mut s = fresh(session_id);
411 s.title = "Committed".to_string();
412 s.metadata_version = 1;
413 s.title_version = 2;
414
415 store.commit_metadata(&s).await.unwrap();
416
417 let after = store
418 .storage()
419 .load_session(session_id)
420 .await
421 .unwrap()
422 .unwrap();
423 assert_eq!(after.title, "Committed");
424 assert_eq!(after.metadata_version, 1);
425 assert_eq!(after.title_version, 2);
426 }
427}