Skip to main content

claude_agent/session/
manager.rs

1//! Session lifecycle management.
2
3use std::sync::Arc;
4
5use super::persistence::{MemoryPersistence, Persistence};
6use super::state::{Session, SessionConfig, SessionId, SessionMessage, SessionState};
7use super::{SessionError, SessionResult};
8
9pub struct SessionManager {
10    persistence: Arc<dyn Persistence>,
11}
12
13impl SessionManager {
14    pub fn new(persistence: Arc<dyn Persistence>) -> Self {
15        Self { persistence }
16    }
17
18    pub fn in_memory() -> Self {
19        Self::new(Arc::new(MemoryPersistence::new()))
20    }
21
22    pub async fn create(&self, config: SessionConfig) -> SessionResult<Session> {
23        let session = Session::new(config);
24        self.persistence.save(&session).await?;
25        Ok(session)
26    }
27
28    pub async fn create_with_tenant(
29        &self,
30        config: SessionConfig,
31        tenant_id: impl Into<String>,
32    ) -> SessionResult<Session> {
33        let mut session = Session::new(config);
34        session.tenant_id = Some(tenant_id.into());
35        self.persistence.save(&session).await?;
36        Ok(session)
37    }
38
39    pub async fn get(&self, id: &SessionId) -> SessionResult<Session> {
40        let session = self
41            .persistence
42            .load(id)
43            .await?
44            .ok_or_else(|| SessionError::NotFound { id: id.to_string() })?;
45
46        if session.is_expired() {
47            self.persistence.delete(id).await?;
48            return Err(SessionError::Expired { id: id.to_string() });
49        }
50
51        Ok(session)
52    }
53
54    pub async fn get_by_str(&self, id: &str) -> SessionResult<Session> {
55        self.get(&SessionId::from(id)).await
56    }
57
58    pub async fn update(&self, session: &Session) -> SessionResult<()> {
59        self.persistence.save(session).await
60    }
61
62    pub async fn add_message(
63        &self,
64        session_id: &SessionId,
65        message: SessionMessage,
66    ) -> SessionResult<()> {
67        self.persistence.add_message(session_id, message).await
68    }
69
70    pub async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
71        self.persistence.delete(id).await
72    }
73
74    pub async fn list(&self) -> SessionResult<Vec<SessionId>> {
75        self.persistence.list(None).await
76    }
77
78    pub async fn list_for_tenant(&self, tenant_id: &str) -> SessionResult<Vec<SessionId>> {
79        self.persistence.list(Some(tenant_id)).await
80    }
81
82    pub async fn fork(&self, id: &SessionId) -> SessionResult<Session> {
83        let original = self.get(id).await?;
84
85        let mut forked = Session::new(original.config.clone());
86        forked.parent_id = Some(original.id);
87        forked.tenant_id = original.tenant_id.clone();
88        forked.summary = original.summary.clone();
89
90        // Copy messages up to current leaf
91        for msg in original.current_branch() {
92            let mut cloned = msg.clone();
93            cloned.is_sidechain = true;
94            forked.messages.push(cloned);
95        }
96
97        // Update leaf pointer
98        if let Some(last) = forked.messages.last() {
99            forked.current_leaf_id = Some(last.id.clone());
100        }
101
102        self.persistence.save(&forked).await?;
103        Ok(forked)
104    }
105
106    pub async fn complete(&self, id: &SessionId) -> SessionResult<()> {
107        let mut session = self.get(id).await?;
108        session.set_state(SessionState::Completed);
109        self.persistence.save(&session).await
110    }
111
112    pub async fn set_error(&self, id: &SessionId) -> SessionResult<()> {
113        let mut session = self.get(id).await?;
114        session.set_state(SessionState::Failed);
115        self.persistence.save(&session).await
116    }
117
118    pub async fn cleanup_expired(&self) -> SessionResult<usize> {
119        self.persistence.cleanup_expired().await
120    }
121
122    pub async fn exists(&self, id: &SessionId) -> SessionResult<bool> {
123        match self.persistence.load(id).await? {
124            Some(session) => Ok(!session.is_expired()),
125            None => Ok(false),
126        }
127    }
128}
129
130impl Default for SessionManager {
131    fn default() -> Self {
132        Self::in_memory()
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::types::ContentBlock;
140
141    #[tokio::test]
142    async fn test_session_manager_create() {
143        let manager = SessionManager::in_memory();
144        let session = manager.create(SessionConfig::default()).await.unwrap();
145
146        assert_eq!(session.state, SessionState::Created);
147        assert!(session.messages.is_empty());
148    }
149
150    #[tokio::test]
151    async fn test_session_manager_get() {
152        let manager = SessionManager::in_memory();
153        let session = manager.create(SessionConfig::default()).await.unwrap();
154        let session_id = session.id;
155
156        let restored = manager.get(&session_id).await.unwrap();
157        assert_eq!(restored.id, session_id);
158    }
159
160    #[tokio::test]
161    async fn test_session_manager_not_found() {
162        let manager = SessionManager::in_memory();
163        let fake_id = SessionId::new();
164
165        let result = manager.get(&fake_id).await;
166        assert!(matches!(result, Err(SessionError::NotFound { .. })));
167    }
168
169    #[tokio::test]
170    async fn test_session_manager_add_message() {
171        let manager = SessionManager::in_memory();
172        let session = manager.create(SessionConfig::default()).await.unwrap();
173        let session_id = session.id;
174
175        let message = SessionMessage::user(vec![ContentBlock::text("Hello")]);
176        manager.add_message(&session_id, message).await.unwrap();
177
178        let restored = manager.get(&session_id).await.unwrap();
179        assert_eq!(restored.messages.len(), 1);
180    }
181
182    #[tokio::test]
183    async fn test_session_manager_fork() {
184        let manager = SessionManager::in_memory();
185
186        // Create original session with messages
187        let session = manager.create(SessionConfig::default()).await.unwrap();
188        let session_id = session.id;
189
190        let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
191        manager.add_message(&session_id, msg1).await.unwrap();
192
193        let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Hi!")]);
194        manager.add_message(&session_id, msg2).await.unwrap();
195
196        // Fork
197        let forked = manager.fork(&session_id).await.unwrap();
198
199        // Forked session should have the same messages
200        assert_eq!(forked.messages.len(), 2);
201        assert_ne!(forked.id, session_id);
202        assert_eq!(forked.parent_id, Some(session_id));
203
204        // Messages should be marked as sidechain
205        assert!(forked.messages.iter().all(|m| m.is_sidechain));
206    }
207
208    #[tokio::test]
209    async fn test_session_manager_complete() {
210        let manager = SessionManager::in_memory();
211        let session = manager.create(SessionConfig::default()).await.unwrap();
212        let session_id = session.id;
213
214        manager.complete(&session_id).await.unwrap();
215
216        let completed = manager.get(&session_id).await.unwrap();
217        assert_eq!(completed.state, SessionState::Completed);
218    }
219
220    #[tokio::test]
221    async fn test_session_manager_tenant_filtering() {
222        let manager = SessionManager::in_memory();
223
224        let _s1 = manager
225            .create_with_tenant(SessionConfig::default(), "tenant-a")
226            .await
227            .unwrap();
228        let _s2 = manager
229            .create_with_tenant(SessionConfig::default(), "tenant-a")
230            .await
231            .unwrap();
232        let _s3 = manager
233            .create_with_tenant(SessionConfig::default(), "tenant-b")
234            .await
235            .unwrap();
236
237        let all = manager.list().await.unwrap();
238        assert_eq!(all.len(), 3);
239
240        let tenant_a = manager.list_for_tenant("tenant-a").await.unwrap();
241        assert_eq!(tenant_a.len(), 2);
242
243        let tenant_b = manager.list_for_tenant("tenant-b").await.unwrap();
244        assert_eq!(tenant_b.len(), 1);
245    }
246
247    #[tokio::test]
248    async fn test_session_manager_expired() {
249        let manager = SessionManager::in_memory();
250
251        let config = SessionConfig {
252            ttl_secs: Some(0), // Expire immediately
253            ..Default::default()
254        };
255        let session = manager.create(config).await.unwrap();
256        let session_id = session.id;
257
258        // Wait for expiry
259        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
260
261        let result = manager.get(&session_id).await;
262        assert!(matches!(result, Err(SessionError::Expired { .. })));
263    }
264}