1use crate::agent::AgentEvent;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, VecDeque};
11use tokio::sync::RwLock;
12use tokio_util::sync::CancellationToken;
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15#[serde(rename_all = "snake_case")]
16#[non_exhaustive]
17pub enum SubagentStatus {
18 Running,
19 Completed,
20 Failed,
21 Cancelled,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SubagentProgressEntry {
26 pub timestamp_ms: u64,
27 pub status: String,
28 pub metadata: serde_json::Value,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct SubagentTaskSnapshot {
33 pub task_id: String,
34 pub parent_session_id: String,
35 pub child_session_id: String,
36 pub agent: String,
37 pub description: String,
38 pub status: SubagentStatus,
39 pub started_ms: u64,
40 pub updated_ms: u64,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 pub finished_ms: Option<u64>,
43 #[serde(skip_serializing_if = "Option::is_none")]
44 pub output: Option<String>,
45 #[serde(skip_serializing_if = "Option::is_none")]
46 pub success: Option<bool>,
47 pub progress: Vec<SubagentProgressEntry>,
48}
49
50#[derive(Debug, Default)]
51pub struct InMemorySubagentTaskTracker {
52 tasks: RwLock<HashMap<String, SubagentTaskSnapshot>>,
53 cancellers: RwLock<HashMap<String, CancellationToken>>,
54 terminal_order: RwLock<VecDeque<String>>,
59 max_terminal_tasks: Option<usize>,
62}
63
64impl InMemorySubagentTaskTracker {
65 pub fn new() -> Self {
66 Self::default()
67 }
68
69 pub fn with_max_terminal_tasks(max: usize) -> Self {
72 Self {
73 tasks: RwLock::new(HashMap::new()),
74 cancellers: RwLock::new(HashMap::new()),
75 terminal_order: RwLock::new(VecDeque::new()),
76 max_terminal_tasks: Some(max),
77 }
78 }
79
80 async fn mark_terminal_and_evict(&self, task_id: &str) {
85 let cap = match self.max_terminal_tasks {
86 Some(n) => n,
87 None => return,
88 };
89 let mut order = self.terminal_order.write().await;
97 let mut tasks = self.tasks.write().await;
98 let mut cancellers = self.cancellers.write().await;
99 if !order.iter().any(|id| id == task_id) {
100 order.push_back(task_id.to_string());
101 }
102 while order.len() > cap {
103 if let Some(victim) = order.pop_front() {
104 tasks.remove(&victim);
105 cancellers.remove(&victim);
106 }
107 }
108 }
109
110 pub async fn register_canceller(&self, task_id: &str, token: CancellationToken) {
114 self.cancellers
115 .write()
116 .await
117 .insert(task_id.to_string(), token);
118 }
119
120 pub async fn clear_canceller(&self, task_id: &str) {
121 self.cancellers.write().await.remove(task_id);
122 }
123
124 pub async fn cancel(&self, task_id: &str) -> bool {
130 let token = self.cancellers.write().await.remove(task_id);
131 match token {
132 Some(token) => {
133 token.cancel();
134 let now = now_ms();
135 let transitioned = {
136 let mut tasks = self.tasks.write().await;
137 if let Some(entry) = tasks.get_mut(task_id) {
138 if entry.status == SubagentStatus::Running {
139 entry.status = SubagentStatus::Cancelled;
140 entry.updated_ms = now;
141 true
142 } else {
143 false
144 }
145 } else {
146 false
147 }
148 };
149 if transitioned {
150 self.mark_terminal_and_evict(task_id).await;
151 }
152 true
153 }
154 None => false,
155 }
156 }
157
158 pub async fn record_event(&self, event: &AgentEvent) {
160 match event {
161 AgentEvent::SubagentStart {
162 task_id,
163 session_id,
164 parent_session_id,
165 agent,
166 description,
167 } => {
168 let now = now_ms();
169 let mut tasks = self.tasks.write().await;
170 tasks
171 .entry(task_id.clone())
172 .and_modify(|task| {
173 task.parent_session_id = parent_session_id.clone();
176 task.child_session_id = session_id.clone();
177 task.agent = agent.clone();
178 task.description = description.clone();
179 task.updated_ms = now;
180 })
181 .or_insert_with(|| SubagentTaskSnapshot {
182 task_id: task_id.clone(),
183 parent_session_id: parent_session_id.clone(),
184 child_session_id: session_id.clone(),
185 agent: agent.clone(),
186 description: description.clone(),
187 status: SubagentStatus::Running,
188 started_ms: now,
189 updated_ms: now,
190 finished_ms: None,
191 output: None,
192 success: None,
193 progress: Vec::new(),
194 });
195 }
196 AgentEvent::SubagentProgress {
197 task_id,
198 session_id,
199 status,
200 metadata,
201 } => {
202 let now = now_ms();
203 let mut tasks = self.tasks.write().await;
204 let entry = tasks
205 .entry(task_id.clone())
206 .or_insert_with(|| SubagentTaskSnapshot {
207 task_id: task_id.clone(),
208 parent_session_id: String::new(),
209 child_session_id: session_id.clone(),
210 agent: String::new(),
211 description: String::new(),
212 status: SubagentStatus::Running,
213 started_ms: now,
214 updated_ms: now,
215 finished_ms: None,
216 output: None,
217 success: None,
218 progress: Vec::new(),
219 });
220 entry.updated_ms = now;
221 entry.progress.push(SubagentProgressEntry {
222 timestamp_ms: now,
223 status: status.clone(),
224 metadata: metadata.clone(),
225 });
226 }
227 AgentEvent::SubagentEnd {
228 task_id,
229 session_id,
230 agent,
231 output,
232 success,
233 } => {
234 let now = now_ms();
235 let was_running = {
236 let mut tasks = self.tasks.write().await;
237 let entry =
238 tasks
239 .entry(task_id.clone())
240 .or_insert_with(|| SubagentTaskSnapshot {
241 task_id: task_id.clone(),
242 parent_session_id: String::new(),
243 child_session_id: session_id.clone(),
244 agent: agent.clone(),
245 description: String::new(),
246 status: SubagentStatus::Running,
247 started_ms: now,
248 updated_ms: now,
249 finished_ms: None,
250 output: None,
251 success: None,
252 progress: Vec::new(),
253 });
254 let was_running = entry.status == SubagentStatus::Running;
255 if entry.status != SubagentStatus::Cancelled {
259 entry.status = if *success {
260 SubagentStatus::Completed
261 } else {
262 SubagentStatus::Failed
263 };
264 }
265 entry.updated_ms = now;
266 entry.finished_ms = Some(now);
267 entry.output = Some(output.clone());
268 entry.success = Some(*success);
269 was_running
270 };
271 if was_running {
272 self.mark_terminal_and_evict(task_id).await;
273 }
274 }
275 _ => {}
276 }
277 }
278
279 pub async fn get(&self, task_id: &str) -> Option<SubagentTaskSnapshot> {
280 self.tasks.read().await.get(task_id).cloned()
281 }
282
283 pub async fn list(&self) -> Vec<SubagentTaskSnapshot> {
284 let mut tasks = self
285 .tasks
286 .read()
287 .await
288 .values()
289 .cloned()
290 .collect::<Vec<_>>();
291 tasks.sort_by_key(|task| task.started_ms);
292 tasks
293 }
294
295 pub async fn list_pending(&self) -> Vec<SubagentTaskSnapshot> {
296 self.list()
297 .await
298 .into_iter()
299 .filter(|task| task.status == SubagentStatus::Running)
300 .collect()
301 }
302
303 pub async fn list_for_parent(&self, parent_session_id: &str) -> Vec<SubagentTaskSnapshot> {
304 self.list()
305 .await
306 .into_iter()
307 .filter(|task| task.parent_session_id == parent_session_id)
308 .collect()
309 }
310
311 pub async fn replace_snapshots(&self, snapshots: Vec<SubagentTaskSnapshot>) {
323 let mut map = HashMap::with_capacity(snapshots.len());
324 for snap in snapshots {
325 map.insert(snap.task_id.clone(), snap);
326 }
327 *self.tasks.write().await = map;
328 self.cancellers.write().await.clear();
330 }
331}
332
333fn now_ms() -> u64 {
334 use std::time::{SystemTime, UNIX_EPOCH};
335 SystemTime::now()
336 .duration_since(UNIX_EPOCH)
337 .map(|d| d.as_millis() as u64)
338 .unwrap_or(0)
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 fn start_event(task_id: &str, parent: &str, child: &str) -> AgentEvent {
346 AgentEvent::SubagentStart {
347 task_id: task_id.to_string(),
348 session_id: child.to_string(),
349 parent_session_id: parent.to_string(),
350 agent: "explore".to_string(),
351 description: "find things".to_string(),
352 }
353 }
354
355 fn progress_event(task_id: &str, child: &str, status: &str) -> AgentEvent {
356 AgentEvent::SubagentProgress {
357 task_id: task_id.to_string(),
358 session_id: child.to_string(),
359 status: status.to_string(),
360 metadata: serde_json::json!({}),
361 }
362 }
363
364 fn end_event(task_id: &str, child: &str, success: bool) -> AgentEvent {
365 AgentEvent::SubagentEnd {
366 task_id: task_id.to_string(),
367 session_id: child.to_string(),
368 agent: "explore".to_string(),
369 output: "done".to_string(),
370 success,
371 }
372 }
373
374 #[tokio::test]
375 async fn lifecycle_start_progress_end_transitions_status() {
376 let tracker = InMemorySubagentTaskTracker::new();
377
378 tracker
379 .record_event(&start_event("task-1", "parent", "child"))
380 .await;
381 let snap = tracker.get("task-1").await.unwrap();
382 assert_eq!(snap.status, SubagentStatus::Running);
383 assert_eq!(snap.parent_session_id, "parent");
384 assert_eq!(snap.child_session_id, "child");
385 assert!(snap.finished_ms.is_none());
386
387 tracker
388 .record_event(&progress_event("task-1", "child", "tool_completed: bash"))
389 .await;
390 let snap = tracker.get("task-1").await.unwrap();
391 assert_eq!(snap.status, SubagentStatus::Running);
392 assert_eq!(snap.progress.len(), 1);
393
394 tracker
395 .record_event(&end_event("task-1", "child", true))
396 .await;
397 let snap = tracker.get("task-1").await.unwrap();
398 assert_eq!(snap.status, SubagentStatus::Completed);
399 assert_eq!(snap.success, Some(true));
400 assert_eq!(snap.output.as_deref(), Some("done"));
401 assert!(snap.finished_ms.is_some());
402 }
403
404 #[tokio::test]
405 async fn failed_end_event_marks_status_failed() {
406 let tracker = InMemorySubagentTaskTracker::new();
407 tracker
408 .record_event(&start_event("task-2", "parent", "child"))
409 .await;
410 tracker
411 .record_event(&end_event("task-2", "child", false))
412 .await;
413 let snap = tracker.get("task-2").await.unwrap();
414 assert_eq!(snap.status, SubagentStatus::Failed);
415 assert_eq!(snap.success, Some(false));
416 }
417
418 #[tokio::test]
419 async fn pending_list_excludes_completed_tasks() {
420 let tracker = InMemorySubagentTaskTracker::new();
421 tracker
422 .record_event(&start_event("task-a", "parent", "child-a"))
423 .await;
424 tracker
425 .record_event(&start_event("task-b", "parent", "child-b"))
426 .await;
427 tracker
428 .record_event(&end_event("task-a", "child-a", true))
429 .await;
430
431 let pending = tracker.list_pending().await;
432 assert_eq!(pending.len(), 1);
433 assert_eq!(pending[0].task_id, "task-b");
434 }
435
436 #[tokio::test]
437 async fn list_for_parent_filters_by_session() {
438 let tracker = InMemorySubagentTaskTracker::new();
439 tracker
440 .record_event(&start_event("task-a", "session-1", "child-a"))
441 .await;
442 tracker
443 .record_event(&start_event("task-b", "session-2", "child-b"))
444 .await;
445
446 let mine = tracker.list_for_parent("session-1").await;
447 assert_eq!(mine.len(), 1);
448 assert_eq!(mine[0].task_id, "task-a");
449 }
450
451 #[tokio::test]
452 async fn end_before_start_still_records_terminal_state() {
453 let tracker = InMemorySubagentTaskTracker::new();
454 tracker
455 .record_event(&end_event("task-late", "child", true))
456 .await;
457 let snap = tracker.get("task-late").await.unwrap();
458 assert_eq!(snap.status, SubagentStatus::Completed);
459 }
460
461 #[tokio::test]
462 async fn non_subagent_events_are_ignored() {
463 let tracker = InMemorySubagentTaskTracker::new();
464 tracker
465 .record_event(&AgentEvent::TextDelta {
466 text: "ignore me".to_string(),
467 })
468 .await;
469 assert!(tracker.list().await.is_empty());
470 }
471
472 #[tokio::test]
473 async fn cancel_fires_token_and_marks_snapshot_cancelled() {
474 let tracker = InMemorySubagentTaskTracker::new();
475 tracker
476 .record_event(&start_event("task-c", "parent", "child"))
477 .await;
478
479 let token = CancellationToken::new();
480 tracker.register_canceller("task-c", token.clone()).await;
481 assert!(!token.is_cancelled());
482
483 let fired = tracker.cancel("task-c").await;
484 assert!(fired, "cancel should report success");
485 assert!(token.is_cancelled(), "registered token should be triggered");
486
487 let snap = tracker.get("task-c").await.unwrap();
488 assert_eq!(snap.status, SubagentStatus::Cancelled);
489 }
490
491 #[tokio::test]
492 async fn cancel_returns_false_for_unknown_task() {
493 let tracker = InMemorySubagentTaskTracker::new();
494 assert!(!tracker.cancel("task-does-not-exist").await);
495 }
496
497 #[tokio::test]
498 async fn late_subagent_end_does_not_downgrade_cancelled_status() {
499 let tracker = InMemorySubagentTaskTracker::new();
500 tracker
501 .record_event(&start_event("task-d", "parent", "child"))
502 .await;
503 let token = CancellationToken::new();
504 tracker.register_canceller("task-d", token).await;
505 assert!(tracker.cancel("task-d").await);
506
507 tracker
510 .record_event(&end_event("task-d", "child", false))
511 .await;
512 let snap = tracker.get("task-d").await.unwrap();
513 assert_eq!(snap.status, SubagentStatus::Cancelled);
514 assert!(snap.finished_ms.is_some());
515 assert_eq!(snap.success, Some(false));
516 }
517
518 #[tokio::test]
519 async fn clear_canceller_disarms_future_cancel_calls() {
520 let tracker = InMemorySubagentTaskTracker::new();
521 tracker
522 .record_event(&start_event("task-e", "parent", "child"))
523 .await;
524 let token = CancellationToken::new();
525 tracker.register_canceller("task-e", token.clone()).await;
526 tracker.clear_canceller("task-e").await;
527
528 assert!(!tracker.cancel("task-e").await);
529 assert!(!token.is_cancelled());
530 }
531
532 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
533 async fn concurrent_record_and_cancel_under_terminal_cap_does_not_deadlock() {
534 let tracker = std::sync::Arc::new(InMemorySubagentTaskTracker::with_max_terminal_tasks(8));
538 let mut handles = Vec::new();
539 for i in 0..60 {
540 let t = std::sync::Arc::clone(&tracker);
541 handles.push(tokio::spawn(async move {
542 let task_id = format!("t-{i}");
543 let child = format!("c-{i}");
544 t.record_event(&start_event(&task_id, "parent", &child))
545 .await;
546 if i % 2 == 0 {
547 t.register_canceller(&task_id, CancellationToken::new())
548 .await;
549 let _ = t.cancel(&task_id).await;
550 } else {
551 t.record_event(&end_event(&task_id, &child, true)).await;
552 }
553 }));
554 }
555 for h in handles {
556 h.await.unwrap();
557 }
558 let terminal = tracker
560 .list()
561 .await
562 .into_iter()
563 .filter(|t| t.status != SubagentStatus::Running)
564 .count();
565 assert!(
566 terminal <= 8,
567 "terminal cap must hold under load, got {terminal}"
568 );
569 }
570
571 #[tokio::test]
572 async fn max_terminal_tasks_evicts_oldest_completed_only() {
573 let tracker = InMemorySubagentTaskTracker::with_max_terminal_tasks(2);
574
575 for i in 0..3 {
577 let task_id = format!("done-{i}");
578 tracker
579 .record_event(&start_event(&task_id, "parent", "child"))
580 .await;
581 tracker
582 .record_event(&end_event(&task_id, "child", true))
583 .await;
584 }
585
586 let list = tracker.list().await;
588 let ids: Vec<&str> = list.iter().map(|t| t.task_id.as_str()).collect();
589 assert_eq!(ids.len(), 2);
590 assert!(ids.contains(&"done-1"));
591 assert!(ids.contains(&"done-2"));
592 assert!(
593 !ids.contains(&"done-0"),
594 "oldest terminal entry must be evicted"
595 );
596 }
597
598 #[tokio::test]
599 async fn max_terminal_tasks_never_evicts_running_tasks() {
600 let tracker = InMemorySubagentTaskTracker::with_max_terminal_tasks(1);
601
602 tracker
606 .record_event(&start_event("running", "parent", "child"))
607 .await;
608 for i in 0..3 {
609 let task_id = format!("done-{i}");
610 tracker
611 .record_event(&start_event(&task_id, "parent", "child"))
612 .await;
613 tracker
614 .record_event(&end_event(&task_id, "child", true))
615 .await;
616 }
617
618 let list = tracker.list().await;
619 let ids: Vec<&str> = list.iter().map(|t| t.task_id.as_str()).collect();
620 assert!(
621 ids.contains(&"running"),
622 "running task must never be evicted"
623 );
624 assert!(ids.contains(&"done-2"));
626 assert!(!ids.contains(&"done-0"));
627 assert!(!ids.contains(&"done-1"));
628 assert_eq!(list.len(), 2);
629 }
630
631 #[tokio::test]
632 async fn cancel_path_also_participates_in_terminal_cap() {
633 let tracker = InMemorySubagentTaskTracker::with_max_terminal_tasks(1);
634
635 for i in 0..2 {
637 let task_id = format!("c-{i}");
638 tracker
639 .record_event(&start_event(&task_id, "parent", "child"))
640 .await;
641 tracker
642 .register_canceller(&task_id, CancellationToken::new())
643 .await;
644 assert!(tracker.cancel(&task_id).await);
645 }
646
647 let list = tracker.list().await;
648 assert_eq!(list.len(), 1);
649 assert_eq!(list[0].task_id, "c-1");
650 }
651}