batuta/agent/tool/
spawn.rs1use std::sync::Arc;
11
12use async_trait::async_trait;
13use tokio::sync::Mutex;
14
15use crate::agent::capability::Capability;
16use crate::agent::driver::ToolDefinition;
17use crate::agent::manifest::AgentManifest;
18use crate::agent::pool::{AgentPool, SpawnConfig};
19
20use super::{Tool, ToolResult};
21
22pub struct SpawnTool {
25 pool: Arc<Mutex<AgentPool>>,
26 parent_manifest: AgentManifest,
27 current_depth: u32,
28 max_depth: u32,
29}
30
31impl SpawnTool {
32 pub fn new(
34 pool: Arc<Mutex<AgentPool>>,
35 parent_manifest: AgentManifest,
36 current_depth: u32,
37 max_depth: u32,
38 ) -> Self {
39 Self { pool, parent_manifest, current_depth, max_depth }
40 }
41}
42
43#[async_trait]
44impl Tool for SpawnTool {
45 fn name(&self) -> &'static str {
46 "spawn_agent"
47 }
48
49 fn definition(&self) -> ToolDefinition {
50 ToolDefinition {
51 name: "spawn_agent".into(),
52 description: "Spawn a sub-agent to handle a delegated task. \
53 The child agent runs its own perceive-reason-act loop \
54 and returns its final response."
55 .into(),
56 input_schema: serde_json::json!({
57 "type": "object",
58 "properties": {
59 "query": {
60 "type": "string",
61 "description": "The task to delegate to the sub-agent"
62 },
63 "name": {
64 "type": "string",
65 "description": "Optional name for the sub-agent (defaults to parent name + '-sub')"
66 }
67 },
68 "required": ["query"]
69 }),
70 }
71 }
72
73 #[cfg_attr(
74 feature = "agents-contracts",
75 provable_contracts_macros::contract("agent-loop-v1", equation = "spawn_depth_bound")
76 )]
77 async fn execute(&self, input: serde_json::Value) -> ToolResult {
78 if self.current_depth >= self.max_depth {
80 return ToolResult::error(format!(
81 "spawn depth limit reached ({}/{})",
82 self.current_depth, self.max_depth,
83 ));
84 }
85
86 let query = match input.get("query").and_then(|v| v.as_str()) {
87 Some(q) => q.to_string(),
88 None => {
89 return ToolResult::error("missing required field: query");
90 }
91 };
92
93 let name = match input.get("name").and_then(|v| v.as_str()) {
94 Some(n) => n.to_string(),
95 None => format!("{}-sub", self.parent_manifest.name),
96 };
97
98 let mut child_manifest = self.parent_manifest.clone();
100 child_manifest.name = name;
101 child_manifest.resources.max_iterations = child_manifest.resources.max_iterations.min(10);
103
104 let config = SpawnConfig { manifest: child_manifest, query };
105
106 let mut pool = self.pool.lock().await;
108 let id = match pool.spawn(config) {
109 Ok(id) => id,
110 Err(e) => {
111 return ToolResult::error(format!("spawn failed: {e}"));
112 }
113 };
114
115 match pool.join_next().await {
116 Some((completed_id, Ok(result))) if completed_id == id => {
117 ToolResult::success(result.text)
118 }
119 Some((_, Ok(result))) => {
120 ToolResult::success(result.text)
122 }
123 Some((_, Err(e))) => ToolResult::error(format!("sub-agent error: {e}")),
124 None => ToolResult::error("sub-agent produced no result"),
125 }
126 }
127
128 fn required_capability(&self) -> Capability {
129 Capability::Spawn { max_depth: self.max_depth }
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136 use crate::agent::driver::mock::MockDriver;
137
138 fn make_pool() -> Arc<Mutex<AgentPool>> {
139 let driver = MockDriver::single_response("child response");
140 Arc::new(Mutex::new(AgentPool::new(Arc::new(driver), 4)))
141 }
142
143 #[test]
144 fn test_spawn_tool_definition() {
145 let pool = make_pool();
146 let manifest = AgentManifest::default();
147 let tool = SpawnTool::new(pool, manifest, 0, 3);
148 let def = tool.definition();
149 assert_eq!(def.name, "spawn_agent");
150 assert!(def.description.contains("sub-agent"));
151 }
152
153 #[test]
154 fn test_spawn_tool_capability() {
155 let pool = make_pool();
156 let manifest = AgentManifest::default();
157 let tool = SpawnTool::new(pool, manifest, 0, 3);
158 assert_eq!(tool.required_capability(), Capability::Spawn { max_depth: 3 },);
159 }
160
161 #[tokio::test]
162 async fn test_spawn_tool_depth_limit() {
163 let pool = make_pool();
164 let manifest = AgentManifest::default();
165 let tool = SpawnTool::new(pool, manifest, 3, 3);
167 let result = tool.execute(serde_json::json!({ "query": "hello" })).await;
168 assert!(result.is_error);
169 assert!(result.content.contains("depth limit"));
170 }
171
172 #[tokio::test]
173 async fn test_spawn_tool_missing_query() {
174 let pool = make_pool();
175 let manifest = AgentManifest::default();
176 let tool = SpawnTool::new(pool, manifest, 0, 3);
177 let result = tool.execute(serde_json::json!({})).await;
178 assert!(result.is_error);
179 assert!(result.content.contains("missing"));
180 }
181
182 #[tokio::test]
183 async fn test_spawn_tool_executes_child() {
184 let pool = make_pool();
185 let manifest = AgentManifest::default();
186 let tool = SpawnTool::new(pool, manifest, 0, 3);
187 let result = tool
188 .execute(serde_json::json!({
189 "query": "do something",
190 "name": "worker"
191 }))
192 .await;
193 assert!(!result.is_error, "error: {}", result.content);
194 assert_eq!(result.content, "child response");
195 }
196
197 #[tokio::test]
198 async fn test_spawn_tool_default_name() {
199 let pool = make_pool();
200 let mut manifest = AgentManifest::default();
201 manifest.name = "parent".into();
202 let tool = SpawnTool::new(pool, manifest, 0, 3);
203 let result = tool.execute(serde_json::json!({ "query": "hello" })).await;
204 assert!(!result.is_error, "error: {}", result.content);
205 }
206}