Skip to main content

bamboo_infrastructure/storage/
session_merge.rs

1//! Merge-aware session save helper.
2//!
3//! Provides [`merge_save_session`], which preserves any concurrent UI edits to
4//! the authoritative metadata group (`title`, `title_version`, `pinned`,
5//! `metadata_version`) before writing the runtime-modified session to storage.
6//! Re-reads the latest persisted copy and only takes in-memory values when the
7//! caller's `metadata_version` strictly exceeds disk's.
8//!
9//! ## Field-by-field merge policy
10//!
11//! All authoritative metadata fields are grouped under `metadata_version`:
12//! when `disk.metadata_version >= session.metadata_version`, the on-disk
13//! `title`, `title_version`, `pinned`, and `metadata_version` overwrite the
14//! in-memory values before writing. Authoritative writers bump
15//! `metadata_version` (and `title_version` for title edits) before calling so
16//! their values survive the merge; non-authoritative writers don't bump and so
17//! are overwritten by any later disk changes.
18//!
19//! ## Two save primitives
20//!
21//! - **`merge_save_session`** — stateless merge+save. Still works for
22//!   non-authoritative writers that hold `Arc<dyn Storage>` directly.
23//! - **`LockedSessionStore::merge_save_runtime`** — per-session-locked variant
24//!   that additionally serializes writes for the same session. Prefer this for
25//!   server-side paths where an authoritative writer may race with a runtime
26//!   save.
27//! - **`LockedSessionStore::commit_metadata`** — plain save inside a per-session
28//!   lock. For authoritative writers that have already performed
29//!   load→mutate→bump inside the lock; no merge needed (they hold the latest).
30//!
31//! Bare [`Storage::save_session`] is reserved for first-write paths (e.g. new
32//! session creation) where there is no prior on-disk copy to merge against.
33
34use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42// ── LockedSessionStore ────────────────────────────────────────────────
43
44/// Wraps a [`Storage`] implementation with per-session write serialization.
45///
46/// Under the hood it maintains a `DashMap<String, Arc<Mutex<()>>>` so that
47/// only writes targeting the *same* session are serialised; different
48/// sessions proceed concurrently.
49pub struct LockedSessionStore {
50    storage: Arc<dyn Storage>,
51    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
52}
53
54impl LockedSessionStore {
55    /// Wrap an existing storage backend.
56    pub fn new(storage: Arc<dyn Storage>) -> Self {
57        Self {
58            storage,
59            locks: Arc::new(DashMap::new()),
60        }
61    }
62
63    /// Borrow the inner storage for read-only access.
64    pub fn storage(&self) -> &Arc<dyn Storage> {
65        &self.storage
66    }
67
68    /// Acquire a per-session serialization guard.
69    ///
70    /// Only writes for the **same** session are serialised; writes for
71    /// different sessions can proceed concurrently.
72    pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
73        let lock = self
74            .locks
75            .entry(session_id.to_string())
76            .or_insert_with(|| Arc::new(Mutex::new(())))
77            .clone();
78        lock.lock_owned().await
79    }
80
81    /// Authoritative metadata commit.
82    ///
83    /// The caller must have already loaded the latest session, mutated the
84    /// metadata fields, and bumped `metadata_version` (and `title_version` if
85    /// applicable).  This method simply acquires the per-session lock and
86    /// performs a plain `storage.save_session`.
87    ///
88    /// The lock guarantees that no other write for this session interleaves
89    /// between the caller's load and this save, so merge is unnecessary.
90    pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
91        let _guard = self.acquire_lock(&session.id).await;
92        self.storage.save_session(session).await
93    }
94
95    /// Runtime / non-authoritative save with per-session lock.
96    ///
97    /// Inside the lock: reload disk, merge the authoritative metadata group
98    /// (`title`, `title_version`, `pinned`, `metadata_version`) from disk into
99    /// the in-memory copy if disk's `metadata_version >= session.metadata_version`,
100    /// then save.
101    ///
102    /// This is the locked equivalent of [`merge_save_session`]; prefer it for
103    /// server-side paths where an authoritative write may race with this save.
104    pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
105        let _guard = self.acquire_lock(&session.id).await;
106        merge_authoritative_metadata_into_stale(&self.storage, session).await;
107        self.storage.save_session(session).await
108    }
109}
110
111/// Infrastructure implementation of the domain runtime-persistence port.
112/// Server should assemble this as `Arc<dyn RuntimeSessionPersistence>` and must
113/// not define a separate adapter layer for the same behavior.
114#[async_trait::async_trait]
115impl RuntimeSessionPersistence for LockedSessionStore {
116    async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
117        self.merge_save_runtime(session).await
118    }
119}
120
121// ── Internal merge helper ─────────────────────────────────────────────
122
123/// Re-read the on-disk session and, when the disk copy carries a
124/// `metadata_version >= session.metadata_version`, overwrite the in-memory
125/// authoritative metadata fields with the disk values.
126///
127/// This is the core staleness-correction: non-authoritative writers call it
128/// before saving so they don't accidentally revert a concurrent UI edit.
129async fn merge_authoritative_metadata_into_stale(
130    storage: &Arc<dyn Storage>,
131    session: &mut Session,
132) {
133    if let Ok(Some(latest)) = storage.load_session(&session.id).await {
134        if latest.metadata_version >= session.metadata_version {
135            session.title = latest.title;
136            session.title_version = latest.title_version;
137            session.pinned = latest.pinned;
138            session.metadata_version = latest.metadata_version;
139        }
140    }
141}
142
143// ── Free merge-save function ──────────────────────────────────────────
144
145/// Save a session while preserving any concurrent UI edits to the
146/// authoritative metadata group.
147///
148/// Behaviour: if the on-disk session has `metadata_version >=
149/// session.metadata_version`, the on-disk `title`, `title_version`, `pinned`
150/// and `metadata_version` overwrite the in-memory values before writing.
151///
152/// This is the stateless variant (no per-session lock). Prefer
153/// [`LockedSessionStore::merge_save_runtime`] for server-side paths where an
154/// authoritative writer may race with this save.
155pub async fn merge_save_session(
156    storage: &Arc<dyn Storage>,
157    session: &mut Session,
158) -> std::io::Result<()> {
159    merge_authoritative_metadata_into_stale(storage, session).await;
160    storage.save_session(session).await
161}
162
163// ── Tests ─────────────────────────────────────────────────────────────
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::storage::v2::SessionStoreV2;
169    use bamboo_domain::session::types::Session;
170
171    async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
172        let temp = tempfile::tempdir().unwrap();
173        let storage = SessionStoreV2::new(temp.path().to_path_buf())
174            .await
175            .expect("storage init");
176        (temp, Arc::new(storage) as Arc<dyn Storage>)
177    }
178
179    fn fresh(id: &str) -> Session {
180        Session::new(id.to_string(), "test-model".to_string())
181    }
182
183    // ── Free-function merge tests (updated for metadata-group) ──────
184
185    #[tokio::test]
186    async fn merge_preserves_disk_title_when_versions_equal() {
187        let (_temp, storage) = make_storage().await;
188        let session_id = "merge-equal";
189
190        let mut on_disk = fresh(session_id);
191        on_disk.title = "User Set This".to_string();
192        on_disk.title_version = 0;
193        on_disk.metadata_version = 0;
194        storage.save_session(&on_disk).await.unwrap();
195
196        let mut runtime_copy = fresh(session_id);
197        runtime_copy.title = "Stale Default".to_string();
198        runtime_copy.title_version = 0;
199        runtime_copy.metadata_version = 0;
200        runtime_copy.messages = vec![];
201
202        merge_save_session(&storage, &mut runtime_copy)
203            .await
204            .unwrap();
205
206        let after = storage.load_session(session_id).await.unwrap().unwrap();
207        assert_eq!(after.title, "User Set This");
208        assert_eq!(after.title_version, 0);
209        assert_eq!(runtime_copy.title, "User Set This");
210    }
211
212    #[tokio::test]
213    async fn merge_preserves_disk_when_disk_version_higher() {
214        let (_temp, storage) = make_storage().await;
215        let session_id = "merge-higher";
216
217        let mut on_disk = fresh(session_id);
218        on_disk.title = "User Title v3".to_string();
219        on_disk.title_version = 3;
220        on_disk.metadata_version = 5;
221        storage.save_session(&on_disk).await.unwrap();
222
223        let mut runtime_copy = fresh(session_id);
224        runtime_copy.title = "Stale".to_string();
225        runtime_copy.title_version = 1;
226        runtime_copy.metadata_version = 0;
227
228        merge_save_session(&storage, &mut runtime_copy)
229            .await
230            .unwrap();
231
232        let after = storage.load_session(session_id).await.unwrap().unwrap();
233        assert_eq!(after.title, "User Title v3");
234        assert_eq!(after.title_version, 3);
235        assert_eq!(after.metadata_version, 5);
236    }
237
238    #[tokio::test]
239    async fn merge_now_preserves_disk_pinned_in_metadata_group() {
240        let (_temp, storage) = make_storage().await;
241        let session_id = "pinned-merge";
242
243        let mut on_disk = fresh(session_id);
244        on_disk.pinned = true;
245        on_disk.metadata_version = 2;
246        storage.save_session(&on_disk).await.unwrap();
247
248        let mut runtime_copy = fresh(session_id);
249        runtime_copy.pinned = false;
250        runtime_copy.metadata_version = 0;
251
252        merge_save_session(&storage, &mut runtime_copy)
253            .await
254            .unwrap();
255
256        let after = storage.load_session(session_id).await.unwrap().unwrap();
257        assert!(
258            after.pinned,
259            "disk pinned=true should win over runtime false"
260        );
261        assert_eq!(after.metadata_version, 2);
262    }
263
264    #[tokio::test]
265    async fn merge_keeps_in_memory_when_session_version_higher() {
266        let (_temp, storage) = make_storage().await;
267        let session_id = "merge-bumped";
268
269        let mut on_disk = fresh(session_id);
270        on_disk.title = "Old".to_string();
271        on_disk.title_version = 1;
272        on_disk.metadata_version = 3;
273        storage.save_session(&on_disk).await.unwrap();
274
275        let mut authoritative_copy = fresh(session_id);
276        authoritative_copy.title = "New Authoritative".to_string();
277        authoritative_copy.title_version = 2;
278        authoritative_copy.metadata_version = 4;
279        authoritative_copy.pinned = true;
280
281        merge_save_session(&storage, &mut authoritative_copy)
282            .await
283            .unwrap();
284
285        let after = storage.load_session(session_id).await.unwrap().unwrap();
286        assert_eq!(after.title, "New Authoritative");
287        assert_eq!(after.title_version, 2);
288        assert_eq!(after.metadata_version, 4);
289        assert!(after.pinned);
290    }
291
292    #[tokio::test]
293    async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
294        let (_temp, storage) = make_storage().await;
295        let session_id = "merge-messages";
296
297        let mut on_disk = fresh(session_id);
298        on_disk.title = "Fresh Title".to_string();
299        on_disk.title_version = 2;
300        on_disk.metadata_version = 5;
301        storage.save_session(&on_disk).await.unwrap();
302
303        let mut runtime_copy = fresh(session_id);
304        runtime_copy.title = "Stale".to_string();
305        runtime_copy.metadata_version = 0;
306        runtime_copy.messages = vec![bamboo_domain::session::types::Message {
307            role: bamboo_domain::session::types::Role::User,
308            content: "keep me".to_string(),
309            id: "msg-1".to_string(),
310            created_at: chrono::Utc::now(),
311            reasoning: None,
312            content_parts: None,
313            image_ocr: None,
314            phase: None,
315            tool_calls: None,
316            tool_call_id: None,
317            tool_success: None,
318            compressed: false,
319            compressed_by_event_id: None,
320            never_compress: false,
321            compression_level: 0,
322            metadata: None,
323        }];
324
325        merge_save_session(&storage, &mut runtime_copy)
326            .await
327            .unwrap();
328
329        let after = storage.load_session(session_id).await.unwrap().unwrap();
330        assert_eq!(after.title, "Fresh Title");
331        assert_eq!(after.metadata_version, 5);
332        assert_eq!(after.messages.len(), 1);
333        assert_eq!(after.messages[0].content, "keep me");
334    }
335
336    // ── LockedSessionStore tests ────────────────────────────────────
337
338    #[tokio::test]
339    async fn locked_merge_save_runtime_serialises_concurrent_writes() {
340        let (_temp, storage) = make_storage().await;
341        let store = Arc::new(LockedSessionStore::new(storage));
342        let session_id = "lock-serial".to_string();
343
344        // Seed with base version.
345        let base = fresh(&session_id);
346        store.storage().save_session(&base).await.unwrap();
347
348        // Two concurrent authorised writers each bump and commit.
349        // We'll simulate via clone-and-bump-then-commit.
350        let store_a = store.clone();
351        let store_b = store.clone();
352        let sid_a = session_id.clone();
353        let sid_b = session_id.clone();
354
355        let a = tokio::spawn(async move {
356            let _guard = store_a.acquire_lock(&sid_a).await;
357            let mut s = store_a
358                .storage()
359                .load_session(&sid_a)
360                .await
361                .unwrap()
362                .unwrap();
363            s.title = "Writer A".to_string();
364            s.title_version = s.title_version.saturating_add(1);
365            s.metadata_version = s.metadata_version.saturating_add(1);
366            s.updated_at = chrono::Utc::now();
367            store_a.storage().save_session(&s).await.unwrap();
368            s.title_version
369        });
370
371        // Tiny yield so A goes first.
372        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
373
374        let b = tokio::spawn(async move {
375            let _guard = store_b.acquire_lock(&sid_b).await;
376            let mut s = store_b
377                .storage()
378                .load_session(&sid_b)
379                .await
380                .unwrap()
381                .unwrap();
382            s.title = "Writer B".to_string();
383            s.title_version = s.title_version.saturating_add(1);
384            s.metadata_version = s.metadata_version.saturating_add(1);
385            s.updated_at = chrono::Utc::now();
386            store_b.storage().save_session(&s).await.unwrap();
387            s.title_version
388        });
389
390        let (ver_a, ver_b) = tokio::join!(a, b);
391        let final_s = store
392            .storage()
393            .load_session(&session_id)
394            .await
395            .unwrap()
396            .unwrap();
397        assert!(
398            ver_a.unwrap() != ver_b.unwrap(),
399            "concurrent writers must produce distinct versions"
400        );
401        assert_eq!(final_s.metadata_version, 2);
402    }
403
404    #[tokio::test]
405    async fn commit_metadata_is_plain_save_inside_lock() {
406        let (_temp, storage) = make_storage().await;
407        let store = LockedSessionStore::new(storage);
408        let session_id = "commit-plain";
409
410        let mut s = fresh(session_id);
411        s.title = "Committed".to_string();
412        s.metadata_version = 1;
413        s.title_version = 2;
414
415        store.commit_metadata(&s).await.unwrap();
416
417        let after = store
418            .storage()
419            .load_session(session_id)
420            .await
421            .unwrap()
422            .unwrap();
423        assert_eq!(after.title, "Committed");
424        assert_eq!(after.metadata_version, 1);
425        assert_eq!(after.title_version, 2);
426    }
427}