1use std::sync::Arc;
16
17use bamboo_agent_core::storage::Storage;
18use bamboo_agent_core::Session;
19use bamboo_storage::LockedSessionStore;
20
21use crate::{read_cached_session, SessionCache};
22
23#[derive(Clone)]
26pub struct SessionRepository {
27 cache: SessionCache,
28 storage: Arc<dyn Storage>,
29 persistence: Arc<LockedSessionStore>,
30}
31
32impl SessionRepository {
33 pub fn new(
34 cache: SessionCache,
35 storage: Arc<dyn Storage>,
36 persistence: Arc<LockedSessionStore>,
37 ) -> Self {
38 Self {
39 cache,
40 storage,
41 persistence,
42 }
43 }
44
45 pub fn cache(&self) -> &SessionCache {
46 &self.cache
47 }
48
49 pub fn storage(&self) -> &Arc<dyn Storage> {
50 &self.storage
51 }
52
53 pub fn persistence(&self) -> &Arc<LockedSessionStore> {
54 &self.persistence
55 }
56
57 pub async fn load(&self, session_id: &str) -> Option<Session> {
60 if let Some(session) = read_cached_session(&self.cache, session_id) {
61 return Some(session);
62 }
63
64 match self.storage.load_session(session_id).await {
65 Ok(Some(session)) => {
66 self.cache.insert(
67 session_id.to_string(),
68 Arc::new(parking_lot::RwLock::new(session.clone())),
69 );
70 Some(session)
71 }
72 _ => None,
73 }
74 }
75
76 pub async fn try_load(&self, session_id: &str) -> std::io::Result<Option<Session>> {
80 if let Some(session) = read_cached_session(&self.cache, session_id) {
81 return Ok(Some(session));
82 }
83 let loaded = self.storage.load_session(session_id).await?;
84 if let Some(ref session) = loaded {
85 self.cache.insert(
86 session_id.to_string(),
87 Arc::new(parking_lot::RwLock::new(session.clone())),
88 );
89 }
90 Ok(loaded)
91 }
92
93 pub async fn save(&self, session: &mut Session) -> std::io::Result<()> {
97 self.persistence.merge_save_runtime(session).await?;
98 self.cache.insert(
99 session.id.clone(),
100 Arc::new(parking_lot::RwLock::new(session.clone())),
101 );
102 Ok(())
103 }
104
105 pub async fn load_or_create(&self, session_id: &str, model: &str) -> Session {
107 if let Some(session) = self.load(session_id).await {
108 return session;
109 }
110 Session::new(session_id.to_string(), model.to_string())
111 }
112
113 pub async fn load_merged(&self, session_id: &str) -> Option<Session> {
122 let memory_session = read_cached_session(&self.cache, session_id);
123 let storage_session = self
124 .storage
125 .load_session(session_id)
126 .await
127 .unwrap_or_default();
128
129 match (memory_session, storage_session) {
130 (Some(memory), Some(storage)) => {
131 let prefer_storage = should_prefer_storage(&memory, &storage);
132 let diverged = prefer_storage || memory.messages.len() != storage.messages.len();
133 let chosen_len = if prefer_storage {
134 storage.messages.len()
135 } else {
136 memory.messages.len()
137 };
138 macro_rules! merged_log {
139 ($level:ident) => {
140 tracing::$level!(
141 "[{}] load_session_merged: memory={} msgs (updated_at={}), storage={} msgs (updated_at={}), prefer_storage={} -> chose {} msgs",
142 session_id,
143 memory.messages.len(),
144 memory.updated_at,
145 storage.messages.len(),
146 storage.updated_at,
147 prefer_storage,
148 chosen_len,
149 )
150 };
151 }
152 if diverged {
153 merged_log!(debug);
154 } else {
155 merged_log!(trace);
156 }
157 let memory_updated_at = memory.updated_at;
158 let chosen = if prefer_storage { storage } else { memory };
159 if prefer_storage && chosen.updated_at >= memory_updated_at {
167 self.cache.insert(
168 session_id.to_string(),
169 Arc::new(parking_lot::RwLock::new(chosen.clone())),
170 );
171 }
172 Some(chosen)
173 }
174 (Some(memory), None) => Some(memory),
175 (None, Some(storage)) => {
176 self.cache.insert(
177 session_id.to_string(),
178 Arc::new(parking_lot::RwLock::new(storage.clone())),
179 );
180 Some(storage)
181 }
182 (None, None) => None,
183 }
184 }
185
186 pub async fn save_and_cache(&self, session: &mut Session) {
189 if let Err(error) = self.persistence.merge_save_runtime(session).await {
190 tracing::warn!("[{}] Failed to save session: {}", session.id, error);
191 }
192 self.cache.insert(
193 session.id.clone(),
194 Arc::new(parking_lot::RwLock::new(session.clone())),
195 );
196 }
197}
198
199fn should_prefer_storage(memory_session: &Session, storage_session: &Session) -> bool {
200 if storage_session.updated_at < memory_session.updated_at {
205 return false;
206 }
207 storage_session.updated_at > memory_session.updated_at
211 || (memory_session.pending_question.is_none() && storage_session.pending_question.is_some())
212}
213
214#[async_trait::async_trait]
218impl bamboo_domain::RuntimeSessionPersistence for SessionRepository {
219 async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
220 self.save(session).await
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use bamboo_agent_core::storage::Storage;
228 use chrono::Utc;
229 use std::collections::HashMap;
230 use std::sync::Mutex;
231
232 #[derive(Default)]
233 struct MapStorage {
234 sessions: Mutex<HashMap<String, Session>>,
235 }
236
237 #[async_trait::async_trait]
238 impl Storage for MapStorage {
239 async fn save_session(&self, session: &Session) -> std::io::Result<()> {
240 self.sessions
241 .lock()
242 .unwrap()
243 .insert(session.id.clone(), session.clone());
244 Ok(())
245 }
246 async fn load_session(&self, session_id: &str) -> std::io::Result<Option<Session>> {
247 Ok(self.sessions.lock().unwrap().get(session_id).cloned())
248 }
249 async fn delete_session(&self, session_id: &str) -> std::io::Result<bool> {
250 Ok(self.sessions.lock().unwrap().remove(session_id).is_some())
251 }
252 }
253
254 fn test_repo(storage: Arc<dyn Storage>) -> SessionRepository {
255 let cache: SessionCache = Arc::new(dashmap::DashMap::new());
256 let persistence = Arc::new(LockedSessionStore::new(storage.clone()));
257 SessionRepository::new(cache, storage, persistence)
258 }
259
260 fn cache_put(repo: &SessionRepository, session: &Session) {
261 repo.cache().insert(
262 session.id.clone(),
263 Arc::new(parking_lot::RwLock::new(session.clone())),
264 );
265 }
266
267 #[tokio::test]
272 async fn load_merged_does_not_regress_to_older_storage() {
273 let storage: Arc<dyn Storage> = Arc::new(MapStorage::default());
274 let repo = test_repo(storage.clone());
275 let id = "s1";
276
277 let mut stale = Session::new(id.to_string(), "m");
278 stale.set_pending_question(
279 "tc1".into(),
280 "kind".into(),
281 "q?".into(),
282 vec!["OK".into()],
283 true,
284 );
285 stale.updated_at = Utc::now() - chrono::Duration::seconds(10);
286 storage.save_session(&stale).await.unwrap();
287
288 let mut fresh = Session::new(id.to_string(), "m");
289 fresh.updated_at = Utc::now();
290 cache_put(&repo, &fresh);
291
292 let merged = repo.load_merged(id).await.expect("session exists");
293 assert!(
294 merged.pending_question.is_none(),
295 "must return the newer answered memory copy, not the stale storage one"
296 );
297 let cached = read_cached_session(repo.cache(), id).expect("cached");
298 assert!(
299 cached.pending_question.is_none(),
300 "load_merged must never regress the cache to a stale storage copy"
301 );
302 }
303
304 #[tokio::test]
308 async fn load_merged_recovers_pending_question_from_same_age_storage() {
309 let storage: Arc<dyn Storage> = Arc::new(MapStorage::default());
310 let repo = test_repo(storage.clone());
311 let id = "s2";
312 let ts = Utc::now();
313
314 let mut with_pending = Session::new(id.to_string(), "m");
315 with_pending.set_pending_question(
316 "tc".into(),
317 "k".into(),
318 "q".into(),
319 vec!["OK".into()],
320 true,
321 );
322 with_pending.updated_at = ts;
323 storage.save_session(&with_pending).await.unwrap();
324
325 let mut lost = with_pending.clone();
326 lost.clear_pending_question();
327 lost.updated_at = ts;
328 cache_put(&repo, &lost);
329
330 let merged = repo.load_merged(id).await.expect("session exists");
331 assert!(
332 merged.pending_question.is_some(),
333 "same-age storage carrying a pending question must still be recovered"
334 );
335 }
336}