claude_agent/session/
manager.rs1use std::sync::Arc;
4
5use super::persistence::{MemoryPersistence, Persistence};
6use super::state::{Session, SessionConfig, SessionId, SessionMessage, SessionState};
7use super::{SessionError, SessionResult};
8
9pub struct SessionManager {
10 persistence: Arc<dyn Persistence>,
11}
12
13impl SessionManager {
14 pub fn new(persistence: Arc<dyn Persistence>) -> Self {
15 Self { persistence }
16 }
17
18 pub fn in_memory() -> Self {
19 Self::new(Arc::new(MemoryPersistence::new()))
20 }
21
22 pub fn backend_name(&self) -> &str {
23 self.persistence.name()
24 }
25
26 pub async fn create(&self, config: SessionConfig) -> SessionResult<Session> {
27 let session = Session::new(config);
28 self.persistence.save(&session).await?;
29 Ok(session)
30 }
31
32 pub async fn create_with_tenant(
33 &self,
34 config: SessionConfig,
35 tenant_id: impl Into<String>,
36 ) -> SessionResult<Session> {
37 let mut session = Session::new(config);
38 session.tenant_id = Some(tenant_id.into());
39 self.persistence.save(&session).await?;
40 Ok(session)
41 }
42
43 pub async fn get(&self, id: &SessionId) -> SessionResult<Session> {
44 let session = self
45 .persistence
46 .load(id)
47 .await?
48 .ok_or_else(|| SessionError::NotFound { id: id.to_string() })?;
49
50 if session.is_expired() {
51 self.persistence.delete(id).await?;
52 return Err(SessionError::Expired { id: id.to_string() });
53 }
54
55 Ok(session)
56 }
57
58 pub async fn get_by_str(&self, id: &str) -> SessionResult<Session> {
59 self.get(&SessionId::from(id)).await
60 }
61
62 pub async fn update(&self, session: &Session) -> SessionResult<()> {
63 self.persistence.save(session).await
64 }
65
66 pub async fn add_message(
67 &self,
68 session_id: &SessionId,
69 message: SessionMessage,
70 ) -> SessionResult<()> {
71 self.persistence.add_message(session_id, message).await
72 }
73
74 pub async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
75 self.persistence.delete(id).await
76 }
77
78 pub async fn list(&self) -> SessionResult<Vec<SessionId>> {
79 self.persistence.list(None).await
80 }
81
82 pub async fn list_for_tenant(&self, tenant_id: &str) -> SessionResult<Vec<SessionId>> {
83 self.persistence.list(Some(tenant_id)).await
84 }
85
86 pub async fn fork(&self, id: &SessionId) -> SessionResult<Session> {
87 let original = self.get(id).await?;
88
89 let mut forked = Session::new(original.config.clone());
90 forked.tenant_id = original.tenant_id.clone();
91 forked.summary = original.summary.clone();
92
93 for msg in original.get_current_branch() {
95 let mut cloned = msg.clone();
96 cloned.is_sidechain = true;
97 forked.messages.push(cloned);
98 }
99
100 if let Some(last) = forked.messages.last() {
102 forked.current_leaf_id = Some(last.id.clone());
103 }
104
105 self.persistence.save(&forked).await?;
106 Ok(forked)
107 }
108
109 pub async fn complete(&self, id: &SessionId) -> SessionResult<()> {
110 let mut session = self.get(id).await?;
111 session.set_state(SessionState::Completed);
112 self.persistence.save(&session).await
113 }
114
115 pub async fn set_error(&self, id: &SessionId) -> SessionResult<()> {
116 let mut session = self.get(id).await?;
117 session.set_state(SessionState::Failed);
118 self.persistence.save(&session).await
119 }
120
121 pub async fn cleanup_expired(&self) -> SessionResult<usize> {
122 self.persistence.cleanup_expired().await
123 }
124
125 pub async fn exists(&self, id: &SessionId) -> SessionResult<bool> {
126 match self.persistence.load(id).await? {
127 Some(session) => Ok(!session.is_expired()),
128 None => Ok(false),
129 }
130 }
131}
132
133impl Default for SessionManager {
134 fn default() -> Self {
135 Self::in_memory()
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use super::*;
142 use crate::types::ContentBlock;
143
144 #[tokio::test]
145 async fn test_session_manager_create() {
146 let manager = SessionManager::in_memory();
147 let session = manager.create(SessionConfig::default()).await.unwrap();
148
149 assert_eq!(session.state, SessionState::Created);
150 assert!(session.messages.is_empty());
151 }
152
153 #[tokio::test]
154 async fn test_session_manager_get() {
155 let manager = SessionManager::in_memory();
156 let session = manager.create(SessionConfig::default()).await.unwrap();
157 let session_id = session.id;
158
159 let restored = manager.get(&session_id).await.unwrap();
160 assert_eq!(restored.id, session_id);
161 }
162
163 #[tokio::test]
164 async fn test_session_manager_not_found() {
165 let manager = SessionManager::in_memory();
166 let fake_id = SessionId::new();
167
168 let result = manager.get(&fake_id).await;
169 assert!(matches!(result, Err(SessionError::NotFound { .. })));
170 }
171
172 #[tokio::test]
173 async fn test_session_manager_add_message() {
174 let manager = SessionManager::in_memory();
175 let session = manager.create(SessionConfig::default()).await.unwrap();
176 let session_id = session.id;
177
178 let message = SessionMessage::user(vec![ContentBlock::text("Hello")]);
179 manager.add_message(&session_id, message).await.unwrap();
180
181 let restored = manager.get(&session_id).await.unwrap();
182 assert_eq!(restored.messages.len(), 1);
183 }
184
185 #[tokio::test]
186 async fn test_session_manager_fork() {
187 let manager = SessionManager::in_memory();
188
189 let session = manager.create(SessionConfig::default()).await.unwrap();
191 let session_id = session.id;
192
193 let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
194 manager.add_message(&session_id, msg1).await.unwrap();
195
196 let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Hi!")]);
197 manager.add_message(&session_id, msg2).await.unwrap();
198
199 let forked = manager.fork(&session_id).await.unwrap();
201
202 assert_eq!(forked.messages.len(), 2);
204 assert_ne!(forked.id, session_id);
205
206 assert!(forked.messages.iter().all(|m| m.is_sidechain));
208 }
209
210 #[tokio::test]
211 async fn test_session_manager_complete() {
212 let manager = SessionManager::in_memory();
213 let session = manager.create(SessionConfig::default()).await.unwrap();
214 let session_id = session.id;
215
216 manager.complete(&session_id).await.unwrap();
217
218 let completed = manager.get(&session_id).await.unwrap();
219 assert_eq!(completed.state, SessionState::Completed);
220 }
221
222 #[tokio::test]
223 async fn test_session_manager_tenant_filtering() {
224 let manager = SessionManager::in_memory();
225
226 let _s1 = manager
227 .create_with_tenant(SessionConfig::default(), "tenant-a")
228 .await
229 .unwrap();
230 let _s2 = manager
231 .create_with_tenant(SessionConfig::default(), "tenant-a")
232 .await
233 .unwrap();
234 let _s3 = manager
235 .create_with_tenant(SessionConfig::default(), "tenant-b")
236 .await
237 .unwrap();
238
239 let all = manager.list().await.unwrap();
240 assert_eq!(all.len(), 3);
241
242 let tenant_a = manager.list_for_tenant("tenant-a").await.unwrap();
243 assert_eq!(tenant_a.len(), 2);
244
245 let tenant_b = manager.list_for_tenant("tenant-b").await.unwrap();
246 assert_eq!(tenant_b.len(), 1);
247 }
248
249 #[tokio::test]
250 async fn test_session_manager_expired() {
251 let manager = SessionManager::in_memory();
252
253 let config = SessionConfig {
254 ttl_secs: Some(0), ..Default::default()
256 };
257 let session = manager.create(config).await.unwrap();
258 let session_id = session.id;
259
260 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
262
263 let result = manager.get(&session_id).await;
264 assert!(matches!(result, Err(SessionError::Expired { .. })));
265 }
266}