bamboo_infrastructure/storage/
session_merge.rs1use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42const AUTHORITATIVE_METADATA_KEYS: &[&str] = &["gold_config"];
43
44pub struct LockedSessionStore {
52 storage: Arc<dyn Storage>,
53 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
54}
55
56impl LockedSessionStore {
57 pub fn new(storage: Arc<dyn Storage>) -> Self {
59 Self {
60 storage,
61 locks: Arc::new(DashMap::new()),
62 }
63 }
64
65 pub fn storage(&self) -> &Arc<dyn Storage> {
67 &self.storage
68 }
69
70 pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
75 let lock = self
76 .locks
77 .entry(session_id.to_string())
78 .or_insert_with(|| Arc::new(Mutex::new(())))
79 .clone();
80 lock.lock_owned().await
81 }
82
83 pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
93 let _guard = self.acquire_lock(&session.id).await;
94 self.storage.save_session(session).await
95 }
96
97 pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
107 let _guard = self.acquire_lock(&session.id).await;
108
109 let existing_message_count = self
115 .storage
116 .load_session(&session.id)
117 .await
118 .ok()
119 .flatten()
120 .map(|s| s.messages.len());
121 let incoming_message_count = session.messages.len();
122 if existing_message_count.is_some_and(|existing| existing > incoming_message_count) {
123 tracing::warn!(
124 "[{}] merge_save_runtime SHRINK: disk has {:?} messages, saving {} (last_role={:?}, updated_at={}); a stale writer is reverting a concurrent append",
125 session.id,
126 existing_message_count,
127 incoming_message_count,
128 session.messages.last().map(|m| format!("{:?}", m.role)),
129 session.updated_at,
130 );
131 } else {
132 tracing::debug!(
133 "[{}] merge_save_runtime: disk={:?} messages, saving {} (updated_at={})",
134 session.id,
135 existing_message_count,
136 incoming_message_count,
137 session.updated_at,
138 );
139 }
140
141 merge_authoritative_metadata_into_stale(&self.storage, session).await;
142 self.storage.save_session(session).await
143 }
144
145 pub async fn update_runtime_config<F>(
157 &self,
158 session_id: &str,
159 mutate: F,
160 ) -> std::io::Result<Option<Session>>
161 where
162 F: FnOnce(&mut Session),
163 {
164 let _guard = self.acquire_lock(session_id).await;
165 let Some(mut session) = self.storage.load_session(session_id).await? else {
166 return Ok(None);
167 };
168 mutate(&mut session);
169 self.storage.save_session(&session).await?;
170 Ok(Some(session))
171 }
172}
173
174#[async_trait::async_trait]
178impl RuntimeSessionPersistence for LockedSessionStore {
179 async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
180 self.merge_save_runtime(session).await
181 }
182}
183
184async fn merge_authoritative_metadata_into_stale(
193 storage: &Arc<dyn Storage>,
194 session: &mut Session,
195) {
196 if let Ok(Some(latest)) = storage.load_session(&session.id).await {
197 if latest.metadata_version >= session.metadata_version {
198 session.title = latest.title;
199 session.title_version = latest.title_version;
200 session.pinned = latest.pinned;
201 for key in AUTHORITATIVE_METADATA_KEYS {
202 if let Some(value) = latest.metadata.get(*key) {
203 session.metadata.insert((*key).to_string(), value.clone());
204 } else {
205 session.metadata.remove(*key);
206 }
207 }
208 session.metadata_version = latest.metadata_version;
209 }
210 }
211}
212
213pub async fn merge_save_session(
226 storage: &Arc<dyn Storage>,
227 session: &mut Session,
228) -> std::io::Result<()> {
229 merge_authoritative_metadata_into_stale(storage, session).await;
230 storage.save_session(session).await
231}
232
233#[cfg(test)]
236mod tests {
237 use super::*;
238 use crate::storage::v2::SessionStoreV2;
239 use bamboo_domain::session::types::Session;
240
241 async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
242 let temp = tempfile::tempdir().unwrap();
243 let storage = SessionStoreV2::new(temp.path().to_path_buf())
244 .await
245 .expect("storage init");
246 (temp, Arc::new(storage) as Arc<dyn Storage>)
247 }
248
249 fn fresh(id: &str) -> Session {
250 Session::new(id.to_string(), "test-model".to_string())
251 }
252
253 #[tokio::test]
256 async fn update_runtime_config_preserves_concurrently_appended_messages() {
257 use bamboo_domain::session::types::Message;
258 use bamboo_domain::ReasoningEffort;
259
260 let (_temp, storage) = make_storage().await;
261 let store = LockedSessionStore::new(storage.clone());
262 let session_id = "cfg-preserve";
263
264 let mut initial = fresh(session_id);
266 initial.add_message(Message::user("hello"));
267 initial.add_message(Message::assistant("hi", None));
268 storage.save_session(&initial).await.unwrap();
269
270 let mut after_chat = storage.load_session(session_id).await.unwrap().unwrap();
272 after_chat.add_message(Message::user("second question"));
273 storage.save_session(&after_chat).await.unwrap();
274 assert_eq!(after_chat.messages.len(), 3);
275
276 let updated = store
280 .update_runtime_config(session_id, |s| {
281 s.reasoning_effort = Some(ReasoningEffort::Max);
282 })
283 .await
284 .unwrap()
285 .expect("session exists");
286
287 assert_eq!(updated.reasoning_effort, Some(ReasoningEffort::Max));
288 assert_eq!(
289 updated.messages.len(),
290 3,
291 "config patch must not revert a concurrently-appended message"
292 );
293
294 let on_disk = storage.load_session(session_id).await.unwrap().unwrap();
295 assert_eq!(on_disk.messages.len(), 3);
296 assert_eq!(on_disk.reasoning_effort, Some(ReasoningEffort::Max));
297 }
298
299 #[tokio::test]
300 async fn update_runtime_config_returns_none_for_missing_session() {
301 use bamboo_domain::ReasoningEffort;
302
303 let (_temp, storage) = make_storage().await;
304 let store = LockedSessionStore::new(storage);
305 let result = store
306 .update_runtime_config("does-not-exist", |s| {
307 s.reasoning_effort = Some(ReasoningEffort::Low);
308 })
309 .await
310 .unwrap();
311 assert!(result.is_none());
312 }
313
314 #[tokio::test]
315 async fn merge_save_runtime_overwrites_messages_from_stale_snapshot() {
316 use bamboo_domain::session::types::Message;
321
322 let (_temp, storage) = make_storage().await;
323 let store = LockedSessionStore::new(storage.clone());
324 let session_id = "stale-clobber";
325
326 let mut baseline = fresh(session_id);
328 baseline.add_message(Message::user("hello"));
329 storage.save_session(&baseline).await.unwrap();
330 let mut stale_snapshot = storage.load_session(session_id).await.unwrap().unwrap();
331
332 let mut after_chat = storage.load_session(session_id).await.unwrap().unwrap();
334 after_chat.add_message(Message::user("second"));
335 storage.save_session(&after_chat).await.unwrap();
336 assert_eq!(
337 storage
338 .load_session(session_id)
339 .await
340 .unwrap()
341 .unwrap()
342 .messages
343 .len(),
344 2
345 );
346
347 store.merge_save_runtime(&mut stale_snapshot).await.unwrap();
349 let after = storage.load_session(session_id).await.unwrap().unwrap();
350 assert_eq!(
351 after.messages.len(),
352 1,
353 "merge_save_runtime clobbers concurrent appends — this is why config patches must use update_runtime_config"
354 );
355 }
356
357 #[tokio::test]
360 async fn merge_preserves_disk_title_when_versions_equal() {
361 let (_temp, storage) = make_storage().await;
362 let session_id = "merge-equal";
363
364 let mut on_disk = fresh(session_id);
365 on_disk.title = "User Set This".to_string();
366 on_disk.title_version = 0;
367 on_disk.metadata_version = 0;
368 storage.save_session(&on_disk).await.unwrap();
369
370 let mut runtime_copy = fresh(session_id);
371 runtime_copy.title = "Stale Default".to_string();
372 runtime_copy.title_version = 0;
373 runtime_copy.metadata_version = 0;
374 runtime_copy.messages = vec![];
375
376 merge_save_session(&storage, &mut runtime_copy)
377 .await
378 .unwrap();
379
380 let after = storage.load_session(session_id).await.unwrap().unwrap();
381 assert_eq!(after.title, "User Set This");
382 assert_eq!(after.title_version, 0);
383 assert_eq!(runtime_copy.title, "User Set This");
384 }
385
386 #[tokio::test]
387 async fn merge_preserves_disk_when_disk_version_higher() {
388 let (_temp, storage) = make_storage().await;
389 let session_id = "merge-higher";
390
391 let mut on_disk = fresh(session_id);
392 on_disk.title = "User Title v3".to_string();
393 on_disk.title_version = 3;
394 on_disk.metadata_version = 5;
395 storage.save_session(&on_disk).await.unwrap();
396
397 let mut runtime_copy = fresh(session_id);
398 runtime_copy.title = "Stale".to_string();
399 runtime_copy.title_version = 1;
400 runtime_copy.metadata_version = 0;
401
402 merge_save_session(&storage, &mut runtime_copy)
403 .await
404 .unwrap();
405
406 let after = storage.load_session(session_id).await.unwrap().unwrap();
407 assert_eq!(after.title, "User Title v3");
408 assert_eq!(after.title_version, 3);
409 assert_eq!(after.metadata_version, 5);
410 }
411
412 #[tokio::test]
413 async fn merge_now_preserves_disk_pinned_in_metadata_group() {
414 let (_temp, storage) = make_storage().await;
415 let session_id = "pinned-merge";
416
417 let mut on_disk = fresh(session_id);
418 on_disk.pinned = true;
419 on_disk.metadata_version = 2;
420 storage.save_session(&on_disk).await.unwrap();
421
422 let mut runtime_copy = fresh(session_id);
423 runtime_copy.pinned = false;
424 runtime_copy.metadata_version = 0;
425
426 merge_save_session(&storage, &mut runtime_copy)
427 .await
428 .unwrap();
429
430 let after = storage.load_session(session_id).await.unwrap().unwrap();
431 assert!(
432 after.pinned,
433 "disk pinned=true should win over runtime false"
434 );
435 assert_eq!(after.metadata_version, 2);
436 }
437
438 #[tokio::test]
439 async fn merge_keeps_in_memory_when_session_version_higher() {
440 let (_temp, storage) = make_storage().await;
441 let session_id = "merge-bumped";
442
443 let mut on_disk = fresh(session_id);
444 on_disk.title = "Old".to_string();
445 on_disk.title_version = 1;
446 on_disk.metadata_version = 3;
447 storage.save_session(&on_disk).await.unwrap();
448
449 let mut authoritative_copy = fresh(session_id);
450 authoritative_copy.title = "New Authoritative".to_string();
451 authoritative_copy.title_version = 2;
452 authoritative_copy.metadata_version = 4;
453 authoritative_copy.pinned = true;
454
455 merge_save_session(&storage, &mut authoritative_copy)
456 .await
457 .unwrap();
458
459 let after = storage.load_session(session_id).await.unwrap().unwrap();
460 assert_eq!(after.title, "New Authoritative");
461 assert_eq!(after.title_version, 2);
462 assert_eq!(after.metadata_version, 4);
463 assert!(after.pinned);
464 }
465
466 #[tokio::test]
467 async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
468 let (_temp, storage) = make_storage().await;
469 let session_id = "merge-messages";
470
471 let mut on_disk = fresh(session_id);
472 on_disk.title = "Fresh Title".to_string();
473 on_disk.title_version = 2;
474 on_disk.metadata_version = 5;
475 storage.save_session(&on_disk).await.unwrap();
476
477 let mut runtime_copy = fresh(session_id);
478 runtime_copy.title = "Stale".to_string();
479 runtime_copy.metadata_version = 0;
480 runtime_copy.messages = vec![bamboo_domain::session::types::Message {
481 role: bamboo_domain::session::types::Role::User,
482 content: "keep me".to_string(),
483 id: "msg-1".to_string(),
484 created_at: chrono::Utc::now(),
485 reasoning: None,
486 content_parts: None,
487 image_ocr: None,
488 phase: None,
489 tool_calls: None,
490 tool_call_id: None,
491 tool_success: None,
492 compressed: false,
493 compressed_by_event_id: None,
494 never_compress: false,
495 compression_level: 0,
496 metadata: None,
497 }];
498
499 merge_save_session(&storage, &mut runtime_copy)
500 .await
501 .unwrap();
502
503 let after = storage.load_session(session_id).await.unwrap().unwrap();
504 assert_eq!(after.title, "Fresh Title");
505 assert_eq!(after.metadata_version, 5);
506 assert_eq!(after.messages.len(), 1);
507 assert_eq!(after.messages[0].content, "keep me");
508 }
509
510 #[tokio::test]
513 async fn locked_merge_save_runtime_serialises_concurrent_writes() {
514 let (_temp, storage) = make_storage().await;
515 let store = Arc::new(LockedSessionStore::new(storage));
516 let session_id = "lock-serial".to_string();
517
518 let base = fresh(&session_id);
520 store.storage().save_session(&base).await.unwrap();
521
522 let store_a = store.clone();
525 let store_b = store.clone();
526 let sid_a = session_id.clone();
527 let sid_b = session_id.clone();
528
529 let a = tokio::spawn(async move {
530 let _guard = store_a.acquire_lock(&sid_a).await;
531 let mut s = store_a
532 .storage()
533 .load_session(&sid_a)
534 .await
535 .unwrap()
536 .unwrap();
537 s.title = "Writer A".to_string();
538 s.title_version = s.title_version.saturating_add(1);
539 s.metadata_version = s.metadata_version.saturating_add(1);
540 s.updated_at = chrono::Utc::now();
541 store_a.storage().save_session(&s).await.unwrap();
542 s.title_version
543 });
544
545 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
547
548 let b = tokio::spawn(async move {
549 let _guard = store_b.acquire_lock(&sid_b).await;
550 let mut s = store_b
551 .storage()
552 .load_session(&sid_b)
553 .await
554 .unwrap()
555 .unwrap();
556 s.title = "Writer B".to_string();
557 s.title_version = s.title_version.saturating_add(1);
558 s.metadata_version = s.metadata_version.saturating_add(1);
559 s.updated_at = chrono::Utc::now();
560 store_b.storage().save_session(&s).await.unwrap();
561 s.title_version
562 });
563
564 let (ver_a, ver_b) = tokio::join!(a, b);
565 let final_s = store
566 .storage()
567 .load_session(&session_id)
568 .await
569 .unwrap()
570 .unwrap();
571 assert!(
572 ver_a.unwrap() != ver_b.unwrap(),
573 "concurrent writers must produce distinct versions"
574 );
575 assert_eq!(final_s.metadata_version, 2);
576 }
577
578 #[tokio::test]
579 async fn commit_metadata_is_plain_save_inside_lock() {
580 let (_temp, storage) = make_storage().await;
581 let store = LockedSessionStore::new(storage);
582 let session_id = "commit-plain";
583
584 let mut s = fresh(session_id);
585 s.title = "Committed".to_string();
586 s.metadata_version = 1;
587 s.title_version = 2;
588
589 store.commit_metadata(&s).await.unwrap();
590
591 let after = store
592 .storage()
593 .load_session(session_id)
594 .await
595 .unwrap()
596 .unwrap();
597 assert_eq!(after.title, "Committed");
598 assert_eq!(after.metadata_version, 1);
599 assert_eq!(after.title_version, 2);
600 }
601}