Skip to main content

claude_agent/agent/
task_registry.rs

1//! Task registry for managing background agent tasks with Session-based persistence.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{RwLock, oneshot};
8use tokio::task::JoinHandle;
9use tracing::warn;
10
11use crate::session::{
12    Persistence, Session, SessionConfig, SessionId, SessionMessage, SessionState, SessionType,
13};
14use crate::types::{ContentBlock, Message, Role};
15
16use super::AgentResult;
17
18struct TaskRuntime {
19    handle: Option<JoinHandle<()>>,
20    cancel_tx: Option<oneshot::Sender<()>>,
21}
22
23#[derive(Clone)]
24pub struct TaskRegistry {
25    runtime: Arc<RwLock<HashMap<String, TaskRuntime>>>,
26    persistence: Arc<dyn Persistence>,
27    parent_session_id: Option<SessionId>,
28    default_ttl: Option<Duration>,
29}
30
31impl TaskRegistry {
32    pub fn new(persistence: Arc<dyn Persistence>) -> Self {
33        Self {
34            runtime: Arc::new(RwLock::new(HashMap::new())),
35            persistence,
36            parent_session_id: None,
37            default_ttl: Some(Duration::from_secs(3600)),
38        }
39    }
40
41    pub fn parent_session(mut self, parent_id: SessionId) -> Self {
42        self.parent_session_id = Some(parent_id);
43        self
44    }
45
46    pub fn ttl(mut self, ttl: Duration) -> Self {
47        self.default_ttl = Some(ttl);
48        self
49    }
50
51    pub async fn register(
52        &self,
53        id: String,
54        agent_type: String,
55        description: String,
56    ) -> oneshot::Receiver<()> {
57        let (cancel_tx, cancel_rx) = oneshot::channel();
58
59        let config = SessionConfig {
60            ttl_secs: self.default_ttl.map(|d| d.as_secs()),
61            ..Default::default()
62        };
63
64        let session = match self.parent_session_id {
65            Some(parent_id) => Session::new_subagent(parent_id, &agent_type, &description, config),
66            None => {
67                let mut s = Session::new(config);
68                s.session_type = SessionType::Subagent {
69                    agent_type,
70                    description,
71                };
72                s
73            }
74        };
75
76        let session_id = SessionId::from(id.as_str());
77        let mut session = session;
78        session.id = session_id;
79        session.state = SessionState::Active;
80
81        if let Err(e) = self.persistence.save(&session).await {
82            warn!(session_id = %session_id, error = %e, "Failed to save task session on register");
83        }
84
85        let mut runtime = self.runtime.write().await;
86        runtime.insert(
87            id,
88            TaskRuntime {
89                handle: None,
90                cancel_tx: Some(cancel_tx),
91            },
92        );
93
94        cancel_rx
95    }
96
97    pub async fn set_handle(&self, id: &str, handle: JoinHandle<()>) {
98        let mut runtime = self.runtime.write().await;
99        if let Some(rt) = runtime.get_mut(id) {
100            rt.handle = Some(handle);
101        }
102    }
103
104    pub async fn complete(&self, id: &str, result: AgentResult) {
105        let session_id = SessionId::from(id);
106
107        match self.persistence.load(&session_id).await {
108            Ok(Some(mut session)) => {
109                session.state = SessionState::Completed;
110
111                for msg in &result.messages {
112                    let content: Vec<ContentBlock> = msg.content.clone();
113                    let session_msg = match msg.role {
114                        Role::User => SessionMessage::user(content),
115                        Role::Assistant => SessionMessage::assistant(content),
116                    };
117                    session.add_message(session_msg);
118                }
119
120                if let Err(e) = self.persistence.save(&session).await {
121                    warn!(session_id = %session_id, error = %e, "Failed to save completed task session");
122                }
123            }
124            Ok(None) => {
125                warn!(session_id = %session_id, "Task session not found on complete");
126            }
127            Err(e) => {
128                warn!(session_id = %session_id, error = %e, "Failed to load task session on complete");
129            }
130        }
131
132        let mut runtime = self.runtime.write().await;
133        runtime.remove(id);
134    }
135
136    pub async fn fail(&self, id: &str, error: String) {
137        let session_id = SessionId::from(id);
138
139        match self.persistence.load(&session_id).await {
140            Ok(Some(mut session)) => {
141                session.state = SessionState::Failed;
142                session.error = Some(error);
143                if let Err(e) = self.persistence.save(&session).await {
144                    warn!(session_id = %session_id, error = %e, "Failed to save failed task session");
145                }
146            }
147            Ok(None) => {
148                warn!(session_id = %session_id, "Task session not found on fail");
149            }
150            Err(e) => {
151                warn!(session_id = %session_id, error = %e, "Failed to load task session on fail");
152            }
153        }
154
155        let mut runtime = self.runtime.write().await;
156        runtime.remove(id);
157    }
158
159    pub async fn cancel(&self, id: &str) -> bool {
160        let session_id = SessionId::from(id);
161
162        let cancelled = {
163            let mut runtime = self.runtime.write().await;
164            if let Some(rt) = runtime.get_mut(id) {
165                if let Some(tx) = rt.cancel_tx.take() {
166                    let _ = tx.send(());
167                }
168                if let Some(handle) = rt.handle.take() {
169                    handle.abort();
170                }
171                runtime.remove(id);
172                true
173            } else {
174                false
175            }
176        };
177
178        if cancelled {
179            match self.persistence.load(&session_id).await {
180                Ok(Some(mut session)) => {
181                    session.state = SessionState::Cancelled;
182                    if let Err(e) = self.persistence.save(&session).await {
183                        warn!(session_id = %session_id, error = %e, "Failed to save cancelled task session");
184                    }
185                }
186                Ok(None) => {
187                    warn!(session_id = %session_id, "Task session not found on cancel");
188                }
189                Err(e) => {
190                    warn!(session_id = %session_id, error = %e, "Failed to load task session on cancel");
191                }
192            }
193        }
194
195        cancelled
196    }
197
198    pub async fn get_status(&self, id: &str) -> Option<SessionState> {
199        let session_id = SessionId::from(id);
200        self.persistence
201            .load(&session_id)
202            .await
203            .ok()
204            .flatten()
205            .map(|s| s.state)
206    }
207
208    pub async fn get_result(
209        &self,
210        id: &str,
211    ) -> Option<(SessionState, Option<String>, Option<String>)> {
212        let session_id = SessionId::from(id);
213        self.persistence
214            .load(&session_id)
215            .await
216            .ok()
217            .flatten()
218            .map(|s| {
219                let text = s.messages.last().and_then(|m| {
220                    m.content.iter().find_map(|c| match c {
221                        ContentBlock::Text { text, .. } => Some(text.clone()),
222                        _ => None,
223                    })
224                });
225                (s.state, text, s.error)
226            })
227    }
228
229    pub async fn wait_for_completion(
230        &self,
231        id: &str,
232        timeout: Duration,
233    ) -> Option<(SessionState, Option<String>, Option<String>)> {
234        let deadline = std::time::Instant::now() + timeout;
235        let poll_interval = Duration::from_millis(100);
236
237        loop {
238            if let Some((state, output, error)) = self.get_result(id).await {
239                if state != SessionState::Active && state != SessionState::WaitingForTools {
240                    return Some((state, output, error));
241                }
242            } else {
243                return None;
244            }
245
246            if std::time::Instant::now() >= deadline {
247                return self.get_result(id).await;
248            }
249
250            tokio::time::sleep(poll_interval).await;
251        }
252    }
253
254    pub async fn list_running(&self) -> Vec<(String, String, Duration)> {
255        let runtime = self.runtime.read().await;
256        let mut result = Vec::new();
257
258        for id in runtime.keys() {
259            let session_id = SessionId::from(id.as_str());
260            if let Ok(Some(session)) = self.persistence.load(&session_id).await
261                && session.is_running()
262            {
263                let description = match &session.session_type {
264                    SessionType::Subagent { description, .. } => description.clone(),
265                    _ => String::new(),
266                };
267                let elapsed = (chrono::Utc::now() - session.created_at)
268                    .to_std()
269                    .unwrap_or_default();
270                result.push((id.clone(), description, elapsed));
271            }
272        }
273
274        result
275    }
276
277    pub async fn cleanup_completed(&self) -> usize {
278        self.persistence.cleanup_expired().await.unwrap_or(0)
279    }
280
281    pub async fn running_count(&self) -> usize {
282        self.runtime.read().await.len()
283    }
284
285    pub async fn save_messages(&self, id: &str, messages: Vec<Message>) {
286        let session_id = SessionId::from(id);
287
288        match self.persistence.load(&session_id).await {
289            Ok(Some(mut session)) => {
290                for msg in messages {
291                    let content: Vec<ContentBlock> = msg.content;
292                    let session_msg = match msg.role {
293                        Role::User => SessionMessage::user(content),
294                        Role::Assistant => SessionMessage::assistant(content),
295                    };
296                    session.add_message(session_msg);
297                }
298                if let Err(e) = self.persistence.save(&session).await {
299                    warn!(session_id = %session_id, error = %e, "Failed to save task messages");
300                }
301            }
302            Ok(None) => {
303                warn!(session_id = %session_id, "Task session not found when saving messages");
304            }
305            Err(e) => {
306                warn!(session_id = %session_id, error = %e, "Failed to load task session when saving messages");
307            }
308        }
309    }
310
311    pub async fn get_messages(&self, id: &str) -> Option<Vec<Message>> {
312        let session_id = SessionId::from(id);
313        self.persistence
314            .load(&session_id)
315            .await
316            .ok()
317            .flatten()
318            .map(|s| s.to_api_messages())
319    }
320
321    pub async fn get_session(&self, id: &str) -> Option<Session> {
322        let session_id = SessionId::from(id);
323        self.persistence.load(&session_id).await.ok().flatten()
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::agent::AgentState;
331    use crate::session::MemoryPersistence;
332    use crate::types::{StopReason, Usage};
333
334    fn test_registry() -> TaskRegistry {
335        TaskRegistry::new(Arc::new(MemoryPersistence::new()))
336    }
337
338    // Use valid UUIDs for tests to ensure consistent session IDs
339    const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000001";
340    const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000002";
341    const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000003";
342    const TASK_4_UUID: &str = "00000000-0000-0000-0000-000000000004";
343
344    fn mock_result() -> AgentResult {
345        AgentResult {
346            text: "Test result".to_string(),
347            usage: Usage::default(),
348            tool_calls: 0,
349            iterations: 1,
350            stop_reason: StopReason::EndTurn,
351            state: AgentState::Completed,
352            metrics: Default::default(),
353            session_id: "test-session".to_string(),
354            structured_output: None,
355            messages: Vec::new(),
356            uuid: "test-uuid".to_string(),
357        }
358    }
359
360    #[tokio::test]
361    async fn test_register_and_complete() {
362        let registry = test_registry();
363        let _cancel_rx = registry
364            .register(TASK_1_UUID.into(), "Explore".into(), "Test task".into())
365            .await;
366
367        assert_eq!(
368            registry.get_status(TASK_1_UUID).await,
369            Some(SessionState::Active)
370        );
371
372        registry.complete(TASK_1_UUID, mock_result()).await;
373
374        let (status, _, _) = registry.get_result(TASK_1_UUID).await.unwrap();
375        assert_eq!(status, SessionState::Completed);
376    }
377
378    #[tokio::test]
379    async fn test_fail_task() {
380        let registry = test_registry();
381        registry
382            .register(TASK_2_UUID.into(), "Explore".into(), "Failing task".into())
383            .await;
384
385        registry
386            .fail(TASK_2_UUID, "Something went wrong".into())
387            .await;
388
389        let (status, _, error) = registry.get_result(TASK_2_UUID).await.unwrap();
390        assert_eq!(status, SessionState::Failed);
391        assert_eq!(error, Some("Something went wrong".to_string()));
392    }
393
394    #[tokio::test]
395    async fn test_cancel_task() {
396        let registry = test_registry();
397        registry
398            .register(
399                TASK_3_UUID.into(),
400                "Explore".into(),
401                "Cancellable task".into(),
402            )
403            .await;
404
405        assert!(registry.cancel(TASK_3_UUID).await);
406        assert_eq!(
407            registry.get_status(TASK_3_UUID).await,
408            Some(SessionState::Cancelled)
409        );
410
411        assert!(!registry.cancel(TASK_3_UUID).await);
412    }
413
414    #[tokio::test]
415    async fn test_not_found() {
416        let registry = test_registry();
417        assert!(registry.get_status("nonexistent").await.is_none());
418        assert!(registry.get_result("nonexistent").await.is_none());
419    }
420
421    #[tokio::test]
422    async fn test_messages() {
423        let registry = test_registry();
424        registry
425            .register(TASK_4_UUID.into(), "Explore".into(), "Message test".into())
426            .await;
427
428        let messages = vec![
429            Message::user("Hello"),
430            Message {
431                role: Role::Assistant,
432                content: vec![ContentBlock::text("Hi there!")],
433            },
434        ];
435
436        registry.save_messages(TASK_4_UUID, messages).await;
437
438        let loaded = registry.get_messages(TASK_4_UUID).await.unwrap();
439        assert_eq!(loaded.len(), 2);
440    }
441}