1use crate::session::types::{SubAgentFailureReason, SubAgentLifecycleStatus};
2use crate::session::{SessionEvent, SessionStore, event_id};
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::future::Future;
6use std::pin::Pin;
7use std::sync::Arc;
8use std::time::{SystemTime, UNIX_EPOCH};
9use tokio::sync::{Mutex, Semaphore};
10use tokio::time::{Duration, sleep};
11use uuid::Uuid;
12
13const MAX_EVENT_CONTENT_BYTES: usize = 16 * 1024;
14const MAX_PARENT_SUMMARY_BYTES: usize = 2048;
15
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
17#[serde(rename_all = "snake_case")]
18pub enum SubagentStatus {
19 Pending,
20 Running,
21 Completed,
22 Failed,
23 Cancelled,
24}
25
26impl SubagentStatus {
27 pub fn is_terminal(&self) -> bool {
28 matches!(
29 self,
30 SubagentStatus::Completed | SubagentStatus::Failed | SubagentStatus::Cancelled
31 )
32 }
33
34 pub fn label(&self) -> &'static str {
35 match self {
36 SubagentStatus::Pending => "queued",
37 SubagentStatus::Running => "running",
38 SubagentStatus::Completed => "done",
39 SubagentStatus::Failed => "error",
40 SubagentStatus::Cancelled => "cancelled",
41 }
42 }
43
44 fn as_lifecycle_status(&self) -> SubAgentLifecycleStatus {
45 match self {
46 Self::Pending => SubAgentLifecycleStatus::Pending,
47 Self::Running => SubAgentLifecycleStatus::Running,
48 Self::Completed => SubAgentLifecycleStatus::Completed,
49 Self::Failed => SubAgentLifecycleStatus::Failed,
50 Self::Cancelled => SubAgentLifecycleStatus::Cancelled,
51 }
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct SubagentNode {
57 pub task_id: String,
58 pub name: String,
59 pub parent_task_id: Option<String>,
60 pub parent_session_id: String,
61 pub agent_name: String,
62 pub prompt: String,
63 pub depth: usize,
64 pub session_id: String,
65 pub status: SubagentStatus,
66 pub started_at: u64,
67 pub updated_at: u64,
68 pub summary: Option<String>,
69 pub error: Option<String>,
70 pub failure_reason: Option<SubAgentFailureReason>,
71 pub progress_seq: u64,
72}
73
74#[derive(Debug, Clone)]
75pub struct SubagentRequest {
76 pub name: String,
77 pub description: String,
78 pub prompt: String,
79 pub subagent_type: String,
80 pub resume_task_id: Option<String>,
81 pub parent_session_id: String,
82 pub parent_task_id: Option<String>,
83 pub depth: usize,
84}
85
86#[derive(Debug, Clone, Serialize)]
87pub struct SubagentAcceptance {
88 pub task_id: String,
89 pub status: String,
90 pub message: String,
91}
92
93#[derive(Debug, Clone)]
94pub struct SubagentExecutionRequest {
95 pub task_id: String,
96 pub name: String,
97 pub parent_session_id: String,
98 pub parent_task_id: Option<String>,
99 pub description: String,
100 pub prompt: String,
101 pub subagent_type: String,
102 pub child_session_id: String,
103 pub depth: usize,
104}
105
106#[derive(Debug, Clone)]
107pub struct SubagentExecutionResult {
108 pub status: SubagentStatus,
109 pub summary: String,
110 pub error: Option<String>,
111 pub failure_reason: Option<SubAgentFailureReason>,
112}
113
114type SubagentExecutionFuture = Pin<Box<dyn Future<Output = SubagentExecutionResult> + Send>>;
115pub type SubagentExecutor =
116 Arc<dyn Fn(SubagentExecutionRequest) -> SubagentExecutionFuture + Send + Sync>;
117
118#[derive(Clone)]
119pub struct SubagentManager {
120 inner: Arc<Mutex<SubagentManagerState>>,
121 queue: Arc<Semaphore>,
122 max_depth: usize,
123 executor: SubagentExecutor,
124}
125
126#[derive(Default)]
127struct SubagentManagerState {
128 by_task_id: HashMap<String, SubagentNode>,
129 children_by_parent: HashMap<String, Vec<String>>,
130}
131
132impl SubagentManager {
133 pub fn new(max_parallel: usize, max_depth: usize, executor: SubagentExecutor) -> Self {
134 Self {
135 inner: Arc::new(Mutex::new(SubagentManagerState::default())),
136 queue: Arc::new(Semaphore::new(max_parallel.max(1))),
137 max_depth,
138 executor,
139 }
140 }
141
142 pub async fn start_or_resume(
143 &self,
144 request: SubagentRequest,
145 parent_session: SessionStore,
146 ) -> anyhow::Result<SubagentAcceptance> {
147 let child_depth = request.depth.saturating_add(1);
148 if child_depth > self.max_depth {
149 anyhow::bail!(
150 "sub-agent depth {} exceeds configured limit {}",
151 child_depth,
152 self.max_depth
153 );
154 }
155
156 let now = now_secs();
157 let mut state = self.inner.lock().await;
158
159 let (task_id, child_session_id, should_spawn) =
160 if let Some(task_id) = request.resume_task_id.as_ref() {
161 let Some(existing) = state.by_task_id.get_mut(task_id) else {
162 anyhow::bail!("unknown task_id '{}'", task_id);
163 };
164 if existing.parent_session_id != request.parent_session_id {
165 anyhow::bail!(
166 "task_id '{}' is not owned by current parent session",
167 task_id
168 );
169 }
170
171 if matches!(
172 existing.status,
173 SubagentStatus::Pending | SubagentStatus::Running
174 ) {
175 return Ok(SubagentAcceptance {
176 task_id: task_id.clone(),
177 status: existing.status.label().to_string(),
178 message: "sub-agent is already active".to_string(),
179 });
180 }
181
182 existing.status = SubagentStatus::Pending;
183 existing.updated_at = now;
184 existing.name = request.name.clone();
185 existing.summary = None;
186 existing.error = None;
187 existing.failure_reason = None;
188
189 (task_id.clone(), existing.session_id.clone(), true)
190 } else {
191 let task_id = Uuid::now_v7().to_string();
192 let child_session_id = Uuid::new_v4().to_string();
193 let node = SubagentNode {
194 task_id: task_id.clone(),
195 name: request.name.clone(),
196 parent_task_id: request.parent_task_id.clone(),
197 parent_session_id: request.parent_session_id.clone(),
198 agent_name: request.subagent_type.clone(),
199 prompt: request.prompt.clone(),
200 depth: child_depth,
201 session_id: child_session_id.clone(),
202 status: SubagentStatus::Pending,
203 started_at: now,
204 updated_at: now,
205 summary: None,
206 error: None,
207 failure_reason: None,
208 progress_seq: 0,
209 };
210 state.by_task_id.insert(task_id.clone(), node);
211 state
212 .children_by_parent
213 .entry(request.parent_session_id.clone())
214 .or_default()
215 .push(task_id.clone());
216 (task_id, child_session_id, true)
217 };
218
219 drop(state);
220
221 parent_session.append(&SessionEvent::SubAgentStart {
222 id: event_id(),
223 task_id: Some(task_id.clone()),
224 name: Some(request.name.clone()),
225 parent_id: request.parent_task_id.clone(),
226 parent_session_id: Some(request.parent_session_id.clone()),
227 agent_name: Some(request.subagent_type.clone()),
228 session_id: Some(child_session_id.clone()),
229 status: SubAgentLifecycleStatus::Pending,
230 created_at: now,
231 updated_at: now,
232 prompt: bounded_text(&request.prompt, MAX_EVENT_CONTENT_BYTES),
233 depth: child_depth,
234 })?;
235
236 if should_spawn {
237 let execution = SubagentExecutionRequest {
238 task_id: task_id.clone(),
239 name: request.name,
240 parent_session_id: request.parent_session_id,
241 parent_task_id: request.parent_task_id,
242 description: request.description,
243 prompt: request.prompt,
244 subagent_type: request.subagent_type,
245 child_session_id,
246 depth: child_depth,
247 };
248 self.spawn_task(parent_session, execution);
249 }
250
251 Ok(SubagentAcceptance {
252 task_id,
253 status: SubagentStatus::Pending.label().to_string(),
254 message: "sub-agent accepted".to_string(),
255 })
256 }
257
258 pub async fn list_for_parent(&self, parent_session_id: &str) -> Vec<SubagentNode> {
259 let state = self.inner.lock().await;
260 let mut nodes = state
261 .children_by_parent
262 .get(parent_session_id)
263 .into_iter()
264 .flat_map(|task_ids| task_ids.iter())
265 .filter_map(|task_id| state.by_task_id.get(task_id))
266 .cloned()
267 .collect::<Vec<_>>();
268 nodes.sort_by(|a, b| {
269 a.started_at
270 .cmp(&b.started_at)
271 .then(a.task_id.cmp(&b.task_id))
272 });
273 nodes
274 }
275
276 pub async fn wait_for_terminal(
277 &self,
278 parent_session_id: &str,
279 task_id: &str,
280 ) -> anyhow::Result<SubagentNode> {
281 loop {
282 let maybe_node = {
283 let state = self.inner.lock().await;
284 let Some(node) = state.by_task_id.get(task_id) else {
285 anyhow::bail!("unknown task_id '{task_id}'");
286 };
287 if node.parent_session_id != parent_session_id {
288 anyhow::bail!(
289 "task_id '{}' is not owned by current parent session",
290 task_id
291 );
292 }
293 if node.status.is_terminal() {
294 Some(node.clone())
295 } else {
296 None
297 }
298 };
299
300 if let Some(node) = maybe_node {
301 return Ok(node);
302 }
303
304 sleep(Duration::from_millis(50)).await;
305 }
306 }
307
308 pub async fn wait_for_all(&self, parent_session_id: &str) {
309 loop {
310 let pending = {
311 let state = self.inner.lock().await;
312 state
313 .children_by_parent
314 .get(parent_session_id)
315 .into_iter()
316 .flat_map(|task_ids| task_ids.iter())
317 .filter_map(|task_id| state.by_task_id.get(task_id))
318 .any(|node| !node.status.is_terminal())
319 };
320 if !pending {
321 return;
322 }
323 sleep(Duration::from_millis(50)).await;
324 }
325 }
326
327 fn spawn_task(&self, parent_session: SessionStore, execution: SubagentExecutionRequest) {
328 let queue = Arc::clone(&self.queue);
329 let manager = self.clone();
330 let executor = Arc::clone(&self.executor);
331 tokio::spawn(async move {
332 let task_id = execution.task_id.clone();
333 let permit = match queue.acquire_owned().await {
334 Ok(permit) => permit,
335 Err(_) => {
336 manager
337 .finish_task(
338 &parent_session,
339 &task_id,
340 SubagentExecutionResult {
341 status: SubagentStatus::Failed,
342 summary: "sub-agent queue is unavailable".to_string(),
343 error: Some("queue unavailable".to_string()),
344 failure_reason: Some(SubAgentFailureReason::RuntimeError),
345 },
346 )
347 .await;
348 return;
349 }
350 };
351
352 manager.mark_running(&parent_session, &task_id).await;
353 let result = executor(execution).await;
354 manager.finish_task(&parent_session, &task_id, result).await;
355 drop(permit);
356 });
357 }
358
359 async fn mark_running(&self, parent_session: &SessionStore, task_id: &str) {
360 let mut state = self.inner.lock().await;
361 let Some(node) = state.by_task_id.get_mut(task_id) else {
362 return;
363 };
364 if node.status.is_terminal() {
365 return;
366 }
367 node.status = SubagentStatus::Running;
368 node.updated_at = now_secs();
369 node.progress_seq = node.progress_seq.saturating_add(1);
370 let seq = node.progress_seq;
371 let _ = parent_session.append(&SessionEvent::SubAgentProgress {
372 id: event_id(),
373 task_id: Some(task_id.to_string()),
374 seq,
375 content: "sub-agent execution started".to_string(),
376 });
377 }
378
379 async fn finish_task(
380 &self,
381 parent_session: &SessionStore,
382 task_id: &str,
383 mut result: SubagentExecutionResult,
384 ) {
385 let mut state = self.inner.lock().await;
386 let Some(node) = state.by_task_id.get_mut(task_id) else {
387 return;
388 };
389
390 if node.status.is_terminal() {
391 return;
392 }
393
394 if !result.status.is_terminal() {
395 result.status = if result.error.is_some() {
396 SubagentStatus::Failed
397 } else {
398 SubagentStatus::Completed
399 };
400 }
401
402 node.status = result.status.clone();
403 node.updated_at = now_secs();
404 node.summary = Some(bounded_text(&result.summary, MAX_PARENT_SUMMARY_BYTES));
405 node.error = result
406 .error
407 .as_ref()
408 .map(|text| bounded_text(text, MAX_EVENT_CONTENT_BYTES));
409 node.failure_reason = result.failure_reason.clone();
410
411 let output = node.error.clone().unwrap_or_else(|| result.summary.clone());
412 let _ = parent_session.append(&SessionEvent::SubAgentResult {
413 id: event_id(),
414 task_id: Some(task_id.to_string()),
415 status: node.status.as_lifecycle_status(),
416 summary: node.summary.clone(),
417 failure_reason: node.failure_reason.clone(),
418 is_error: matches!(
419 node.status,
420 SubagentStatus::Failed | SubagentStatus::Cancelled
421 ),
422 output: bounded_text(&output, MAX_EVENT_CONTENT_BYTES),
423 });
424 }
425}
426
427fn now_secs() -> u64 {
428 SystemTime::now()
429 .duration_since(UNIX_EPOCH)
430 .map_or(0, |duration| duration.as_secs())
431}
432
433fn bounded_text(input: &str, max_bytes: usize) -> String {
434 if input.len() <= max_bytes {
435 return input.to_string();
436 }
437
438 let mut out = String::with_capacity(max_bytes + 32);
439 for ch in input.chars() {
440 if out.len() + ch.len_utf8() > max_bytes {
441 break;
442 }
443 out.push(ch);
444 }
445 out.push_str("\n...[truncated]");
446 out
447}