use crate::agent::coder::CoderAgent;
use crate::agent::Agent;
use crate::config::AgentConfig;
use crate::llm::Message;
use crate::orchestrator::executor::TaskExecutor;
use crate::orchestrator::graph::{TaskGraph, TaskNode};
use crate::tools::{Tool, ToolContext, ToolResult};
use anyhow::Result;
use async_trait::async_trait;
use std::sync::Arc;
pub struct SpawnTaskTool;
#[async_trait]
impl Tool for SpawnTaskTool {
fn name(&self) -> &str {
"spawn_task"
}
fn description(&self) -> &str {
"Delegate work to one or more sub-agent(s). \
In single-task mode, provide 'description' to run a CoderAgent on that task \
and get back its final message. \
In multi-task mode, provide a 'tasks' array with id/description/depends_on \
fields to run tasks in parallel with dependency ordering and get back a \
summary report. Nesting depth is capped at 3 levels."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"description": {
"type": "string",
"description": "Single-task mode: the task description to give to a CoderAgent."
},
"tasks": {
"type": "array",
"description": "Multi-task mode: list of tasks with dependency ordering.",
"items": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "Unique identifier for this task (e.g. 'write-tests')."
},
"description": {
"type": "string",
"description": "What this sub-agent should do."
},
"depends_on": {
"type": "array",
"items": { "type": "string" },
"description": "IDs of tasks that must complete before this one starts.",
"default": []
}
},
"required": ["id", "description"]
}
},
"parallel": {
"type": "boolean",
"description": "Run independent tasks in parallel (default: true). \
Set false to force serial execution.",
"default": true
},
"max_concurrent": {
"type": "integer",
"description": "Maximum tasks running at the same time (default: 4).",
"default": 4
}
}
})
}
async fn execute(&self, args: serde_json::Value, ctx: &ToolContext) -> Result<ToolResult> {
if ctx.nesting_depth >= 3 {
return Ok(ToolResult {
output: format!(
"spawn_task refused: maximum nesting depth (3) reached. \
Current depth: {}. Cannot spawn further sub-agents.",
ctx.nesting_depth
),
is_error: true,
});
}
let mut sub_ctx = ctx.clone();
sub_ctx.nesting_depth += 1;
let tasks_val = args.get("tasks");
let description_val = args.get("description").and_then(|v| v.as_str());
match (tasks_val, description_val) {
(Some(tasks_arr), _) => {
let tasks = match tasks_arr.as_array() {
Some(arr) => arr,
None => {
return Ok(ToolResult {
output: "spawn_task: 'tasks' must be a JSON array, not a scalar."
.to_string(),
is_error: true,
});
}
};
if tasks.is_empty() {
return Ok(ToolResult {
output: "spawn_task: 'tasks' array is empty — nothing to execute."
.to_string(),
is_error: false,
});
}
let parallel = args
.get("parallel")
.and_then(|v| v.as_bool())
.unwrap_or(true);
let max_concurrent = args
.get("max_concurrent")
.and_then(|v| v.as_u64())
.unwrap_or(4) as usize;
let mut graph = TaskGraph::new();
for task_val in tasks {
let id = match task_val.get("id").and_then(|v| v.as_str()) {
Some(s) => s,
None => {
return Ok(ToolResult {
output: "spawn_task: each task must have an 'id' string field."
.to_string(),
is_error: true,
});
}
};
let desc = match task_val.get("description").and_then(|v| v.as_str()) {
Some(s) => s,
None => {
return Ok(ToolResult {
output: format!(
"spawn_task: task '{}' must have a 'description' string field.",
id
),
is_error: true,
});
}
};
let mut node = TaskNode::new(id, desc);
if let Some(deps) = task_val.get("depends_on").and_then(|v| v.as_array()) {
for dep_val in deps {
if let Some(dep_id) = dep_val.as_str() {
node = node.with_dependency(dep_id);
}
}
}
graph.add_task(node)?;
}
let concurrency = if parallel { max_concurrent } else { 1 };
let executor = TaskExecutor::new(graph).with_max_concurrent(concurrency);
let report = executor
.run(
Arc::clone(&ctx.llm),
Arc::clone(&ctx.tools),
sub_ctx,
Arc::clone(&ctx.io),
)
.await?;
let completed_count = report.task_results.len();
let failed_count = report.failed.len();
let cancelled_count = report.cancelled.len();
let mut output = format!(
"spawn_task completed: {} succeeded, {} failed, {} cancelled\n\
Duration: {:.1}s\n\n",
completed_count,
failed_count,
cancelled_count,
report.total_duration.as_secs_f64(),
);
if !report.task_results.is_empty() {
output.push_str("## Completed tasks\n");
let mut results: Vec<_> = report.task_results.iter().collect();
results.sort_by_key(|(id, _)| id.as_str());
for (id, result) in results {
let preview: String = result.final_message.chars().take(200).collect();
let ellipsis = if result.final_message.len() > 200 {
"…"
} else {
""
};
output.push_str(&format!(
"- **{}**: {}{} (iters={}, tools={})\n",
id, preview, ellipsis, result.iterations, result.tool_calls_total,
));
}
output.push('\n');
}
if !report.failed.is_empty() {
output.push_str(&format!(
"## Failed tasks\n{}\n\n",
report.failed.join(", ")
));
}
if !report.cancelled.is_empty() {
output.push_str(&format!(
"## Cancelled tasks\n{}\n\n",
report.cancelled.join(", ")
));
}
Ok(ToolResult {
output,
is_error: !report.failed.is_empty(),
})
}
(None, Some(description)) => {
let agent = CoderAgent::new(AgentConfig::default());
let mut messages = vec![
Message::system(agent.system_prompt().as_str()),
Message::user(description),
];
let result = agent
.run(
&mut messages,
ctx.tools.as_ref(), ctx.llm.as_ref(), &sub_ctx,
)
.await?;
Ok(ToolResult {
output: result.final_message,
is_error: false,
})
}
(None, None) => Ok(ToolResult {
output: "spawn_task: provide either 'description' (single-task mode) \
or a 'tasks' array (multi-task mode)."
.to_string(),
is_error: true,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::io::NullIO;
use crate::llm::NullLlmProvider;
use std::path::PathBuf;
use tokio::sync::Mutex;
fn make_ctx() -> ToolContext {
let llm = Arc::new(NullLlmProvider);
let tools = Arc::new(crate::tools::ToolRegistry::new());
ToolContext {
working_dir: PathBuf::from("/tmp"),
sandbox_enabled: false,
io: Arc::new(NullIO),
compact_mode: false,
lsp_client: Arc::new(Mutex::new(None)),
mcp_client: None,
nesting_depth: 0,
llm,
tools,
permissions: vec![],
formatters: std::collections::HashMap::new(),
}
}
#[test]
fn test_spawn_task_name() {
assert_eq!(SpawnTaskTool.name(), "spawn_task");
}
#[test]
fn test_spawn_task_description_non_empty() {
assert!(!SpawnTaskTool.description().is_empty());
}
#[test]
fn test_spawn_task_schema_has_required_properties() {
let schema = SpawnTaskTool.parameters_schema();
let props = &schema["properties"];
assert!(
props["description"].is_object(),
"schema missing 'description'"
);
assert!(props["tasks"].is_object(), "schema missing 'tasks'");
assert!(props["parallel"].is_object(), "schema missing 'parallel'");
assert!(
props["max_concurrent"].is_object(),
"schema missing 'max_concurrent'"
);
}
#[tokio::test]
async fn test_nesting_depth_at_limit_is_refused() {
let mut ctx = make_ctx();
ctx.nesting_depth = 3;
let result = SpawnTaskTool
.execute(serde_json::json!({ "description": "test" }), &ctx)
.await
.unwrap();
assert!(result.is_error, "should be an error at depth 3");
assert!(
result.output.contains("maximum nesting depth"),
"error message should mention nesting depth: {}",
result.output
);
}
#[tokio::test]
async fn test_nesting_depth_below_limit_is_allowed_to_attempt() {
let mut ctx = make_ctx();
ctx.nesting_depth = 2;
let result = SpawnTaskTool
.execute(serde_json::json!({ "tasks": [] }), &ctx)
.await
.unwrap();
assert!(!result.is_error);
assert!(result.output.contains("empty"));
}
#[tokio::test]
async fn test_no_args_returns_usage_error() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(serde_json::json!({}), &ctx)
.await
.unwrap();
assert!(result.is_error, "missing args should be an error");
let msg = result.output.to_lowercase();
assert!(
msg.contains("description") || msg.contains("tasks"),
"error message should mention required fields: {}",
result.output
);
}
#[tokio::test]
async fn test_empty_tasks_array_succeeds() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(serde_json::json!({ "tasks": [] }), &ctx)
.await
.unwrap();
assert!(!result.is_error);
assert!(result.output.contains("empty"));
}
#[tokio::test]
async fn test_tasks_not_array_returns_error() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(serde_json::json!({ "tasks": "not-an-array" }), &ctx)
.await
.unwrap();
assert!(result.is_error, "non-array tasks should be an error");
}
#[tokio::test]
async fn test_task_missing_id_returns_error() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(
serde_json::json!({ "tasks": [{ "description": "no id here" }] }),
&ctx,
)
.await
.unwrap();
assert!(result.is_error, "task without 'id' should be an error");
}
#[tokio::test]
async fn test_task_missing_description_returns_error() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(serde_json::json!({ "tasks": [{ "id": "t1" }] }), &ctx)
.await
.unwrap();
assert!(
result.is_error,
"task without 'description' should be an error"
);
}
#[tokio::test]
async fn test_unknown_dependency_returns_error() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(
serde_json::json!({
"tasks": [
{
"id": "t1",
"description": "first task",
"depends_on": ["does-not-exist"]
}
]
}),
&ctx,
)
.await;
assert!(
result.is_err() || result.unwrap().is_error,
"dependency on unknown task should fail"
);
}
#[tokio::test]
async fn test_parallel_false_accepted_as_serial() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(serde_json::json!({ "tasks": [], "parallel": false }), &ctx)
.await
.unwrap();
assert!(!result.is_error);
}
#[tokio::test]
async fn test_max_concurrent_respected() {
let ctx = make_ctx();
let result = SpawnTaskTool
.execute(
serde_json::json!({ "tasks": [], "max_concurrent": 1 }),
&ctx,
)
.await
.unwrap();
assert!(!result.is_error);
}
}