use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rig::completion::ToolDefinition;
use rig::tool::{ToolDyn, ToolError};
use rustc_hash::FxHashMap;
use serde::{Deserialize, Serialize};
use serde_json::{json, Value};
use crate::ast::AgentParams;
use crate::event::{EventKind, EventLog};
use crate::mcp::McpClient;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpawnAgentParams {
pub task_id: String,
pub prompt: String,
#[serde(default)]
pub context: Option<Value>,
#[serde(default)]
pub max_turns: Option<u32>,
}
#[derive(Clone)]
pub struct SpawnAgentTool {
current_depth: u32,
max_depth: u32,
parent_task_id: Arc<str>,
event_log: EventLog,
mcp_clients: FxHashMap<String, Arc<McpClient>>,
mcp_names: Vec<String>,
}
impl SpawnAgentTool {
pub fn new(
current_depth: u32,
max_depth: u32,
parent_task_id: Arc<str>,
event_log: EventLog,
) -> Self {
Self {
current_depth,
max_depth,
parent_task_id,
event_log,
mcp_clients: FxHashMap::default(),
mcp_names: Vec::new(),
}
}
pub fn with_mcp(
current_depth: u32,
max_depth: u32,
parent_task_id: Arc<str>,
event_log: EventLog,
mcp_clients: FxHashMap<String, Arc<McpClient>>,
mcp_names: Vec<String>,
) -> Self {
Self {
current_depth,
max_depth,
parent_task_id,
event_log,
mcp_clients,
mcp_names,
}
}
pub fn name(&self) -> &str {
"spawn_agent"
}
pub fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: "spawn_agent".to_string(),
description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
runs independently with max 10 turns and returns its result."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "Unique identifier for the child task (e.g., 'subtask-1')"
},
"prompt": {
"type": "string",
"description": "Goal/prompt describing what the child agent should accomplish"
}
},
"required": ["task_id", "prompt"],
"additionalProperties": false
}),
}
}
pub async fn call(&self, args: String) -> Result<String, SpawnAgentError> {
let params: SpawnAgentParams =
serde_json::from_str(&args).map_err(|e| SpawnAgentError::InvalidArgs(e.to_string()))?;
if self.current_depth >= self.max_depth {
return Err(SpawnAgentError::DepthLimitReached {
current: self.current_depth,
max: self.max_depth,
});
}
let child_depth = self.current_depth + 1;
self.event_log.emit(EventKind::AgentSpawned {
parent_task_id: self.parent_task_id.clone(),
child_task_id: Arc::from(params.task_id.as_str()),
depth: child_depth,
});
if self.mcp_clients.is_empty() {
return Ok(json!({
"status": "spawned",
"child_task_id": params.task_id,
"depth": child_depth,
"note": "Child agent execution requires MCP client context"
})
.to_string());
}
let remaining_depth = self.max_depth.saturating_sub(self.current_depth);
let child_params = AgentParams {
prompt: params.prompt,
system: params.context.as_ref().map(|ctx| {
format!(
"Context from parent agent:\n{}",
serde_json::to_string_pretty(ctx).unwrap_or_default()
)
}),
mcp: self.mcp_names.clone(),
max_turns: params.max_turns.or(Some(10)),
depth_limit: Some(remaining_depth),
..Default::default()
};
let mut child_loop = super::RigAgentLoop::new(
params.task_id.clone(),
child_params,
self.event_log.clone(),
self.mcp_clients.clone(),
)
.map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?;
let result = child_loop
.run_auto()
.await
.map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?;
Ok(json!({
"status": "completed",
"child_task_id": params.task_id,
"depth": child_depth,
"result": result.final_output,
"turns": result.turns,
"total_tokens": result.total_tokens
})
.to_string())
}
pub fn can_spawn(&self) -> bool {
self.current_depth < self.max_depth
}
pub fn child_depth(&self) -> u32 {
self.current_depth + 1
}
}
#[derive(Debug, thiserror::Error)]
pub enum SpawnAgentError {
#[error("spawn_agent: depth limit reached (current: {current}, max: {max})")]
DepthLimitReached { current: u32, max: u32 },
#[error("spawn_agent: invalid arguments - {0}")]
InvalidArgs(String),
#[error("spawn_agent: execution failed - {0}")]
ExecutionFailed(String),
}
impl std::fmt::Debug for SpawnAgentTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpawnAgentTool")
.field("current_depth", &self.current_depth)
.field("max_depth", &self.max_depth)
.field("parent_task_id", &self.parent_task_id)
.finish()
}
}
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
impl ToolDyn for SpawnAgentTool {
fn name(&self) -> String {
"spawn_agent".to_string()
}
fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
let def = ToolDefinition {
name: "spawn_agent".to_string(),
description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
runs independently with max 10 turns and returns its result."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"task_id": {
"type": "string",
"description": "Unique identifier for the child task (e.g., 'subtask-1')"
},
"prompt": {
"type": "string",
"description": "Goal/prompt describing what the child agent should accomplish"
}
},
"required": ["task_id", "prompt"],
"additionalProperties": false
}),
};
Box::pin(async move { def })
}
fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
Box::pin(async move {
self.call(args).await.map_err(|e| {
ToolError::ToolCallError(Box::new(std::io::Error::other(e.to_string())))
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn spawn_agent_tool_name() {
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
assert_eq!(tool.name(), "spawn_agent");
}
#[test]
fn spawn_agent_tool_can_spawn() {
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
assert!(tool.can_spawn());
let at_limit = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
assert!(!at_limit.can_spawn());
}
#[test]
fn spawn_agent_tool_child_depth() {
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
assert_eq!(tool.child_depth(), 2);
}
#[test]
fn spawn_agent_params_deserializes() {
let json = json!({
"task_id": "child-1",
"prompt": "Do something",
"context": {"key": "value"},
"max_turns": 5
});
let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
assert_eq!(params.task_id, "child-1");
assert_eq!(params.prompt, "Do something");
assert!(params.context.is_some());
assert_eq!(params.max_turns, Some(5));
}
#[test]
fn spawn_agent_params_minimal() {
let json = json!({
"task_id": "child-1",
"prompt": "Do something"
});
let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
assert_eq!(params.task_id, "child-1");
assert!(params.context.is_none());
assert!(params.max_turns.is_none());
}
#[tokio::test]
async fn spawn_agent_at_max_depth_fails() {
let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
let args = json!({
"task_id": "child-1",
"prompt": "Do something"
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("depth limit"));
}
#[tokio::test]
async fn spawn_agent_below_max_depth_succeeds() {
let tool = SpawnAgentTool::new(2, 3, "parent".into(), EventLog::new());
let args = json!({
"task_id": "child-1",
"prompt": "Do something"
})
.to_string();
let result = tool.call(args).await;
assert!(result.is_ok());
let response: Value = serde_json::from_str(&result.unwrap()).unwrap();
assert_eq!(response["status"], "spawned");
assert_eq!(response["child_task_id"], "child-1");
assert_eq!(response["depth"], 3);
}
#[tokio::test]
async fn spawn_agent_emits_event() {
let event_log = EventLog::new();
let tool = SpawnAgentTool::new(1, 3, "parent".into(), event_log.clone());
let args = json!({
"task_id": "child-1",
"prompt": "Do something"
})
.to_string();
let _ = tool.call(args).await;
let events = event_log.events();
let spawned_events: Vec<_> = events
.iter()
.filter(|e| matches!(e.kind, EventKind::AgentSpawned { .. }))
.collect();
assert_eq!(spawned_events.len(), 1);
if let EventKind::AgentSpawned {
parent_task_id,
child_task_id,
depth,
} = &spawned_events[0].kind
{
assert_eq!(&**parent_task_id, "parent");
assert_eq!(&**child_task_id, "child-1");
assert_eq!(*depth, 2);
}
}
#[test]
fn tool_definition_has_required_params() {
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
let def = tool.definition();
let required = def
.parameters
.get("required")
.and_then(|v| v.as_array())
.expect("required should be an array");
assert!(required.iter().any(|v| v == "task_id"));
assert!(required.iter().any(|v| v == "prompt"));
assert_eq!(required.len(), 2);
let additional = def
.parameters
.get("additionalProperties")
.expect("additionalProperties should exist");
assert_eq!(additional, false);
}
#[test]
fn spawn_agent_implements_tool_dyn() {
use rig::tool::ToolDyn;
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
let name: String = ToolDyn::name(&tool);
assert_eq!(name, "spawn_agent");
}
#[tokio::test]
async fn spawn_agent_tool_dyn_definition_returns_correct_schema() {
use rig::tool::ToolDyn;
let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
let def = ToolDyn::definition(&tool, "test".to_string()).await;
assert_eq!(def.name, "spawn_agent");
assert!(def.description.contains("sub-agent"));
assert!(def.parameters.get("required").is_some());
}
#[tokio::test]
async fn spawn_agent_tool_dyn_call_enforces_depth_limit() {
use rig::tool::ToolDyn;
let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
let args = json!({
"task_id": "child-1",
"prompt": "Do something"
})
.to_string();
let result = ToolDyn::call(&tool, args).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("depth limit"));
}
#[test]
fn spawn_agent_with_mcp_creates_correctly() {
let event_log = EventLog::new();
let mcp_clients = FxHashMap::default();
let mcp_names = vec!["novanet".to_string()];
let tool = SpawnAgentTool::with_mcp(
1,
3,
"parent".into(),
event_log,
mcp_clients,
mcp_names.clone(),
);
assert_eq!(tool.name(), "spawn_agent");
assert!(tool.can_spawn());
assert_eq!(tool.child_depth(), 2);
}
#[test]
fn depth_calculation_allows_three_levels() {
let root = SpawnAgentTool::new(1, 3, "root".into(), EventLog::new());
assert!(root.can_spawn(), "Root should be able to spawn");
assert_eq!(root.child_depth(), 2);
let child = SpawnAgentTool::new(1, 2, "child".into(), EventLog::new());
assert!(
child.can_spawn(),
"Child should be able to spawn grandchild"
);
assert_eq!(child.child_depth(), 2);
let grandchild = SpawnAgentTool::new(1, 1, "grandchild".into(), EventLog::new());
assert!(
!grandchild.can_spawn(),
"Grandchild should NOT be able to spawn"
);
}
#[test]
fn remaining_depth_calculation_formula() {
let root_current = 1_u32;
let root_max = 3_u32;
let child_will_receive = root_max.saturating_sub(root_current); assert_eq!(child_will_receive, 2, "Child should receive depth_limit=2");
let child_current = 1_u32;
let child_max = child_will_receive; let grandchild_will_receive = child_max.saturating_sub(child_current); assert_eq!(
grandchild_will_receive, 1,
"Grandchild should receive depth_limit=1"
);
let grandchild_current = 1_u32;
let grandchild_max = grandchild_will_receive; let can_spawn = grandchild_current < grandchild_max; assert!(!can_spawn, "Grandchild should not be able to spawn");
}
}