claude_agent/session/
manager.rs

1//! Session lifecycle management.
2
3use std::sync::Arc;
4
5use super::persistence::{MemoryPersistence, Persistence};
6use super::state::{Session, SessionConfig, SessionId, SessionMessage, SessionState};
7use super::{SessionError, SessionResult};
8
9pub struct SessionManager {
10    persistence: Arc<dyn Persistence>,
11}
12
13impl SessionManager {
14    pub fn new(persistence: Arc<dyn Persistence>) -> Self {
15        Self { persistence }
16    }
17
18    pub fn in_memory() -> Self {
19        Self::new(Arc::new(MemoryPersistence::new()))
20    }
21
22    pub fn backend_name(&self) -> &str {
23        self.persistence.name()
24    }
25
26    pub async fn create(&self, config: SessionConfig) -> SessionResult<Session> {
27        let session = Session::new(config);
28        self.persistence.save(&session).await?;
29        Ok(session)
30    }
31
32    pub async fn create_with_tenant(
33        &self,
34        config: SessionConfig,
35        tenant_id: impl Into<String>,
36    ) -> SessionResult<Session> {
37        let mut session = Session::new(config);
38        session.tenant_id = Some(tenant_id.into());
39        self.persistence.save(&session).await?;
40        Ok(session)
41    }
42
43    pub async fn get(&self, id: &SessionId) -> SessionResult<Session> {
44        let session = self
45            .persistence
46            .load(id)
47            .await?
48            .ok_or_else(|| SessionError::NotFound { id: id.to_string() })?;
49
50        if session.is_expired() {
51            self.persistence.delete(id).await?;
52            return Err(SessionError::Expired { id: id.to_string() });
53        }
54
55        Ok(session)
56    }
57
58    pub async fn get_by_str(&self, id: &str) -> SessionResult<Session> {
59        self.get(&SessionId::from(id)).await
60    }
61
62    pub async fn update(&self, session: &Session) -> SessionResult<()> {
63        self.persistence.save(session).await
64    }
65
66    pub async fn add_message(
67        &self,
68        session_id: &SessionId,
69        message: SessionMessage,
70    ) -> SessionResult<()> {
71        self.persistence.add_message(session_id, message).await
72    }
73
74    pub async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
75        self.persistence.delete(id).await
76    }
77
78    pub async fn list(&self) -> SessionResult<Vec<SessionId>> {
79        self.persistence.list(None).await
80    }
81
82    pub async fn list_for_tenant(&self, tenant_id: &str) -> SessionResult<Vec<SessionId>> {
83        self.persistence.list(Some(tenant_id)).await
84    }
85
86    pub async fn fork(&self, id: &SessionId) -> SessionResult<Session> {
87        let original = self.get(id).await?;
88
89        let mut forked = Session::new(original.config.clone());
90        forked.tenant_id = original.tenant_id.clone();
91        forked.summary = original.summary.clone();
92
93        // Copy messages up to current leaf
94        for msg in original.get_current_branch() {
95            let mut cloned = msg.clone();
96            cloned.is_sidechain = true;
97            forked.messages.push(cloned);
98        }
99
100        // Update leaf pointer
101        if let Some(last) = forked.messages.last() {
102            forked.current_leaf_id = Some(last.id.clone());
103        }
104
105        self.persistence.save(&forked).await?;
106        Ok(forked)
107    }
108
109    pub async fn complete(&self, id: &SessionId) -> SessionResult<()> {
110        let mut session = self.get(id).await?;
111        session.set_state(SessionState::Completed);
112        self.persistence.save(&session).await
113    }
114
115    pub async fn set_error(&self, id: &SessionId) -> SessionResult<()> {
116        let mut session = self.get(id).await?;
117        session.set_state(SessionState::Failed);
118        self.persistence.save(&session).await
119    }
120
121    pub async fn cleanup_expired(&self) -> SessionResult<usize> {
122        self.persistence.cleanup_expired().await
123    }
124
125    pub async fn exists(&self, id: &SessionId) -> SessionResult<bool> {
126        match self.persistence.load(id).await? {
127            Some(session) => Ok(!session.is_expired()),
128            None => Ok(false),
129        }
130    }
131}
132
133impl Default for SessionManager {
134    fn default() -> Self {
135        Self::in_memory()
136    }
137}
138
139#[cfg(test)]
140mod tests {
141    use super::*;
142    use crate::types::ContentBlock;
143
144    #[tokio::test]
145    async fn test_session_manager_create() {
146        let manager = SessionManager::in_memory();
147        let session = manager.create(SessionConfig::default()).await.unwrap();
148
149        assert_eq!(session.state, SessionState::Created);
150        assert!(session.messages.is_empty());
151    }
152
153    #[tokio::test]
154    async fn test_session_manager_get() {
155        let manager = SessionManager::in_memory();
156        let session = manager.create(SessionConfig::default()).await.unwrap();
157        let session_id = session.id;
158
159        let restored = manager.get(&session_id).await.unwrap();
160        assert_eq!(restored.id, session_id);
161    }
162
163    #[tokio::test]
164    async fn test_session_manager_not_found() {
165        let manager = SessionManager::in_memory();
166        let fake_id = SessionId::new();
167
168        let result = manager.get(&fake_id).await;
169        assert!(matches!(result, Err(SessionError::NotFound { .. })));
170    }
171
172    #[tokio::test]
173    async fn test_session_manager_add_message() {
174        let manager = SessionManager::in_memory();
175        let session = manager.create(SessionConfig::default()).await.unwrap();
176        let session_id = session.id;
177
178        let message = SessionMessage::user(vec![ContentBlock::text("Hello")]);
179        manager.add_message(&session_id, message).await.unwrap();
180
181        let restored = manager.get(&session_id).await.unwrap();
182        assert_eq!(restored.messages.len(), 1);
183    }
184
185    #[tokio::test]
186    async fn test_session_manager_fork() {
187        let manager = SessionManager::in_memory();
188
189        // Create original session with messages
190        let session = manager.create(SessionConfig::default()).await.unwrap();
191        let session_id = session.id;
192
193        let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
194        manager.add_message(&session_id, msg1).await.unwrap();
195
196        let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Hi!")]);
197        manager.add_message(&session_id, msg2).await.unwrap();
198
199        // Fork
200        let forked = manager.fork(&session_id).await.unwrap();
201
202        // Forked session should have the same messages
203        assert_eq!(forked.messages.len(), 2);
204        assert_ne!(forked.id, session_id);
205
206        // Messages should be marked as sidechain
207        assert!(forked.messages.iter().all(|m| m.is_sidechain));
208    }
209
210    #[tokio::test]
211    async fn test_session_manager_complete() {
212        let manager = SessionManager::in_memory();
213        let session = manager.create(SessionConfig::default()).await.unwrap();
214        let session_id = session.id;
215
216        manager.complete(&session_id).await.unwrap();
217
218        let completed = manager.get(&session_id).await.unwrap();
219        assert_eq!(completed.state, SessionState::Completed);
220    }
221
222    #[tokio::test]
223    async fn test_session_manager_tenant_filtering() {
224        let manager = SessionManager::in_memory();
225
226        let _s1 = manager
227            .create_with_tenant(SessionConfig::default(), "tenant-a")
228            .await
229            .unwrap();
230        let _s2 = manager
231            .create_with_tenant(SessionConfig::default(), "tenant-a")
232            .await
233            .unwrap();
234        let _s3 = manager
235            .create_with_tenant(SessionConfig::default(), "tenant-b")
236            .await
237            .unwrap();
238
239        let all = manager.list().await.unwrap();
240        assert_eq!(all.len(), 3);
241
242        let tenant_a = manager.list_for_tenant("tenant-a").await.unwrap();
243        assert_eq!(tenant_a.len(), 2);
244
245        let tenant_b = manager.list_for_tenant("tenant-b").await.unwrap();
246        assert_eq!(tenant_b.len(), 1);
247    }
248
249    #[tokio::test]
250    async fn test_session_manager_expired() {
251        let manager = SessionManager::in_memory();
252
253        let config = SessionConfig {
254            ttl_secs: Some(0), // Expire immediately
255            ..Default::default()
256        };
257        let session = manager.create(config).await.unwrap();
258        let session_id = session.id;
259
260        // Wait for expiry
261        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
262
263        let result = manager.get(&session_id).await;
264        assert!(matches!(result, Err(SessionError::Expired { .. })));
265    }
266}