1use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{RwLock, oneshot};
8use tokio::task::JoinHandle;
9use tracing::warn;
10
11use crate::session::{
12 Persistence, Session, SessionConfig, SessionId, SessionMessage, SessionState, SessionType,
13};
14use crate::types::{ContentBlock, Message, Role};
15
16use super::AgentResult;
17
18struct TaskRuntime {
19 handle: Option<JoinHandle<()>>,
20 cancel_tx: Option<oneshot::Sender<()>>,
21}
22
23#[derive(Clone)]
24pub struct TaskRegistry {
25 runtime: Arc<RwLock<HashMap<String, TaskRuntime>>>,
26 persistence: Arc<dyn Persistence>,
27 parent_session_id: Option<SessionId>,
28 default_ttl: Option<Duration>,
29}
30
31impl TaskRegistry {
32 pub fn new(persistence: Arc<dyn Persistence>) -> Self {
33 Self {
34 runtime: Arc::new(RwLock::new(HashMap::new())),
35 persistence,
36 parent_session_id: None,
37 default_ttl: Some(Duration::from_secs(3600)),
38 }
39 }
40
41 pub fn parent_session(mut self, parent_id: SessionId) -> Self {
42 self.parent_session_id = Some(parent_id);
43 self
44 }
45
46 pub fn ttl(mut self, ttl: Duration) -> Self {
47 self.default_ttl = Some(ttl);
48 self
49 }
50
51 pub async fn register(
52 &self,
53 id: String,
54 agent_type: String,
55 description: String,
56 ) -> oneshot::Receiver<()> {
57 let (cancel_tx, cancel_rx) = oneshot::channel();
58
59 let config = SessionConfig {
60 ttl_secs: self.default_ttl.map(|d| d.as_secs()),
61 ..Default::default()
62 };
63
64 let session = match self.parent_session_id {
65 Some(parent_id) => Session::new_subagent(parent_id, &agent_type, &description, config),
66 None => {
67 let mut s = Session::new(config);
68 s.session_type = SessionType::Subagent {
69 agent_type,
70 description,
71 };
72 s
73 }
74 };
75
76 let session_id = SessionId::from(id.as_str());
77 let mut session = session;
78 session.id = session_id;
79 session.state = SessionState::Active;
80
81 if let Err(e) = self.persistence.save(&session).await {
82 warn!(session_id = %session_id, error = %e, "Failed to save task session on register");
83 }
84
85 let mut runtime = self.runtime.write().await;
86 runtime.insert(
87 id,
88 TaskRuntime {
89 handle: None,
90 cancel_tx: Some(cancel_tx),
91 },
92 );
93
94 cancel_rx
95 }
96
97 pub async fn set_handle(&self, id: &str, handle: JoinHandle<()>) {
98 let mut runtime = self.runtime.write().await;
99 if let Some(rt) = runtime.get_mut(id) {
100 rt.handle = Some(handle);
101 }
102 }
103
104 pub async fn complete(&self, id: &str, result: AgentResult) {
105 let session_id = SessionId::from(id);
106
107 match self.persistence.load(&session_id).await {
108 Ok(Some(mut session)) => {
109 session.state = SessionState::Completed;
110
111 for msg in &result.messages {
112 let content: Vec<ContentBlock> = msg.content.clone();
113 let session_msg = match msg.role {
114 Role::User => SessionMessage::user(content),
115 Role::Assistant => SessionMessage::assistant(content),
116 };
117 session.add_message(session_msg);
118 }
119
120 if let Err(e) = self.persistence.save(&session).await {
121 warn!(session_id = %session_id, error = %e, "Failed to save completed task session");
122 }
123 }
124 Ok(None) => {
125 warn!(session_id = %session_id, "Task session not found on complete");
126 }
127 Err(e) => {
128 warn!(session_id = %session_id, error = %e, "Failed to load task session on complete");
129 }
130 }
131
132 let mut runtime = self.runtime.write().await;
133 runtime.remove(id);
134 }
135
136 pub async fn fail(&self, id: &str, error: String) {
137 let session_id = SessionId::from(id);
138
139 match self.persistence.load(&session_id).await {
140 Ok(Some(mut session)) => {
141 session.state = SessionState::Failed;
142 session.error = Some(error);
143 if let Err(e) = self.persistence.save(&session).await {
144 warn!(session_id = %session_id, error = %e, "Failed to save failed task session");
145 }
146 }
147 Ok(None) => {
148 warn!(session_id = %session_id, "Task session not found on fail");
149 }
150 Err(e) => {
151 warn!(session_id = %session_id, error = %e, "Failed to load task session on fail");
152 }
153 }
154
155 let mut runtime = self.runtime.write().await;
156 runtime.remove(id);
157 }
158
159 pub async fn cancel(&self, id: &str) -> bool {
160 let session_id = SessionId::from(id);
161
162 let cancelled = {
163 let mut runtime = self.runtime.write().await;
164 if let Some(rt) = runtime.get_mut(id) {
165 if let Some(tx) = rt.cancel_tx.take() {
166 let _ = tx.send(());
167 }
168 if let Some(handle) = rt.handle.take() {
169 handle.abort();
170 }
171 runtime.remove(id);
172 true
173 } else {
174 false
175 }
176 };
177
178 if cancelled {
179 match self.persistence.load(&session_id).await {
180 Ok(Some(mut session)) => {
181 session.state = SessionState::Cancelled;
182 if let Err(e) = self.persistence.save(&session).await {
183 warn!(session_id = %session_id, error = %e, "Failed to save cancelled task session");
184 }
185 }
186 Ok(None) => {
187 warn!(session_id = %session_id, "Task session not found on cancel");
188 }
189 Err(e) => {
190 warn!(session_id = %session_id, error = %e, "Failed to load task session on cancel");
191 }
192 }
193 }
194
195 cancelled
196 }
197
198 pub async fn get_status(&self, id: &str) -> Option<SessionState> {
199 let session_id = SessionId::from(id);
200 self.persistence
201 .load(&session_id)
202 .await
203 .ok()
204 .flatten()
205 .map(|s| s.state)
206 }
207
208 pub async fn get_result(
209 &self,
210 id: &str,
211 ) -> Option<(SessionState, Option<String>, Option<String>)> {
212 let session_id = SessionId::from(id);
213 self.persistence
214 .load(&session_id)
215 .await
216 .ok()
217 .flatten()
218 .map(|s| {
219 let text = s.messages.last().and_then(|m| {
220 m.content.iter().find_map(|c| match c {
221 ContentBlock::Text { text, .. } => Some(text.clone()),
222 _ => None,
223 })
224 });
225 (s.state, text, s.error)
226 })
227 }
228
229 pub async fn wait_for_completion(
230 &self,
231 id: &str,
232 timeout: Duration,
233 ) -> Option<(SessionState, Option<String>, Option<String>)> {
234 let deadline = std::time::Instant::now() + timeout;
235 let poll_interval = Duration::from_millis(100);
236
237 loop {
238 if let Some((state, output, error)) = self.get_result(id).await {
239 if state != SessionState::Active && state != SessionState::WaitingForTools {
240 return Some((state, output, error));
241 }
242 } else {
243 return None;
244 }
245
246 if std::time::Instant::now() >= deadline {
247 return self.get_result(id).await;
248 }
249
250 tokio::time::sleep(poll_interval).await;
251 }
252 }
253
254 pub async fn list_running(&self) -> Vec<(String, String, Duration)> {
255 let runtime = self.runtime.read().await;
256 let mut result = Vec::new();
257
258 for id in runtime.keys() {
259 let session_id = SessionId::from(id.as_str());
260 if let Ok(Some(session)) = self.persistence.load(&session_id).await
261 && session.is_running()
262 {
263 let description = match &session.session_type {
264 SessionType::Subagent { description, .. } => description.clone(),
265 _ => String::new(),
266 };
267 let elapsed = (chrono::Utc::now() - session.created_at)
268 .to_std()
269 .unwrap_or_default();
270 result.push((id.clone(), description, elapsed));
271 }
272 }
273
274 result
275 }
276
277 pub async fn cleanup_completed(&self) -> usize {
278 self.persistence.cleanup_expired().await.unwrap_or(0)
279 }
280
281 pub async fn running_count(&self) -> usize {
282 self.runtime.read().await.len()
283 }
284
285 pub async fn save_messages(&self, id: &str, messages: Vec<Message>) {
286 let session_id = SessionId::from(id);
287
288 match self.persistence.load(&session_id).await {
289 Ok(Some(mut session)) => {
290 for msg in messages {
291 let content: Vec<ContentBlock> = msg.content;
292 let session_msg = match msg.role {
293 Role::User => SessionMessage::user(content),
294 Role::Assistant => SessionMessage::assistant(content),
295 };
296 session.add_message(session_msg);
297 }
298 if let Err(e) = self.persistence.save(&session).await {
299 warn!(session_id = %session_id, error = %e, "Failed to save task messages");
300 }
301 }
302 Ok(None) => {
303 warn!(session_id = %session_id, "Task session not found when saving messages");
304 }
305 Err(e) => {
306 warn!(session_id = %session_id, error = %e, "Failed to load task session when saving messages");
307 }
308 }
309 }
310
311 pub async fn get_messages(&self, id: &str) -> Option<Vec<Message>> {
312 let session_id = SessionId::from(id);
313 self.persistence
314 .load(&session_id)
315 .await
316 .ok()
317 .flatten()
318 .map(|s| s.to_api_messages())
319 }
320
321 pub async fn get_session(&self, id: &str) -> Option<Session> {
322 let session_id = SessionId::from(id);
323 self.persistence.load(&session_id).await.ok().flatten()
324 }
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::agent::AgentState;
331 use crate::session::MemoryPersistence;
332 use crate::types::{StopReason, Usage};
333
334 fn test_registry() -> TaskRegistry {
335 TaskRegistry::new(Arc::new(MemoryPersistence::new()))
336 }
337
338 const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000001";
340 const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000002";
341 const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000003";
342 const TASK_4_UUID: &str = "00000000-0000-0000-0000-000000000004";
343
344 fn mock_result() -> AgentResult {
345 AgentResult {
346 text: "Test result".to_string(),
347 usage: Usage::default(),
348 tool_calls: 0,
349 iterations: 1,
350 stop_reason: StopReason::EndTurn,
351 state: AgentState::Completed,
352 metrics: Default::default(),
353 session_id: "test-session".to_string(),
354 structured_output: None,
355 messages: Vec::new(),
356 uuid: "test-uuid".to_string(),
357 }
358 }
359
360 #[tokio::test]
361 async fn test_register_and_complete() {
362 let registry = test_registry();
363 let _cancel_rx = registry
364 .register(TASK_1_UUID.into(), "Explore".into(), "Test task".into())
365 .await;
366
367 assert_eq!(
368 registry.get_status(TASK_1_UUID).await,
369 Some(SessionState::Active)
370 );
371
372 registry.complete(TASK_1_UUID, mock_result()).await;
373
374 let (status, _, _) = registry.get_result(TASK_1_UUID).await.unwrap();
375 assert_eq!(status, SessionState::Completed);
376 }
377
378 #[tokio::test]
379 async fn test_fail_task() {
380 let registry = test_registry();
381 registry
382 .register(TASK_2_UUID.into(), "Explore".into(), "Failing task".into())
383 .await;
384
385 registry
386 .fail(TASK_2_UUID, "Something went wrong".into())
387 .await;
388
389 let (status, _, error) = registry.get_result(TASK_2_UUID).await.unwrap();
390 assert_eq!(status, SessionState::Failed);
391 assert_eq!(error, Some("Something went wrong".to_string()));
392 }
393
394 #[tokio::test]
395 async fn test_cancel_task() {
396 let registry = test_registry();
397 registry
398 .register(
399 TASK_3_UUID.into(),
400 "Explore".into(),
401 "Cancellable task".into(),
402 )
403 .await;
404
405 assert!(registry.cancel(TASK_3_UUID).await);
406 assert_eq!(
407 registry.get_status(TASK_3_UUID).await,
408 Some(SessionState::Cancelled)
409 );
410
411 assert!(!registry.cancel(TASK_3_UUID).await);
412 }
413
414 #[tokio::test]
415 async fn test_not_found() {
416 let registry = test_registry();
417 assert!(registry.get_status("nonexistent").await.is_none());
418 assert!(registry.get_result("nonexistent").await.is_none());
419 }
420
421 #[tokio::test]
422 async fn test_messages() {
423 let registry = test_registry();
424 registry
425 .register(TASK_4_UUID.into(), "Explore".into(), "Message test".into())
426 .await;
427
428 let messages = vec![
429 Message::user("Hello"),
430 Message {
431 role: Role::Assistant,
432 content: vec![ContentBlock::text("Hi there!")],
433 },
434 ];
435
436 registry.save_messages(TASK_4_UUID, messages).await;
437
438 let loaded = registry.get_messages(TASK_4_UUID).await.unwrap();
439 assert_eq!(loaded.len(), 2);
440 }
441}