claude_agent/session/
manager.rs1use std::sync::Arc;
4
5use super::persistence::{MemoryPersistence, Persistence};
6use super::state::{Session, SessionConfig, SessionId, SessionMessage, SessionState};
7use super::{SessionError, SessionResult};
8
9pub struct SessionManager {
10 persistence: Arc<dyn Persistence>,
11}
12
13impl SessionManager {
14 pub fn new(persistence: Arc<dyn Persistence>) -> Self {
15 Self { persistence }
16 }
17
18 pub fn in_memory() -> Self {
19 Self::new(Arc::new(MemoryPersistence::new()))
20 }
21
22 pub async fn create(&self, config: SessionConfig) -> SessionResult<Session> {
23 let session = Session::new(config);
24 self.persistence.save(&session).await?;
25 Ok(session)
26 }
27
28 pub async fn create_with_tenant(
29 &self,
30 config: SessionConfig,
31 tenant_id: impl Into<String>,
32 ) -> SessionResult<Session> {
33 let mut session = Session::new(config);
34 session.tenant_id = Some(tenant_id.into());
35 self.persistence.save(&session).await?;
36 Ok(session)
37 }
38
39 pub async fn get(&self, id: &SessionId) -> SessionResult<Session> {
40 let session = self
41 .persistence
42 .load(id)
43 .await?
44 .ok_or_else(|| SessionError::NotFound { id: id.to_string() })?;
45
46 if session.is_expired() {
47 self.persistence.delete(id).await?;
48 return Err(SessionError::Expired { id: id.to_string() });
49 }
50
51 Ok(session)
52 }
53
54 pub async fn get_by_str(&self, id: &str) -> SessionResult<Session> {
55 self.get(&SessionId::from(id)).await
56 }
57
58 pub async fn update(&self, session: &Session) -> SessionResult<()> {
59 self.persistence.save(session).await
60 }
61
62 pub async fn add_message(
63 &self,
64 session_id: &SessionId,
65 message: SessionMessage,
66 ) -> SessionResult<()> {
67 self.persistence.add_message(session_id, message).await
68 }
69
70 pub async fn delete(&self, id: &SessionId) -> SessionResult<bool> {
71 self.persistence.delete(id).await
72 }
73
74 pub async fn list(&self) -> SessionResult<Vec<SessionId>> {
75 self.persistence.list(None).await
76 }
77
78 pub async fn list_for_tenant(&self, tenant_id: &str) -> SessionResult<Vec<SessionId>> {
79 self.persistence.list(Some(tenant_id)).await
80 }
81
82 pub async fn fork(&self, id: &SessionId) -> SessionResult<Session> {
83 let original = self.get(id).await?;
84
85 let mut forked = Session::new(original.config.clone());
86 forked.parent_id = Some(original.id);
87 forked.tenant_id = original.tenant_id.clone();
88 forked.summary = original.summary.clone();
89
90 for msg in original.current_branch() {
92 let mut cloned = msg.clone();
93 cloned.is_sidechain = true;
94 forked.messages.push(cloned);
95 }
96
97 if let Some(last) = forked.messages.last() {
99 forked.current_leaf_id = Some(last.id.clone());
100 }
101
102 self.persistence.save(&forked).await?;
103 Ok(forked)
104 }
105
106 pub async fn complete(&self, id: &SessionId) -> SessionResult<()> {
107 let mut session = self.get(id).await?;
108 session.set_state(SessionState::Completed);
109 self.persistence.save(&session).await
110 }
111
112 pub async fn set_error(&self, id: &SessionId) -> SessionResult<()> {
113 let mut session = self.get(id).await?;
114 session.set_state(SessionState::Failed);
115 self.persistence.save(&session).await
116 }
117
118 pub async fn cleanup_expired(&self) -> SessionResult<usize> {
119 self.persistence.cleanup_expired().await
120 }
121
122 pub async fn exists(&self, id: &SessionId) -> SessionResult<bool> {
123 match self.persistence.load(id).await? {
124 Some(session) => Ok(!session.is_expired()),
125 None => Ok(false),
126 }
127 }
128}
129
130impl Default for SessionManager {
131 fn default() -> Self {
132 Self::in_memory()
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139 use crate::types::ContentBlock;
140
141 #[tokio::test]
142 async fn test_session_manager_create() {
143 let manager = SessionManager::in_memory();
144 let session = manager.create(SessionConfig::default()).await.unwrap();
145
146 assert_eq!(session.state, SessionState::Created);
147 assert!(session.messages.is_empty());
148 }
149
150 #[tokio::test]
151 async fn test_session_manager_get() {
152 let manager = SessionManager::in_memory();
153 let session = manager.create(SessionConfig::default()).await.unwrap();
154 let session_id = session.id;
155
156 let restored = manager.get(&session_id).await.unwrap();
157 assert_eq!(restored.id, session_id);
158 }
159
160 #[tokio::test]
161 async fn test_session_manager_not_found() {
162 let manager = SessionManager::in_memory();
163 let fake_id = SessionId::new();
164
165 let result = manager.get(&fake_id).await;
166 assert!(matches!(result, Err(SessionError::NotFound { .. })));
167 }
168
169 #[tokio::test]
170 async fn test_session_manager_add_message() {
171 let manager = SessionManager::in_memory();
172 let session = manager.create(SessionConfig::default()).await.unwrap();
173 let session_id = session.id;
174
175 let message = SessionMessage::user(vec![ContentBlock::text("Hello")]);
176 manager.add_message(&session_id, message).await.unwrap();
177
178 let restored = manager.get(&session_id).await.unwrap();
179 assert_eq!(restored.messages.len(), 1);
180 }
181
182 #[tokio::test]
183 async fn test_session_manager_fork() {
184 let manager = SessionManager::in_memory();
185
186 let session = manager.create(SessionConfig::default()).await.unwrap();
188 let session_id = session.id;
189
190 let msg1 = SessionMessage::user(vec![ContentBlock::text("Hello")]);
191 manager.add_message(&session_id, msg1).await.unwrap();
192
193 let msg2 = SessionMessage::assistant(vec![ContentBlock::text("Hi!")]);
194 manager.add_message(&session_id, msg2).await.unwrap();
195
196 let forked = manager.fork(&session_id).await.unwrap();
198
199 assert_eq!(forked.messages.len(), 2);
201 assert_ne!(forked.id, session_id);
202 assert_eq!(forked.parent_id, Some(session_id));
203
204 assert!(forked.messages.iter().all(|m| m.is_sidechain));
206 }
207
208 #[tokio::test]
209 async fn test_session_manager_complete() {
210 let manager = SessionManager::in_memory();
211 let session = manager.create(SessionConfig::default()).await.unwrap();
212 let session_id = session.id;
213
214 manager.complete(&session_id).await.unwrap();
215
216 let completed = manager.get(&session_id).await.unwrap();
217 assert_eq!(completed.state, SessionState::Completed);
218 }
219
220 #[tokio::test]
221 async fn test_session_manager_tenant_filtering() {
222 let manager = SessionManager::in_memory();
223
224 let _s1 = manager
225 .create_with_tenant(SessionConfig::default(), "tenant-a")
226 .await
227 .unwrap();
228 let _s2 = manager
229 .create_with_tenant(SessionConfig::default(), "tenant-a")
230 .await
231 .unwrap();
232 let _s3 = manager
233 .create_with_tenant(SessionConfig::default(), "tenant-b")
234 .await
235 .unwrap();
236
237 let all = manager.list().await.unwrap();
238 assert_eq!(all.len(), 3);
239
240 let tenant_a = manager.list_for_tenant("tenant-a").await.unwrap();
241 assert_eq!(tenant_a.len(), 2);
242
243 let tenant_b = manager.list_for_tenant("tenant-b").await.unwrap();
244 assert_eq!(tenant_b.len(), 1);
245 }
246
247 #[tokio::test]
248 async fn test_session_manager_expired() {
249 let manager = SessionManager::in_memory();
250
251 let config = SessionConfig {
252 ttl_secs: Some(0), ..Default::default()
254 };
255 let session = manager.create(config).await.unwrap();
256 let session_id = session.id;
257
258 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
260
261 let result = manager.get(&session_id).await;
262 assert!(matches!(result, Err(SessionError::Expired { .. })));
263 }
264}