bamboo_infrastructure/storage/
session_merge.rs1use std::sync::Arc;
35
36use bamboo_domain::session::types::Session;
37use bamboo_domain::storage::Storage;
38use bamboo_domain::RuntimeSessionPersistence;
39use dashmap::DashMap;
40use tokio::sync::{Mutex, OwnedMutexGuard};
41
42pub struct LockedSessionStore {
50 storage: Arc<dyn Storage>,
51 locks: Arc<DashMap<String, Arc<Mutex<()>>>>,
52}
53
54impl LockedSessionStore {
55 pub fn new(storage: Arc<dyn Storage>) -> Self {
57 Self {
58 storage,
59 locks: Arc::new(DashMap::new()),
60 }
61 }
62
63 pub fn storage(&self) -> &Arc<dyn Storage> {
65 &self.storage
66 }
67
68 pub async fn acquire_lock(&self, session_id: &str) -> OwnedMutexGuard<()> {
73 let lock = self
74 .locks
75 .entry(session_id.to_string())
76 .or_insert_with(|| Arc::new(Mutex::new(())))
77 .clone();
78 lock.lock_owned().await
79 }
80
81 pub async fn commit_metadata(&self, session: &Session) -> std::io::Result<()> {
91 let _guard = self.acquire_lock(&session.id).await;
92 self.storage.save_session(session).await
93 }
94
95 pub async fn merge_save_runtime(&self, session: &mut Session) -> std::io::Result<()> {
105 let _guard = self.acquire_lock(&session.id).await;
106 merge_authoritative_metadata_into_stale(&self.storage, session).await;
107 self.storage.save_session(session).await
108 }
109}
110
111#[async_trait::async_trait]
115impl RuntimeSessionPersistence for LockedSessionStore {
116 async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
117 self.merge_save_runtime(session).await
118 }
119}
120
121async fn merge_authoritative_metadata_into_stale(
130 storage: &Arc<dyn Storage>,
131 session: &mut Session,
132) {
133 if let Ok(Some(latest)) = storage.load_session(&session.id).await {
134 if latest.metadata_version >= session.metadata_version {
135 session.title = latest.title;
136 session.title_version = latest.title_version;
137 session.pinned = latest.pinned;
138 session.metadata_version = latest.metadata_version;
139 }
140 }
141}
142
143pub async fn merge_save_session(
156 storage: &Arc<dyn Storage>,
157 session: &mut Session,
158) -> std::io::Result<()> {
159 merge_authoritative_metadata_into_stale(storage, session).await;
160 storage.save_session(session).await
161}
162
163#[cfg(test)]
166mod tests {
167 use super::*;
168 use crate::storage::v2::SessionStoreV2;
169 use bamboo_domain::session::types::Session;
170
171 async fn make_storage() -> (tempfile::TempDir, Arc<dyn Storage>) {
172 let temp = tempfile::tempdir().unwrap();
173 let storage = SessionStoreV2::new(temp.path().to_path_buf())
174 .await
175 .expect("storage init");
176 (temp, Arc::new(storage) as Arc<dyn Storage>)
177 }
178
179 fn fresh(id: &str) -> Session {
180 Session::new(id.to_string(), "test-model".to_string())
181 }
182
183 #[tokio::test]
186 async fn merge_preserves_disk_title_when_versions_equal() {
187 let (_temp, storage) = make_storage().await;
188 let session_id = "merge-equal";
189
190 let mut on_disk = fresh(session_id);
191 on_disk.title = "User Set This".to_string();
192 on_disk.title_version = 0;
193 on_disk.metadata_version = 0;
194 storage.save_session(&on_disk).await.unwrap();
195
196 let mut runtime_copy = fresh(session_id);
197 runtime_copy.title = "Stale Default".to_string();
198 runtime_copy.title_version = 0;
199 runtime_copy.metadata_version = 0;
200 runtime_copy.messages = vec![];
201
202 merge_save_session(&storage, &mut runtime_copy).await.unwrap();
203
204 let after = storage.load_session(session_id).await.unwrap().unwrap();
205 assert_eq!(after.title, "User Set This");
206 assert_eq!(after.title_version, 0);
207 assert_eq!(runtime_copy.title, "User Set This");
208 }
209
210 #[tokio::test]
211 async fn merge_preserves_disk_when_disk_version_higher() {
212 let (_temp, storage) = make_storage().await;
213 let session_id = "merge-higher";
214
215 let mut on_disk = fresh(session_id);
216 on_disk.title = "User Title v3".to_string();
217 on_disk.title_version = 3;
218 on_disk.metadata_version = 5;
219 storage.save_session(&on_disk).await.unwrap();
220
221 let mut runtime_copy = fresh(session_id);
222 runtime_copy.title = "Stale".to_string();
223 runtime_copy.title_version = 1;
224 runtime_copy.metadata_version = 0;
225
226 merge_save_session(&storage, &mut runtime_copy).await.unwrap();
227
228 let after = storage.load_session(session_id).await.unwrap().unwrap();
229 assert_eq!(after.title, "User Title v3");
230 assert_eq!(after.title_version, 3);
231 assert_eq!(after.metadata_version, 5);
232 }
233
234 #[tokio::test]
235 async fn merge_now_preserves_disk_pinned_in_metadata_group() {
236 let (_temp, storage) = make_storage().await;
237 let session_id = "pinned-merge";
238
239 let mut on_disk = fresh(session_id);
240 on_disk.pinned = true;
241 on_disk.metadata_version = 2;
242 storage.save_session(&on_disk).await.unwrap();
243
244 let mut runtime_copy = fresh(session_id);
245 runtime_copy.pinned = false;
246 runtime_copy.metadata_version = 0;
247
248 merge_save_session(&storage, &mut runtime_copy).await.unwrap();
249
250 let after = storage.load_session(session_id).await.unwrap().unwrap();
251 assert!(after.pinned, "disk pinned=true should win over runtime false");
252 assert_eq!(after.metadata_version, 2);
253 }
254
255 #[tokio::test]
256 async fn merge_keeps_in_memory_when_session_version_higher() {
257 let (_temp, storage) = make_storage().await;
258 let session_id = "merge-bumped";
259
260 let mut on_disk = fresh(session_id);
261 on_disk.title = "Old".to_string();
262 on_disk.title_version = 1;
263 on_disk.metadata_version = 3;
264 storage.save_session(&on_disk).await.unwrap();
265
266 let mut authoritative_copy = fresh(session_id);
267 authoritative_copy.title = "New Authoritative".to_string();
268 authoritative_copy.title_version = 2;
269 authoritative_copy.metadata_version = 4;
270 authoritative_copy.pinned = true;
271
272 merge_save_session(&storage, &mut authoritative_copy).await.unwrap();
273
274 let after = storage.load_session(session_id).await.unwrap().unwrap();
275 assert_eq!(after.title, "New Authoritative");
276 assert_eq!(after.title_version, 2);
277 assert_eq!(after.metadata_version, 4);
278 assert!(after.pinned);
279 }
280
281 #[tokio::test]
282 async fn merge_keeps_runtime_messages_when_disk_only_changed_metadata() {
283 let (_temp, storage) = make_storage().await;
284 let session_id = "merge-messages";
285
286 let mut on_disk = fresh(session_id);
287 on_disk.title = "Fresh Title".to_string();
288 on_disk.title_version = 2;
289 on_disk.metadata_version = 5;
290 storage.save_session(&on_disk).await.unwrap();
291
292 let mut runtime_copy = fresh(session_id);
293 runtime_copy.title = "Stale".to_string();
294 runtime_copy.metadata_version = 0;
295 runtime_copy.messages = vec![bamboo_domain::session::types::Message {
296 role: bamboo_domain::session::types::Role::User,
297 content: "keep me".to_string(),
298 id: "msg-1".to_string(),
299 created_at: chrono::Utc::now(),
300 reasoning: None,
301 content_parts: None,
302 image_ocr: None,
303 phase: None,
304 tool_calls: None,
305 tool_call_id: None,
306 tool_success: None,
307 compressed: false,
308 compressed_by_event_id: None,
309 never_compress: false,
310 compression_level: 0,
311 metadata: None,
312 }];
313
314 merge_save_session(&storage, &mut runtime_copy).await.unwrap();
315
316 let after = storage.load_session(session_id).await.unwrap().unwrap();
317 assert_eq!(after.title, "Fresh Title");
318 assert_eq!(after.metadata_version, 5);
319 assert_eq!(after.messages.len(), 1);
320 assert_eq!(after.messages[0].content, "keep me");
321 }
322
323 #[tokio::test]
326 async fn locked_merge_save_runtime_serialises_concurrent_writes() {
327 let (_temp, storage) = make_storage().await;
328 let store = Arc::new(LockedSessionStore::new(storage));
329 let session_id = "lock-serial".to_string();
330
331 let base = fresh(&session_id);
333 store.storage().save_session(&base).await.unwrap();
334
335 let store_a = store.clone();
338 let store_b = store.clone();
339 let sid_a = session_id.clone();
340 let sid_b = session_id.clone();
341
342 let a = tokio::spawn(async move {
343 let _guard = store_a.acquire_lock(&sid_a).await;
344 let mut s = store_a.storage().load_session(&sid_a).await.unwrap().unwrap();
345 s.title = "Writer A".to_string();
346 s.title_version = s.title_version.saturating_add(1);
347 s.metadata_version = s.metadata_version.saturating_add(1);
348 s.updated_at = chrono::Utc::now();
349 store_a.storage().save_session(&s).await.unwrap();
350 s.title_version
351 });
352
353 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
355
356 let b = tokio::spawn(async move {
357 let _guard = store_b.acquire_lock(&sid_b).await;
358 let mut s = store_b.storage().load_session(&sid_b).await.unwrap().unwrap();
359 s.title = "Writer B".to_string();
360 s.title_version = s.title_version.saturating_add(1);
361 s.metadata_version = s.metadata_version.saturating_add(1);
362 s.updated_at = chrono::Utc::now();
363 store_b.storage().save_session(&s).await.unwrap();
364 s.title_version
365 });
366
367 let (ver_a, ver_b) = tokio::join!(a, b);
368 let final_s = store.storage().load_session(&session_id).await.unwrap().unwrap();
369 assert!(
370 ver_a.unwrap() != ver_b.unwrap(),
371 "concurrent writers must produce distinct versions"
372 );
373 assert_eq!(final_s.metadata_version, 2);
374 }
375
376 #[tokio::test]
377 async fn commit_metadata_is_plain_save_inside_lock() {
378 let (_temp, storage) = make_storage().await;
379 let store = LockedSessionStore::new(storage);
380 let session_id = "commit-plain";
381
382 let mut s = fresh(session_id);
383 s.title = "Committed".to_string();
384 s.metadata_version = 1;
385 s.title_version = 2;
386
387 store.commit_metadata(&s).await.unwrap();
388
389 let after = store.storage().load_session(session_id).await.unwrap().unwrap();
390 assert_eq!(after.title, "Committed");
391 assert_eq!(after.metadata_version, 1);
392 assert_eq!(after.title_version, 2);
393 }
394}