Skip to main content

agent_engine/runtime/
subagent.rs

1use std::collections::HashMap;
2use std::sync::{Arc, RwLock};
3use tokio::sync::{mpsc, oneshot};
4use serde_json::Value;
5
6// ── SubagentResult ───────────────────────────────────────────────────────────────
7
8#[derive(Debug)]
9pub struct SubagentResult {
10    pub text: String,
11    pub model: String,
12    pub input_tokens: u64,
13    pub output_tokens: u64,
14    pub cache_read: u64,
15    pub cache_creation: u64,
16    /// TTL split of `cache_creation` across the subagent's turns.
17    /// `None` only if no turn ever reported a split; otherwise the sum.
18    pub cache_creation_5m: Option<u64>,
19    pub cache_creation_1h: Option<u64>,
20    pub tool_count: u32,
21}
22
23// ── SubagentStatus ───────────────────────────────────────────────────────────────
24
25#[derive(Debug, Clone, PartialEq)]
26pub enum SubagentStatus {
27    Running,
28    Completed,
29    TimedOut,
30    Failed(String),
31}
32
33// ── SubagentState ────────────────────────────────────────────────────────────────
34
35/// All mutable state shared between the subagent thread and its handle.
36/// Collapsed behind a single RwLock so a status poll takes exactly one lock.
37#[derive(Debug)]
38pub struct SubagentState {
39    pub status: SubagentStatus,
40    pub partial_text: String,
41    pub tool_log: Vec<String>,
42    pub conversation_state: Vec<Value>,
43}
44
45impl SubagentState {
46    pub fn new() -> Self {
47        Self {
48            status: SubagentStatus::Running,
49            partial_text: String::new(),
50            tool_log: Vec::new(),
51            conversation_state: Vec::new(),
52        }
53    }
54}
55
56impl Default for SubagentState {
57    fn default() -> Self { Self::new() }
58}
59
60// ── SubagentHandle ───────────────────────────────────────────────────────────────
61
62pub struct SubagentHandle {
63    pub id: String,
64    pub agent_name: String,
65    pub task_preview: String,
66    pub model: String,
67    pub system_prompt: String,
68    pub started_at: std::time::Instant,
69    pub timeout_secs: u64,
70
71    // Shared state updated by the subagent thread — one lock for everything.
72    state: Arc<RwLock<SubagentState>>,
73
74    // Channels
75    steer_tx: Option<mpsc::UnboundedSender<String>>,
76    shutdown_tx: Option<oneshot::Sender<()>>,
77    /// OS thread running the subagent. Stored for graceful shutdown (join).
78    // OS thread handle for graceful shutdown
79    thread_handle: Option<std::thread::JoinHandle<()>>,
80
81    // Final result
82    result_rx: Option<oneshot::Receiver<SubagentResult>>,
83}
84
85impl std::fmt::Debug for SubagentHandle {
86    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
87        f.debug_struct("SubagentHandle")
88            .field("id", &self.id)
89            .field("agent_name", &self.agent_name)
90            .field("model", &self.model)
91            .finish_non_exhaustive()
92    }
93}
94
95impl SubagentHandle {
96    /// Construct a new handle. The state Arc is shared with the spawned subagent thread.
97    #[allow(clippy::too_many_arguments)]
98    pub fn new(
99        id: String,
100        agent_name: String,
101        task_preview: String,
102        model: String,
103        system_prompt: String,
104        timeout_secs: u64,
105        state: Arc<RwLock<SubagentState>>,
106        steer_tx: Option<mpsc::UnboundedSender<String>>,
107        shutdown_tx: Option<oneshot::Sender<()>>,
108        result_rx: Option<oneshot::Receiver<SubagentResult>>,
109    ) -> Self {
110        Self {
111            id,
112            agent_name,
113            task_preview,
114            model,
115            system_prompt,
116            started_at: std::time::Instant::now(),
117            timeout_secs,
118            state,
119            steer_tx,
120            shutdown_tx,
121            thread_handle: None,
122            result_rx,
123        }
124    }
125
126    /// Current status snapshot.
127    pub fn status(&self) -> SubagentStatus {
128        self.state.read().unwrap().status.clone()
129    }
130
131    /// Partial output accumulated so far.
132    pub fn partial_output(&self) -> String {
133        self.state.read().unwrap().partial_text.clone()
134    }
135
136    /// Snapshot of the tool log.
137    pub fn tool_log(&self) -> Vec<String> {
138        self.state.read().unwrap().tool_log.clone()
139    }
140
141    /// Snapshot of conversation state (for resume).
142    pub fn conversation_state(&self) -> Vec<Value> {
143        self.state.read().unwrap().conversation_state.clone()
144    }
145
146    /// Seconds since this handle was created.
147    pub fn elapsed_secs(&self) -> f64 {
148        self.started_at.elapsed().as_secs_f64()
149    }
150
151    /// Send a steering message into the running subagent.
152    pub fn steer(&self, message: &str) -> Result<(), String> {
153        match &self.steer_tx {
154            Some(tx) => tx
155                .send(message.to_string())
156                .map_err(|e| format!("steer channel closed: {e}")),
157            None => Err("no steer channel on this handle".to_string()),
158        }
159    }
160
161    /// Signal the subagent to shut down.
162    /// Store the OS thread handle for graceful shutdown.
163    pub fn set_thread_handle(&mut self, handle: std::thread::JoinHandle<()>) {
164        self.thread_handle = Some(handle);
165    }
166
167    pub fn cancel(&mut self) {
168        if let Some(tx) = self.shutdown_tx.take() {
169            let _ = tx.send(());
170        }
171    }
172
173    /// True if the subagent is no longer running.
174    pub fn is_finished(&self) -> bool {
175        !matches!(self.status(), SubagentStatus::Running)
176    }
177
178    /// Consume the handle and wait for the final result.
179    pub async fn collect(mut self) -> Result<SubagentResult, String> {
180        match self.result_rx.take() {
181            Some(rx) => rx.await.map_err(|_| "subagent result channel dropped".to_string()),
182            None => Err("no result receiver — already collected or never set".to_string()),
183        }
184    }
185}
186
187// ── SubagentRegistry ─────────────────────────────────────────────────────────────
188
189#[derive(Debug)]
190pub struct SubagentRegistry {
191    handles: HashMap<String, SubagentHandle>,
192}
193
194impl SubagentRegistry {
195    pub fn new() -> Self {
196        Self {
197            handles: HashMap::new(),
198        }
199    }
200
201    /// Register a handle and return its id.
202    pub fn register(&mut self, handle: SubagentHandle) -> String {
203        let id = handle.id.clone();
204        self.handles.insert(id.clone(), handle);
205        id
206    }
207
208    pub fn get(&self, id: &str) -> Option<&SubagentHandle> {
209        self.handles.get(id)
210    }
211
212    pub fn get_mut(&mut self, id: &str) -> Option<&mut SubagentHandle> {
213        self.handles.get_mut(id)
214    }
215
216    pub fn remove(&mut self, id: &str) -> Option<SubagentHandle> {
217        self.handles.remove(id)
218    }
219
220    /// Returns (id, agent_name, status) for every tracked handle.
221    pub fn list_active(&self) -> Vec<(String, String, SubagentStatus)> {
222        self.handles
223            .values()
224            .map(|h| (h.id.clone(), h.agent_name.clone(), h.status()))
225            .collect()
226    }
227
228    /// Drop handles that are no longer running.
229    /// Iterate over all handles mutably (for bulk operations like cancel-all).
230    pub fn iter_mut_handles(&mut self) -> impl Iterator<Item = &mut SubagentHandle> {
231        self.handles.values_mut()
232    }
233
234    pub fn cleanup_finished(&mut self) {
235        let finished_ids: Vec<String> = self.handles.iter()
236            .filter(|(_, h)| h.is_finished())
237            .map(|(id, _)| id.clone())
238            .collect();
239        for id in finished_ids {
240            if let Some(mut handle) = self.handles.remove(&id) {
241                // Join the thread to avoid zombies/resource leaks
242                if let Some(th) = handle.thread_handle.take() {
243                    let _ = th.join();
244                }
245            }
246        }
247    }
248}
249
250impl Default for SubagentRegistry {
251    fn default() -> Self {
252        Self::new()
253    }
254}
255
256impl SubagentStatus {
257    pub fn as_str(&self) -> &str {
258        match self {
259            SubagentStatus::Running => "running",
260            SubagentStatus::Completed => "completed",
261            SubagentStatus::TimedOut => "timed_out",
262            SubagentStatus::Failed(_) => "failed",
263        }
264    }
265}
266
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use tokio::sync::{mpsc, oneshot};
272
273    // Keep receivers alive so channels don't close during tests
274    struct TestHandle {
275        handle: SubagentHandle,
276        _steer_rx: mpsc::UnboundedReceiver<String>,
277        _shutdown_rx: oneshot::Receiver<()>,
278    }
279
280    fn make_test_handle(id: &str) -> TestHandle {
281        let state = Arc::new(RwLock::new(SubagentState::new()));
282        let (steer_tx, steer_rx) = mpsc::unbounded_channel();
283        let (shutdown_tx, shutdown_rx) = oneshot::channel();
284        let (_result_tx, result_rx) = oneshot::channel();
285        TestHandle {
286            handle: SubagentHandle::new(
287                id.to_string(),
288                "test-agent".to_string(),
289                "test task".to_string(),
290                "claude-sonnet-4-6".to_string(),
291                "You are a test agent.".to_string(),
292                300,
293                state,
294                Some(steer_tx),
295                Some(shutdown_tx),
296                Some(result_rx),
297            ),
298            _steer_rx: steer_rx,
299            _shutdown_rx: shutdown_rx,
300        }
301    }
302
303    fn make_handle(id: &str) -> SubagentHandle {
304        make_test_handle(id).handle
305    }
306
307    #[test]
308    fn handle_initial_status_is_running() {
309        let h = make_handle("sa_1");
310        assert_eq!(h.status(), SubagentStatus::Running);
311        assert!(!h.is_finished());
312    }
313
314    #[test]
315    fn handle_partial_output_empty_initially() {
316        let h = make_handle("sa_1");
317        assert_eq!(h.partial_output(), "");
318        assert!(h.tool_log().is_empty());
319        assert!(h.conversation_state().is_empty());
320    }
321
322    #[test]
323    fn handle_status_reflects_state_change() {
324        let h = make_handle("sa_1");
325        {
326            let mut s = h.state.write().unwrap();
327            s.status = SubagentStatus::Completed;
328            s.partial_text = "done!".to_string();
329        }
330        assert_eq!(h.status(), SubagentStatus::Completed);
331        assert!(h.is_finished());
332        assert_eq!(h.partial_output(), "done!");
333    }
334
335    #[test]
336    fn handle_steer_sends_message() {
337        let th = make_test_handle("sa_1");
338        assert!(th.handle.steer("redirect").is_ok());
339    }
340
341    #[test]
342    fn handle_steer_fails_without_channel() {
343        let state = Arc::new(RwLock::new(SubagentState::new()));
344        let (_shutdown_tx, _) = oneshot::channel::<()>();
345        let (_, result_rx) = oneshot::channel();
346        let h = SubagentHandle::new(
347            "sa_1".into(), "test".into(), "task".into(),
348            "model".into(), "prompt".into(), 300, state, None, None, Some(result_rx),
349        );
350        assert!(h.steer("msg").is_err());
351    }
352
353    #[test]
354    fn handle_cancel_consumes_shutdown() {
355        let mut h = make_handle("sa_1");
356        h.cancel(); // first call sends
357        h.cancel(); // second call is no-op (already taken)
358    }
359
360    #[test]
361    fn handle_elapsed_increases() {
362        let h = make_handle("sa_1");
363        std::thread::sleep(std::time::Duration::from_millis(10));
364        assert!(h.elapsed_secs() > 0.0);
365    }
366
367    #[test]
368    fn registry_register_and_get() {
369        let mut reg = SubagentRegistry::new();
370        let h = make_handle("sa_1");
371        reg.register(h);
372        assert!(reg.get("sa_1").is_some());
373        assert!(reg.get("sa_99").is_none());
374    }
375
376    #[test]
377    fn registry_remove() {
378        let mut reg = SubagentRegistry::new();
379        reg.register(make_handle("sa_1"));
380        assert!(reg.remove("sa_1").is_some());
381        assert!(reg.get("sa_1").is_none());
382    }
383
384    #[test]
385    fn registry_list_active() {
386        let mut reg = SubagentRegistry::new();
387        reg.register(make_handle("sa_1"));
388        reg.register(make_handle("sa_2"));
389        let active = reg.list_active();
390        assert_eq!(active.len(), 2);
391    }
392
393    #[test]
394    fn registry_cleanup_finished() {
395        let mut reg = SubagentRegistry::new();
396        let h = make_handle("sa_1");
397        {
398            let mut s = h.state.write().unwrap();
399            s.status = SubagentStatus::Completed;
400        }
401        reg.register(h);
402        reg.register(make_handle("sa_2")); // still running
403        reg.cleanup_finished();
404        assert!(reg.get("sa_1").is_none()); // completed, cleaned up
405        assert!(reg.get("sa_2").is_some()); // still running, kept
406    }
407
408    #[test]
409    fn subagent_state_new_defaults() {
410        let s = SubagentState::new();
411        assert_eq!(s.status, SubagentStatus::Running);
412        assert!(s.partial_text.is_empty());
413        assert!(s.tool_log.is_empty());
414        assert!(s.conversation_state.is_empty());
415    }
416
417    #[test]
418    fn subagent_status_as_str() {
419        assert_eq!(SubagentStatus::Running.as_str(), "running");
420        assert_eq!(SubagentStatus::Completed.as_str(), "completed");
421        assert_eq!(SubagentStatus::TimedOut.as_str(), "timed_out");
422        assert_eq!(SubagentStatus::Failed("oops".into()).as_str(), "failed");
423    }
424}