Skip to main content

nika_engine/runtime/
spawn.rs

1//! Spawn Agent Tool (MVP 8 Phase 2)
2//!
3//! Internal tool for recursive agent spawning. Allows an agent to delegate
4//! subtasks to child agents, enabling hierarchical task decomposition.
5//!
6//! ## Depth Limit
7//!
8//! To prevent infinite recursion, each spawn tracks depth and enforces limits:
9//! - Default limit: 3 levels
10//! - Maximum limit: 10 levels
11//! - Spawning at max depth returns an error
12//!
13//! ## Events
14//!
15//! Spawning emits an `AgentSpawned` event with:
16//! - `parent_task_id`: The spawning agent's task ID
17//! - `child_task_id`: The new agent's task ID
18//! - `depth`: Current recursion depth
19//!
20//! ## rig::ToolDyn Integration
21//!
22//! SpawnAgentTool implements `rig::tool::ToolDyn` for seamless integration
23//! with `RigAgentLoop`. When an agent has `depth_limit > current_depth`,
24//! the spawn_agent tool is automatically added to its tool list.
25//!
26//! ## Example
27//!
28//! ```json
29//! {
30//!   "task_id": "subtask-1",
31//!   "prompt": "Generate the header section",
32//!   "context": {"entity": "qr-code"},
33//!   "max_turns": 5
34//! }
35//! ```
36
37use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40
41use rig::completion::ToolDefinition;
42use rig::tool::{ToolDyn, ToolError};
43use rustc_hash::FxHashMap;
44use serde::{Deserialize, Serialize};
45use serde_json::{json, Value};
46use tokio_util::sync::CancellationToken;
47
48use crate::ast::AgentParams;
49use crate::event::{EventKind, EventLog};
50use crate::mcp::McpClient;
51
52/// Parameters for spawning a child agent
53#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SpawnAgentParams {
55    /// Unique identifier for the child task
56    pub task_id: String,
57    /// Prompt/goal for the child agent
58    pub prompt: String,
59    /// Optional context data to pass to child
60    #[serde(default)]
61    pub context: Option<Value>,
62    /// Optional max turns override for child
63    #[serde(default)]
64    pub max_turns: Option<u32>,
65}
66
67/// Internal tool for spawning sub-agents
68///
69/// This tool is automatically added to agents that have depth_limit > current_depth.
70/// It allows recursive task decomposition with safety limits.
71///
72/// Implements `rig::tool::ToolDyn` for integration with RigAgentLoop.
73#[derive(Clone)]
74pub struct SpawnAgentTool {
75    /// Current recursion depth (1 = root agent)
76    current_depth: u32,
77    /// Maximum allowed depth
78    max_depth: u32,
79    /// Parent task ID for event linking
80    parent_task_id: Arc<str>,
81    /// Event log for emitting AgentSpawned events
82    event_log: EventLog,
83    /// MCP clients for child agent tool access
84    mcp_clients: FxHashMap<String, Arc<McpClient>>,
85    /// MCP server names for child agents (from parent AgentParams.mcp)
86    mcp_names: Vec<String>,
87    /// Cancellation token from parent — child agent races against this
88    cancel_token: CancellationToken,
89    /// Parent's model name — propagated to child agents
90    parent_model: Option<String>,
91    /// Parent's provider name — propagated to child agents
92    parent_provider: Option<String>,
93    /// Parent's temperature setting — propagated to child agents
94    parent_temperature: Option<f32>,
95    /// Parent's tools list — propagated to child agents
96    parent_tools: Vec<String>,
97}
98
99impl SpawnAgentTool {
100    /// Create a new SpawnAgentTool (minimal - for testing without MCP)
101    ///
102    /// # Arguments
103    /// * `current_depth` - Current recursion depth (starts at 1 for root)
104    /// * `max_depth` - Maximum allowed depth (default 3)
105    /// * `parent_task_id` - ID of the parent task
106    /// * `event_log` - Shared event log for observability
107    pub fn new(
108        current_depth: u32,
109        max_depth: u32,
110        parent_task_id: Arc<str>,
111        event_log: EventLog,
112    ) -> Self {
113        Self {
114            current_depth,
115            max_depth,
116            parent_task_id,
117            event_log,
118            mcp_clients: FxHashMap::default(),
119            mcp_names: Vec::new(),
120            cancel_token: CancellationToken::new(),
121            parent_model: None,
122            parent_provider: None,
123            parent_temperature: None,
124            parent_tools: Vec::new(),
125        }
126    }
127
128    /// Create a new SpawnAgentTool with MCP clients (for production use)
129    ///
130    /// # Arguments
131    /// * `current_depth` - Current recursion depth (starts at 1 for root)
132    /// * `max_depth` - Maximum allowed depth (default 3)
133    /// * `parent_task_id` - ID of the parent task
134    /// * `event_log` - Shared event log for observability
135    /// * `mcp_clients` - Connected MCP clients for child agent tools
136    /// * `mcp_names` - MCP server names to pass to child agents
137    /// * `cancel_token` - Parent's cancellation token for cooperative shutdown
138    pub fn with_mcp(
139        current_depth: u32,
140        max_depth: u32,
141        parent_task_id: Arc<str>,
142        event_log: EventLog,
143        mcp_clients: FxHashMap<String, Arc<McpClient>>,
144        mcp_names: Vec<String>,
145        cancel_token: CancellationToken,
146    ) -> Self {
147        Self {
148            current_depth,
149            max_depth,
150            parent_task_id,
151            event_log,
152            mcp_clients,
153            mcp_names,
154            cancel_token,
155            parent_model: None,
156            parent_provider: None,
157            parent_temperature: None,
158            parent_tools: Vec::new(),
159        }
160    }
161
162    /// Set parent configuration to propagate to child agents.
163    ///
164    /// Without this, child agents would lose model, provider, temperature,
165    /// and tools configuration from the parent — running with defaults.
166    pub fn with_parent_config(
167        mut self,
168        model: Option<String>,
169        provider: Option<String>,
170        temperature: Option<f32>,
171        tools: Vec<String>,
172    ) -> Self {
173        self.parent_model = model;
174        self.parent_provider = provider;
175        self.parent_temperature = temperature;
176        self.parent_tools = tools;
177        self
178    }
179
180    /// Get the tool name
181    pub fn name(&self) -> &str {
182        "spawn_agent"
183    }
184
185    /// Get the JSON Schema definition for this tool
186    ///
187    /// Note: OpenAI's strict mode requires ALL properties in `required` and
188    /// `additionalProperties: false`. We keep the schema simple with only
189    /// required fields - context and max_turns use sensible defaults.
190    pub fn definition(&self) -> ToolDefinition {
191        ToolDefinition {
192            name: "spawn_agent".to_string(),
193            description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
194                         runs independently with max 10 turns and returns its result."
195                .to_string(),
196            parameters: json!({
197                "type": "object",
198                "properties": {
199                    "task_id": {
200                        "type": "string",
201                        "description": "Unique identifier for the child task (e.g., 'subtask-1')"
202                    },
203                    "prompt": {
204                        "type": "string",
205                        "description": "Goal/prompt describing what the child agent should accomplish"
206                    }
207                },
208                "required": ["task_id", "prompt"],
209                "additionalProperties": false
210            }),
211        }
212    }
213
214    /// Execute the spawn_agent tool
215    ///
216    /// Creates and runs a child `RigAgentLoop` with inherited MCP clients.
217    /// The child agent runs to completion and its result is returned.
218    ///
219    /// # Errors
220    /// Returns an error if:
221    /// - Current depth >= max depth (depth limit reached)
222    /// - Invalid arguments
223    /// - Child agent execution fails
224    pub async fn call(&self, args: String) -> Result<String, SpawnAgentError> {
225        // Parse arguments
226        let params: SpawnAgentParams =
227            serde_json::from_str(&args).map_err(|e| SpawnAgentError::InvalidArgs(e.to_string()))?;
228
229        // Check depth limit
230        if self.current_depth >= self.max_depth {
231            return Err(SpawnAgentError::DepthLimitReached {
232                current: self.current_depth,
233                max: self.max_depth,
234            });
235        }
236
237        // Emit AgentSpawned event
238        let child_depth = self.current_depth + 1;
239        self.event_log.emit(EventKind::AgentSpawned {
240            parent_task_id: self.parent_task_id.clone(),
241            child_task_id: Arc::from(params.task_id.as_str()),
242            depth: child_depth,
243        });
244
245        // If no MCP clients, return placeholder (for compatibility with tests)
246        if self.mcp_clients.is_empty() {
247            return Ok(json!({
248                "status": "spawned",
249                "child_task_id": params.task_id,
250                "depth": child_depth,
251                "note": "Child agent execution requires MCP client context"
252            })
253            .to_string());
254        }
255
256        // Build child AgentParams
257        // Calculate remaining depth from PARENT's current_depth, not child_depth.
258        // With depth_limit=3:
259        //   - Root (depth=1, max=3): remaining = 3-1 = 2 → child gets max=2
260        //   - Child (depth=1, max=2): remaining = 2-1 = 1 → grandchild gets max=1
261        //   - Grandchild (depth=1, max=1): 1 >= 1, cannot spawn ✓
262        let remaining_depth = self.max_depth.saturating_sub(self.current_depth);
263        let child_params = AgentParams {
264            prompt: params.prompt,
265            system: params.context.as_ref().map(|ctx| {
266                format!(
267                    "Context from parent agent:\n{}",
268                    serde_json::to_string_pretty(ctx).unwrap_or_default()
269                )
270            }),
271            mcp: self.mcp_names.clone(),
272            max_turns: params.max_turns.or(Some(10)),
273            depth_limit: Some(remaining_depth),
274            // Propagate parent config so child agents use the same
275            // model, provider, temperature, and tools as the parent
276            model: self.parent_model.clone(),
277            provider: self.parent_provider.clone(),
278            temperature: self.parent_temperature,
279            tools: self.parent_tools.clone(),
280            ..Default::default()
281        };
282
283        // Create child RigAgentLoop
284        let mut child_loop = super::RigAgentLoop::new(
285            params.task_id.clone(),
286            child_params,
287            self.event_log.clone(),
288            self.mcp_clients.clone(),
289        )
290        .map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?;
291
292        // Execute child agent with auto-detected provider (production mode)
293        // Uses ANTHROPIC_API_KEY or OPENAI_API_KEY from environment
294        // Race against parent's cancellation token for cooperative shutdown
295        let result = tokio::select! {
296            res = child_loop.run_auto() => {
297                res.map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?
298            }
299            _ = self.cancel_token.cancelled() => {
300                return Err(SpawnAgentError::ExecutionFailed(
301                    "parent agent was cancelled".to_string(),
302                ));
303            }
304        };
305
306        // Return child's result
307        Ok(json!({
308            "status": "completed",
309            "child_task_id": params.task_id,
310            "depth": child_depth,
311            "result": result.final_output,
312            "turns": result.turns,
313            "total_tokens": result.total_tokens
314        })
315        .to_string())
316    }
317
318    /// Check if spawning is allowed at current depth
319    pub fn can_spawn(&self) -> bool {
320        self.current_depth < self.max_depth
321    }
322
323    /// Get the depth that child agents would have
324    pub fn child_depth(&self) -> u32 {
325        self.current_depth + 1
326    }
327}
328
329/// Errors that can occur when spawning agents
330#[derive(Debug, thiserror::Error)]
331pub enum SpawnAgentError {
332    #[error("spawn_agent: depth limit reached (current: {current}, max: {max})")]
333    DepthLimitReached { current: u32, max: u32 },
334
335    #[error("spawn_agent: invalid arguments - {0}")]
336    InvalidArgs(String),
337
338    #[error("spawn_agent: execution failed - {0}")]
339    ExecutionFailed(String),
340}
341
342impl std::fmt::Debug for SpawnAgentTool {
343    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344        f.debug_struct("SpawnAgentTool")
345            .field("current_depth", &self.current_depth)
346            .field("max_depth", &self.max_depth)
347            .field("parent_task_id", &self.parent_task_id)
348            .finish()
349    }
350}
351
352// ═══════════════════════════════════════════════════════════════════════════════
353// rig::ToolDyn implementation (for integration with RigAgentLoop)
354// ═══════════════════════════════════════════════════════════════════════════════
355
356/// Type alias for boxed future (required by ToolDyn)
357type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
358
359impl ToolDyn for SpawnAgentTool {
360    fn name(&self) -> String {
361        "spawn_agent".to_string()
362    }
363
364    fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
365        // Note: OpenAI's strict mode requires ALL properties in `required` and
366        // `additionalProperties: false`. Keep schema simple with only required fields.
367        let def = ToolDefinition {
368            name: "spawn_agent".to_string(),
369            description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
370                         runs independently with max 10 turns and returns its result."
371                .to_string(),
372            parameters: json!({
373                "type": "object",
374                "properties": {
375                    "task_id": {
376                        "type": "string",
377                        "description": "Unique identifier for the child task (e.g., 'subtask-1')"
378                    },
379                    "prompt": {
380                        "type": "string",
381                        "description": "Goal/prompt describing what the child agent should accomplish"
382                    }
383                },
384                "required": ["task_id", "prompt"],
385                "additionalProperties": false
386            }),
387        };
388        Box::pin(async move { def })
389    }
390
391    fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
392        Box::pin(async move {
393            self.call(args).await.map_err(|e| {
394                ToolError::ToolCallError(Box::new(std::io::Error::other(e.to_string())))
395            })
396        })
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn spawn_agent_tool_name() {
406        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
407        assert_eq!(tool.name(), "spawn_agent");
408    }
409
410    #[test]
411    fn spawn_agent_tool_can_spawn() {
412        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
413        assert!(tool.can_spawn());
414
415        let at_limit = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
416        assert!(!at_limit.can_spawn());
417    }
418
419    #[test]
420    fn spawn_agent_tool_child_depth() {
421        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
422        assert_eq!(tool.child_depth(), 2);
423    }
424
425    #[test]
426    fn spawn_agent_params_deserializes() {
427        let json = json!({
428            "task_id": "child-1",
429            "prompt": "Do something",
430            "context": {"key": "value"},
431            "max_turns": 5
432        });
433
434        let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
435        assert_eq!(params.task_id, "child-1");
436        assert_eq!(params.prompt, "Do something");
437        assert!(params.context.is_some());
438        assert_eq!(params.max_turns, Some(5));
439    }
440
441    #[test]
442    fn spawn_agent_params_minimal() {
443        let json = json!({
444            "task_id": "child-1",
445            "prompt": "Do something"
446        });
447
448        let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
449        assert_eq!(params.task_id, "child-1");
450        assert!(params.context.is_none());
451        assert!(params.max_turns.is_none());
452    }
453
454    #[tokio::test]
455    async fn spawn_agent_at_max_depth_fails() {
456        let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
457
458        let args = json!({
459            "task_id": "child-1",
460            "prompt": "Do something"
461        })
462        .to_string();
463
464        let result = tool.call(args).await;
465        assert!(result.is_err());
466
467        let err = result.unwrap_err();
468        assert!(err.to_string().contains("depth limit"));
469    }
470
471    #[tokio::test]
472    async fn spawn_agent_below_max_depth_succeeds() {
473        let tool = SpawnAgentTool::new(2, 3, "parent".into(), EventLog::new());
474
475        let args = json!({
476            "task_id": "child-1",
477            "prompt": "Do something"
478        })
479        .to_string();
480
481        let result = tool.call(args).await;
482        assert!(result.is_ok());
483
484        let response: Value = serde_json::from_str(&result.unwrap()).unwrap();
485        assert_eq!(response["status"], "spawned");
486        assert_eq!(response["child_task_id"], "child-1");
487        assert_eq!(response["depth"], 3);
488    }
489
490    #[tokio::test]
491    async fn spawn_agent_emits_event() {
492        let event_log = EventLog::new();
493        let tool = SpawnAgentTool::new(1, 3, "parent".into(), event_log.clone());
494
495        let args = json!({
496            "task_id": "child-1",
497            "prompt": "Do something"
498        })
499        .to_string();
500
501        let _ = tool.call(args).await;
502
503        // Check that AgentSpawned event was emitted
504        let events = event_log.events();
505        let spawned_events: Vec<_> = events
506            .iter()
507            .filter(|e| matches!(e.kind, EventKind::AgentSpawned { .. }))
508            .collect();
509
510        assert_eq!(spawned_events.len(), 1);
511
512        if let EventKind::AgentSpawned {
513            parent_task_id,
514            child_task_id,
515            depth,
516        } = &spawned_events[0].kind
517        {
518            assert_eq!(&**parent_task_id, "parent");
519            assert_eq!(&**child_task_id, "child-1");
520            assert_eq!(*depth, 2);
521        }
522    }
523
524    #[test]
525    fn tool_definition_has_required_params() {
526        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
527        let def = tool.definition();
528
529        let required = def
530            .parameters
531            .get("required")
532            .and_then(|v| v.as_array())
533            .expect("required should be an array");
534
535        // OpenAI strict mode: all properties in required + additionalProperties: false
536        // We keep schema simple with only task_id and prompt (context/max_turns use defaults)
537        assert!(required.iter().any(|v| v == "task_id"));
538        assert!(required.iter().any(|v| v == "prompt"));
539        assert_eq!(required.len(), 2);
540
541        // Check additionalProperties is false (required for OpenAI strict mode)
542        let additional = def
543            .parameters
544            .get("additionalProperties")
545            .expect("additionalProperties should exist");
546        assert_eq!(additional, false);
547    }
548
549    // =========================================================================
550    // rig::ToolDyn implementation tests
551    // =========================================================================
552
553    #[test]
554    fn spawn_agent_implements_tool_dyn() {
555        use rig::tool::ToolDyn;
556
557        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
558
559        // Test ToolDyn::name()
560        let name: String = ToolDyn::name(&tool);
561        assert_eq!(name, "spawn_agent");
562    }
563
564    #[tokio::test]
565    async fn spawn_agent_tool_dyn_definition_returns_correct_schema() {
566        use rig::tool::ToolDyn;
567
568        let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
569
570        // Test ToolDyn::definition()
571        let def = ToolDyn::definition(&tool, "test".to_string()).await;
572
573        assert_eq!(def.name, "spawn_agent");
574        assert!(def.description.contains("sub-agent"));
575        assert!(def.parameters.get("required").is_some());
576    }
577
578    #[tokio::test]
579    async fn spawn_agent_tool_dyn_call_enforces_depth_limit() {
580        use rig::tool::ToolDyn;
581
582        let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
583
584        let args = json!({
585            "task_id": "child-1",
586            "prompt": "Do something"
587        })
588        .to_string();
589
590        // Test ToolDyn::call() - should fail at max depth
591        let result = ToolDyn::call(&tool, args).await;
592        assert!(result.is_err());
593        assert!(result.unwrap_err().to_string().contains("depth limit"));
594    }
595
596    #[test]
597    fn spawn_agent_with_mcp_creates_correctly() {
598        let event_log = EventLog::new();
599        let mcp_clients = FxHashMap::default();
600        let mcp_names = vec!["novanet".to_string()];
601
602        let tool = SpawnAgentTool::with_mcp(
603            1,
604            3,
605            "parent".into(),
606            event_log,
607            mcp_clients,
608            mcp_names.clone(),
609            CancellationToken::new(),
610        );
611
612        assert_eq!(tool.name(), "spawn_agent");
613        assert!(tool.can_spawn());
614        assert_eq!(tool.child_depth(), 2);
615    }
616
617    // =========================================================================
618    // Depth calculation regression tests
619    // =========================================================================
620
621    #[test]
622    fn depth_calculation_allows_three_levels() {
623        // With depth_limit=3, we should allow:
624        // - Root (depth=1) → can spawn child
625        // - Child (depth=2) → can spawn grandchild
626        // - Grandchild (depth=3) → cannot spawn
627
628        // Root agent: current=1, max=3
629        let root = SpawnAgentTool::new(1, 3, "root".into(), EventLog::new());
630        assert!(root.can_spawn(), "Root should be able to spawn");
631        assert_eq!(root.child_depth(), 2);
632
633        // Simulate what child receives: remaining = max - current = 3 - 1 = 2
634        // So child sees: current=1, max=2
635        let child = SpawnAgentTool::new(1, 2, "child".into(), EventLog::new());
636        assert!(
637            child.can_spawn(),
638            "Child should be able to spawn grandchild"
639        );
640        assert_eq!(child.child_depth(), 2);
641
642        // Simulate what grandchild receives: remaining = max - current = 2 - 1 = 1
643        // So grandchild sees: current=1, max=1
644        let grandchild = SpawnAgentTool::new(1, 1, "grandchild".into(), EventLog::new());
645        assert!(
646            !grandchild.can_spawn(),
647            "Grandchild should NOT be able to spawn"
648        );
649    }
650
651    #[test]
652    fn remaining_depth_calculation_formula() {
653        // Verify the formula: remaining_depth = max_depth - current_depth
654        // NOT: max_depth - child_depth (old buggy formula)
655
656        // Root: current=1, max=3
657        let root_current = 1_u32;
658        let root_max = 3_u32;
659        let child_will_receive = root_max.saturating_sub(root_current); // 3-1=2 ✓
660        assert_eq!(child_will_receive, 2, "Child should receive depth_limit=2");
661
662        // Child (simulated): current=1, max=2
663        let child_current = 1_u32;
664        let child_max = child_will_receive; // 2
665        let grandchild_will_receive = child_max.saturating_sub(child_current); // 2-1=1 ✓
666        assert_eq!(
667            grandchild_will_receive, 1,
668            "Grandchild should receive depth_limit=1"
669        );
670
671        // Grandchild (simulated): current=1, max=1
672        let grandchild_current = 1_u32;
673        let grandchild_max = grandchild_will_receive; // 1
674        let can_spawn = grandchild_current < grandchild_max; // 1 < 1 = false ✓
675        assert!(!can_spawn, "Grandchild should not be able to spawn");
676    }
677}