agent_sdk/
subagent.rs

1//! Subagent support for spawning child agents.
2//!
3//! This module provides the ability to spawn subagents from within an agent.
4//! Subagents are isolated agent instances that run to completion and return
5//! only their final response to the parent agent.
6//!
7//! # Overview
8//!
9//! Subagents are useful for:
10//! - Delegating complex subtasks to specialized agents
11//! - Running parallel investigations
12//! - Isolating context for specific operations
13//!
14//! # Example
15//!
16//! ```ignore
17//! use agent_sdk::subagent::{SubagentTool, SubagentConfig};
18//!
19//! let config = SubagentConfig::new("researcher")
20//!     .with_system_prompt("You are a research specialist...")
21//!     .with_max_turns(10);
22//!
23//! let tool = SubagentTool::new(config, provider, tools);
24//! registry.register(tool);
25//! ```
26//!
27//! # Behavior
28//!
29//! When a subagent runs:
30//! 1. A new isolated thread is created
31//! 2. The subagent runs until completion or max turns
32//! 3. Only the final text response is returned to the parent
33//! 4. The parent does not see the subagent's intermediate tool calls
34
35mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::AgentEvent;
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use async_trait::async_trait;
47use serde::{Deserialize, Serialize};
48use serde_json::{Value, json};
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51use tokio::sync::mpsc;
52
53/// Configuration for a subagent.
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56    /// Name of the subagent (for identification).
57    pub name: String,
58    /// System prompt for the subagent.
59    pub system_prompt: String,
60    /// Maximum number of turns before stopping.
61    pub max_turns: usize,
62    /// Optional timeout in milliseconds.
63    pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67    /// Create a new subagent configuration.
68    #[must_use]
69    pub fn new(name: impl Into<String>) -> Self {
70        Self {
71            name: name.into(),
72            system_prompt: String::new(),
73            max_turns: 10,
74            timeout_ms: None,
75        }
76    }
77
78    /// Set the system prompt.
79    #[must_use]
80    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81        self.system_prompt = prompt.into();
82        self
83    }
84
85    /// Set the maximum number of turns.
86    #[must_use]
87    pub const fn with_max_turns(mut self, max: usize) -> Self {
88        self.max_turns = max;
89        self
90    }
91
92    /// Set the timeout in milliseconds.
93    #[must_use]
94    pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95        self.timeout_ms = Some(timeout);
96        self
97    }
98}
99
100/// Log entry for a single tool call within a subagent.
101#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103    /// Tool name.
104    pub name: String,
105    /// Brief context/args (e.g., file path, command).
106    pub context: String,
107    /// Brief result summary.
108    pub result: String,
109    /// Whether the tool call succeeded.
110    pub success: bool,
111    /// Duration in milliseconds.
112    pub duration_ms: Option<u64>,
113}
114
115/// Result from a subagent execution.
116#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct SubagentResult {
118    /// Name of the subagent.
119    pub name: String,
120    /// The final text response (only visible part to parent).
121    pub final_response: String,
122    /// Total number of turns taken.
123    pub total_turns: usize,
124    /// Number of tool calls made by the subagent.
125    pub tool_count: u32,
126    /// Log of tool calls made by the subagent.
127    pub tool_logs: Vec<ToolCallLog>,
128    /// Token usage statistics.
129    pub usage: TokenUsage,
130    /// Whether the subagent completed successfully.
131    pub success: bool,
132    /// Duration in milliseconds.
133    pub duration_ms: u64,
134}
135
136/// Tool for spawning subagents.
137///
138/// This tool allows an agent to spawn a child agent that runs independently
139/// and returns only its final response.
140///
141/// # Example
142///
143/// ```ignore
144/// use agent_sdk::subagent::{SubagentTool, SubagentConfig};
145///
146/// let config = SubagentConfig::new("analyzer")
147///     .with_system_prompt("You analyze code...");
148///
149/// let tool = SubagentTool::new(config, provider.clone(), tools.clone());
150/// ```
151pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
152where
153    P: LlmProvider,
154    H: AgentHooks,
155    M: MessageStore,
156    S: StateStore,
157{
158    config: SubagentConfig,
159    provider: Arc<P>,
160    tools: Arc<ToolRegistry<()>>,
161    hooks: Arc<H>,
162    message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
163    state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
164}
165
166impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
167where
168    P: LlmProvider + 'static,
169{
170    /// Create a new subagent tool with default hooks and in-memory stores.
171    #[must_use]
172    pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
173        Self {
174            config,
175            provider,
176            tools,
177            hooks: Arc::new(DefaultHooks),
178            message_store_factory: Arc::new(InMemoryStore::new),
179            state_store_factory: Arc::new(InMemoryStore::new),
180        }
181    }
182}
183
184impl<P, H, M, S> SubagentTool<P, H, M, S>
185where
186    P: LlmProvider + Clone + 'static,
187    H: AgentHooks + Clone + 'static,
188    M: MessageStore + 'static,
189    S: StateStore + 'static,
190{
191    /// Create with custom hooks.
192    #[must_use]
193    pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
194        self,
195        hooks: Arc<H2>,
196    ) -> SubagentTool<P, H2, M, S> {
197        SubagentTool {
198            config: self.config,
199            provider: self.provider,
200            tools: self.tools,
201            hooks,
202            message_store_factory: self.message_store_factory,
203            state_store_factory: self.state_store_factory,
204        }
205    }
206
207    /// Create with custom store factories.
208    #[must_use]
209    pub fn with_stores<M2, S2, MF, SF>(
210        self,
211        message_factory: MF,
212        state_factory: SF,
213    ) -> SubagentTool<P, H, M2, S2>
214    where
215        M2: MessageStore + 'static,
216        S2: StateStore + 'static,
217        MF: Fn() -> M2 + Send + Sync + 'static,
218        SF: Fn() -> S2 + Send + Sync + 'static,
219    {
220        SubagentTool {
221            config: self.config,
222            provider: self.provider,
223            tools: self.tools,
224            hooks: self.hooks,
225            message_store_factory: Arc::new(message_factory),
226            state_store_factory: Arc::new(state_factory),
227        }
228    }
229
230    /// Get the subagent configuration.
231    #[must_use]
232    pub const fn config(&self) -> &SubagentConfig {
233        &self.config
234    }
235
236    /// Run the subagent with a task.
237    ///
238    /// If `parent_tx` is provided, the subagent will emit `SubagentProgress` events
239    /// to the parent's event channel, allowing the UI to show live progress.
240    #[allow(clippy::too_many_lines)]
241    async fn run_subagent(
242        &self,
243        task: &str,
244        subagent_id: String,
245        parent_tx: Option<mpsc::Sender<AgentEvent>>,
246    ) -> Result<SubagentResult> {
247        use crate::agent_loop::AgentLoop;
248
249        let start = Instant::now();
250        let thread_id = ThreadId::new();
251
252        // Create stores for this subagent run
253        let message_store = (self.message_store_factory)();
254        let state_store = (self.state_store_factory)();
255
256        // Create agent config
257        let agent_config = AgentConfig {
258            max_turns: self.config.max_turns,
259            system_prompt: self.config.system_prompt.clone(),
260            ..Default::default()
261        };
262
263        // Build the subagent
264        let agent = AgentLoop::new(
265            (*self.provider).clone(),
266            (*self.tools).clone(),
267            (*self.hooks).clone(),
268            message_store,
269            state_store,
270            agent_config,
271        );
272
273        // Create tool context
274        let tool_ctx = ToolContext::new(());
275
276        // Run with optional timeout
277        let mut rx = agent.run(thread_id, task.to_string(), tool_ctx);
278
279        let mut final_response = String::new();
280        let mut total_turns = 0;
281        let mut tool_count = 0u32;
282        let mut tool_logs: Vec<ToolCallLog> = Vec::new();
283        let mut pending_tools: std::collections::HashMap<String, (String, String)> =
284            std::collections::HashMap::new();
285        let mut total_usage = TokenUsage::default();
286        let mut success = true;
287
288        let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
289
290        loop {
291            let recv_result = if let Some(timeout) = timeout_duration {
292                let remaining = timeout.saturating_sub(start.elapsed());
293                if remaining.is_zero() {
294                    final_response = "Subagent timed out".to_string();
295                    success = false;
296                    break;
297                }
298                tokio::time::timeout(remaining, rx.recv()).await
299            } else {
300                Ok(rx.recv().await)
301            };
302
303            match recv_result {
304                Ok(Some(event)) => match event {
305                    AgentEvent::Text { text } => {
306                        final_response.push_str(&text);
307                    }
308                    AgentEvent::ToolCallStart {
309                        id, name, input, ..
310                    } => {
311                        // Track tool calls made by the subagent
312                        tool_count += 1;
313                        let context = extract_tool_context(&name, &input);
314                        pending_tools.insert(id, (name.clone(), context.clone()));
315
316                        // Emit progress event to parent
317                        if let Some(ref tx) = parent_tx {
318                            let _ = tx
319                                .send(AgentEvent::SubagentProgress {
320                                    subagent_id: subagent_id.clone(),
321                                    subagent_name: self.config.name.clone(),
322                                    tool_name: name,
323                                    tool_context: context,
324                                    completed: false,
325                                    success: false,
326                                    tool_count,
327                                    total_tokens: u64::from(total_usage.input_tokens)
328                                        + u64::from(total_usage.output_tokens),
329                                })
330                                .await;
331                        }
332                    }
333                    AgentEvent::ToolCallEnd { id, name, result } => {
334                        // Create log entry when tool completes
335                        let context = pending_tools
336                            .remove(&id)
337                            .map(|(_, ctx)| ctx)
338                            .unwrap_or_default();
339                        let result_summary = summarize_tool_result(&name, &result);
340                        let tool_success = result.success;
341                        tool_logs.push(ToolCallLog {
342                            name: name.clone(),
343                            context: context.clone(),
344                            result: result_summary,
345                            success: tool_success,
346                            duration_ms: result.duration_ms,
347                        });
348
349                        // Emit progress event to parent
350                        if let Some(ref tx) = parent_tx {
351                            let _ = tx
352                                .send(AgentEvent::SubagentProgress {
353                                    subagent_id: subagent_id.clone(),
354                                    subagent_name: self.config.name.clone(),
355                                    tool_name: name,
356                                    tool_context: context,
357                                    completed: true,
358                                    success: tool_success,
359                                    tool_count,
360                                    total_tokens: u64::from(total_usage.input_tokens)
361                                        + u64::from(total_usage.output_tokens),
362                                })
363                                .await;
364                        }
365                    }
366                    AgentEvent::TurnComplete { turn, usage, .. } => {
367                        total_turns = turn;
368                        total_usage.add(&usage);
369                    }
370                    AgentEvent::Done {
371                        total_turns: turns, ..
372                    } => {
373                        total_turns = turns;
374                        break;
375                    }
376                    AgentEvent::Error { message, .. } => {
377                        final_response = message;
378                        success = false;
379                        break;
380                    }
381                    _ => {}
382                },
383                Ok(None) => break,
384                Err(_) => {
385                    final_response = "Subagent timed out".to_string();
386                    success = false;
387                    break;
388                }
389            }
390        }
391
392        Ok(SubagentResult {
393            name: self.config.name.clone(),
394            final_response,
395            total_turns,
396            tool_count,
397            tool_logs,
398            usage: total_usage,
399            success,
400            duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
401        })
402    }
403}
404
405/// Extracts context information from tool input for display.
406fn extract_tool_context(name: &str, input: &Value) -> String {
407    match name {
408        "read" => input
409            .get("file_path")
410            .or_else(|| input.get("path"))
411            .and_then(Value::as_str)
412            .unwrap_or("")
413            .to_string(),
414        "write" | "edit" => input
415            .get("file_path")
416            .or_else(|| input.get("path"))
417            .and_then(Value::as_str)
418            .unwrap_or("")
419            .to_string(),
420        "bash" => {
421            let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
422            // Truncate long commands
423            if cmd.len() > 60 {
424                format!("{}...", &cmd[..57])
425            } else {
426                cmd.to_string()
427            }
428        }
429        "glob" | "grep" => input
430            .get("pattern")
431            .and_then(Value::as_str)
432            .unwrap_or("")
433            .to_string(),
434        "web_search" => input
435            .get("query")
436            .and_then(Value::as_str)
437            .unwrap_or("")
438            .to_string(),
439        _ => String::new(),
440    }
441}
442
443/// Summarizes tool result for logging.
444fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
445    if !result.success {
446        let first_line = result.output.lines().next().unwrap_or("Error");
447        return if first_line.len() > 50 {
448            format!("{}...", &first_line[..47])
449        } else {
450            first_line.to_string()
451        };
452    }
453
454    match name {
455        "read" => {
456            let line_count = result.output.lines().count();
457            format!("{line_count} lines")
458        }
459        "write" => "wrote file".to_string(),
460        "edit" => "edited".to_string(),
461        "bash" => {
462            let lines: Vec<&str> = result.output.lines().collect();
463            if lines.is_empty() {
464                "done".to_string()
465            } else if lines.len() == 1 {
466                let line = lines[0];
467                if line.len() > 50 {
468                    format!("{}...", &line[..47])
469                } else {
470                    line.to_string()
471                }
472            } else {
473                format!("{} lines", lines.len())
474            }
475        }
476        "glob" => {
477            let count = result.output.lines().count();
478            format!("{count} files")
479        }
480        "grep" => {
481            let count = result.output.lines().count();
482            format!("{count} matches")
483        }
484        _ => {
485            let line_count = result.output.lines().count();
486            if line_count == 0 {
487                "done".to_string()
488            } else {
489                format!("{line_count} lines")
490            }
491        }
492    }
493}
494
495#[async_trait]
496impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
497where
498    P: LlmProvider + Clone + 'static,
499    H: AgentHooks + Clone + 'static,
500    M: MessageStore + 'static,
501    S: StateStore + 'static,
502{
503    fn name(&self) -> &'static str {
504        // Leak the name to get 'static lifetime (acceptable for long-lived tools)
505        Box::leak(format!("subagent_{}", self.config.name).into_boxed_str())
506    }
507
508    fn description(&self) -> &'static str {
509        Box::leak(
510            format!(
511                "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
512                self.config.name
513            )
514            .into_boxed_str(),
515        )
516    }
517
518    fn input_schema(&self) -> Value {
519        json!({
520            "type": "object",
521            "properties": {
522                "task": {
523                    "type": "string",
524                    "description": "The task or question for the subagent to handle"
525                }
526            },
527            "required": ["task"]
528        })
529    }
530
531    fn tier(&self) -> ToolTier {
532        // Subagent spawning requires confirmation
533        ToolTier::Confirm
534    }
535
536    async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
537        let task = input
538            .get("task")
539            .and_then(Value::as_str)
540            .context("Missing 'task' parameter")?;
541
542        // Get event channel from context for progress updates
543        let parent_tx = ctx.event_tx();
544
545        // Generate a unique ID for this subagent execution
546        let subagent_id = format!(
547            "{}_{:x}",
548            self.config.name,
549            std::time::SystemTime::now()
550                .duration_since(std::time::UNIX_EPOCH)
551                .unwrap_or_default()
552                .as_nanos()
553        );
554
555        let result = self.run_subagent(task, subagent_id, parent_tx).await?;
556
557        Ok(ToolResult {
558            success: result.success,
559            output: result.final_response.clone(),
560            data: Some(serde_json::to_value(&result).unwrap_or_default()),
561            duration_ms: Some(result.duration_ms),
562        })
563    }
564}
565
566#[cfg(test)]
567mod tests {
568    use super::*;
569
570    #[test]
571    fn test_subagent_config_builder() {
572        let config = SubagentConfig::new("test")
573            .with_system_prompt("Test prompt")
574            .with_max_turns(5)
575            .with_timeout_ms(30000);
576
577        assert_eq!(config.name, "test");
578        assert_eq!(config.system_prompt, "Test prompt");
579        assert_eq!(config.max_turns, 5);
580        assert_eq!(config.timeout_ms, Some(30000));
581    }
582
583    #[test]
584    fn test_subagent_config_defaults() {
585        let config = SubagentConfig::new("default");
586
587        assert_eq!(config.name, "default");
588        assert!(config.system_prompt.is_empty());
589        assert_eq!(config.max_turns, 10);
590        assert_eq!(config.timeout_ms, None);
591    }
592
593    #[test]
594    fn test_subagent_result_serialization() {
595        let result = SubagentResult {
596            name: "test".to_string(),
597            final_response: "Done".to_string(),
598            total_turns: 3,
599            tool_count: 5,
600            tool_logs: vec![
601                ToolCallLog {
602                    name: "read".to_string(),
603                    context: "/tmp/test.rs".to_string(),
604                    result: "50 lines".to_string(),
605                    success: true,
606                    duration_ms: Some(10),
607                },
608                ToolCallLog {
609                    name: "grep".to_string(),
610                    context: "TODO".to_string(),
611                    result: "3 matches".to_string(),
612                    success: true,
613                    duration_ms: Some(5),
614                },
615            ],
616            usage: TokenUsage::default(),
617            success: true,
618            duration_ms: 1000,
619        };
620
621        let json = serde_json::to_string(&result).expect("serialize");
622        assert!(json.contains("test"));
623        assert!(json.contains("Done"));
624        assert!(json.contains("tool_count"));
625        assert!(json.contains("tool_logs"));
626        assert!(json.contains("/tmp/test.rs"));
627    }
628
629    #[test]
630    fn test_subagent_result_field_extraction() {
631        // Test that verifies the exact JSON structure expected by bip's tui_session.rs
632        let result = SubagentResult {
633            name: "explore".to_string(),
634            final_response: "Found 3 config files".to_string(),
635            total_turns: 2,
636            tool_count: 5,
637            tool_logs: vec![ToolCallLog {
638                name: "glob".to_string(),
639                context: "**/*.toml".to_string(),
640                result: "3 files".to_string(),
641                success: true,
642                duration_ms: Some(15),
643            }],
644            usage: TokenUsage {
645                input_tokens: 1500,
646                output_tokens: 500,
647            },
648            success: true,
649            duration_ms: 2500,
650        };
651
652        let value = serde_json::to_value(&result).expect("serialize to value");
653
654        // Test tool_count extraction (as_u64 should work for u32)
655        let tool_count = value.get("tool_count").and_then(Value::as_u64);
656        assert_eq!(tool_count, Some(5));
657
658        // Test usage extraction
659        let usage = value.get("usage").expect("usage field");
660        let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
661        let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
662        assert_eq!(input_tokens, Some(1500));
663        assert_eq!(output_tokens, Some(500));
664
665        // Test tool_logs extraction
666        let tool_logs = value.get("tool_logs").and_then(Value::as_array);
667        assert!(tool_logs.is_some());
668        let logs = tool_logs.unwrap();
669        assert_eq!(logs.len(), 1);
670
671        let first_log = &logs[0];
672        assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
673        assert_eq!(
674            first_log.get("context").and_then(Value::as_str),
675            Some("**/*.toml")
676        );
677        assert_eq!(
678            first_log.get("result").and_then(Value::as_str),
679            Some("3 files")
680        );
681        assert_eq!(
682            first_log.get("success").and_then(Value::as_bool),
683            Some(true)
684        );
685    }
686}