Skip to main content

bamboo_infrastructure/storage/
session_merge.rs

1//! Merge-aware session save helper.
2//!
3//! Provides [`merge_save_session`], which preserves any concurrent UI edits to
4//! the authoritative metadata group (`title`, `title_version`, `pinned`,
5//! `metadata_version`) before writing the runtime-modified session to storage.
6//! Re-reads the latest persisted copy and only takes in-memory values when the
7//! caller's `metadata_version` strictly exceeds disk's.
8//!
9//! ## Field-by-field merge policy
10//!
11//! All authoritative metadata fields are grouped under `metadata_version`:
12//! when `disk.metadata_version >= session.metadata_version`, the on-disk
13//! `title`, `title_version`, `pinned`, and `metadata_version` overwrite the
14//! in-memory values before writing. Authoritative writers bump
15//! `metadata_version` (and `title_version` for title edits) before calling so
16//! their values survive the merge; non-authoritative writers don't bump and so
17//! are overwritten by any later disk changes.
18//!
19//! ## Two save primitives
20//!
21//! - **`merge_save_session`** — stateless merge+save. Still works for
22//!   non-authoritative writers that hold `Arc<dyn Storage>` directly.
23//! - **`LockedSessionStore::merge_save_runtime`** — per-session-locked variant
24//!   that additionally serializes writes for the same session. Prefer this for
25//!   server-side paths where an authoritative writer may race with a runtime
26//!   save.
27//! - **`LockedSessionStore::commit_metadata`** — plain save inside a per-session
28//!   lock. For authoritative writers that have already performed
29//!   load→mutate→bump inside the lock; no merge needed (they hold the latest).
30//!
31//! Bare [`Storage::save_session`] is reserved for first-write paths (e.g. new
32//! session creation) where there is no prior on-disk copy to merge against.
33
34use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42// ── LockedSessionStore ────────────────────────────────────────────────
43
44/// Wraps a [`Storage`] implementation with per-session write serialization.
45///
46/// Under the hood it maintains a `DashMap<String, Arc<Mutex<()>>>` so that
47/// only writes targeting the *same* session are serialised; different
48/// sessions proceed concurrently.
49pub struct LockedSessionStore {
50    storage: Arc<dyn Storage>,
51    locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
52}
53
54impl LockedSessionStore {
55    /// Wrap an existing storage backend.
56    pub fn new(storage: Arc<dyn Storage>) -> Self {
57        Self {
58            storage,
59            locks: Arc::new(DashMap::new()),
60        }
61    }
62
63    /// Borrow the inner storage for read-only access.
64    pub fn storage(&self) -> &Arc<dyn Storage> {
65        &self.storage
66    }
67
68    /// Acquire a per-session serialization guard.
69    ///
70    /// Only writes for the **same** session are serialised; writes for
71    /// different sessions can proceed concurrently.
72    pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
73        let lock = self
74            .locks
75            .entry(session_id.to_string())
76            .or_insert_with(|| Arc::new(Mutex::new(())))
77            .clone();
78        lock.lock_owned().await
79    }
80
81    /// Authoritative metadata commit.
82    ///
83    /// The caller must have already loaded the latest session, mutated the
84    /// metadata fields, and bumped `metadata_version` (and `title_version` if
85    /// applicable).  This method simply acquires the per-session lock and
86    /// performs a plain `storage.save_session`.
87    ///
88    /// The lock guarantees that no other write for this session interleaves
89    /// between the caller's load and this save, so merge is unnecessary.
90    pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
91        let _guard = self.acquire_lock(&session.id).await;
92        self.storage.save_session(session).await
93    }
94
95    /// Runtime / non-authoritative save with per-session lock.
96    ///
97    /// Inside the lock: reload disk, merge the authoritative metadata group
98    /// (`title`, `title_version`, `pinned`, `metadata_version`) from disk into
99    /// the in-memory copy if disk's `metadata_version >= session.metadata_version`,
100    /// then save.
101    ///
102    /// This is the locked equivalent of [`merge_save_session`]; prefer it for
103    /// server-side paths where an authoritative write may race with this save.
104    pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
105        let _guard = self.acquire_lock(&session.id).await;
106        merge_authoritative_metadata_into_stale(&self.storage, session).await;
107        self.storage.save_session(session).await
108    }
109}
110
111/// Infrastructure implementation of the domain runtime-persistence port.
112/// Server should assemble this as `Arc<dyn RuntimeSessionPersistence>` and must
113/// not define a separate adapter layer for the same behavior.
114#[async_trait::async_trait]
115impl RuntimeSessionPersistence for LockedSessionStore {
116    async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
117        self.merge_save_runtime(session).await
118    }
119}
120
121// ── Internal merge helper ─────────────────────────────────────────────
122
123/// Re-read the on-disk session and, when the disk copy carries a
124/// `metadata_version >= session.metadata_version`, overwrite the in-memory
125/// authoritative metadata fields with the disk values.
126///
127/// This is the core staleness-correction: non-authoritative writers call it
128/// before saving so they don't accidentally revert a concurrent UI edit.
129async fn merge_authoritative_metadata_into_stale(
130    storage: &Arc<dyn Storage>,
131    session: &mut Session,
132) {
133    if let Ok(Some(latest)) = storage.load_session(&session.id).await {
134        if latest.metadata_version >= session.metadata_version {
135            session.title = latest.title;
136            session.title_version = latest.title_version;
137            session.pinned = latest.pinned;
138            session.metadata_version = latest.metadata_version;
139        }
140    }
141}
142
143// ── Free merge-save function ──────────────────────────────────────────
144
145/// Save a session while preserving any concurrent UI edits to the
146/// authoritative metadata group.
147///
148/// Behaviour: if the on-disk session has `metadata_version >=
149/// session.metadata_version`, the on-disk `title`, `title_version`, `pinned`
150/// and `metadata_version` overwrite the in-memory values before writing.
151///
152/// This is the stateless variant (no per-session lock). Prefer
153/// [`LockedSessionStore::merge_save_runtime`] for server-side paths where an
154/// authoritative writer may race with this save.
155pub async fn merge_save_session(
156    storage: &Arc<dyn Storage>,
157    session: &mut Session,
158) -> std::io::Result<()> {
159    merge_authoritative_metadata_into_stale(storage, session).await;
160    storage.save_session(session).await
161}
162
163// ── Tests ─────────────────────────────────────────────────────────────
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use crate::storage::v2::SessionStoreV2;
169    use bamboo_domain::session::types::Session;
170
171    async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
172        let temp = tempfile::tempdir().unwrap();
173        let storage = SessionStoreV2::new(temp.path().to_path_buf())
174            .await
175            .expect("storage init");
176        (temp, Arc::new(storage) as Arc<dyn Storage>)
177    }
178
179    fn fresh(id: &str) -> Session {
180        Session::new(id.to_string(), "test-model".to_string())
181    }
182
183    // ── Free-function merge tests (updated for metadata-group) ──────
184
185    #[tokio::test]
186    async fn merge_preserves_disk_title_when_versions_equal() {
187        let (_temp, storage) = make_storage().await;
188        let session_id = "merge-equal";
189
190        let mut on_disk = fresh(session_id);
191        on_disk.title = "User Set This".to_string();
192        on_disk.title_version = 0;
193        on_disk.metadata_version = 0;
194        storage.save_session(&on_disk).await.unwrap();
195
196        let mut runtime_copy = fresh(session_id);
197        runtime_copy.title = "Stale Default".to_string();
198        runtime_copy.title_version = 0;
199        runtime_copy.metadata_version = 0;
200        runtime_copy.messages = vec![];
201
202        merge_save_session(&storage, &mut runtime_copy).await.unwrap();
203
204        let after = storage.load_session(session_id).await.unwrap().unwrap();
205        assert_eq!(after.title, "User Set This");
206        assert_eq!(after.title_version, 0);
207        assert_eq!(runtime_copy.title, "User Set This");
208    }
209
210    #[tokio::test]
211    async fn merge_preserves_disk_when_disk_version_higher() {
212        let (_temp, storage) = make_storage().await;
213        let session_id = "merge-higher";
214
215        let mut on_disk = fresh(session_id);
216        on_disk.title = "User Title v3".to_string();
217        on_disk.title_version = 3;
218        on_disk.metadata_version = 5;
219        storage.save_session(&on_disk).await.unwrap();
220
221        let mut runtime_copy = fresh(session_id);
222        runtime_copy.title = "Stale".to_string();
223        runtime_copy.title_version = 1;
224        runtime_copy.metadata_version = 0;
225
226        merge_save_session(&storage, &mut runtime_copy).await.unwrap();
227
228        let after = storage.load_session(session_id).await.unwrap().unwrap();
229        assert_eq!(after.title, "User Title v3");
230        assert_eq!(after.title_version, 3);
231        assert_eq!(after.metadata_version, 5);
232    }
233
234    #[tokio::test]
235    async fn merge_now_preserves_disk_pinned_in_metadata_group() {
236        let (_temp, storage) = make_storage().await;
237        let session_id = "pinned-merge";
238
239        let mut on_disk = fresh(session_id);
240        on_disk.pinned = true;
241        on_disk.metadata_version = 2;
242        storage.save_session(&on_disk).await.unwrap();
243
244        let mut runtime_copy = fresh(session_id);
245        runtime_copy.pinned = false;
246        runtime_copy.metadata_version = 0;
247
248        merge_save_session(&storage, &mut runtime_copy).await.unwrap();
249
250        let after = storage.load_session(session_id).await.unwrap().unwrap();
251        assert!(after.pinned, "disk pinned=true should win over runtime false");
252        assert_eq!(after.metadata_version, 2);
253    }
254
255    #[tokio::test]
256    async fn merge_keeps_in_memory_when_session_version_higher() {
257        let (_temp, storage) = make_storage().await;
258        let session_id = "merge-bumped";
259
260        let mut on_disk = fresh(session_id);
261        on_disk.title = "Old".to_string();
262        on_disk.title_version = 1;
263        on_disk.metadata_version = 3;
264        storage.save_session(&on_disk).await.unwrap();
265
266        let mut authoritative_copy = fresh(session_id);
267        authoritative_copy.title = "New Authoritative".to_string();
268        authoritative_copy.title_version = 2;
269        authoritative_copy.metadata_version = 4;
270        authoritative_copy.pinned = true;
271
272        merge_save_session(&storage, &mut authoritative_copy).await.unwrap();
273
274        let after = storage.load_session(session_id).await.unwrap().unwrap();
275        assert_eq!(after.title, "New Authoritative");
276        assert_eq!(after.title_version, 2);
277        assert_eq!(after.metadata_version, 4);
278        assert!(after.pinned);
279    }
280
281    #[tokio::test]
282    async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
283        let (_temp, storage) = make_storage().await;
284        let session_id = "merge-messages";
285
286        let mut on_disk = fresh(session_id);
287        on_disk.title = "Fresh Title".to_string();
288        on_disk.title_version = 2;
289        on_disk.metadata_version = 5;
290        storage.save_session(&on_disk).await.unwrap();
291
292        let mut runtime_copy = fresh(session_id);
293        runtime_copy.title = "Stale".to_string();
294        runtime_copy.metadata_version = 0;
295        runtime_copy.messages = vec![bamboo_domain::session::types::Message {
296            role: bamboo_domain::session::types::Role::User,
297            content: "keep me".to_string(),
298            id: "msg-1".to_string(),
299            created_at: chrono::Utc::now(),
300            reasoning: None,
301            content_parts: None,
302            image_ocr: None,
303            phase: None,
304            tool_calls: None,
305            tool_call_id: None,
306            tool_success: None,
307            compressed: false,
308            compressed_by_event_id: None,
309            never_compress: false,
310            compression_level: 0,
311            metadata: None,
312        }];
313
314        merge_save_session(&storage, &mut runtime_copy).await.unwrap();
315
316        let after = storage.load_session(session_id).await.unwrap().unwrap();
317        assert_eq!(after.title, "Fresh Title");
318        assert_eq!(after.metadata_version, 5);
319        assert_eq!(after.messages.len(), 1);
320        assert_eq!(after.messages[0].content, "keep me");
321    }
322
323    // ── LockedSessionStore tests ────────────────────────────────────
324
325    #[tokio::test]
326    async fn locked_merge_save_runtime_serialises_concurrent_writes() {
327        let (_temp, storage) = make_storage().await;
328        let store = Arc::new(LockedSessionStore::new(storage));
329        let session_id = "lock-serial".to_string();
330
331        // Seed with base version.
332        let base = fresh(&session_id);
333        store.storage().save_session(&base).await.unwrap();
334
335        // Two concurrent authorised writers each bump and commit.
336        // We'll simulate via clone-and-bump-then-commit.
337        let store_a = store.clone();
338        let store_b = store.clone();
339        let sid_a = session_id.clone();
340        let sid_b = session_id.clone();
341
342        let a = tokio::spawn(async move {
343            let _guard = store_a.acquire_lock(&sid_a).await;
344            let mut s = store_a.storage().load_session(&sid_a).await.unwrap().unwrap();
345            s.title = "Writer A".to_string();
346            s.title_version = s.title_version.saturating_add(1);
347            s.metadata_version = s.metadata_version.saturating_add(1);
348            s.updated_at = chrono::Utc::now();
349            store_a.storage().save_session(&s).await.unwrap();
350            s.title_version
351        });
352
353        // Tiny yield so A goes first.
354        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
355
356        let b = tokio::spawn(async move {
357            let _guard = store_b.acquire_lock(&sid_b).await;
358            let mut s = store_b.storage().load_session(&sid_b).await.unwrap().unwrap();
359            s.title = "Writer B".to_string();
360            s.title_version = s.title_version.saturating_add(1);
361            s.metadata_version = s.metadata_version.saturating_add(1);
362            s.updated_at = chrono::Utc::now();
363            store_b.storage().save_session(&s).await.unwrap();
364            s.title_version
365        });
366
367        let (ver_a, ver_b) = tokio::join!(a, b);
368        let final_s = store.storage().load_session(&session_id).await.unwrap().unwrap();
369        assert!(
370            ver_a.unwrap() != ver_b.unwrap(),
371            "concurrent writers must produce distinct versions"
372        );
373        assert_eq!(final_s.metadata_version, 2);
374    }
375
376    #[tokio::test]
377    async fn commit_metadata_is_plain_save_inside_lock() {
378        let (_temp, storage) = make_storage().await;
379        let store = LockedSessionStore::new(storage);
380        let session_id = "commit-plain";
381
382        let mut s = fresh(session_id);
383        s.title = "Committed".to_string();
384        s.metadata_version = 1;
385        s.title_version = 2;
386
387        store.commit_metadata(&s).await.unwrap();
388
389        let after = store.storage().load_session(session_id).await.unwrap().unwrap();
390        assert_eq!(after.title, "Committed");
391        assert_eq!(after.metadata_version, 1);
392        assert_eq!(after.title_version, 2);
393    }
394}