1use std::sync::Arc;
16
17use bamboo_agent_core::storage::Storage;
18use bamboo_agent_core::Session;
19use bamboo_storage::LockedSessionStore;
20
21use crate::{read_cached_session, SessionCache};
22
23#[derive(Clone)]
26pub struct SessionRepository {
27 cache: SessionCache,
28 storage: Arc<dyn Storage>,
29 persistence: Arc<LockedSessionStore>,
30}
31
32impl SessionRepository {
33 pub fn new(
34 cache: SessionCache,
35 storage: Arc<dyn Storage>,
36 persistence: Arc<LockedSessionStore>,
37 ) -> Self {
38 Self {
39 cache,
40 storage,
41 persistence,
42 }
43 }
44
45 pub fn cache(&self) -> &SessionCache {
46 &self.cache
47 }
48
49 pub fn storage(&self) -> &Arc<dyn Storage> {
50 &self.storage
51 }
52
53 pub fn persistence(&self) -> &Arc<LockedSessionStore> {
54 &self.persistence
55 }
56
57 pub async fn load(&self, session_id: &str) -> Option<Session> {
60 if let Some(session) = read_cached_session(&self.cache, session_id) {
61 return Some(session);
62 }
63
64 match self.storage.load_session(session_id).await {
65 Ok(Some(session)) => {
66 self.cache.insert(
67 session_id.to_string(),
68 Arc::new(parking_lot::RwLock::new(session.clone())),
69 );
70 Some(session)
71 }
72 _ => None,
73 }
74 }
75
76 pub async fn try_load(&self, session_id: &str) -> std::io::Result<Option<Session>> {
80 if let Some(session) = read_cached_session(&self.cache, session_id) {
81 return Ok(Some(session));
82 }
83 let loaded = self.storage.load_session(session_id).await?;
84 if let Some(ref session) = loaded {
85 self.cache.insert(
86 session_id.to_string(),
87 Arc::new(parking_lot::RwLock::new(session.clone())),
88 );
89 }
90 Ok(loaded)
91 }
92
93 pub async fn save(&self, session: &mut Session) -> std::io::Result<()> {
97 self.persistence.merge_save_runtime(session).await?;
98 self.cache.insert(
99 session.id.clone(),
100 Arc::new(parking_lot::RwLock::new(session.clone())),
101 );
102 Ok(())
103 }
104
105 pub async fn load_or_create(&self, session_id: &str, model: &str) -> Session {
107 if let Some(session) = self.load(session_id).await {
108 return session;
109 }
110 Session::new(session_id.to_string(), model.to_string())
111 }
112
113 pub async fn load_merged(&self, session_id: &str) -> Option<Session> {
122 let memory_session = read_cached_session(&self.cache, session_id);
123 let storage_session = self
124 .storage
125 .load_session(session_id)
126 .await
127 .unwrap_or_default();
128
129 match (memory_session, storage_session) {
130 (Some(memory), Some(storage)) => {
131 let prefer_storage = should_prefer_storage(&memory, &storage);
132 let diverged = prefer_storage || memory.messages.len() != storage.messages.len();
133 let chosen_len = if prefer_storage {
134 storage.messages.len()
135 } else {
136 memory.messages.len()
137 };
138 macro_rules! merged_log {
139 ($level:ident) => {
140 tracing::$level!(
141 "[{}] load_session_merged: memory={} msgs (updated_at={}), storage={} msgs (updated_at={}), prefer_storage={} -> chose {} msgs",
142 session_id,
143 memory.messages.len(),
144 memory.updated_at,
145 storage.messages.len(),
146 storage.updated_at,
147 prefer_storage,
148 chosen_len,
149 )
150 };
151 }
152 if diverged {
153 merged_log!(debug);
154 } else {
155 merged_log!(trace);
156 }
157 let memory_updated_at = memory.updated_at;
158 let chosen = if prefer_storage { storage } else { memory };
159 if prefer_storage && chosen.updated_at >= memory_updated_at {
167 self.cache.insert(
168 session_id.to_string(),
169 Arc::new(parking_lot::RwLock::new(chosen.clone())),
170 );
171 }
172 Some(chosen)
173 }
174 (Some(memory), None) => Some(memory),
175 (None, Some(storage)) => {
176 self.cache.insert(
177 session_id.to_string(),
178 Arc::new(parking_lot::RwLock::new(storage.clone())),
179 );
180 Some(storage)
181 }
182 (None, None) => None,
183 }
184 }
185
186 pub async fn save_and_cache(&self, session: &mut Session) {
189 if let Err(error) = self.persistence.merge_save_runtime(session).await {
190 tracing::warn!("[{}] Failed to save session: {}", session.id, error);
191 }
192 self.cache.insert(
193 session.id.clone(),
194 Arc::new(parking_lot::RwLock::new(session.clone())),
195 );
196 }
197}
198
199fn should_prefer_storage(memory_session: &Session, storage_session: &Session) -> bool {
200 if storage_session.updated_at < memory_session.updated_at {
205 return false;
206 }
207 storage_session.updated_at > memory_session.updated_at
211 || (memory_session.pending_question.is_none() && storage_session.pending_question.is_some())
212}
213
214#[async_trait::async_trait]
218impl bamboo_domain::RuntimeSessionPersistence for SessionRepository {
219 async fn save_runtime_session(&self, session: &mut Session) -> std::io::Result<()> {
220 self.save(session).await
221 }
222
223 async fn append_token_usage_record(
224 &self,
225 session_id: &str,
226 json_line: &str,
227 ) -> std::io::Result<()> {
228 self.storage
229 .append_token_usage_record(session_id, json_line)
230 .await
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use bamboo_agent_core::storage::Storage;
238 use chrono::Utc;
239 use std::collections::HashMap;
240 use std::sync::Mutex;
241
242 #[derive(Default)]
243 struct MapStorage {
244 sessions: Mutex<HashMap<String, Session>>,
245 }
246
247 #[async_trait::async_trait]
248 impl Storage for MapStorage {
249 async fn save_session(&self, session: &Session) -> std::io::Result<()> {
250 self.sessions
251 .lock()
252 .unwrap()
253 .insert(session.id.clone(), session.clone());
254 Ok(())
255 }
256 async fn load_session(&self, session_id: &str) -> std::io::Result<Option<Session>> {
257 Ok(self.sessions.lock().unwrap().get(session_id).cloned())
258 }
259 async fn delete_session(&self, session_id: &str) -> std::io::Result<bool> {
260 Ok(self.sessions.lock().unwrap().remove(session_id).is_some())
261 }
262 }
263
264 fn test_repo(storage: Arc<dyn Storage>) -> SessionRepository {
265 let cache: SessionCache = Arc::new(dashmap::DashMap::new());
266 let persistence = Arc::new(LockedSessionStore::new(storage.clone()));
267 SessionRepository::new(cache, storage, persistence)
268 }
269
270 fn cache_put(repo: &SessionRepository, session: &Session) {
271 repo.cache().insert(
272 session.id.clone(),
273 Arc::new(parking_lot::RwLock::new(session.clone())),
274 );
275 }
276
277 #[tokio::test]
282 async fn load_merged_does_not_regress_to_older_storage() {
283 let storage: Arc<dyn Storage> = Arc::new(MapStorage::default());
284 let repo = test_repo(storage.clone());
285 let id = "s1";
286
287 let mut stale = Session::new(id.to_string(), "m");
288 stale.set_pending_question(
289 "tc1".into(),
290 "kind".into(),
291 "q?".into(),
292 vec!["OK".into()],
293 true,
294 );
295 stale.updated_at = Utc::now() - chrono::Duration::seconds(10);
296 storage.save_session(&stale).await.unwrap();
297
298 let mut fresh = Session::new(id.to_string(), "m");
299 fresh.updated_at = Utc::now();
300 cache_put(&repo, &fresh);
301
302 let merged = repo.load_merged(id).await.expect("session exists");
303 assert!(
304 merged.pending_question.is_none(),
305 "must return the newer answered memory copy, not the stale storage one"
306 );
307 let cached = read_cached_session(repo.cache(), id).expect("cached");
308 assert!(
309 cached.pending_question.is_none(),
310 "load_merged must never regress the cache to a stale storage copy"
311 );
312 }
313
314 #[tokio::test]
318 async fn load_merged_recovers_pending_question_from_same_age_storage() {
319 let storage: Arc<dyn Storage> = Arc::new(MapStorage::default());
320 let repo = test_repo(storage.clone());
321 let id = "s2";
322 let ts = Utc::now();
323
324 let mut with_pending = Session::new(id.to_string(), "m");
325 with_pending.set_pending_question(
326 "tc".into(),
327 "k".into(),
328 "q".into(),
329 vec!["OK".into()],
330 true,
331 );
332 with_pending.updated_at = ts;
333 storage.save_session(&with_pending).await.unwrap();
334
335 let mut lost = with_pending.clone();
336 lost.clear_pending_question();
337 lost.updated_at = ts;
338 cache_put(&repo, &lost);
339
340 let merged = repo.load_merged(id).await.expect("session exists");
341 assert!(
342 merged.pending_question.is_some(),
343 "same-age storage carrying a pending question must still be recovered"
344 );
345 }
346}