Skip to main content

hh_cli/core/agent/
subagent_manager.rs

1use crate::session::types::{SubAgentFailureReason, SubAgentLifecycleStatus};
2use crate::session::{SessionEvent, SessionStore, event_id};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tokio::sync::{Mutex, Semaphore};
10use tokio::time::{Duration, sleep};
11use uuid::Uuid;
12
13const MAX_EVENT_CONTENT_BYTES: usize = 16 * 1024;
14const MAX_PARENT_SUMMARY_BYTES: usize = 2048;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "snake_case")]
18pub enum SubagentStatus {
19    Pending,
20    Running,
21    Completed,
22    Failed,
23    Cancelled,
24}
25
26impl SubagentStatus {
27    pub fn is_terminal(&self) -> bool {
28        matches!(
29            self,
30            SubagentStatus::Completed | SubagentStatus::Failed | SubagentStatus::Cancelled
31        )
32    }
33
34    pub fn label(&self) -> &'static str {
35        match self {
36            SubagentStatus::Pending => "queued",
37            SubagentStatus::Running => "running",
38            SubagentStatus::Completed => "done",
39            SubagentStatus::Failed => "error",
40            SubagentStatus::Cancelled => "cancelled",
41        }
42    }
43
44    fn as_lifecycle_status(&self) -> SubAgentLifecycleStatus {
45        match self {
46            Self::Pending => SubAgentLifecycleStatus::Pending,
47            Self::Running => SubAgentLifecycleStatus::Running,
48            Self::Completed => SubAgentLifecycleStatus::Completed,
49            Self::Failed => SubAgentLifecycleStatus::Failed,
50            Self::Cancelled => SubAgentLifecycleStatus::Cancelled,
51        }
52    }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SubagentNode {
57    pub task_id: String,
58    pub name: String,
59    pub parent_task_id: Option<String>,
60    pub parent_session_id: String,
61    pub agent_name: String,
62    pub prompt: String,
63    pub depth: usize,
64    pub session_id: String,
65    pub status: SubagentStatus,
66    pub started_at: u64,
67    pub updated_at: u64,
68    pub summary: Option<String>,
69    pub error: Option<String>,
70    pub failure_reason: Option<SubAgentFailureReason>,
71    pub progress_seq: u64,
72}
73
74#[derive(Debug, Clone)]
75pub struct SubagentRequest {
76    pub name: String,
77    pub description: String,
78    pub prompt: String,
79    pub subagent_type: String,
80    pub resume_task_id: Option<String>,
81    pub parent_session_id: String,
82    pub parent_task_id: Option<String>,
83    pub depth: usize,
84}
85
86#[derive(Debug, Clone, Serialize)]
87pub struct SubagentAcceptance {
88    pub task_id: String,
89    pub status: String,
90    pub message: String,
91}
92
93#[derive(Debug, Clone)]
94pub struct SubagentExecutionRequest {
95    pub task_id: String,
96    pub name: String,
97    pub parent_session_id: String,
98    pub parent_task_id: Option<String>,
99    pub description: String,
100    pub prompt: String,
101    pub subagent_type: String,
102    pub child_session_id: String,
103    pub depth: usize,
104}
105
106#[derive(Debug, Clone)]
107pub struct SubagentExecutionResult {
108    pub status: SubagentStatus,
109    pub summary: String,
110    pub error: Option<String>,
111    pub failure_reason: Option<SubAgentFailureReason>,
112}
113
114type SubagentExecutionFuture = Pin<Box<dyn Future<Output = SubagentExecutionResult> + Send>>;
115pub type SubagentExecutor =
116    Arc<dyn Fn(SubagentExecutionRequest) -> SubagentExecutionFuture + Send + Sync>;
117
118#[derive(Clone)]
119pub struct SubagentManager {
120    inner: Arc<Mutex<SubagentManagerState>>,
121    queue: Arc<Semaphore>,
122    max_depth: usize,
123    executor: SubagentExecutor,
124}
125
126#[derive(Default)]
127struct SubagentManagerState {
128    by_task_id: HashMap<String, SubagentNode>,
129    children_by_parent: HashMap<String, Vec<String>>,
130}
131
132impl SubagentManager {
133    pub fn new(max_parallel: usize, max_depth: usize, executor: SubagentExecutor) -> Self {
134        Self {
135            inner: Arc::new(Mutex::new(SubagentManagerState::default())),
136            queue: Arc::new(Semaphore::new(max_parallel.max(1))),
137            max_depth,
138            executor,
139        }
140    }
141
142    pub async fn start_or_resume(
143        &self,
144        request: SubagentRequest,
145        parent_session: SessionStore,
146    ) -> anyhow::Result<SubagentAcceptance> {
147        let child_depth = request.depth.saturating_add(1);
148        if child_depth > self.max_depth {
149            anyhow::bail!(
150                "sub-agent depth {} exceeds configured limit {}",
151                child_depth,
152                self.max_depth
153            );
154        }
155
156        let now = now_secs();
157        let mut state = self.inner.lock().await;
158
159        let (task_id, child_session_id, should_spawn) =
160            if let Some(task_id) = request.resume_task_id.as_ref() {
161                let Some(existing) = state.by_task_id.get_mut(task_id) else {
162                    anyhow::bail!("unknown task_id '{}'", task_id);
163                };
164                if existing.parent_session_id != request.parent_session_id {
165                    anyhow::bail!(
166                        "task_id '{}' is not owned by current parent session",
167                        task_id
168                    );
169                }
170
171                if matches!(
172                    existing.status,
173                    SubagentStatus::Pending | SubagentStatus::Running
174                ) {
175                    return Ok(SubagentAcceptance {
176                        task_id: task_id.clone(),
177                        status: existing.status.label().to_string(),
178                        message: "sub-agent is already active".to_string(),
179                    });
180                }
181
182                existing.status = SubagentStatus::Pending;
183                existing.updated_at = now;
184                existing.name = request.name.clone();
185                existing.summary = None;
186                existing.error = None;
187                existing.failure_reason = None;
188
189                (task_id.clone(), existing.session_id.clone(), true)
190            } else {
191                let task_id = Uuid::now_v7().to_string();
192                let child_session_id = Uuid::new_v4().to_string();
193                let node = SubagentNode {
194                    task_id: task_id.clone(),
195                    name: request.name.clone(),
196                    parent_task_id: request.parent_task_id.clone(),
197                    parent_session_id: request.parent_session_id.clone(),
198                    agent_name: request.subagent_type.clone(),
199                    prompt: request.prompt.clone(),
200                    depth: child_depth,
201                    session_id: child_session_id.clone(),
202                    status: SubagentStatus::Pending,
203                    started_at: now,
204                    updated_at: now,
205                    summary: None,
206                    error: None,
207                    failure_reason: None,
208                    progress_seq: 0,
209                };
210                state.by_task_id.insert(task_id.clone(), node);
211                state
212                    .children_by_parent
213                    .entry(request.parent_session_id.clone())
214                    .or_default()
215                    .push(task_id.clone());
216                (task_id, child_session_id, true)
217            };
218
219        drop(state);
220
221        parent_session.append(&SessionEvent::SubAgentStart {
222            id: event_id(),
223            task_id: Some(task_id.clone()),
224            name: Some(request.name.clone()),
225            parent_id: request.parent_task_id.clone(),
226            parent_session_id: Some(request.parent_session_id.clone()),
227            agent_name: Some(request.subagent_type.clone()),
228            session_id: Some(child_session_id.clone()),
229            status: SubAgentLifecycleStatus::Pending,
230            created_at: now,
231            updated_at: now,
232            prompt: bounded_text(&request.prompt, MAX_EVENT_CONTENT_BYTES),
233            depth: child_depth,
234        })?;
235
236        if should_spawn {
237            let execution = SubagentExecutionRequest {
238                task_id: task_id.clone(),
239                name: request.name,
240                parent_session_id: request.parent_session_id,
241                parent_task_id: request.parent_task_id,
242                description: request.description,
243                prompt: request.prompt,
244                subagent_type: request.subagent_type,
245                child_session_id,
246                depth: child_depth,
247            };
248            self.spawn_task(parent_session, execution);
249        }
250
251        Ok(SubagentAcceptance {
252            task_id,
253            status: SubagentStatus::Pending.label().to_string(),
254            message: "sub-agent accepted".to_string(),
255        })
256    }
257
258    pub async fn list_for_parent(&self, parent_session_id: &str) -> Vec<SubagentNode> {
259        let state = self.inner.lock().await;
260        let mut nodes = state
261            .children_by_parent
262            .get(parent_session_id)
263            .into_iter()
264            .flat_map(|task_ids| task_ids.iter())
265            .filter_map(|task_id| state.by_task_id.get(task_id))
266            .cloned()
267            .collect::<Vec<_>>();
268        nodes.sort_by(|a, b| {
269            a.started_at
270                .cmp(&b.started_at)
271                .then(a.task_id.cmp(&b.task_id))
272        });
273        nodes
274    }
275
276    pub async fn wait_for_terminal(
277        &self,
278        parent_session_id: &str,
279        task_id: &str,
280    ) -> anyhow::Result<SubagentNode> {
281        loop {
282            let maybe_node = {
283                let state = self.inner.lock().await;
284                let Some(node) = state.by_task_id.get(task_id) else {
285                    anyhow::bail!("unknown task_id '{task_id}'");
286                };
287                if node.parent_session_id != parent_session_id {
288                    anyhow::bail!(
289                        "task_id '{}' is not owned by current parent session",
290                        task_id
291                    );
292                }
293                if node.status.is_terminal() {
294                    Some(node.clone())
295                } else {
296                    None
297                }
298            };
299
300            if let Some(node) = maybe_node {
301                return Ok(node);
302            }
303
304            sleep(Duration::from_millis(50)).await;
305        }
306    }
307
308    pub async fn wait_for_all(&self, parent_session_id: &str) {
309        loop {
310            let pending = {
311                let state = self.inner.lock().await;
312                state
313                    .children_by_parent
314                    .get(parent_session_id)
315                    .into_iter()
316                    .flat_map(|task_ids| task_ids.iter())
317                    .filter_map(|task_id| state.by_task_id.get(task_id))
318                    .any(|node| !node.status.is_terminal())
319            };
320            if !pending {
321                return;
322            }
323            sleep(Duration::from_millis(50)).await;
324        }
325    }
326
327    fn spawn_task(&self, parent_session: SessionStore, execution: SubagentExecutionRequest) {
328        let queue = Arc::clone(&self.queue);
329        let manager = self.clone();
330        let executor = Arc::clone(&self.executor);
331        tokio::spawn(async move {
332            let task_id = execution.task_id.clone();
333            let permit = match queue.acquire_owned().await {
334                Ok(permit) => permit,
335                Err(_) => {
336                    manager
337                        .finish_task(
338                            &parent_session,
339                            &task_id,
340                            SubagentExecutionResult {
341                                status: SubagentStatus::Failed,
342                                summary: "sub-agent queue is unavailable".to_string(),
343                                error: Some("queue unavailable".to_string()),
344                                failure_reason: Some(SubAgentFailureReason::RuntimeError),
345                            },
346                        )
347                        .await;
348                    return;
349                }
350            };
351
352            manager.mark_running(&parent_session, &task_id).await;
353            let result = executor(execution).await;
354            manager.finish_task(&parent_session, &task_id, result).await;
355            drop(permit);
356        });
357    }
358
359    async fn mark_running(&self, parent_session: &SessionStore, task_id: &str) {
360        let mut state = self.inner.lock().await;
361        let Some(node) = state.by_task_id.get_mut(task_id) else {
362            return;
363        };
364        if node.status.is_terminal() {
365            return;
366        }
367        node.status = SubagentStatus::Running;
368        node.updated_at = now_secs();
369        node.progress_seq = node.progress_seq.saturating_add(1);
370        let seq = node.progress_seq;
371        let _ = parent_session.append(&SessionEvent::SubAgentProgress {
372            id: event_id(),
373            task_id: Some(task_id.to_string()),
374            seq,
375            content: "sub-agent execution started".to_string(),
376        });
377    }
378
379    async fn finish_task(
380        &self,
381        parent_session: &SessionStore,
382        task_id: &str,
383        mut result: SubagentExecutionResult,
384    ) {
385        let mut state = self.inner.lock().await;
386        let Some(node) = state.by_task_id.get_mut(task_id) else {
387            return;
388        };
389
390        if node.status.is_terminal() {
391            return;
392        }
393
394        if !result.status.is_terminal() {
395            result.status = if result.error.is_some() {
396                SubagentStatus::Failed
397            } else {
398                SubagentStatus::Completed
399            };
400        }
401
402        node.status = result.status.clone();
403        node.updated_at = now_secs();
404        node.summary = Some(bounded_text(&result.summary, MAX_PARENT_SUMMARY_BYTES));
405        node.error = result
406            .error
407            .as_ref()
408            .map(|text| bounded_text(text, MAX_EVENT_CONTENT_BYTES));
409        node.failure_reason = result.failure_reason.clone();
410
411        let output = node.error.clone().unwrap_or_else(|| result.summary.clone());
412        let _ = parent_session.append(&SessionEvent::SubAgentResult {
413            id: event_id(),
414            task_id: Some(task_id.to_string()),
415            status: node.status.as_lifecycle_status(),
416            summary: node.summary.clone(),
417            failure_reason: node.failure_reason.clone(),
418            is_error: matches!(
419                node.status,
420                SubagentStatus::Failed | SubagentStatus::Cancelled
421            ),
422            output: bounded_text(&output, MAX_EVENT_CONTENT_BYTES),
423        });
424    }
425}
426
427fn now_secs() -> u64 {
428    SystemTime::now()
429        .duration_since(UNIX_EPOCH)
430        .map_or(0, |duration| duration.as_secs())
431}
432
433fn bounded_text(input: &str, max_bytes: usize) -> String {
434    if input.len() <= max_bytes {
435        return input.to_string();
436    }
437
438    let mut out = String::with_capacity(max_bytes + 32);
439    for ch in input.chars() {
440        if out.len() + ch.len_utf8() > max_bytes {
441            break;
442        }
443        out.push(ch);
444    }
445    out.push_str("\n...[truncated]");
446    out
447}