Skip to main content

batuta/agent/tool/
spawn.rs

1//! Sub-agent spawning tool.
2//!
3//! Allows an agent to delegate work to a child agent running
4//! its own perceive-reason-act loop. The child shares the parent's
5//! LLM driver and memory substrate but gets its own loop guard.
6//!
7//! Requires `Capability::Spawn { max_depth }` — recursion is
8//! bounded by depth tracking (Jidoka: stop on runaway spawning).
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use tokio::sync::Mutex;
14
15use crate::agent::capability::Capability;
16use crate::agent::driver::ToolDefinition;
17use crate::agent::manifest::AgentManifest;
18use crate::agent::pool::{AgentPool, SpawnConfig};
19
20use super::{Tool, ToolResult};
21
22/// Tool that spawns a sub-agent, waits for completion, and
23/// returns the child's response as the tool result.
24pub struct SpawnTool {
25    pool: Arc<Mutex<AgentPool>>,
26    parent_manifest: AgentManifest,
27    current_depth: u32,
28    max_depth: u32,
29}
30
31impl SpawnTool {
32    /// Create a spawn tool with depth tracking.
33    pub fn new(
34        pool: Arc<Mutex<AgentPool>>,
35        parent_manifest: AgentManifest,
36        current_depth: u32,
37        max_depth: u32,
38    ) -> Self {
39        Self { pool, parent_manifest, current_depth, max_depth }
40    }
41}
42
43#[async_trait]
44impl Tool for SpawnTool {
45    fn name(&self) -> &'static str {
46        "spawn_agent"
47    }
48
49    fn definition(&self) -> ToolDefinition {
50        ToolDefinition {
51            name: "spawn_agent".into(),
52            description: "Spawn a sub-agent to handle a delegated task. \
53                The child agent runs its own perceive-reason-act loop \
54                and returns its final response."
55                .into(),
56            input_schema: serde_json::json!({
57                "type": "object",
58                "properties": {
59                    "query": {
60                        "type": "string",
61                        "description": "The task to delegate to the sub-agent"
62                    },
63                    "name": {
64                        "type": "string",
65                        "description": "Optional name for the sub-agent (defaults to parent name + '-sub')"
66                    }
67                },
68                "required": ["query"]
69            }),
70        }
71    }
72
73    #[cfg_attr(
74        feature = "agents-contracts",
75        provable_contracts_macros::contract("agent-loop-v1", equation = "spawn_depth_bound")
76    )]
77    async fn execute(&self, input: serde_json::Value) -> ToolResult {
78        // Jidoka: depth guard
79        if self.current_depth >= self.max_depth {
80            return ToolResult::error(format!(
81                "spawn depth limit reached ({}/{})",
82                self.current_depth, self.max_depth,
83            ));
84        }
85
86        let query = match input.get("query").and_then(|v| v.as_str()) {
87            Some(q) => q.to_string(),
88            None => {
89                return ToolResult::error("missing required field: query");
90            }
91        };
92
93        let name = match input.get("name").and_then(|v| v.as_str()) {
94            Some(n) => n.to_string(),
95            None => format!("{}-sub", self.parent_manifest.name),
96        };
97
98        // Build child manifest (inherits parent config, new name)
99        let mut child_manifest = self.parent_manifest.clone();
100        child_manifest.name = name;
101        // Reduce child iterations to prevent runaway
102        child_manifest.resources.max_iterations = child_manifest.resources.max_iterations.min(10);
103
104        let config = SpawnConfig { manifest: child_manifest, query };
105
106        // Spawn and await
107        let mut pool = self.pool.lock().await;
108        let id = match pool.spawn(config) {
109            Ok(id) => id,
110            Err(e) => {
111                return ToolResult::error(format!("spawn failed: {e}"));
112            }
113        };
114
115        match pool.join_next().await {
116            Some((completed_id, Ok(result))) if completed_id == id => {
117                ToolResult::success(result.text)
118            }
119            Some((_, Ok(result))) => {
120                // Different agent finished first — still return it
121                ToolResult::success(result.text)
122            }
123            Some((_, Err(e))) => ToolResult::error(format!("sub-agent error: {e}")),
124            None => ToolResult::error("sub-agent produced no result"),
125        }
126    }
127
128    fn required_capability(&self) -> Capability {
129        Capability::Spawn { max_depth: self.max_depth }
130    }
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136    use crate::agent::driver::mock::MockDriver;
137
138    fn make_pool() -> Arc<Mutex<AgentPool>> {
139        let driver = MockDriver::single_response("child response");
140        Arc::new(Mutex::new(AgentPool::new(Arc::new(driver), 4)))
141    }
142
143    #[test]
144    fn test_spawn_tool_definition() {
145        let pool = make_pool();
146        let manifest = AgentManifest::default();
147        let tool = SpawnTool::new(pool, manifest, 0, 3);
148        let def = tool.definition();
149        assert_eq!(def.name, "spawn_agent");
150        assert!(def.description.contains("sub-agent"));
151    }
152
153    #[test]
154    fn test_spawn_tool_capability() {
155        let pool = make_pool();
156        let manifest = AgentManifest::default();
157        let tool = SpawnTool::new(pool, manifest, 0, 3);
158        assert_eq!(tool.required_capability(), Capability::Spawn { max_depth: 3 },);
159    }
160
161    #[tokio::test]
162    async fn test_spawn_tool_depth_limit() {
163        let pool = make_pool();
164        let manifest = AgentManifest::default();
165        // current_depth == max_depth → blocked
166        let tool = SpawnTool::new(pool, manifest, 3, 3);
167        let result = tool.execute(serde_json::json!({ "query": "hello" })).await;
168        assert!(result.is_error);
169        assert!(result.content.contains("depth limit"));
170    }
171
172    #[tokio::test]
173    async fn test_spawn_tool_missing_query() {
174        let pool = make_pool();
175        let manifest = AgentManifest::default();
176        let tool = SpawnTool::new(pool, manifest, 0, 3);
177        let result = tool.execute(serde_json::json!({})).await;
178        assert!(result.is_error);
179        assert!(result.content.contains("missing"));
180    }
181
182    #[tokio::test]
183    async fn test_spawn_tool_executes_child() {
184        let pool = make_pool();
185        let manifest = AgentManifest::default();
186        let tool = SpawnTool::new(pool, manifest, 0, 3);
187        let result = tool
188            .execute(serde_json::json!({
189                "query": "do something",
190                "name": "worker"
191            }))
192            .await;
193        assert!(!result.is_error, "error: {}", result.content);
194        assert_eq!(result.content, "child response");
195    }
196
197    #[tokio::test]
198    async fn test_spawn_tool_default_name() {
199        let pool = make_pool();
200        let mut manifest = AgentManifest::default();
201        manifest.name = "parent".into();
202        let tool = SpawnTool::new(pool, manifest, 0, 3);
203        let result = tool.execute(serde_json::json!({ "query": "hello" })).await;
204        assert!(!result.is_error, "error: {}", result.content);
205    }
206}