claude_agent/session/
persistence.rs

1//! Session Persistence Backends
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::state::{Session, SessionId, SessionMessage};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12
13#[async_trait::async_trait]
14pub trait Persistence: Send + Sync {
15    fn name(&self) -> &str;
16
17    // Core CRUD
18    async fn save(&self, session: &Session) -> SessionResult<()>;
19    async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>>;
20    async fn delete(&self, id: &SessionId) -> SessionResult<bool>;
21    async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>>;
22    async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>>;
23
24    // Summaries
25    async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()>;
26    async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>>;
27
28    // Queue
29    async fn enqueue(
30        &self,
31        session_id: &SessionId,
32        content: String,
33        priority: i32,
34    ) -> SessionResult<QueueItem>;
35    async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>>;
36    async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool>;
37    async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>>;
38
39    // Cleanup
40    async fn cleanup_expired(&self) -> SessionResult<usize>;
41
42    // Default implementations (convenience methods)
43    async fn add_message(
44        &self,
45        session_id: &SessionId,
46        message: SessionMessage,
47    ) -> SessionResult<()> {
48        let mut session = self
49            .load(session_id)
50            .await?
51            .ok_or_else(|| SessionError::NotFound {
52                id: session_id.to_string(),
53            })?;
54        session.add_message(message);
55        self.save(&session).await
56    }
57}
58
59#[derive(Debug, Default)]
60pub struct MemoryPersistence {
61    sessions: Arc<RwLock<HashMap<String, Session>>>,
62    summaries: Arc<RwLock<HashMap<String, Vec<SummarySnapshot>>>>,
63    queue: Arc<RwLock<HashMap<String, Vec<QueueItem>>>>,
64}
65
66impl MemoryPersistence {
67    pub fn new() -> Self {
68        Self::default()
69    }
70
71    pub async fn count(&self) -> usize {
72        self.sessions.read().await.len()
73    }
74
75    pub async fn clear(&self) {
76        self.sessions.write().await.clear();
77        self.summaries.write().await.clear();
78        self.queue.write().await.clear();
79    }
80}
81
82#[async_trait::async_trait]
83impl Persistence for MemoryPersistence {
84    fn name(&self) -> &str {
85        "memory"
86    }
87
88    async fn save(&self, session: &Session) -> SessionResult<()> {
89        self.sessions
90            .write()
91            .await
92            .insert(session.id.to_string(), session.clone());
93        Ok(())
94    }
95
96    async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
97        Ok(self.sessions.read().await.get(&id.to_string()).cloned())
98    }
99
100    async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
101        let key = id.to_string();
102        self.summaries.write().await.remove(&key);
103        self.queue.write().await.remove(&key);
104        Ok(self.sessions.write().await.remove(&key).is_some())
105    }
106
107    async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
108        Ok(self
109            .sessions
110            .read()
111            .await
112            .values()
113            .filter(|s| {
114                tenant_id
115                    .map(|t| s.tenant_id.as_deref() == Some(t))
116                    .unwrap_or(true)
117            })
118            .map(|s| s.id)
119            .collect())
120    }
121
122    async fn list_children(&self, parent_id: &SessionId) -> SessionResult<Vec<SessionId>> {
123        Ok(self
124            .sessions
125            .read()
126            .await
127            .values()
128            .filter(|s| s.parent_id.as_ref() == Some(parent_id))
129            .map(|s| s.id)
130            .collect())
131    }
132
133    async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
134        self.summaries
135            .write()
136            .await
137            .entry(snapshot.session_id.to_string())
138            .or_default()
139            .push(snapshot);
140        Ok(())
141    }
142
143    async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
144        Ok(self
145            .summaries
146            .read()
147            .await
148            .get(&session_id.to_string())
149            .cloned()
150            .unwrap_or_default())
151    }
152
153    async fn enqueue(
154        &self,
155        session_id: &SessionId,
156        content: String,
157        priority: i32,
158    ) -> SessionResult<QueueItem> {
159        let item = QueueItem::enqueue(*session_id, content).with_priority(priority);
160        self.queue
161            .write()
162            .await
163            .entry(session_id.to_string())
164            .or_default()
165            .push(item.clone());
166        Ok(item)
167    }
168
169    async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
170        let mut queue = self.queue.write().await;
171        if let Some(items) = queue.get_mut(&session_id.to_string()) {
172            items.sort_by(|a, b| b.priority.cmp(&a.priority));
173            if let Some(pos) = items
174                .iter()
175                .position(|i| i.status == super::types::QueueStatus::Pending)
176            {
177                items[pos].start_processing();
178                return Ok(Some(items[pos].clone()));
179            }
180        }
181        Ok(None)
182    }
183
184    async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
185        for items in self.queue.write().await.values_mut() {
186            if let Some(item) = items.iter_mut().find(|i| i.id == item_id) {
187                item.cancel();
188                return Ok(true);
189            }
190        }
191        Ok(false)
192    }
193
194    async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
195        Ok(self
196            .queue
197            .read()
198            .await
199            .get(&session_id.to_string())
200            .map(|items| {
201                items
202                    .iter()
203                    .filter(|i| i.status == super::types::QueueStatus::Pending)
204                    .cloned()
205                    .collect()
206            })
207            .unwrap_or_default())
208    }
209
210    async fn cleanup_expired(&self) -> SessionResult<usize> {
211        let mut sessions = self.sessions.write().await;
212        let before = sessions.len();
213        sessions.retain(|_, s| !s.is_expired());
214        Ok(before - sessions.len())
215    }
216}
217
218pub struct PersistenceFactory;
219
220impl PersistenceFactory {
221    pub fn memory() -> Arc<dyn Persistence> {
222        Arc::new(MemoryPersistence::new())
223    }
224
225    /// Create a JSONL persistence backend (requires `jsonl` feature).
226    #[cfg(feature = "jsonl")]
227    pub async fn jsonl(
228        config: super::persistence_jsonl::JsonlConfig,
229    ) -> SessionResult<Arc<dyn Persistence>> {
230        Ok(Arc::new(
231            super::persistence_jsonl::JsonlPersistence::new(config).await?,
232        ))
233    }
234
235    /// Create a JSONL persistence backend with default configuration (requires `jsonl` feature).
236    #[cfg(feature = "jsonl")]
237    pub async fn jsonl_default() -> SessionResult<Arc<dyn Persistence>> {
238        Ok(Arc::new(
239            super::persistence_jsonl::JsonlPersistence::default_config().await?,
240        ))
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::session::state::SessionConfig;
248    use crate::types::ContentBlock;
249
250    #[tokio::test]
251    async fn test_save_load() {
252        let persistence = MemoryPersistence::new();
253        let session = Session::new(SessionConfig::default());
254        let id = session.id;
255
256        persistence.save(&session).await.unwrap();
257        let loaded = persistence.load(&id).await.unwrap();
258
259        assert!(loaded.is_some());
260        assert_eq!(loaded.unwrap().id, id);
261    }
262
263    #[tokio::test]
264    async fn test_delete() {
265        let persistence = MemoryPersistence::new();
266        let session = Session::new(SessionConfig::default());
267        let id = session.id;
268
269        persistence.save(&session).await.unwrap();
270        assert!(persistence.delete(&id).await.unwrap());
271        assert!(persistence.load(&id).await.unwrap().is_none());
272    }
273
274    #[tokio::test]
275    async fn test_list_by_tenant() {
276        let persistence = MemoryPersistence::new();
277
278        let mut s1 = Session::new(SessionConfig::default());
279        s1.tenant_id = Some("tenant-a".to_string());
280
281        let mut s2 = Session::new(SessionConfig::default());
282        s2.tenant_id = Some("tenant-b".to_string());
283
284        persistence.save(&s1).await.unwrap();
285        persistence.save(&s2).await.unwrap();
286
287        assert_eq!(persistence.list(None).await.unwrap().len(), 2);
288        assert_eq!(persistence.list(Some("tenant-a")).await.unwrap().len(), 1);
289    }
290
291    #[tokio::test]
292    async fn test_add_message() {
293        let persistence = MemoryPersistence::new();
294        let session = Session::new(SessionConfig::default());
295        let id = session.id;
296
297        persistence.save(&session).await.unwrap();
298        persistence
299            .add_message(&id, SessionMessage::user(vec![ContentBlock::text("Hello")]))
300            .await
301            .unwrap();
302
303        let loaded = persistence.load(&id).await.unwrap().unwrap();
304        assert_eq!(loaded.messages.len(), 1);
305    }
306
307    #[tokio::test]
308    async fn test_summaries() {
309        let persistence = MemoryPersistence::new();
310        let session = Session::new(SessionConfig::default());
311        let id = session.id;
312
313        persistence.save(&session).await.unwrap();
314        persistence
315            .add_summary(SummarySnapshot::new(id, "First"))
316            .await
317            .unwrap();
318        persistence
319            .add_summary(SummarySnapshot::new(id, "Second"))
320            .await
321            .unwrap();
322
323        let summaries = persistence.get_summaries(&id).await.unwrap();
324        assert_eq!(summaries.len(), 2);
325    }
326
327    #[tokio::test]
328    async fn test_queue_priority() {
329        let persistence = MemoryPersistence::new();
330        let session = Session::new(SessionConfig::default());
331        let id = session.id;
332
333        persistence.save(&session).await.unwrap();
334        persistence
335            .enqueue(&id, "Low".to_string(), 1)
336            .await
337            .unwrap();
338        persistence
339            .enqueue(&id, "High".to_string(), 10)
340            .await
341            .unwrap();
342
343        let next = persistence.dequeue(&id).await.unwrap().unwrap();
344        assert_eq!(next.content, "High");
345    }
346
347    #[tokio::test]
348    async fn test_cleanup_expired() {
349        let persistence = MemoryPersistence::new();
350        let config = SessionConfig {
351            ttl_secs: Some(0),
352            ..Default::default()
353        };
354        let session = Session::new(config);
355
356        persistence.save(&session).await.unwrap();
357        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
358
359        assert_eq!(persistence.cleanup_expired().await.unwrap(), 1);
360        assert_eq!(persistence.count().await, 0);
361    }
362}