1use std::future::Future;
38use std::pin::Pin;
39use std::sync::Arc;
40
41use rig::completion::ToolDefinition;
42use rig::tool::{ToolDyn, ToolError};
43use rustc_hash::FxHashMap;
44use serde::{Deserialize, Serialize};
45use serde_json::{json, Value};
46use tokio_util::sync::CancellationToken;
47
48use crate::ast::AgentParams;
49use crate::event::{EventKind, EventLog};
50use crate::mcp::McpClient;
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
54pub struct SpawnAgentParams {
55 pub task_id: String,
57 pub prompt: String,
59 #[serde(default)]
61 pub context: Option<Value>,
62 #[serde(default)]
64 pub max_turns: Option<u32>,
65}
66
67#[derive(Clone)]
74pub struct SpawnAgentTool {
75 current_depth: u32,
77 max_depth: u32,
79 parent_task_id: Arc<str>,
81 event_log: EventLog,
83 mcp_clients: FxHashMap<String, Arc<McpClient>>,
85 mcp_names: Vec<String>,
87 cancel_token: CancellationToken,
89 parent_model: Option<String>,
91 parent_provider: Option<String>,
93 parent_temperature: Option<f32>,
95 parent_tools: Vec<String>,
97}
98
99impl SpawnAgentTool {
100 pub fn new(
108 current_depth: u32,
109 max_depth: u32,
110 parent_task_id: Arc<str>,
111 event_log: EventLog,
112 ) -> Self {
113 Self {
114 current_depth,
115 max_depth,
116 parent_task_id,
117 event_log,
118 mcp_clients: FxHashMap::default(),
119 mcp_names: Vec::new(),
120 cancel_token: CancellationToken::new(),
121 parent_model: None,
122 parent_provider: None,
123 parent_temperature: None,
124 parent_tools: Vec::new(),
125 }
126 }
127
128 pub fn with_mcp(
139 current_depth: u32,
140 max_depth: u32,
141 parent_task_id: Arc<str>,
142 event_log: EventLog,
143 mcp_clients: FxHashMap<String, Arc<McpClient>>,
144 mcp_names: Vec<String>,
145 cancel_token: CancellationToken,
146 ) -> Self {
147 Self {
148 current_depth,
149 max_depth,
150 parent_task_id,
151 event_log,
152 mcp_clients,
153 mcp_names,
154 cancel_token,
155 parent_model: None,
156 parent_provider: None,
157 parent_temperature: None,
158 parent_tools: Vec::new(),
159 }
160 }
161
162 pub fn with_parent_config(
167 mut self,
168 model: Option<String>,
169 provider: Option<String>,
170 temperature: Option<f32>,
171 tools: Vec<String>,
172 ) -> Self {
173 self.parent_model = model;
174 self.parent_provider = provider;
175 self.parent_temperature = temperature;
176 self.parent_tools = tools;
177 self
178 }
179
180 pub fn name(&self) -> &str {
182 "spawn_agent"
183 }
184
185 pub fn definition(&self) -> ToolDefinition {
191 ToolDefinition {
192 name: "spawn_agent".to_string(),
193 description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
194 runs independently with max 10 turns and returns its result."
195 .to_string(),
196 parameters: json!({
197 "type": "object",
198 "properties": {
199 "task_id": {
200 "type": "string",
201 "description": "Unique identifier for the child task (e.g., 'subtask-1')"
202 },
203 "prompt": {
204 "type": "string",
205 "description": "Goal/prompt describing what the child agent should accomplish"
206 }
207 },
208 "required": ["task_id", "prompt"],
209 "additionalProperties": false
210 }),
211 }
212 }
213
214 pub async fn call(&self, args: String) -> Result<String, SpawnAgentError> {
225 let params: SpawnAgentParams =
227 serde_json::from_str(&args).map_err(|e| SpawnAgentError::InvalidArgs(e.to_string()))?;
228
229 if self.current_depth >= self.max_depth {
231 return Err(SpawnAgentError::DepthLimitReached {
232 current: self.current_depth,
233 max: self.max_depth,
234 });
235 }
236
237 let child_depth = self.current_depth + 1;
239 self.event_log.emit(EventKind::AgentSpawned {
240 parent_task_id: self.parent_task_id.clone(),
241 child_task_id: Arc::from(params.task_id.as_str()),
242 depth: child_depth,
243 });
244
245 if self.mcp_clients.is_empty() {
247 return Ok(json!({
248 "status": "spawned",
249 "child_task_id": params.task_id,
250 "depth": child_depth,
251 "note": "Child agent execution requires MCP client context"
252 })
253 .to_string());
254 }
255
256 let remaining_depth = self.max_depth.saturating_sub(self.current_depth);
263 let child_params = AgentParams {
264 prompt: params.prompt,
265 system: params.context.as_ref().map(|ctx| {
266 format!(
267 "Context from parent agent:\n{}",
268 serde_json::to_string_pretty(ctx).unwrap_or_default()
269 )
270 }),
271 mcp: self.mcp_names.clone(),
272 max_turns: params.max_turns.or(Some(10)),
273 depth_limit: Some(remaining_depth),
274 model: self.parent_model.clone(),
277 provider: self.parent_provider.clone(),
278 temperature: self.parent_temperature,
279 tools: self.parent_tools.clone(),
280 ..Default::default()
281 };
282
283 let mut child_loop = super::RigAgentLoop::new(
285 params.task_id.clone(),
286 child_params,
287 self.event_log.clone(),
288 self.mcp_clients.clone(),
289 )
290 .map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?;
291
292 let result = tokio::select! {
296 res = child_loop.run_auto() => {
297 res.map_err(|e| SpawnAgentError::ExecutionFailed(e.to_string()))?
298 }
299 _ = self.cancel_token.cancelled() => {
300 return Err(SpawnAgentError::ExecutionFailed(
301 "parent agent was cancelled".to_string(),
302 ));
303 }
304 };
305
306 Ok(json!({
308 "status": "completed",
309 "child_task_id": params.task_id,
310 "depth": child_depth,
311 "result": result.final_output,
312 "turns": result.turns,
313 "total_tokens": result.total_tokens
314 })
315 .to_string())
316 }
317
318 pub fn can_spawn(&self) -> bool {
320 self.current_depth < self.max_depth
321 }
322
323 pub fn child_depth(&self) -> u32 {
325 self.current_depth + 1
326 }
327}
328
329#[derive(Debug, thiserror::Error)]
331pub enum SpawnAgentError {
332 #[error("spawn_agent: depth limit reached (current: {current}, max: {max})")]
333 DepthLimitReached { current: u32, max: u32 },
334
335 #[error("spawn_agent: invalid arguments - {0}")]
336 InvalidArgs(String),
337
338 #[error("spawn_agent: execution failed - {0}")]
339 ExecutionFailed(String),
340}
341
342impl std::fmt::Debug for SpawnAgentTool {
343 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
344 f.debug_struct("SpawnAgentTool")
345 .field("current_depth", &self.current_depth)
346 .field("max_depth", &self.max_depth)
347 .field("parent_task_id", &self.parent_task_id)
348 .finish()
349 }
350}
351
352type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
358
359impl ToolDyn for SpawnAgentTool {
360 fn name(&self) -> String {
361 "spawn_agent".to_string()
362 }
363
364 fn definition(&self, _prompt: String) -> BoxFuture<'_, ToolDefinition> {
365 let def = ToolDefinition {
368 name: "spawn_agent".to_string(),
369 description: "Spawn a sub-agent to handle a delegated subtask. The child agent \
370 runs independently with max 10 turns and returns its result."
371 .to_string(),
372 parameters: json!({
373 "type": "object",
374 "properties": {
375 "task_id": {
376 "type": "string",
377 "description": "Unique identifier for the child task (e.g., 'subtask-1')"
378 },
379 "prompt": {
380 "type": "string",
381 "description": "Goal/prompt describing what the child agent should accomplish"
382 }
383 },
384 "required": ["task_id", "prompt"],
385 "additionalProperties": false
386 }),
387 };
388 Box::pin(async move { def })
389 }
390
391 fn call(&self, args: String) -> BoxFuture<'_, Result<String, ToolError>> {
392 Box::pin(async move {
393 self.call(args).await.map_err(|e| {
394 ToolError::ToolCallError(Box::new(std::io::Error::other(e.to_string())))
395 })
396 })
397 }
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn spawn_agent_tool_name() {
406 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
407 assert_eq!(tool.name(), "spawn_agent");
408 }
409
410 #[test]
411 fn spawn_agent_tool_can_spawn() {
412 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
413 assert!(tool.can_spawn());
414
415 let at_limit = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
416 assert!(!at_limit.can_spawn());
417 }
418
419 #[test]
420 fn spawn_agent_tool_child_depth() {
421 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
422 assert_eq!(tool.child_depth(), 2);
423 }
424
425 #[test]
426 fn spawn_agent_params_deserializes() {
427 let json = json!({
428 "task_id": "child-1",
429 "prompt": "Do something",
430 "context": {"key": "value"},
431 "max_turns": 5
432 });
433
434 let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
435 assert_eq!(params.task_id, "child-1");
436 assert_eq!(params.prompt, "Do something");
437 assert!(params.context.is_some());
438 assert_eq!(params.max_turns, Some(5));
439 }
440
441 #[test]
442 fn spawn_agent_params_minimal() {
443 let json = json!({
444 "task_id": "child-1",
445 "prompt": "Do something"
446 });
447
448 let params: SpawnAgentParams = serde_json::from_value(json).unwrap();
449 assert_eq!(params.task_id, "child-1");
450 assert!(params.context.is_none());
451 assert!(params.max_turns.is_none());
452 }
453
454 #[tokio::test]
455 async fn spawn_agent_at_max_depth_fails() {
456 let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
457
458 let args = json!({
459 "task_id": "child-1",
460 "prompt": "Do something"
461 })
462 .to_string();
463
464 let result = tool.call(args).await;
465 assert!(result.is_err());
466
467 let err = result.unwrap_err();
468 assert!(err.to_string().contains("depth limit"));
469 }
470
471 #[tokio::test]
472 async fn spawn_agent_below_max_depth_succeeds() {
473 let tool = SpawnAgentTool::new(2, 3, "parent".into(), EventLog::new());
474
475 let args = json!({
476 "task_id": "child-1",
477 "prompt": "Do something"
478 })
479 .to_string();
480
481 let result = tool.call(args).await;
482 assert!(result.is_ok());
483
484 let response: Value = serde_json::from_str(&result.unwrap()).unwrap();
485 assert_eq!(response["status"], "spawned");
486 assert_eq!(response["child_task_id"], "child-1");
487 assert_eq!(response["depth"], 3);
488 }
489
490 #[tokio::test]
491 async fn spawn_agent_emits_event() {
492 let event_log = EventLog::new();
493 let tool = SpawnAgentTool::new(1, 3, "parent".into(), event_log.clone());
494
495 let args = json!({
496 "task_id": "child-1",
497 "prompt": "Do something"
498 })
499 .to_string();
500
501 let _ = tool.call(args).await;
502
503 let events = event_log.events();
505 let spawned_events: Vec<_> = events
506 .iter()
507 .filter(|e| matches!(e.kind, EventKind::AgentSpawned { .. }))
508 .collect();
509
510 assert_eq!(spawned_events.len(), 1);
511
512 if let EventKind::AgentSpawned {
513 parent_task_id,
514 child_task_id,
515 depth,
516 } = &spawned_events[0].kind
517 {
518 assert_eq!(&**parent_task_id, "parent");
519 assert_eq!(&**child_task_id, "child-1");
520 assert_eq!(*depth, 2);
521 }
522 }
523
524 #[test]
525 fn tool_definition_has_required_params() {
526 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
527 let def = tool.definition();
528
529 let required = def
530 .parameters
531 .get("required")
532 .and_then(|v| v.as_array())
533 .expect("required should be an array");
534
535 assert!(required.iter().any(|v| v == "task_id"));
538 assert!(required.iter().any(|v| v == "prompt"));
539 assert_eq!(required.len(), 2);
540
541 let additional = def
543 .parameters
544 .get("additionalProperties")
545 .expect("additionalProperties should exist");
546 assert_eq!(additional, false);
547 }
548
549 #[test]
554 fn spawn_agent_implements_tool_dyn() {
555 use rig::tool::ToolDyn;
556
557 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
558
559 let name: String = ToolDyn::name(&tool);
561 assert_eq!(name, "spawn_agent");
562 }
563
564 #[tokio::test]
565 async fn spawn_agent_tool_dyn_definition_returns_correct_schema() {
566 use rig::tool::ToolDyn;
567
568 let tool = SpawnAgentTool::new(1, 3, "parent".into(), EventLog::new());
569
570 let def = ToolDyn::definition(&tool, "test".to_string()).await;
572
573 assert_eq!(def.name, "spawn_agent");
574 assert!(def.description.contains("sub-agent"));
575 assert!(def.parameters.get("required").is_some());
576 }
577
578 #[tokio::test]
579 async fn spawn_agent_tool_dyn_call_enforces_depth_limit() {
580 use rig::tool::ToolDyn;
581
582 let tool = SpawnAgentTool::new(3, 3, "parent".into(), EventLog::new());
583
584 let args = json!({
585 "task_id": "child-1",
586 "prompt": "Do something"
587 })
588 .to_string();
589
590 let result = ToolDyn::call(&tool, args).await;
592 assert!(result.is_err());
593 assert!(result.unwrap_err().to_string().contains("depth limit"));
594 }
595
596 #[test]
597 fn spawn_agent_with_mcp_creates_correctly() {
598 let event_log = EventLog::new();
599 let mcp_clients = FxHashMap::default();
600 let mcp_names = vec!["novanet".to_string()];
601
602 let tool = SpawnAgentTool::with_mcp(
603 1,
604 3,
605 "parent".into(),
606 event_log,
607 mcp_clients,
608 mcp_names.clone(),
609 CancellationToken::new(),
610 );
611
612 assert_eq!(tool.name(), "spawn_agent");
613 assert!(tool.can_spawn());
614 assert_eq!(tool.child_depth(), 2);
615 }
616
617 #[test]
622 fn depth_calculation_allows_three_levels() {
623 let root = SpawnAgentTool::new(1, 3, "root".into(), EventLog::new());
630 assert!(root.can_spawn(), "Root should be able to spawn");
631 assert_eq!(root.child_depth(), 2);
632
633 let child = SpawnAgentTool::new(1, 2, "child".into(), EventLog::new());
636 assert!(
637 child.can_spawn(),
638 "Child should be able to spawn grandchild"
639 );
640 assert_eq!(child.child_depth(), 2);
641
642 let grandchild = SpawnAgentTool::new(1, 1, "grandchild".into(), EventLog::new());
645 assert!(
646 !grandchild.can_spawn(),
647 "Grandchild should NOT be able to spawn"
648 );
649 }
650
651 #[test]
652 fn remaining_depth_calculation_formula() {
653 let root_current = 1_u32;
658 let root_max = 3_u32;
659 let child_will_receive = root_max.saturating_sub(root_current); assert_eq!(child_will_receive, 2, "Child should receive depth_limit=2");
661
662 let child_current = 1_u32;
664 let child_max = child_will_receive; let grandchild_will_receive = child_max.saturating_sub(child_current); assert_eq!(
667 grandchild_will_receive, 1,
668 "Grandchild should receive depth_limit=1"
669 );
670
671 let grandchild_current = 1_u32;
673 let grandchild_max = grandchild_will_receive; let can_spawn = grandchild_current < grandchild_max; assert!(!can_spawn, "Grandchild should not be able to spawn");
676 }
677}