claude_agent/session/
persistence.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::state::{Session, SessionId, SessionMessage};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12
13#[async_trait::async_trait]
14pub trait Persistence: Send + Sync {
15 fn name(&self) -> &str;
16
17 async fn save(&self, session: &Session) -> SessionResult<()>;
19 async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>>;
20 async fn delete(&self, id: &SessionId) -> SessionResult<bool>;
21 async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>>;
22
23 async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()>;
25 async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>>;
26
27 async fn enqueue(
29 &self,
30 session_id: &SessionId,
31 content: String,
32 priority: i32,
33 ) -> SessionResult<QueueItem>;
34 async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>>;
35 async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool>;
36 async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>>;
37
38 async fn cleanup_expired(&self) -> SessionResult<usize>;
40
41 async fn add_message(
48 &self,
49 session_id: &SessionId,
50 message: SessionMessage,
51 ) -> SessionResult<()> {
52 let mut session = self
53 .load(session_id)
54 .await?
55 .ok_or_else(|| SessionError::NotFound {
56 id: session_id.to_string(),
57 })?;
58 session.add_message(message);
59 self.save(&session).await
60 }
61}
62
63#[derive(Debug, Default)]
64pub struct MemoryPersistence {
65 sessions: Arc<RwLock<HashMap<String, Session>>>,
66 summaries: Arc<RwLock<HashMap<String, Vec<SummarySnapshot>>>>,
67 queue: Arc<RwLock<HashMap<String, Vec<QueueItem>>>>,
68}
69
70impl MemoryPersistence {
71 pub fn new() -> Self {
72 Self::default()
73 }
74
75 pub async fn count(&self) -> usize {
76 self.sessions.read().await.len()
77 }
78
79 pub async fn clear(&self) {
80 self.sessions.write().await.clear();
81 self.summaries.write().await.clear();
82 self.queue.write().await.clear();
83 }
84}
85
86#[async_trait::async_trait]
87impl Persistence for MemoryPersistence {
88 fn name(&self) -> &str {
89 "memory"
90 }
91
92 async fn save(&self, session: &Session) -> SessionResult<()> {
93 self.sessions
94 .write()
95 .await
96 .insert(session.id.to_string(), session.clone());
97 Ok(())
98 }
99
100 async fn add_message(
101 &self,
102 session_id: &SessionId,
103 message: SessionMessage,
104 ) -> SessionResult<()> {
105 let mut sessions = self.sessions.write().await;
106 if let Some(session) = sessions.get_mut(&session_id.to_string()) {
107 session.add_message(message);
108 Ok(())
109 } else {
110 Err(SessionError::NotFound {
111 id: session_id.to_string(),
112 })
113 }
114 }
115
116 async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
117 Ok(self.sessions.read().await.get(&id.to_string()).cloned())
118 }
119
120 async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
121 let key = id.to_string();
122 let mut sessions = self.sessions.write().await;
123 let mut summaries = self.summaries.write().await;
124 let mut queue = self.queue.write().await;
125 summaries.remove(&key);
126 queue.remove(&key);
127 Ok(sessions.remove(&key).is_some())
128 }
129
130 async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
131 Ok(self
132 .sessions
133 .read()
134 .await
135 .values()
136 .filter(|s| {
137 tenant_id
138 .map(|t| s.tenant_id.as_deref() == Some(t))
139 .unwrap_or(true)
140 })
141 .map(|s| s.id)
142 .collect())
143 }
144
145 async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
146 self.summaries
147 .write()
148 .await
149 .entry(snapshot.session_id.to_string())
150 .or_default()
151 .push(snapshot);
152 Ok(())
153 }
154
155 async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
156 Ok(self
157 .summaries
158 .read()
159 .await
160 .get(&session_id.to_string())
161 .cloned()
162 .unwrap_or_default())
163 }
164
165 async fn enqueue(
166 &self,
167 session_id: &SessionId,
168 content: String,
169 priority: i32,
170 ) -> SessionResult<QueueItem> {
171 let item = QueueItem::enqueue(*session_id, content).priority(priority);
172 self.queue
173 .write()
174 .await
175 .entry(session_id.to_string())
176 .or_default()
177 .push(item.clone());
178 Ok(item)
179 }
180
181 async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
182 let mut queue = self.queue.write().await;
183 if let Some(items) = queue.get_mut(&session_id.to_string()) {
184 items.sort_by(|a, b| b.priority.cmp(&a.priority));
185 if let Some(pos) = items
186 .iter()
187 .position(|i| i.status == super::types::QueueStatus::Pending)
188 {
189 items[pos].start_processing();
190 return Ok(Some(items[pos].clone()));
191 }
192 }
193 Ok(None)
194 }
195
196 async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
197 for items in self.queue.write().await.values_mut() {
198 if let Some(item) = items.iter_mut().find(|i| i.id == item_id) {
199 item.cancel();
200 return Ok(true);
201 }
202 }
203 Ok(false)
204 }
205
206 async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
207 Ok(self
208 .queue
209 .read()
210 .await
211 .get(&session_id.to_string())
212 .map(|items| {
213 items
214 .iter()
215 .filter(|i| i.status == super::types::QueueStatus::Pending)
216 .cloned()
217 .collect()
218 })
219 .unwrap_or_default())
220 }
221
222 async fn cleanup_expired(&self) -> SessionResult<usize> {
223 let mut sessions = self.sessions.write().await;
227 let mut summaries = self.summaries.write().await;
228 let mut queue = self.queue.write().await;
229
230 let expired_keys: Vec<String> = sessions
231 .iter()
232 .filter(|(_, s)| s.is_expired())
233 .map(|(k, _)| k.clone())
234 .collect();
235
236 for key in &expired_keys {
237 sessions.remove(key);
238 summaries.remove(key);
239 queue.remove(key);
240 }
241
242 Ok(expired_keys.len())
243 }
244}
245
246pub struct PersistenceFactory;
247
248impl PersistenceFactory {
249 pub fn memory() -> Arc<dyn Persistence> {
250 Arc::new(MemoryPersistence::new())
251 }
252
253 #[cfg(feature = "jsonl")]
255 pub async fn jsonl(
256 config: super::persistence_jsonl::JsonlConfig,
257 ) -> SessionResult<Arc<dyn Persistence>> {
258 Ok(Arc::new(
259 super::persistence_jsonl::JsonlPersistence::new(config).await?,
260 ))
261 }
262
263 #[cfg(feature = "jsonl")]
265 pub async fn jsonl_default() -> SessionResult<Arc<dyn Persistence>> {
266 Ok(Arc::new(
267 super::persistence_jsonl::JsonlPersistence::default_config().await?,
268 ))
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::session::state::SessionConfig;
276 use crate::types::ContentBlock;
277
278 #[tokio::test]
279 async fn test_save_load() {
280 let persistence = MemoryPersistence::new();
281 let session = Session::new(SessionConfig::default());
282 let id = session.id;
283
284 persistence.save(&session).await.unwrap();
285 let loaded = persistence.load(&id).await.unwrap();
286
287 assert!(loaded.is_some());
288 assert_eq!(loaded.unwrap().id, id);
289 }
290
291 #[tokio::test]
292 async fn test_delete() {
293 let persistence = MemoryPersistence::new();
294 let session = Session::new(SessionConfig::default());
295 let id = session.id;
296
297 persistence.save(&session).await.unwrap();
298 assert!(persistence.delete(&id).await.unwrap());
299 assert!(persistence.load(&id).await.unwrap().is_none());
300 }
301
302 #[tokio::test]
303 async fn test_list_by_tenant() {
304 let persistence = MemoryPersistence::new();
305
306 let mut s1 = Session::new(SessionConfig::default());
307 s1.tenant_id = Some("tenant-a".to_string());
308
309 let mut s2 = Session::new(SessionConfig::default());
310 s2.tenant_id = Some("tenant-b".to_string());
311
312 persistence.save(&s1).await.unwrap();
313 persistence.save(&s2).await.unwrap();
314
315 assert_eq!(persistence.list(None).await.unwrap().len(), 2);
316 assert_eq!(persistence.list(Some("tenant-a")).await.unwrap().len(), 1);
317 }
318
319 #[tokio::test]
320 async fn test_add_message() {
321 let persistence = MemoryPersistence::new();
322 let session = Session::new(SessionConfig::default());
323 let id = session.id;
324
325 persistence.save(&session).await.unwrap();
326 persistence
327 .add_message(&id, SessionMessage::user(vec![ContentBlock::text("Hello")]))
328 .await
329 .unwrap();
330
331 let loaded = persistence.load(&id).await.unwrap().unwrap();
332 assert_eq!(loaded.messages.len(), 1);
333 }
334
335 #[tokio::test]
336 async fn test_summaries() {
337 let persistence = MemoryPersistence::new();
338 let session = Session::new(SessionConfig::default());
339 let id = session.id;
340
341 persistence.save(&session).await.unwrap();
342 persistence
343 .add_summary(SummarySnapshot::new(id, "First"))
344 .await
345 .unwrap();
346 persistence
347 .add_summary(SummarySnapshot::new(id, "Second"))
348 .await
349 .unwrap();
350
351 let summaries = persistence.get_summaries(&id).await.unwrap();
352 assert_eq!(summaries.len(), 2);
353 }
354
355 #[tokio::test]
356 async fn test_queue_priority() {
357 let persistence = MemoryPersistence::new();
358 let session = Session::new(SessionConfig::default());
359 let id = session.id;
360
361 persistence.save(&session).await.unwrap();
362 persistence
363 .enqueue(&id, "Low".to_string(), 1)
364 .await
365 .unwrap();
366 persistence
367 .enqueue(&id, "High".to_string(), 10)
368 .await
369 .unwrap();
370
371 let next = persistence.dequeue(&id).await.unwrap().unwrap();
372 assert_eq!(next.content, "High");
373 }
374
375 #[tokio::test]
376 async fn test_cleanup_expired() {
377 let persistence = MemoryPersistence::new();
378 let config = SessionConfig {
379 ttl_secs: Some(0),
380 ..Default::default()
381 };
382 let session = Session::new(config);
383
384 persistence.save(&session).await.unwrap();
385 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
386
387 assert_eq!(persistence.cleanup_expired().await.unwrap(), 1);
388 assert_eq!(persistence.count().await, 0);
389 }
390}