claude_agent/agent/
task_registry.rs

1//! Task registry for managing background agent tasks with Session-based persistence.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::time::Duration;
6
7use tokio::sync::{RwLock, oneshot};
8use tokio::task::JoinHandle;
9
10use crate::session::{
11    Persistence, Session, SessionConfig, SessionId, SessionMessage, SessionState, SessionType,
12};
13use crate::types::{ContentBlock, Message, Role};
14
15use super::AgentResult;
16
17struct TaskRuntime {
18    handle: Option<JoinHandle<()>>,
19    cancel_tx: Option<oneshot::Sender<()>>,
20}
21
22#[derive(Clone)]
23pub struct TaskRegistry {
24    runtime: Arc<RwLock<HashMap<String, TaskRuntime>>>,
25    persistence: Arc<dyn Persistence>,
26    parent_session_id: Option<SessionId>,
27    default_ttl: Option<Duration>,
28}
29
30impl TaskRegistry {
31    pub fn new(persistence: Arc<dyn Persistence>) -> Self {
32        Self {
33            runtime: Arc::new(RwLock::new(HashMap::new())),
34            persistence,
35            parent_session_id: None,
36            default_ttl: Some(Duration::from_secs(3600)),
37        }
38    }
39
40    pub fn with_parent_session(mut self, parent_id: SessionId) -> Self {
41        self.parent_session_id = Some(parent_id);
42        self
43    }
44
45    pub fn with_ttl(mut self, ttl: Duration) -> Self {
46        self.default_ttl = Some(ttl);
47        self
48    }
49
50    pub async fn register(
51        &self,
52        id: String,
53        agent_type: String,
54        description: String,
55    ) -> oneshot::Receiver<()> {
56        let (cancel_tx, cancel_rx) = oneshot::channel();
57
58        let config = SessionConfig {
59            ttl_secs: self.default_ttl.map(|d| d.as_secs()),
60            ..Default::default()
61        };
62
63        let session = match self.parent_session_id {
64            Some(parent_id) => Session::new_subagent(parent_id, &agent_type, &description, config),
65            None => {
66                let mut s = Session::new(config);
67                s.session_type = SessionType::Subagent {
68                    agent_type,
69                    description,
70                };
71                s
72            }
73        };
74
75        let session_id = SessionId::from(id.as_str());
76        let mut session = session;
77        session.id = session_id;
78        session.state = SessionState::Active;
79
80        let _ = self.persistence.save(&session).await;
81
82        let mut runtime = self.runtime.write().await;
83        runtime.insert(
84            id,
85            TaskRuntime {
86                handle: None,
87                cancel_tx: Some(cancel_tx),
88            },
89        );
90
91        cancel_rx
92    }
93
94    pub async fn set_handle(&self, id: &str, handle: JoinHandle<()>) {
95        let mut runtime = self.runtime.write().await;
96        if let Some(rt) = runtime.get_mut(id) {
97            rt.handle = Some(handle);
98        }
99    }
100
101    pub async fn complete(&self, id: &str, result: AgentResult) {
102        let session_id = SessionId::from(id);
103
104        if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
105            session.state = SessionState::Completed;
106
107            for msg in &result.messages {
108                let content: Vec<ContentBlock> = msg.content.clone();
109                let session_msg = match msg.role {
110                    Role::User => SessionMessage::user(content),
111                    Role::Assistant => SessionMessage::assistant(content),
112                };
113                session.add_message(session_msg);
114            }
115
116            let _ = self.persistence.save(&session).await;
117        }
118
119        let mut runtime = self.runtime.write().await;
120        runtime.remove(id);
121    }
122
123    pub async fn fail(&self, id: &str, error: String) {
124        let session_id = SessionId::from(id);
125
126        if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
127            session.state = SessionState::Failed;
128            session.error = Some(error);
129            let _ = self.persistence.save(&session).await;
130        }
131
132        let mut runtime = self.runtime.write().await;
133        runtime.remove(id);
134    }
135
136    pub async fn cancel(&self, id: &str) -> bool {
137        let session_id = SessionId::from(id);
138
139        let cancelled = {
140            let mut runtime = self.runtime.write().await;
141            if let Some(rt) = runtime.get_mut(id) {
142                if let Some(tx) = rt.cancel_tx.take() {
143                    let _ = tx.send(());
144                }
145                if let Some(handle) = rt.handle.take() {
146                    handle.abort();
147                }
148                runtime.remove(id);
149                true
150            } else {
151                false
152            }
153        };
154
155        if cancelled && let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
156            session.state = SessionState::Cancelled;
157            let _ = self.persistence.save(&session).await;
158        }
159
160        cancelled
161    }
162
163    pub async fn get_status(&self, id: &str) -> Option<SessionState> {
164        let session_id = SessionId::from(id);
165        self.persistence
166            .load(&session_id)
167            .await
168            .ok()
169            .flatten()
170            .map(|s| s.state)
171    }
172
173    pub async fn get_result(
174        &self,
175        id: &str,
176    ) -> Option<(SessionState, Option<String>, Option<String>)> {
177        let session_id = SessionId::from(id);
178        self.persistence
179            .load(&session_id)
180            .await
181            .ok()
182            .flatten()
183            .map(|s| {
184                let text = s.messages.last().and_then(|m| {
185                    m.content.iter().find_map(|c| match c {
186                        ContentBlock::Text { text, .. } => Some(text.clone()),
187                        _ => None,
188                    })
189                });
190                (s.state, text, s.error)
191            })
192    }
193
194    pub async fn wait_for_completion(
195        &self,
196        id: &str,
197        timeout: Duration,
198    ) -> Option<(SessionState, Option<String>, Option<String>)> {
199        let deadline = std::time::Instant::now() + timeout;
200        let poll_interval = Duration::from_millis(100);
201
202        loop {
203            if let Some((state, output, error)) = self.get_result(id).await {
204                if state != SessionState::Active && state != SessionState::WaitingForTools {
205                    return Some((state, output, error));
206                }
207            } else {
208                return None;
209            }
210
211            if std::time::Instant::now() >= deadline {
212                return self.get_result(id).await;
213            }
214
215            tokio::time::sleep(poll_interval).await;
216        }
217    }
218
219    pub async fn list_running(&self) -> Vec<(String, String, Duration)> {
220        let runtime = self.runtime.read().await;
221        let mut result = Vec::new();
222
223        for id in runtime.keys() {
224            let session_id = SessionId::from(id.as_str());
225            if let Ok(Some(session)) = self.persistence.load(&session_id).await
226                && session.is_running()
227            {
228                let description = match &session.session_type {
229                    SessionType::Subagent { description, .. } => description.clone(),
230                    _ => String::new(),
231                };
232                let elapsed = (chrono::Utc::now() - session.created_at)
233                    .to_std()
234                    .unwrap_or_default();
235                result.push((id.clone(), description, elapsed));
236            }
237        }
238
239        result
240    }
241
242    pub async fn cleanup_completed(&self) -> usize {
243        self.persistence.cleanup_expired().await.unwrap_or(0)
244    }
245
246    pub async fn running_count(&self) -> usize {
247        self.runtime.read().await.len()
248    }
249
250    pub async fn save_messages(&self, id: &str, messages: Vec<Message>) {
251        let session_id = SessionId::from(id);
252
253        if let Ok(Some(mut session)) = self.persistence.load(&session_id).await {
254            for msg in messages {
255                let content: Vec<ContentBlock> = msg.content;
256                let session_msg = match msg.role {
257                    Role::User => SessionMessage::user(content),
258                    Role::Assistant => SessionMessage::assistant(content),
259                };
260                session.add_message(session_msg);
261            }
262            let _ = self.persistence.save(&session).await;
263        }
264    }
265
266    pub async fn get_messages(&self, id: &str) -> Option<Vec<Message>> {
267        let session_id = SessionId::from(id);
268        self.persistence
269            .load(&session_id)
270            .await
271            .ok()
272            .flatten()
273            .map(|s| s.to_api_messages())
274    }
275
276    pub async fn get_session(&self, id: &str) -> Option<Session> {
277        let session_id = SessionId::from(id);
278        self.persistence.load(&session_id).await.ok().flatten()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::agent::AgentState;
286    use crate::session::MemoryPersistence;
287    use crate::types::{StopReason, Usage};
288
289    fn test_registry() -> TaskRegistry {
290        TaskRegistry::new(Arc::new(MemoryPersistence::new()))
291    }
292
293    // Use valid UUIDs for tests to ensure consistent session IDs
294    const TASK_1_UUID: &str = "00000000-0000-0000-0000-000000000001";
295    const TASK_2_UUID: &str = "00000000-0000-0000-0000-000000000002";
296    const TASK_3_UUID: &str = "00000000-0000-0000-0000-000000000003";
297    const TASK_4_UUID: &str = "00000000-0000-0000-0000-000000000004";
298
299    fn mock_result() -> AgentResult {
300        AgentResult {
301            text: "Test result".to_string(),
302            usage: Usage::default(),
303            tool_calls: 0,
304            iterations: 1,
305            stop_reason: StopReason::EndTurn,
306            state: AgentState::Completed,
307            metrics: Default::default(),
308            session_id: "test-session".to_string(),
309            structured_output: None,
310            messages: Vec::new(),
311            uuid: "test-uuid".to_string(),
312        }
313    }
314
315    #[tokio::test]
316    async fn test_register_and_complete() {
317        let registry = test_registry();
318        let _cancel_rx = registry
319            .register(TASK_1_UUID.into(), "Explore".into(), "Test task".into())
320            .await;
321
322        assert_eq!(
323            registry.get_status(TASK_1_UUID).await,
324            Some(SessionState::Active)
325        );
326
327        registry.complete(TASK_1_UUID, mock_result()).await;
328
329        let (status, _, _) = registry.get_result(TASK_1_UUID).await.unwrap();
330        assert_eq!(status, SessionState::Completed);
331    }
332
333    #[tokio::test]
334    async fn test_fail_task() {
335        let registry = test_registry();
336        registry
337            .register(TASK_2_UUID.into(), "Explore".into(), "Failing task".into())
338            .await;
339
340        registry
341            .fail(TASK_2_UUID, "Something went wrong".into())
342            .await;
343
344        let (status, _, error) = registry.get_result(TASK_2_UUID).await.unwrap();
345        assert_eq!(status, SessionState::Failed);
346        assert_eq!(error, Some("Something went wrong".to_string()));
347    }
348
349    #[tokio::test]
350    async fn test_cancel_task() {
351        let registry = test_registry();
352        registry
353            .register(
354                TASK_3_UUID.into(),
355                "Explore".into(),
356                "Cancellable task".into(),
357            )
358            .await;
359
360        assert!(registry.cancel(TASK_3_UUID).await);
361        assert_eq!(
362            registry.get_status(TASK_3_UUID).await,
363            Some(SessionState::Cancelled)
364        );
365
366        assert!(!registry.cancel(TASK_3_UUID).await);
367    }
368
369    #[tokio::test]
370    async fn test_not_found() {
371        let registry = test_registry();
372        assert!(registry.get_status("nonexistent").await.is_none());
373        assert!(registry.get_result("nonexistent").await.is_none());
374    }
375
376    #[tokio::test]
377    async fn test_messages() {
378        let registry = test_registry();
379        registry
380            .register(TASK_4_UUID.into(), "Explore".into(), "Message test".into())
381            .await;
382
383        let messages = vec![
384            Message::user("Hello"),
385            Message {
386                role: Role::Assistant,
387                content: vec![ContentBlock::text("Hi there!")],
388            },
389        ];
390
391        registry.save_messages(TASK_4_UUID, messages).await;
392
393        let loaded = registry.get_messages(TASK_4_UUID).await.unwrap();
394        assert_eq!(loaded.len(), 2);
395    }
396}