Skip to main content

claude_agent/session/
persistence.rs

1//! Session Persistence Backends
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use tokio::sync::RwLock;
7use uuid::Uuid;
8
9use super::state::{Session, SessionId, SessionMessage};
10use super::types::{QueueItem, SummarySnapshot};
11use super::{SessionError, SessionResult};
12
13#[async_trait::async_trait]
14pub trait Persistence: Send + Sync {
15    fn name(&self) -> &str;
16
17    // Core CRUD
18    async fn save(&self, session: &Session) -> SessionResult<()>;
19    async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>>;
20    async fn delete(&self, id: &SessionId) -> SessionResult<bool>;
21    async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>>;
22
23    // Summaries
24    async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()>;
25    async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>>;
26
27    // Queue
28    async fn enqueue(
29        &self,
30        session_id: &SessionId,
31        content: String,
32        priority: i32,
33    ) -> SessionResult<QueueItem>;
34    async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>>;
35    async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool>;
36    async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>>;
37
38    // Cleanup
39    async fn cleanup_expired(&self) -> SessionResult<usize>;
40
41    /// Append a message to an existing session.
42    ///
43    /// Concurrency contract: implementations may hold a write lock for the duration
44    /// of this call. Callers must not hold other persistence locks to avoid deadlocks.
45    /// The default implementation performs a load-modify-save cycle; backends should
46    /// override this with a more efficient single-lock approach when possible.
47    async fn add_message(
48        &self,
49        session_id: &SessionId,
50        message: SessionMessage,
51    ) -> SessionResult<()> {
52        let mut session = self
53            .load(session_id)
54            .await?
55            .ok_or_else(|| SessionError::NotFound {
56                id: session_id.to_string(),
57            })?;
58        session.add_message(message);
59        self.save(&session).await
60    }
61}
62
63#[derive(Debug, Default)]
64pub struct MemoryPersistence {
65    sessions: Arc<RwLock<HashMap<String, Session>>>,
66    summaries: Arc<RwLock<HashMap<String, Vec<SummarySnapshot>>>>,
67    queue: Arc<RwLock<HashMap<String, Vec<QueueItem>>>>,
68}
69
70impl MemoryPersistence {
71    pub fn new() -> Self {
72        Self::default()
73    }
74
75    pub async fn count(&self) -> usize {
76        self.sessions.read().await.len()
77    }
78
79    pub async fn clear(&self) {
80        self.sessions.write().await.clear();
81        self.summaries.write().await.clear();
82        self.queue.write().await.clear();
83    }
84}
85
86#[async_trait::async_trait]
87impl Persistence for MemoryPersistence {
88    fn name(&self) -> &str {
89        "memory"
90    }
91
92    async fn save(&self, session: &Session) -> SessionResult<()> {
93        self.sessions
94            .write()
95            .await
96            .insert(session.id.to_string(), session.clone());
97        Ok(())
98    }
99
100    async fn add_message(
101        &self,
102        session_id: &SessionId,
103        message: SessionMessage,
104    ) -> SessionResult<()> {
105        let mut sessions = self.sessions.write().await;
106        if let Some(session) = sessions.get_mut(&session_id.to_string()) {
107            session.add_message(message);
108            Ok(())
109        } else {
110            Err(SessionError::NotFound {
111                id: session_id.to_string(),
112            })
113        }
114    }
115
116    async fn load(&self, id: &SessionId) -> SessionResult<Option<Session>> {
117        Ok(self.sessions.read().await.get(&id.to_string()).cloned())
118    }
119
120    async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
121        let key = id.to_string();
122        let mut sessions = self.sessions.write().await;
123        let mut summaries = self.summaries.write().await;
124        let mut queue = self.queue.write().await;
125        summaries.remove(&key);
126        queue.remove(&key);
127        Ok(sessions.remove(&key).is_some())
128    }
129
130    async fn list(&self, tenant_id: Option<&str>) -> SessionResult<Vec<SessionId>> {
131        Ok(self
132            .sessions
133            .read()
134            .await
135            .values()
136            .filter(|s| {
137                tenant_id
138                    .map(|t| s.tenant_id.as_deref() == Some(t))
139                    .unwrap_or(true)
140            })
141            .map(|s| s.id)
142            .collect())
143    }
144
145    async fn add_summary(&self, snapshot: SummarySnapshot) -> SessionResult<()> {
146        self.summaries
147            .write()
148            .await
149            .entry(snapshot.session_id.to_string())
150            .or_default()
151            .push(snapshot);
152        Ok(())
153    }
154
155    async fn get_summaries(&self, session_id: &SessionId) -> SessionResult<Vec<SummarySnapshot>> {
156        Ok(self
157            .summaries
158            .read()
159            .await
160            .get(&session_id.to_string())
161            .cloned()
162            .unwrap_or_default())
163    }
164
165    async fn enqueue(
166        &self,
167        session_id: &SessionId,
168        content: String,
169        priority: i32,
170    ) -> SessionResult<QueueItem> {
171        let item = QueueItem::enqueue(*session_id, content).priority(priority);
172        self.queue
173            .write()
174            .await
175            .entry(session_id.to_string())
176            .or_default()
177            .push(item.clone());
178        Ok(item)
179    }
180
181    async fn dequeue(&self, session_id: &SessionId) -> SessionResult<Option<QueueItem>> {
182        let mut queue = self.queue.write().await;
183        if let Some(items) = queue.get_mut(&session_id.to_string()) {
184            items.sort_by(|a, b| b.priority.cmp(&a.priority));
185            if let Some(pos) = items
186                .iter()
187                .position(|i| i.status == super::types::QueueStatus::Pending)
188            {
189                items[pos].start_processing();
190                return Ok(Some(items[pos].clone()));
191            }
192        }
193        Ok(None)
194    }
195
196    async fn cancel_queued(&self, item_id: Uuid) -> SessionResult<bool> {
197        for items in self.queue.write().await.values_mut() {
198            if let Some(item) = items.iter_mut().find(|i| i.id == item_id) {
199                item.cancel();
200                return Ok(true);
201            }
202        }
203        Ok(false)
204    }
205
206    async fn pending_queue(&self, session_id: &SessionId) -> SessionResult<Vec<QueueItem>> {
207        Ok(self
208            .queue
209            .read()
210            .await
211            .get(&session_id.to_string())
212            .map(|items| {
213                items
214                    .iter()
215                    .filter(|i| i.status == super::types::QueueStatus::Pending)
216                    .cloned()
217                    .collect()
218            })
219            .unwrap_or_default())
220    }
221
222    async fn cleanup_expired(&self) -> SessionResult<usize> {
223        // Hold all three write locks simultaneously to prevent races where a
224        // concurrent operation could observe a session removed from `sessions`
225        // but still present in `summaries` or `queue`.
226        let mut sessions = self.sessions.write().await;
227        let mut summaries = self.summaries.write().await;
228        let mut queue = self.queue.write().await;
229
230        let expired_keys: Vec<String> = sessions
231            .iter()
232            .filter(|(_, s)| s.is_expired())
233            .map(|(k, _)| k.clone())
234            .collect();
235
236        for key in &expired_keys {
237            sessions.remove(key);
238            summaries.remove(key);
239            queue.remove(key);
240        }
241
242        Ok(expired_keys.len())
243    }
244}
245
246pub struct PersistenceFactory;
247
248impl PersistenceFactory {
249    pub fn memory() -> Arc<dyn Persistence> {
250        Arc::new(MemoryPersistence::new())
251    }
252
253    /// Create a JSONL persistence backend (requires `jsonl` feature).
254    #[cfg(feature = "jsonl")]
255    pub async fn jsonl(
256        config: super::persistence_jsonl::JsonlConfig,
257    ) -> SessionResult<Arc<dyn Persistence>> {
258        Ok(Arc::new(
259            super::persistence_jsonl::JsonlPersistence::new(config).await?,
260        ))
261    }
262
263    /// Create a JSONL persistence backend with default configuration (requires `jsonl` feature).
264    #[cfg(feature = "jsonl")]
265    pub async fn jsonl_default() -> SessionResult<Arc<dyn Persistence>> {
266        Ok(Arc::new(
267            super::persistence_jsonl::JsonlPersistence::default_config().await?,
268        ))
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::session::state::SessionConfig;
276    use crate::types::ContentBlock;
277
278    #[tokio::test]
279    async fn test_save_load() {
280        let persistence = MemoryPersistence::new();
281        let session = Session::new(SessionConfig::default());
282        let id = session.id;
283
284        persistence.save(&session).await.unwrap();
285        let loaded = persistence.load(&id).await.unwrap();
286
287        assert!(loaded.is_some());
288        assert_eq!(loaded.unwrap().id, id);
289    }
290
291    #[tokio::test]
292    async fn test_delete() {
293        let persistence = MemoryPersistence::new();
294        let session = Session::new(SessionConfig::default());
295        let id = session.id;
296
297        persistence.save(&session).await.unwrap();
298        assert!(persistence.delete(&id).await.unwrap());
299        assert!(persistence.load(&id).await.unwrap().is_none());
300    }
301
302    #[tokio::test]
303    async fn test_list_by_tenant() {
304        let persistence = MemoryPersistence::new();
305
306        let mut s1 = Session::new(SessionConfig::default());
307        s1.tenant_id = Some("tenant-a".to_string());
308
309        let mut s2 = Session::new(SessionConfig::default());
310        s2.tenant_id = Some("tenant-b".to_string());
311
312        persistence.save(&s1).await.unwrap();
313        persistence.save(&s2).await.unwrap();
314
315        assert_eq!(persistence.list(None).await.unwrap().len(), 2);
316        assert_eq!(persistence.list(Some("tenant-a")).await.unwrap().len(), 1);
317    }
318
319    #[tokio::test]
320    async fn test_add_message() {
321        let persistence = MemoryPersistence::new();
322        let session = Session::new(SessionConfig::default());
323        let id = session.id;
324
325        persistence.save(&session).await.unwrap();
326        persistence
327            .add_message(&id, SessionMessage::user(vec![ContentBlock::text("Hello")]))
328            .await
329            .unwrap();
330
331        let loaded = persistence.load(&id).await.unwrap().unwrap();
332        assert_eq!(loaded.messages.len(), 1);
333    }
334
335    #[tokio::test]
336    async fn test_summaries() {
337        let persistence = MemoryPersistence::new();
338        let session = Session::new(SessionConfig::default());
339        let id = session.id;
340
341        persistence.save(&session).await.unwrap();
342        persistence
343            .add_summary(SummarySnapshot::new(id, "First"))
344            .await
345            .unwrap();
346        persistence
347            .add_summary(SummarySnapshot::new(id, "Second"))
348            .await
349            .unwrap();
350
351        let summaries = persistence.get_summaries(&id).await.unwrap();
352        assert_eq!(summaries.len(), 2);
353    }
354
355    #[tokio::test]
356    async fn test_queue_priority() {
357        let persistence = MemoryPersistence::new();
358        let session = Session::new(SessionConfig::default());
359        let id = session.id;
360
361        persistence.save(&session).await.unwrap();
362        persistence
363            .enqueue(&id, "Low".to_string(), 1)
364            .await
365            .unwrap();
366        persistence
367            .enqueue(&id, "High".to_string(), 10)
368            .await
369            .unwrap();
370
371        let next = persistence.dequeue(&id).await.unwrap().unwrap();
372        assert_eq!(next.content, "High");
373    }
374
375    #[tokio::test]
376    async fn test_cleanup_expired() {
377        let persistence = MemoryPersistence::new();
378        let config = SessionConfig {
379            ttl_secs: Some(0),
380            ..Default::default()
381        };
382        let session = Session::new(config);
383
384        persistence.save(&session).await.unwrap();
385        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
386
387        assert_eq!(persistence.cleanup_expired().await.unwrap(), 1);
388        assert_eq!(persistence.count().await, 0);
389    }
390}