mermaid_cli/providers/tool/
subagent.rs1use std::sync::Arc;
26use std::time::{Duration, Instant};
27
28use async_trait::async_trait;
29use serde_json::Value;
30use tokio::sync::{Semaphore, mpsc};
31use tokio::time::timeout;
32use tokio_util::sync::CancellationToken;
33
34use crate::domain::{
35 Msg, State, ToolDefinition, ToolMetadata, ToolOutcome, ToolRunMetadata, TurnState, update,
36};
37use crate::effect::{EffectRunner, MSG_CHANNEL_CAPACITY};
38use crate::models::MessageRole;
39use crate::providers::ProviderFactory;
40use crate::providers::ctx::{ExecContext, ProgressEvent, SubagentPhase};
41
42use super::ToolExecutor;
43use super::ToolRegistry;
44
45pub const MAX_DEPTH: usize = 3;
48
49pub const MAX_INFLIGHT: usize = 10;
54
55pub const DEFAULT_TIMEOUT_SECS: u64 = 20 * 60;
58
59tokio::task_local! {
60 static SUBAGENT_DEPTH: usize;
65}
66
67pub struct SubagentSpawner {
69 providers: Arc<ProviderFactory>,
70 inflight: Arc<Semaphore>,
71}
72
73impl SubagentSpawner {
74 pub fn new(providers: Arc<ProviderFactory>) -> Self {
75 Self {
76 providers,
77 inflight: Arc::new(Semaphore::new(MAX_INFLIGHT)),
78 }
79 }
80}
81
82pub struct SubagentTool {
84 spawner: Arc<SubagentSpawner>,
85}
86
87impl SubagentTool {
88 pub fn new(spawner: Arc<SubagentSpawner>) -> Self {
89 Self { spawner }
90 }
91}
92
93#[async_trait]
94impl ToolExecutor for SubagentTool {
95 fn name(&self) -> &'static str {
96 "agent"
97 }
98
99 fn schema(&self) -> ToolDefinition {
100 ToolDefinition {
101 name: "agent".to_string(),
102 description: format!(
103 "Spawn a child agent with its own context and tool access to work on an \
104 independent sub-task. Useful for parallel fan-out (emit multiple `agent` \
105 calls in the same turn to run them concurrently) or for scoping a noisy \
106 sub-task (the child's tool output doesn't clutter the parent's turn). \
107 Depth-capped at {max_depth}; breadth-capped at {max_breadth} concurrent. \
108 Subagents don't get GUI (screenshot/click/…) access because coordinate \
109 metadata can't be shared cleanly.",
110 max_depth = MAX_DEPTH,
111 max_breadth = MAX_INFLIGHT,
112 ),
113 input_schema: serde_json::json!({
114 "type": "object",
115 "properties": {
116 "prompt": {
117 "type": "string",
118 "description": "The task for the subagent. Self-contained; the subagent has no access to the parent's conversation."
119 },
120 "description": {
121 "type": "string",
122 "description": "Short label shown in the parent's status line (e.g. 'list domain files')."
123 }
124 },
125 "required": ["prompt"]
126 }),
127 }
128 }
129
130 async fn execute(&self, args: Value, ctx: ExecContext) -> ToolOutcome {
131 let started = Instant::now();
132
133 let current_depth = SUBAGENT_DEPTH.try_with(|d| *d).unwrap_or(0);
135 if current_depth >= MAX_DEPTH {
136 return ToolOutcome::error(format!("subagent depth limit {} reached", MAX_DEPTH), 0.0);
137 }
138
139 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
141 Some(s) if !s.trim().is_empty() => s.to_string(),
142 _ => {
143 return ToolOutcome::error("agent requires non-empty `prompt`", 0.0);
144 },
145 };
146 let description = args
147 .get("description")
148 .and_then(|v| v.as_str())
149 .unwrap_or("subagent")
150 .to_string();
151
152 let permit = tokio::select! {
156 biased;
157 _ = ctx.token.cancelled() => return ToolOutcome::cancelled(),
158 p = self.spawner.inflight.clone().acquire_owned() => match p {
159 Ok(permit) => permit,
160 Err(_) => return ToolOutcome::error(
161 "subagent semaphore closed",
162 started.elapsed().as_secs_f64(),
163 ),
164 },
165 };
166
167 let config = (*ctx.config).clone();
177 let cwd = ctx.workdir.clone();
178 let model_id = if ctx.model_id.is_empty() {
179 default_model_id(&config)
180 } else {
181 ctx.model_id.clone()
182 };
183 let child_model_id = model_id.clone();
184 let child_state = State::new(config.clone(), cwd.clone(), model_id);
185
186 let child_tools = build_child_registry(self.spawner.providers.clone());
187
188 let child_token = ctx.token.child_token();
192 let (child_tx, child_rx) = mpsc::channel(MSG_CHANNEL_CAPACITY);
193 let child_runner =
194 EffectRunner::new_child(child_tx, cwd, self.spawner.providers.clone(), child_tools);
195
196 let drive = drive_child(
199 child_state,
200 child_runner,
201 child_rx,
202 ctx.progress.clone(),
203 prompt,
204 description.clone(),
205 child_token,
206 );
207 let depth_scoped = SUBAGENT_DEPTH.scope(current_depth + 1, drive);
208
209 let result = timeout(Duration::from_secs(DEFAULT_TIMEOUT_SECS), depth_scoped).await;
210 drop(permit);
211
212 let elapsed = started.elapsed().as_secs_f64();
213 match result {
214 Ok(Ok(summary)) => ToolOutcome::success(summary, "subagent completed", elapsed)
215 .with_metadata(subagent_metadata(child_model_id)),
216 Ok(Err(DriveError::Cancelled)) => ToolOutcome::cancelled(),
217 Ok(Err(DriveError::Errored(e))) => {
218 ToolOutcome::error(format!("subagent ({}): {}", description, e), elapsed)
219 .with_metadata(subagent_metadata(child_model_id))
220 },
221 Err(_) => ToolOutcome::error(
222 format!(
223 "subagent ({}) exceeded {}s timeout",
224 description, DEFAULT_TIMEOUT_SECS
225 ),
226 elapsed,
227 )
228 .with_metadata(subagent_metadata(child_model_id)),
229 }
230 }
231}
232
233fn subagent_metadata(model_id: String) -> ToolRunMetadata {
234 ToolRunMetadata {
235 detail: ToolMetadata::Subagent { model_id },
236 ..ToolRunMetadata::default()
237 }
238}
239
240enum DriveError {
241 Cancelled,
242 Errored(String),
243}
244
245async fn drive_child(
249 mut state: State,
250 mut runner: EffectRunner,
251 mut msg_rx: mpsc::Receiver<Msg>,
252 parent_progress: mpsc::Sender<ProgressEvent>,
253 prompt: String,
254 description: String,
255 token: CancellationToken,
256) -> Result<String, DriveError> {
257 let _ = parent_progress
259 .send(ProgressEvent::SubagentText(format!(
260 "▶ {} — {}",
261 description,
262 prompt.chars().take(80).collect::<String>()
263 )))
264 .await;
265
266 runner.dispatch(crate::domain::Cmd::RefreshInstructions);
268
269 let seed = Msg::SubmitPrompt {
271 text: prompt,
272 attachment_ids: vec![],
273 };
274 let (new_state, cmds) = update(state, seed);
275 state = new_state;
276 for cmd in cmds {
277 runner.dispatch(cmd);
278 }
279
280 loop {
282 if token.is_cancelled() {
283 runner.shutdown().await;
284 return Err(DriveError::Cancelled);
285 }
286 if matches!(state.turn, TurnState::Idle) && state.ui.queued_messages.is_empty() {
287 break;
288 }
289
290 let msg = tokio::select! {
291 biased;
292 _ = token.cancelled() => {
293 runner.shutdown().await;
294 return Err(DriveError::Cancelled);
295 },
296 recv = msg_rx.recv() => match recv {
297 Some(m) => m,
298 None => {
299 break;
301 },
302 },
303 };
304
305 forward_child_event(&msg, &parent_progress, &state).await;
309
310 let (new_state, cmds) = update(state, msg);
311 state = new_state;
312 for cmd in cmds {
313 runner.dispatch(cmd);
314 }
315 if state.should_exit {
316 break;
317 }
318 }
319
320 runner.shutdown().await;
321
322 let summary = state
324 .session
325 .messages()
326 .iter()
327 .rev()
328 .find(|m| m.role == MessageRole::Assistant)
329 .map(|m| m.content.clone())
330 .unwrap_or_default();
331 if summary.trim().is_empty() {
332 return Err(DriveError::Errored(
333 "subagent produced no assistant output".to_string(),
334 ));
335 }
336 Ok(summary)
337}
338
339async fn forward_child_event(msg: &Msg, progress: &mpsc::Sender<ProgressEvent>, state: &State) {
344 match msg {
345 Msg::ToolStarted {
346 turn: _, call_id, ..
347 } => {
348 let tool_name = lookup_tool_name(state, *call_id).unwrap_or_else(|| "tool".to_string());
349 let _ = progress
350 .send(ProgressEvent::SubagentToolCall {
351 child_call_id: *call_id,
352 tool_name,
353 phase: SubagentPhase::Started,
354 })
355 .await;
356 },
357 Msg::ToolFinished {
358 turn: _,
359 call_id,
360 outcome,
361 } => {
362 let tool_name = lookup_tool_name(state, *call_id).unwrap_or_else(|| "tool".to_string());
363 let phase = if outcome.is_success() {
364 SubagentPhase::Finished
365 } else {
366 SubagentPhase::Errored
367 };
368 let _ = progress
369 .send(ProgressEvent::SubagentToolCall {
370 child_call_id: *call_id,
371 tool_name,
372 phase,
373 })
374 .await;
375 },
376 Msg::StreamText { chunk, .. } => {
377 if !chunk.trim().is_empty() {
380 let snippet: String = chunk.chars().take(120).collect();
381 let _ = progress.send(ProgressEvent::SubagentText(snippet)).await;
382 }
383 },
384 _ => {},
385 }
386}
387
388fn lookup_tool_name(state: &State, call_id: crate::domain::ToolCallId) -> Option<String> {
391 match &state.turn {
392 TurnState::ExecutingTools { calls, .. } => calls
393 .iter()
394 .find(|c| c.call_id == call_id)
395 .map(|c| c.source.function.name.clone()),
396 _ => None,
397 }
398}
399
400fn build_child_registry(providers: Arc<ProviderFactory>) -> Arc<ToolRegistry> {
414 use super::{
415 computer_use, exec, filesystem, mcp,
416 web::{WebFetchTool, WebSearchTool},
417 };
418 let mut r = ToolRegistry::new();
419 r.register(Arc::new(filesystem::ReadFileTool));
420 r.register(Arc::new(filesystem::WriteFileTool));
421 r.register(Arc::new(filesystem::EditFileTool));
422 r.register(Arc::new(filesystem::DeleteFileTool));
423 r.register(Arc::new(filesystem::CreateDirectoryTool));
424 r.register(Arc::new(exec::ExecuteCommandTool));
425 r.register(Arc::new(mcp::McpToolProxy));
426 if let Some(key) = crate::utils::resolve_api_key("OLLAMA_API_KEY", None) {
427 r.register(Arc::new(WebSearchTool::new(key.clone())));
428 r.register(Arc::new(WebFetchTool::new(key)));
429 }
430 let _ = computer_use::probe;
434 let _ = providers;
435 Arc::new(r)
436}
437
438fn default_model_id(config: &crate::app::Config) -> String {
443 if !config.default_model.provider.is_empty() && !config.default_model.name.is_empty() {
444 format!(
445 "{}/{}",
446 config.default_model.provider, config.default_model.name
447 )
448 } else {
449 config.default_model.name.clone()
450 }
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456 use crate::domain::{ToolCallId, TurnId};
457 use crate::providers::ctx::test_exec_context;
458 use std::path::PathBuf;
459
460 #[tokio::test]
461 async fn depth_cap_rejects_when_at_max() {
462 let spawner = Arc::new(SubagentSpawner::new(Arc::new(ProviderFactory::new(
463 crate::app::Config::default(),
464 ))));
465 let tool = SubagentTool::new(spawner);
466 let (ctx, _rx) = test_exec_context(TurnId(1), ToolCallId(1), PathBuf::from("/tmp"));
467
468 let outcome = SUBAGENT_DEPTH
469 .scope(
470 MAX_DEPTH,
471 tool.execute(serde_json::json!({"prompt": "hi"}), ctx),
472 )
473 .await;
474 let error = outcome.error_message().expect("expected error");
475 assert!(
476 error.contains("depth limit"),
477 "expected depth-limit error, got: {}",
478 error
479 );
480 }
481
482 #[tokio::test]
483 async fn empty_prompt_is_rejected() {
484 let spawner = Arc::new(SubagentSpawner::new(Arc::new(ProviderFactory::new(
485 crate::app::Config::default(),
486 ))));
487 let tool = SubagentTool::new(spawner);
488 let (ctx, _rx) = test_exec_context(TurnId(1), ToolCallId(1), PathBuf::from("/tmp"));
489 let outcome = tool.execute(serde_json::json!({"prompt": " "}), ctx).await;
490 assert_eq!(outcome.status, crate::domain::ToolStatus::Error);
491 }
492
493 #[test]
497 fn default_model_id_reads_config_provider_and_name() {
498 let mut cfg = crate::app::Config::default();
499 cfg.default_model.provider = "ollama".to_string();
500 cfg.default_model.name = "qwen3-coder:30b".to_string();
501 assert_eq!(default_model_id(&cfg), "ollama/qwen3-coder:30b");
502 }
503
504 #[test]
505 fn default_model_id_returns_bare_name_when_provider_empty() {
506 let mut cfg = crate::app::Config::default();
507 cfg.default_model.name = "just-a-name".to_string();
508 assert_eq!(default_model_id(&cfg), "just-a-name");
511 }
512
513 #[test]
514 fn build_child_registry_excludes_gui_and_self() {
515 let providers = Arc::new(ProviderFactory::new(crate::app::Config::default()));
516 let r = build_child_registry(providers);
517 assert!(r.get("screenshot").is_none());
519 assert!(r.get("click").is_none());
520 assert!(r.get("type_text").is_none());
521 assert!(r.get("press_key").is_none());
522 assert!(r.get("scroll").is_none());
523 assert!(r.get("mouse_move").is_none());
524 assert!(r.get("list_windows").is_none());
525 assert!(r.get("agent").is_none());
527 assert!(r.get("read_file").is_some());
529 assert!(r.get("execute_command").is_some());
530 }
531}