1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{RwLock, oneshot};
8use tokio::task::JoinHandle;
9
10use crate::session::{
11 Persistence, Session, SessionConfig, SessionId, SessionMessage, SessionState, SessionType,
12};
13use crate::types::{ContentBlock, Message, Role};
14
15use super::AgentResult;
16
17struct TaskRuntime {
18 handle: Option<JoinHandle<()>>,
19 cancel_tx: Option<oneshot::Sender<()>>,
20}
21
22#[derive(Clone)]
23pub struct TaskRegistry {
24 runtime: Arc<RwLock<HashMap<String, TaskRuntime>>>,
25 persistence: Arc<dyn Persistence>,
26 parent_session_id: Option<SessionId>,
27 default_ttl: Option<Duration>,
28}
29
30impl TaskRegistry {
31 pub fn new(persistence: Arc<dyn Persistence>) -> Self {
32 Self {
33 runtime: Arc::new(RwLock::new(HashMap::new())),
34 persistence,
35 parent_session_id: None,
36 default_ttl: Some(Duration::from_secs(3600)),
37 }
38 }
39
40 pub fn with_parent_session(mut self, parent_id: SessionId) -> Self {
41 self.parent_session_id = Some(parent_id);
42 self
43 }
44
45 pub fn with_ttl(mut self, ttl: Duration) -> Self {
46 self.default_ttl = Some(ttl);
47 self
48 }
49
50 pub async fn register(
51 &self,
52 id: String,
53 agent_type: String,
54 description: String,
55 ) -> oneshot::Receiver<()> {
56 let (cancel_tx, cancel_rx) = oneshot::channel();
57
58 let config = SessionConfig {
59 ttl_secs: self.default_ttl.map(|d| d.as_secs()),
60 ..Default::default()
61 };
62
63 let session = match self.parent_session_id {
64 Some(parent_id) => Session::new_subagent(parent_id, &agent_type, &description, config),
65 None => {
66 let mut s = Session::new(config);
67 s.session_type = SessionType::Subagent {
68 agent_type,
69 description,
70 };
71 s
72 }
73 };
74
75 let session_id = SessionId::from(id.as_str());
76 let mut session = session;
77 session.id = session_id;
78 session.state = SessionState::Active;
79
80 let _ = self.persistence.save(&session).await;
81
82 let mut runtime = self.runtime.write().await;
83 runtime.insert(
84 id,
85 TaskRuntime {
86 handle: None,
87 cancel_tx: Some(cancel_tx),
88 },
89 );
90
91 cancel_rx
92 }
93
94 pub async fn set_handle(&self, id: &str, handle: JoinHandle<()>) {
95 let mut runtime = self.runtime.write().await;
96 if let Some(rt) = runtime.get_mut(id) {
97 rt.handle = Some(handle);
98 }
99 }
100
101 pub async fn complete(&self, id: &str, result: AgentResult) {
102 let session_id = SessionId::from(id);
103
104 if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
105 session.state = SessionState::Completed;
106
107 for msg in &result.messages {
108 let content: Vec<ContentBlock> = msg.content.clone();
109 let session_msg = match msg.role {
110 Role::User => SessionMessage::user(content),
111 Role::Assistant => SessionMessage::assistant(content),
112 };
113 session.add_message(session_msg);
114 }
115
116 let _ = self.persistence.save(&session).await;
117 }
118
119 let mut runtime = self.runtime.write().await;
120 runtime.remove(id);
121 }
122
123 pub async fn fail(&self, id: &str, error: String) {
124 let session_id = SessionId::from(id);
125
126 if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
127 session.state = SessionState::Failed;
128 session.error = Some(error);
129 let _ = self.persistence.save(&session).await;
130 }
131
132 let mut runtime = self.runtime.write().await;
133 runtime.remove(id);
134 }
135
136 pub async fn cancel(&self, id: &str) -> bool {
137 let session_id = SessionId::from(id);
138
139 let cancelled = {
140 let mut runtime = self.runtime.write().await;
141 if let Some(rt) = runtime.get_mut(id) {
142 if let Some(tx) = rt.cancel_tx.take() {
143 let _ = tx.send(());
144 }
145 if let Some(handle) = rt.handle.take() {
146 handle.abort();
147 }
148 runtime.remove(id);
149 true
150 } else {
151 false
152 }
153 };
154
155 if cancelled && let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
156 session.state = SessionState::Cancelled;
157 let _ = self.persistence.save(&session).await;
158 }
159
160 cancelled
161 }
162
163 pub async fn get_status(&self, id: &str) -> Option<SessionState> {
164 let session_id = SessionId::from(id);
165 self.persistence
166 .load(&session_id)
167 .await
168 .ok()
169 .flatten()
170 .map(|s| s.state)
171 }
172
173 pub async fn get_result(
174 &self,
175 id: &str,
176 ) -> Option<(SessionState, Option<String>, Option<String>)> {
177 let session_id = SessionId::from(id);
178 self.persistence
179 .load(&session_id)
180 .await
181 .ok()
182 .flatten()
183 .map(|s| {
184 let text = s.messages.last().and_then(|m| {
185 m.content.iter().find_map(|c| match c {
186 ContentBlock::Text { text, .. } => Some(text.clone()),
187 _ => None,
188 })
189 });
190 (s.state, text, s.error)
191 })
192 }
193
194 pub async fn wait_for_completion(
195 &self,
196 id: &str,
197 timeout: Duration,
198 ) -> Option<(SessionState, Option<String>, Option<String>)> {
199 let deadline = std::time::Instant::now() + timeout;
200 let poll_interval = Duration::from_millis(100);
201
202 loop {
203 if let Some((state, output, error)) = self.get_result(id).await {
204 if state != SessionState::Active && state != SessionState::WaitingForTools {
205 return Some((state, output, error));
206 }
207 } else {
208 return None;
209 }
210
211 if std::time::Instant::now() >= deadline {
212 return self.get_result(id).await;
213 }
214
215 tokio::time::sleep(poll_interval).await;
216 }
217 }
218
219 pub async fn list_running(&self) -> Vec<(String, String, Duration)> {
220 let runtime = self.runtime.read().await;
221 let mut result = Vec::new();
222
223 for id in runtime.keys() {
224 let session_id = SessionId::from(id.as_str());
225 if let Ok(Some(session)) = self.persistence.load(&session_id).await
226 && session.is_running()
227 {
228 let description = match &session.session_type {
229 SessionType::Subagent { description, .. } => description.clone(),
230 _ => String::new(),
231 };
232 let elapsed = (chrono::Utc::now() - session.created_at)
233 .to_std()
234 .unwrap_or_default();
235 result.push((id.clone(), description, elapsed));
236 }
237 }
238
239 result
240 }
241
242 pub async fn cleanup_completed(&self) -> usize {
243 self.persistence.cleanup_expired().await.unwrap_or(0)
244 }
245
246 pub async fn running_count(&self) -> usize {
247 self.runtime.read().await.len()
248 }
249
250 pub async fn save_messages(&self, id: &str, messages: Vec<Message>) {
251 let session_id = SessionId::from(id);
252
253 if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
254 for msg in messages {
255 let content: Vec<ContentBlock> = msg.content;
256 let session_msg = match msg.role {
257 Role::User => SessionMessage::user(content),
258 Role::Assistant => SessionMessage::assistant(content),
259 };
260 session.add_message(session_msg);
261 }
262 let _ = self.persistence.save(&session).await;
263 }
264 }
265
266 pub async fn get_messages(&self, id: &str) -> Option<Vec<Message>> {
267 let session_id = SessionId::from(id);
268 self.persistence
269 .load(&session_id)
270 .await
271 .ok()
272 .flatten()
273 .map(|s| s.to_api_messages())
274 }
275
276 pub async fn get_session(&self, id: &str) -> Option<Session> {
277 let session_id = SessionId::from(id);
278 self.persistence.load(&session_id).await.ok().flatten()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::agent::AgentState;
286 use crate::session::MemoryPersistence;
287 use crate::types::{StopReason, Usage};
288
289 fn test_registry() -> TaskRegistry {
290 TaskRegistry::new(Arc::new(MemoryPersistence::new()))
291 }
292
293 const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000001";
295 const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000002";
296 const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000003";
297 const TASK_4_UUID: &str = "00000000-0000-0000-0000-000000000004";
298
299 fn mock_result() -> AgentResult {
300 AgentResult {
301 text: "Test result".to_string(),
302 usage: Usage::default(),
303 tool_calls: 0,
304 iterations: 1,
305 stop_reason: StopReason::EndTurn,
306 state: AgentState::Completed,
307 metrics: Default::default(),
308 session_id: "test-session".to_string(),
309 structured_output: None,
310 messages: Vec::new(),
311 uuid: "test-uuid".to_string(),
312 }
313 }
314
315 #[tokio::test]
316 async fn test_register_and_complete() {
317 let registry = test_registry();
318 let _cancel_rx = registry
319 .register(TASK_1_UUID.into(), "Explore".into(), "Test task".into())
320 .await;
321
322 assert_eq!(
323 registry.get_status(TASK_1_UUID).await,
324 Some(SessionState::Active)
325 );
326
327 registry.complete(TASK_1_UUID, mock_result()).await;
328
329 let (status, _, _) = registry.get_result(TASK_1_UUID).await.unwrap();
330 assert_eq!(status, SessionState::Completed);
331 }
332
333 #[tokio::test]
334 async fn test_fail_task() {
335 let registry = test_registry();
336 registry
337 .register(TASK_2_UUID.into(), "Explore".into(), "Failing task".into())
338 .await;
339
340 registry
341 .fail(TASK_2_UUID, "Something went wrong".into())
342 .await;
343
344 let (status, _, error) = registry.get_result(TASK_2_UUID).await.unwrap();
345 assert_eq!(status, SessionState::Failed);
346 assert_eq!(error, Some("Something went wrong".to_string()));
347 }
348
349 #[tokio::test]
350 async fn test_cancel_task() {
351 let registry = test_registry();
352 registry
353 .register(
354 TASK_3_UUID.into(),
355 "Explore".into(),
356 "Cancellable task".into(),
357 )
358 .await;
359
360 assert!(registry.cancel(TASK_3_UUID).await);
361 assert_eq!(
362 registry.get_status(TASK_3_UUID).await,
363 Some(SessionState::Cancelled)
364 );
365
366 assert!(!registry.cancel(TASK_3_UUID).await);
367 }
368
369 #[tokio::test]
370 async fn test_not_found() {
371 let registry = test_registry();
372 assert!(registry.get_status("nonexistent").await.is_none());
373 assert!(registry.get_result("nonexistent").await.is_none());
374 }
375
376 #[tokio::test]
377 async fn test_messages() {
378 let registry = test_registry();
379 registry
380 .register(TASK_4_UUID.into(), "Explore".into(), "Message test".into())
381 .await;
382
383 let messages = vec![
384 Message::user("Hello"),
385 Message {
386 role: Role::Assistant,
387 content: vec![ContentBlock::text("Hi there!")],
388 },
389 ];
390
391 registry.save_messages(TASK_4_UUID, messages).await;
392
393 let loaded = registry.get_messages(TASK_4_UUID).await.unwrap();
394 assert_eq!(loaded.len(), 2);
395 }
396}