claude_agent/session/
persistence.rs1use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::state::{Session, SessionId, SessionMessage};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12
13#[async_trait::async_trait]
14pub trait Persistence: Send + Sync {
15 fn name(&self) -> &str;
16
17 async fn save(&self, session: &Session) -> SessionResult<()>;
19 async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>>;
20 async fn delete(&self, id: &SessionId) -> SessionResult<bool>;
21 async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>>;
22 async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>>;
23
24 async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()>;
26 async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>>;
27
28 async fn enqueue(
30 &self,
31 session_id: &SessionId,
32 content: String,
33 priority: i32,
34 ) -> SessionResult<QueueItem>;
35 async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>>;
36 async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool>;
37 async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>>;
38
39 async fn cleanup_expired(&self) -> SessionResult<usize>;
41
42 async fn add_message(
44 &self,
45 session_id: &SessionId,
46 message: SessionMessage,
47 ) -> SessionResult<()> {
48 let mut session = self
49 .load(session_id)
50 .await?
51 .ok_or_else(|| SessionError::NotFound {
52 id: session_id.to_string(),
53 })?;
54 session.add_message(message);
55 self.save(&session).await
56 }
57}
58
59#[derive(Debug, Default)]
60pub struct MemoryPersistence {
61 sessions: Arc<RwLock<HashMap<String, Session>>>,
62 summaries: Arc<RwLock<HashMap<String, Vec<SummarySnapshot>>>>,
63 queue: Arc<RwLock<HashMap<String, Vec<QueueItem>>>>,
64}
65
66impl MemoryPersistence {
67 pub fn new() -> Self {
68 Self::default()
69 }
70
71 pub async fn count(&self) -> usize {
72 self.sessions.read().await.len()
73 }
74
75 pub async fn clear(&self) {
76 self.sessions.write().await.clear();
77 self.summaries.write().await.clear();
78 self.queue.write().await.clear();
79 }
80}
81
82#[async_trait::async_trait]
83impl Persistence for MemoryPersistence {
84 fn name(&self) -> &str {
85 "memory"
86 }
87
88 async fn save(&self, session: &Session) -> SessionResult<()> {
89 self.sessions
90 .write()
91 .await
92 .insert(session.id.to_string(), session.clone());
93 Ok(())
94 }
95
96 async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
97 Ok(self.sessions.read().await.get(&id.to_string()).cloned())
98 }
99
100 async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
101 let key = id.to_string();
102 self.summaries.write().await.remove(&key);
103 self.queue.write().await.remove(&key);
104 Ok(self.sessions.write().await.remove(&key).is_some())
105 }
106
107 async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
108 Ok(self
109 .sessions
110 .read()
111 .await
112 .values()
113 .filter(|s| {
114 tenant_id
115 .map(|t| s.tenant_id.as_deref() == Some(t))
116 .unwrap_or(true)
117 })
118 .map(|s| s.id)
119 .collect())
120 }
121
122 async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>> {
123 Ok(self
124 .sessions
125 .read()
126 .await
127 .values()
128 .filter(|s| s.parent_id.as_ref() == Some(parent_id))
129 .map(|s| s.id)
130 .collect())
131 }
132
133 async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
134 self.summaries
135 .write()
136 .await
137 .entry(snapshot.session_id.to_string())
138 .or_default()
139 .push(snapshot);
140 Ok(())
141 }
142
143 async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
144 Ok(self
145 .summaries
146 .read()
147 .await
148 .get(&session_id.to_string())
149 .cloned()
150 .unwrap_or_default())
151 }
152
153 async fn enqueue(
154 &self,
155 session_id: &SessionId,
156 content: String,
157 priority: i32,
158 ) -> SessionResult<QueueItem> {
159 let item = QueueItem::enqueue(*session_id, content).with_priority(priority);
160 self.queue
161 .write()
162 .await
163 .entry(session_id.to_string())
164 .or_default()
165 .push(item.clone());
166 Ok(item)
167 }
168
169 async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
170 let mut queue = self.queue.write().await;
171 if let Some(items) = queue.get_mut(&session_id.to_string()) {
172 items.sort_by(|a, b| b.priority.cmp(&a.priority));
173 if let Some(pos) = items
174 .iter()
175 .position(|i| i.status == super::types::QueueStatus::Pending)
176 {
177 items[pos].start_processing();
178 return Ok(Some(items[pos].clone()));
179 }
180 }
181 Ok(None)
182 }
183
184 async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
185 for items in self.queue.write().await.values_mut() {
186 if let Some(item) = items.iter_mut().find(|i| i.id == item_id) {
187 item.cancel();
188 return Ok(true);
189 }
190 }
191 Ok(false)
192 }
193
194 async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
195 Ok(self
196 .queue
197 .read()
198 .await
199 .get(&session_id.to_string())
200 .map(|items| {
201 items
202 .iter()
203 .filter(|i| i.status == super::types::QueueStatus::Pending)
204 .cloned()
205 .collect()
206 })
207 .unwrap_or_default())
208 }
209
210 async fn cleanup_expired(&self) -> SessionResult<usize> {
211 let mut sessions = self.sessions.write().await;
212 let before = sessions.len();
213 sessions.retain(|_, s| !s.is_expired());
214 Ok(before - sessions.len())
215 }
216}
217
218pub struct PersistenceFactory;
219
220impl PersistenceFactory {
221 pub fn memory() -> Arc<dyn Persistence> {
222 Arc::new(MemoryPersistence::new())
223 }
224
225 #[cfg(feature = "jsonl")]
227 pub async fn jsonl(
228 config: super::persistence_jsonl::JsonlConfig,
229 ) -> SessionResult<Arc<dyn Persistence>> {
230 Ok(Arc::new(
231 super::persistence_jsonl::JsonlPersistence::new(config).await?,
232 ))
233 }
234
235 #[cfg(feature = "jsonl")]
237 pub async fn jsonl_default() -> SessionResult<Arc<dyn Persistence>> {
238 Ok(Arc::new(
239 super::persistence_jsonl::JsonlPersistence::default_config().await?,
240 ))
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::session::state::SessionConfig;
248 use crate::types::ContentBlock;
249
250 #[tokio::test]
251 async fn test_save_load() {
252 let persistence = MemoryPersistence::new();
253 let session = Session::new(SessionConfig::default());
254 let id = session.id;
255
256 persistence.save(&session).await.unwrap();
257 let loaded = persistence.load(&id).await.unwrap();
258
259 assert!(loaded.is_some());
260 assert_eq!(loaded.unwrap().id, id);
261 }
262
263 #[tokio::test]
264 async fn test_delete() {
265 let persistence = MemoryPersistence::new();
266 let session = Session::new(SessionConfig::default());
267 let id = session.id;
268
269 persistence.save(&session).await.unwrap();
270 assert!(persistence.delete(&id).await.unwrap());
271 assert!(persistence.load(&id).await.unwrap().is_none());
272 }
273
274 #[tokio::test]
275 async fn test_list_by_tenant() {
276 let persistence = MemoryPersistence::new();
277
278 let mut s1 = Session::new(SessionConfig::default());
279 s1.tenant_id = Some("tenant-a".to_string());
280
281 let mut s2 = Session::new(SessionConfig::default());
282 s2.tenant_id = Some("tenant-b".to_string());
283
284 persistence.save(&s1).await.unwrap();
285 persistence.save(&s2).await.unwrap();
286
287 assert_eq!(persistence.list(None).await.unwrap().len(), 2);
288 assert_eq!(persistence.list(Some("tenant-a")).await.unwrap().len(), 1);
289 }
290
291 #[tokio::test]
292 async fn test_add_message() {
293 let persistence = MemoryPersistence::new();
294 let session = Session::new(SessionConfig::default());
295 let id = session.id;
296
297 persistence.save(&session).await.unwrap();
298 persistence
299 .add_message(&id, SessionMessage::user(vec![ContentBlock::text("Hello")]))
300 .await
301 .unwrap();
302
303 let loaded = persistence.load(&id).await.unwrap().unwrap();
304 assert_eq!(loaded.messages.len(), 1);
305 }
306
307 #[tokio::test]
308 async fn test_summaries() {
309 let persistence = MemoryPersistence::new();
310 let session = Session::new(SessionConfig::default());
311 let id = session.id;
312
313 persistence.save(&session).await.unwrap();
314 persistence
315 .add_summary(SummarySnapshot::new(id, "First"))
316 .await
317 .unwrap();
318 persistence
319 .add_summary(SummarySnapshot::new(id, "Second"))
320 .await
321 .unwrap();
322
323 let summaries = persistence.get_summaries(&id).await.unwrap();
324 assert_eq!(summaries.len(), 2);
325 }
326
327 #[tokio::test]
328 async fn test_queue_priority() {
329 let persistence = MemoryPersistence::new();
330 let session = Session::new(SessionConfig::default());
331 let id = session.id;
332
333 persistence.save(&session).await.unwrap();
334 persistence
335 .enqueue(&id, "Low".to_string(), 1)
336 .await
337 .unwrap();
338 persistence
339 .enqueue(&id, "High".to_string(), 10)
340 .await
341 .unwrap();
342
343 let next = persistence.dequeue(&id).await.unwrap().unwrap();
344 assert_eq!(next.content, "High");
345 }
346
347 #[tokio::test]
348 async fn test_cleanup_expired() {
349 let persistence = MemoryPersistence::new();
350 let config = SessionConfig {
351 ttl_secs: Some(0),
352 ..Default::default()
353 };
354 let session = Session::new(config);
355
356 persistence.save(&session).await.unwrap();
357 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
358
359 assert_eq!(persistence.cleanup_expired().await.unwrap(), 1);
360 assert_eq!(persistence.count().await, 0);
361 }
362}