Skip to main content

bamboo_infrastructure/storage/
session_merge.rs

1//! Merge-aware session save helper.
2//!
3//! Provides [`merge_save_session`], which preserves any concurrent UI edits to
4//! the authoritative metadata group (`title`, `title_version`, `pinned`,
5//! `metadata_version`) before writing the runtime-modified session to storage.
6//! Re-reads the latest persisted copy and only takes in-memory values when the
7//! caller's `metadata_version` strictly exceeds disk's.
8//!
9//! ## Field-by-field merge policy
10//!
11//! All authoritative metadata fields are grouped under `metadata_version`:
12//! when `disk.metadata_version >= session.metadata_version`, the on-disk
13//! `title`, `title_version`, `pinned`, and `metadata_version` overwrite the
14//! in-memory values before writing. Authoritative writers bump
15//! `metadata_version` (and `title_version` for title edits) before calling so
16//! their values survive the merge; non-authoritative writers don't bump and so
17//! are overwritten by any later disk changes.
18//!
19//! ## Two save primitives
20//!
21//! - **`merge_save_session`** — stateless merge+save. Still works for
22//!   non-authoritative writers that hold `Arc<dyn Storage>` directly.
23//! - **`LockedSessionStore::merge_save_runtime`** — per-session-locked variant
24//!   that additionally serializes writes for the same session. Prefer this for
25//!   server-side paths where an authoritative writer may race with a runtime
26//!   save.
27//! - **`LockedSessionStore::commit_metadata`** — plain save inside a per-session
28//!   lock. For authoritative writers that have already performed
29//!   load→mutate→bump inside the lock; no merge needed (they hold the latest).
30//!
31//! Bare [`Storage::save_session`] is reserved for first-write paths (e.g. new
32//! session creation) where there is no prior on-disk copy to merge against.
33
34use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42const AUTHORITATIVE_METADATA_KEYS: &[&str] = &["gold_config"];
43
44// ── LockedSessionStore ────────────────────────────────────────────────
45
46/// Wraps a [`Storage`] implementation with per-session write serialization.
47///
48/// Under the hood it maintains a `DashMap<String, Arc<Mutex<()>>>` so that
49/// only writes targeting the *same* session are serialised; different
50/// sessions proceed concurrently.
51pub struct LockedSessionStore {
52    storage: Arc<dyn Storage>,
53    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
54}
55
56impl LockedSessionStore {
57    /// Wrap an existing storage backend.
58    pub fn new(storage: Arc<dyn Storage>) -> Self {
59        Self {
60            storage,
61            locks: Arc::new(DashMap::new()),
62        }
63    }
64
65    /// Borrow the inner storage for read-only access.
66    pub fn storage(&self) -> &Arc<dyn Storage> {
67        &self.storage
68    }
69
70    /// Acquire a per-session serialization guard.
71    ///
72    /// Only writes for the **same** session are serialised; writes for
73    /// different sessions can proceed concurrently.
74    pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
75        let lock = self
76            .locks
77            .entry(session_id.to_string())
78            .or_insert_with(|| Arc::new(Mutex::new(())))
79            .clone();
80        lock.lock_owned().await
81    }
82
83    /// Authoritative metadata commit.
84    ///
85    /// The caller must have already loaded the latest session, mutated the
86    /// metadata fields, and bumped `metadata_version` (and `title_version` if
87    /// applicable).  This method simply acquires the per-session lock and
88    /// performs a plain `storage.save_session`.
89    ///
90    /// The lock guarantees that no other write for this session interleaves
91    /// between the caller's load and this save, so merge is unnecessary.
92    pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
93        let _guard = self.acquire_lock(&session.id).await;
94        self.storage.save_session(session).await
95    }
96
97    /// Runtime / non-authoritative save with per-session lock.
98    ///
99    /// Inside the lock: reload disk, merge the authoritative metadata group
100    /// (`title`, `title_version`, `pinned`, `metadata_version`) from disk into
101    /// the in-memory copy if disk's `metadata_version >= session.metadata_version`,
102    /// then save.
103    ///
104    /// This is the locked equivalent of [`merge_save_session`]; prefer it for
105    /// server-side paths where an authoritative write may race with this save.
106    pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
107        let _guard = self.acquire_lock(&session.id).await;
108
109        // DIAGNOSTIC: merge_save_runtime overwrites the whole `messages` array
110        // (it only merges authoritative metadata, not messages). If the incoming
111        // session is stale (fewer messages than what is already on disk), this save
112        // silently reverts a concurrent append (e.g. a just-persisted user message).
113        // Log a SHRINK warning so we can identify the stale writer.
114        let existing_message_count = self
115            .storage
116            .load_session(&session.id)
117            .await
118            .ok()
119            .flatten()
120            .map(|s| s.messages.len());
121        let incoming_message_count = session.messages.len();
122        if existing_message_count.is_some_and(|existing| existing > incoming_message_count) {
123            tracing::warn!(
124                "[{}] merge_save_runtime SHRINK: disk has {:?} messages, saving {} (last_role={:?}, updated_at={}); a stale writer is reverting a concurrent append",
125                session.id,
126                existing_message_count,
127                incoming_message_count,
128                session.messages.last().map(|m| format!("{:?}", m.role)),
129                session.updated_at,
130            );
131        } else {
132            tracing::debug!(
133                "[{}] merge_save_runtime: disk={:?} messages, saving {} (updated_at={})",
134                session.id,
135                existing_message_count,
136                incoming_message_count,
137                session.updated_at,
138            );
139        }
140
141        merge_authoritative_metadata_into_stale(&self.storage, session).await;
142        self.storage.save_session(session).await
143    }
144
145    /// Apply a config-only mutation to a session without ever clobbering its
146    /// `messages` (or other concurrently-written state).
147    ///
148    /// Unlike [`Self::merge_save_runtime`], the caller does NOT pass a session
149    /// snapshot. Instead this loads the **latest** session from storage *inside*
150    /// the per-session lock, applies `mutate` (intended for small config fields
151    /// like `model_ref` / `reasoning_effort`), and saves. Because the load and
152    /// save both happen under the lock, a concurrent append (e.g. `POST /chat`
153    /// adding a user message) can never be reverted by this write.
154    ///
155    /// Returns the saved session, or `None` if it does not exist.
156    pub async fn update_runtime_config<F>(
157        &self,
158        session_id: &str,
159        mutate: F,
160    ) -> std::io::Result<Option<Session>>
161    where
162        F: FnOnce(&mut Session),
163    {
164        let _guard = self.acquire_lock(session_id).await;
165        let Some(mut session) = self.storage.load_session(session_id).await? else {
166            return Ok(None);
167        };
168        mutate(&mut session);
169        self.storage.save_session(&session).await?;
170        Ok(Some(session))
171    }
172}
173
174/// Infrastructure implementation of the domain runtime-persistence port.
175/// Server should assemble this as `Arc<dyn RuntimeSessionPersistence>` and must
176/// not define a separate adapter layer for the same behavior.
177#[async_trait::async_trait]
178impl RuntimeSessionPersistence for LockedSessionStore {
179    async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
180        self.merge_save_runtime(session).await
181    }
182}
183
184// ── Internal merge helper ─────────────────────────────────────────────
185
186/// Re-read the on-disk session and, when the disk copy carries a
187/// `metadata_version >= session.metadata_version`, overwrite the in-memory
188/// authoritative metadata fields with the disk values.
189///
190/// This is the core staleness-correction: non-authoritative writers call it
191/// before saving so they don't accidentally revert a concurrent UI edit.
192async fn merge_authoritative_metadata_into_stale(
193    storage: &Arc<dyn Storage>,
194    session: &mut Session,
195) {
196    if let Ok(Some(latest)) = storage.load_session(&session.id).await {
197        if latest.metadata_version >= session.metadata_version {
198            session.title = latest.title;
199            session.title_version = latest.title_version;
200            session.pinned = latest.pinned;
201            for key in AUTHORITATIVE_METADATA_KEYS {
202                if let Some(value) = latest.metadata.get(*key) {
203                    session.metadata.insert((*key).to_string(), value.clone());
204                } else {
205                    session.metadata.remove(*key);
206                }
207            }
208            session.metadata_version = latest.metadata_version;
209        }
210    }
211}
212
213// ── Free merge-save function ──────────────────────────────────────────
214
215/// Save a session while preserving any concurrent UI edits to the
216/// authoritative metadata group.
217///
218/// Behaviour: if the on-disk session has `metadata_version >=
219/// session.metadata_version`, the on-disk `title`, `title_version`, `pinned`
220/// and `metadata_version` overwrite the in-memory values before writing.
221///
222/// This is the stateless variant (no per-session lock). Prefer
223/// [`LockedSessionStore::merge_save_runtime`] for server-side paths where an
224/// authoritative writer may race with this save.
225pub async fn merge_save_session(
226    storage: &Arc<dyn Storage>,
227    session: &mut Session,
228) -> std::io::Result<()> {
229    merge_authoritative_metadata_into_stale(storage, session).await;
230    storage.save_session(session).await
231}
232
233// ── Tests ─────────────────────────────────────────────────────────────
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::storage::v2::SessionStoreV2;
239    use bamboo_domain::session::types::Session;
240
241    async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
242        let temp = tempfile::tempdir().unwrap();
243        let storage = SessionStoreV2::new(temp.path().to_path_buf())
244            .await
245            .expect("storage init");
246        (temp, Arc::new(storage) as Arc<dyn Storage>)
247    }
248
249    fn fresh(id: &str) -> Session {
250        Session::new(id.to_string(), "test-model".to_string())
251    }
252
253    // ── update_runtime_config: config patches must never clobber messages ──
254
255    #[tokio::test]
256    async fn update_runtime_config_preserves_concurrently_appended_messages() {
257        use bamboo_domain::session::types::Message;
258        use bamboo_domain::ReasoningEffort;
259
260        let (_temp, storage) = make_storage().await;
261        let store = LockedSessionStore::new(storage.clone());
262        let session_id = "cfg-preserve";
263
264        // Persisted baseline: one user + one assistant turn.
265        let mut initial = fresh(session_id);
266        initial.add_message(Message::user("hello"));
267        initial.add_message(Message::assistant("hi", None));
268        storage.save_session(&initial).await.unwrap();
269
270        // Simulate `POST /chat` appending a new user message to disk.
271        let mut after_chat = storage.load_session(session_id).await.unwrap().unwrap();
272        after_chat.add_message(Message::user("second question"));
273        storage.save_session(&after_chat).await.unwrap();
274        assert_eq!(after_chat.messages.len(), 3);
275
276        // A config-only patch must load the freshest session and preserve the
277        // appended message (this is the regression that broke message sending on
278        // existing sessions).
279        let updated = store
280            .update_runtime_config(session_id, |s| {
281                s.reasoning_effort = Some(ReasoningEffort::Max);
282            })
283            .await
284            .unwrap()
285            .expect("session exists");
286
287        assert_eq!(updated.reasoning_effort, Some(ReasoningEffort::Max));
288        assert_eq!(
289            updated.messages.len(),
290            3,
291            "config patch must not revert a concurrently-appended message"
292        );
293
294        let on_disk = storage.load_session(session_id).await.unwrap().unwrap();
295        assert_eq!(on_disk.messages.len(), 3);
296        assert_eq!(on_disk.reasoning_effort, Some(ReasoningEffort::Max));
297    }
298
299    #[tokio::test]
300    async fn update_runtime_config_returns_none_for_missing_session() {
301        use bamboo_domain::ReasoningEffort;
302
303        let (_temp, storage) = make_storage().await;
304        let store = LockedSessionStore::new(storage);
305        let result = store
306            .update_runtime_config("does-not-exist", |s| {
307                s.reasoning_effort = Some(ReasoningEffort::Low);
308            })
309            .await
310            .unwrap();
311        assert!(result.is_none());
312    }
313
314    #[tokio::test]
315    async fn merge_save_runtime_overwrites_messages_from_stale_snapshot() {
316        // Characterization of the bug that motivated `update_runtime_config`:
317        // `merge_save_runtime` writes the caller's `messages` verbatim, so a
318        // stale snapshot reverts a concurrent append. Config-only writers must
319        // therefore use `update_runtime_config`, never `merge_save_runtime`.
320        use bamboo_domain::session::types::Message;
321
322        let (_temp, storage) = make_storage().await;
323        let store = LockedSessionStore::new(storage.clone());
324        let session_id = "stale-clobber";
325
326        // A handler loads the session (1 message) …
327        let mut baseline = fresh(session_id);
328        baseline.add_message(Message::user("hello"));
329        storage.save_session(&baseline).await.unwrap();
330        let mut stale_snapshot = storage.load_session(session_id).await.unwrap().unwrap();
331
332        // … then `POST /chat` appends a second message to disk …
333        let mut after_chat = storage.load_session(session_id).await.unwrap().unwrap();
334        after_chat.add_message(Message::user("second"));
335        storage.save_session(&after_chat).await.unwrap();
336        assert_eq!(
337            storage
338                .load_session(session_id)
339                .await
340                .unwrap()
341                .unwrap()
342                .messages
343                .len(),
344            2
345        );
346
347        // … and the stale handler saves via merge_save_runtime -> append reverted.
348        store.merge_save_runtime(&mut stale_snapshot).await.unwrap();
349        let after = storage.load_session(session_id).await.unwrap().unwrap();
350        assert_eq!(
351            after.messages.len(),
352            1,
353            "merge_save_runtime clobbers concurrent appends — this is why config patches must use update_runtime_config"
354        );
355    }
356
357    // ── Free-function merge tests (updated for metadata-group) ──────
358
359    #[tokio::test]
360    async fn merge_preserves_disk_title_when_versions_equal() {
361        let (_temp, storage) = make_storage().await;
362        let session_id = "merge-equal";
363
364        let mut on_disk = fresh(session_id);
365        on_disk.title = "User Set This".to_string();
366        on_disk.title_version = 0;
367        on_disk.metadata_version = 0;
368        storage.save_session(&on_disk).await.unwrap();
369
370        let mut runtime_copy = fresh(session_id);
371        runtime_copy.title = "Stale Default".to_string();
372        runtime_copy.title_version = 0;
373        runtime_copy.metadata_version = 0;
374        runtime_copy.messages = vec![];
375
376        merge_save_session(&storage, &mut runtime_copy)
377            .await
378            .unwrap();
379
380        let after = storage.load_session(session_id).await.unwrap().unwrap();
381        assert_eq!(after.title, "User Set This");
382        assert_eq!(after.title_version, 0);
383        assert_eq!(runtime_copy.title, "User Set This");
384    }
385
386    #[tokio::test]
387    async fn merge_preserves_disk_when_disk_version_higher() {
388        let (_temp, storage) = make_storage().await;
389        let session_id = "merge-higher";
390
391        let mut on_disk = fresh(session_id);
392        on_disk.title = "User Title v3".to_string();
393        on_disk.title_version = 3;
394        on_disk.metadata_version = 5;
395        storage.save_session(&on_disk).await.unwrap();
396
397        let mut runtime_copy = fresh(session_id);
398        runtime_copy.title = "Stale".to_string();
399        runtime_copy.title_version = 1;
400        runtime_copy.metadata_version = 0;
401
402        merge_save_session(&storage, &mut runtime_copy)
403            .await
404            .unwrap();
405
406        let after = storage.load_session(session_id).await.unwrap().unwrap();
407        assert_eq!(after.title, "User Title v3");
408        assert_eq!(after.title_version, 3);
409        assert_eq!(after.metadata_version, 5);
410    }
411
412    #[tokio::test]
413    async fn merge_now_preserves_disk_pinned_in_metadata_group() {
414        let (_temp, storage) = make_storage().await;
415        let session_id = "pinned-merge";
416
417        let mut on_disk = fresh(session_id);
418        on_disk.pinned = true;
419        on_disk.metadata_version = 2;
420        storage.save_session(&on_disk).await.unwrap();
421
422        let mut runtime_copy = fresh(session_id);
423        runtime_copy.pinned = false;
424        runtime_copy.metadata_version = 0;
425
426        merge_save_session(&storage, &mut runtime_copy)
427            .await
428            .unwrap();
429
430        let after = storage.load_session(session_id).await.unwrap().unwrap();
431        assert!(
432            after.pinned,
433            "disk pinned=true should win over runtime false"
434        );
435        assert_eq!(after.metadata_version, 2);
436    }
437
438    #[tokio::test]
439    async fn merge_keeps_in_memory_when_session_version_higher() {
440        let (_temp, storage) = make_storage().await;
441        let session_id = "merge-bumped";
442
443        let mut on_disk = fresh(session_id);
444        on_disk.title = "Old".to_string();
445        on_disk.title_version = 1;
446        on_disk.metadata_version = 3;
447        storage.save_session(&on_disk).await.unwrap();
448
449        let mut authoritative_copy = fresh(session_id);
450        authoritative_copy.title = "New Authoritative".to_string();
451        authoritative_copy.title_version = 2;
452        authoritative_copy.metadata_version = 4;
453        authoritative_copy.pinned = true;
454
455        merge_save_session(&storage, &mut authoritative_copy)
456            .await
457            .unwrap();
458
459        let after = storage.load_session(session_id).await.unwrap().unwrap();
460        assert_eq!(after.title, "New Authoritative");
461        assert_eq!(after.title_version, 2);
462        assert_eq!(after.metadata_version, 4);
463        assert!(after.pinned);
464    }
465
466    #[tokio::test]
467    async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
468        let (_temp, storage) = make_storage().await;
469        let session_id = "merge-messages";
470
471        let mut on_disk = fresh(session_id);
472        on_disk.title = "Fresh Title".to_string();
473        on_disk.title_version = 2;
474        on_disk.metadata_version = 5;
475        storage.save_session(&on_disk).await.unwrap();
476
477        let mut runtime_copy = fresh(session_id);
478        runtime_copy.title = "Stale".to_string();
479        runtime_copy.metadata_version = 0;
480        runtime_copy.messages = vec![bamboo_domain::session::types::Message {
481            role: bamboo_domain::session::types::Role::User,
482            content: "keep me".to_string(),
483            id: "msg-1".to_string(),
484            created_at: chrono::Utc::now(),
485            reasoning: None,
486            content_parts: None,
487            image_ocr: None,
488            phase: None,
489            tool_calls: None,
490            tool_call_id: None,
491            tool_success: None,
492            compressed: false,
493            compressed_by_event_id: None,
494            never_compress: false,
495            compression_level: 0,
496            metadata: None,
497        }];
498
499        merge_save_session(&storage, &mut runtime_copy)
500            .await
501            .unwrap();
502
503        let after = storage.load_session(session_id).await.unwrap().unwrap();
504        assert_eq!(after.title, "Fresh Title");
505        assert_eq!(after.metadata_version, 5);
506        assert_eq!(after.messages.len(), 1);
507        assert_eq!(after.messages[0].content, "keep me");
508    }
509
510    // ── LockedSessionStore tests ────────────────────────────────────
511
512    #[tokio::test]
513    async fn locked_merge_save_runtime_serialises_concurrent_writes() {
514        let (_temp, storage) = make_storage().await;
515        let store = Arc::new(LockedSessionStore::new(storage));
516        let session_id = "lock-serial".to_string();
517
518        // Seed with base version.
519        let base = fresh(&session_id);
520        store.storage().save_session(&base).await.unwrap();
521
522        // Two concurrent authorised writers each bump and commit.
523        // We'll simulate via clone-and-bump-then-commit.
524        let store_a = store.clone();
525        let store_b = store.clone();
526        let sid_a = session_id.clone();
527        let sid_b = session_id.clone();
528
529        let a = tokio::spawn(async move {
530            let _guard = store_a.acquire_lock(&sid_a).await;
531            let mut s = store_a
532                .storage()
533                .load_session(&sid_a)
534                .await
535                .unwrap()
536                .unwrap();
537            s.title = "Writer A".to_string();
538            s.title_version = s.title_version.saturating_add(1);
539            s.metadata_version = s.metadata_version.saturating_add(1);
540            s.updated_at = chrono::Utc::now();
541            store_a.storage().save_session(&s).await.unwrap();
542            s.title_version
543        });
544
545        // Tiny yield so A goes first.
546        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
547
548        let b = tokio::spawn(async move {
549            let _guard = store_b.acquire_lock(&sid_b).await;
550            let mut s = store_b
551                .storage()
552                .load_session(&sid_b)
553                .await
554                .unwrap()
555                .unwrap();
556            s.title = "Writer B".to_string();
557            s.title_version = s.title_version.saturating_add(1);
558            s.metadata_version = s.metadata_version.saturating_add(1);
559            s.updated_at = chrono::Utc::now();
560            store_b.storage().save_session(&s).await.unwrap();
561            s.title_version
562        });
563
564        let (ver_a, ver_b) = tokio::join!(a, b);
565        let final_s = store
566            .storage()
567            .load_session(&session_id)
568            .await
569            .unwrap()
570            .unwrap();
571        assert!(
572            ver_a.unwrap() != ver_b.unwrap(),
573            "concurrent writers must produce distinct versions"
574        );
575        assert_eq!(final_s.metadata_version, 2);
576    }
577
578    #[tokio::test]
579    async fn commit_metadata_is_plain_save_inside_lock() {
580        let (_temp, storage) = make_storage().await;
581        let store = LockedSessionStore::new(storage);
582        let session_id = "commit-plain";
583
584        let mut s = fresh(session_id);
585        s.title = "Committed".to_string();
586        s.metadata_version = 1;
587        s.title_version = 2;
588
589        store.commit_metadata(&s).await.unwrap();
590
591        let after = store
592            .storage()
593            .load_session(session_id)
594            .await
595            .unwrap()
596            .unwrap();
597        assert_eq!(after.title, "Committed");
598        assert_eq!(after.metadata_version, 1);
599        assert_eq!(after.title_version, 2);
600    }
601}