Skip to main content

agent_sdk/
subagent.rs

1//! Subagent support for spawning child agents.
2//!
3//! This module provides the ability to spawn subagents from within an agent.
4//! Subagents are isolated agent instances that run to completion and return
5//! only their final response to the parent agent.
6//!
7//! # Overview
8//!
9//! Subagents are useful for:
10//! - Delegating complex subtasks to specialized agents
11//! - Running parallel investigations
12//! - Isolating context for specific operations
13//!
14//! # Example
15//!
16//! ```ignore
17//! use agent_sdk::subagent::{SubagentTool, SubagentConfig};
18//!
19//! let config = SubagentConfig::new("researcher")
20//!     .with_system_prompt("You are a research specialist...")
21//!     .with_max_turns(10);
22//!
23//! let tool = SubagentTool::new(config, provider, tools);
24//! registry.register(tool);
25//! ```
26//!
27//! # Behavior
28//!
29//! When a subagent runs:
30//! 1. A new isolated thread is created
31//! 2. The subagent runs until completion or max turns
32//! 3. Only the final text response is returned to the parent
33//! 4. The parent does not see the subagent's intermediate tool calls
34
35mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::AgentEvent;
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51
52/// Configuration for a subagent.
53#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct SubagentConfig {
55    /// Name of the subagent (for identification).
56    pub name: String,
57    /// System prompt for the subagent.
58    pub system_prompt: String,
59    /// Maximum number of turns before stopping.
60    pub max_turns: usize,
61    /// Optional timeout in milliseconds.
62    pub timeout_ms: Option<u64>,
63}
64
65impl SubagentConfig {
66    /// Create a new subagent configuration.
67    #[must_use]
68    pub fn new(name: impl Into<String>) -> Self {
69        Self {
70            name: name.into(),
71            system_prompt: String::new(),
72            max_turns: 10,
73            timeout_ms: None,
74        }
75    }
76
77    /// Set the system prompt.
78    #[must_use]
79    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
80        self.system_prompt = prompt.into();
81        self
82    }
83
84    /// Set the maximum number of turns.
85    #[must_use]
86    pub const fn with_max_turns(mut self, max: usize) -> Self {
87        self.max_turns = max;
88        self
89    }
90
91    /// Set the timeout in milliseconds.
92    #[must_use]
93    pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
94        self.timeout_ms = Some(timeout);
95        self
96    }
97}
98
99/// Log entry for a single tool call within a subagent.
100#[derive(Clone, Debug, Serialize, Deserialize)]
101pub struct ToolCallLog {
102    /// Tool name.
103    pub name: String,
104    /// Tool display name.
105    pub display_name: String,
106    /// Brief context/args (e.g., file path, command).
107    pub context: String,
108    /// Brief result summary.
109    pub result: String,
110    /// Whether the tool call succeeded.
111    pub success: bool,
112    /// Duration in milliseconds.
113    pub duration_ms: Option<u64>,
114}
115
116/// Result from a subagent execution.
117#[derive(Clone, Debug, Serialize, Deserialize)]
118pub struct SubagentResult {
119    /// Name of the subagent.
120    pub name: String,
121    /// The final text response (only visible part to parent).
122    pub final_response: String,
123    /// Total number of turns taken.
124    pub total_turns: usize,
125    /// Number of tool calls made by the subagent.
126    pub tool_count: u32,
127    /// Log of tool calls made by the subagent.
128    pub tool_logs: Vec<ToolCallLog>,
129    /// Token usage statistics.
130    pub usage: TokenUsage,
131    /// Whether the subagent completed successfully.
132    pub success: bool,
133    /// Duration in milliseconds.
134    pub duration_ms: u64,
135}
136
137/// Tool for spawning subagents.
138///
139/// This tool allows an agent to spawn a child agent that runs independently
140/// and returns only its final response.
141///
142/// # Example
143///
144/// ```ignore
145/// use agent_sdk::subagent::{SubagentTool, SubagentConfig};
146///
147/// let config = SubagentConfig::new("analyzer")
148///     .with_system_prompt("You analyze code...");
149///
150/// let tool = SubagentTool::new(config, provider.clone(), tools.clone());
151/// ```
152pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
153where
154    P: LlmProvider,
155    H: AgentHooks,
156    M: MessageStore,
157    S: StateStore,
158{
159    config: SubagentConfig,
160    provider: Arc<P>,
161    tools: Arc<ToolRegistry<()>>,
162    hooks: Arc<H>,
163    message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
164    state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
165}
166
167impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
168where
169    P: LlmProvider + 'static,
170{
171    /// Create a new subagent tool with default hooks and in-memory stores.
172    #[must_use]
173    pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
174        Self {
175            config,
176            provider,
177            tools,
178            hooks: Arc::new(DefaultHooks),
179            message_store_factory: Arc::new(InMemoryStore::new),
180            state_store_factory: Arc::new(InMemoryStore::new),
181        }
182    }
183}
184
185impl<P, H, M, S> SubagentTool<P, H, M, S>
186where
187    P: LlmProvider + Clone + 'static,
188    H: AgentHooks + Clone + 'static,
189    M: MessageStore + 'static,
190    S: StateStore + 'static,
191{
192    /// Create with custom hooks.
193    #[must_use]
194    pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
195        self,
196        hooks: Arc<H2>,
197    ) -> SubagentTool<P, H2, M, S> {
198        SubagentTool {
199            config: self.config,
200            provider: self.provider,
201            tools: self.tools,
202            hooks,
203            message_store_factory: self.message_store_factory,
204            state_store_factory: self.state_store_factory,
205        }
206    }
207
208    /// Create with custom store factories.
209    #[must_use]
210    pub fn with_stores<M2, S2, MF, SF>(
211        self,
212        message_factory: MF,
213        state_factory: SF,
214    ) -> SubagentTool<P, H, M2, S2>
215    where
216        M2: MessageStore + 'static,
217        S2: StateStore + 'static,
218        MF: Fn() -> M2 + Send + Sync + 'static,
219        SF: Fn() -> S2 + Send + Sync + 'static,
220    {
221        SubagentTool {
222            config: self.config,
223            provider: self.provider,
224            tools: self.tools,
225            hooks: self.hooks,
226            message_store_factory: Arc::new(message_factory),
227            state_store_factory: Arc::new(state_factory),
228        }
229    }
230
231    /// Get the subagent configuration.
232    #[must_use]
233    pub const fn config(&self) -> &SubagentConfig {
234        &self.config
235    }
236
237    /// Run the subagent with a task.
238    ///
239    /// If `parent_tx` is provided, the subagent will emit `SubagentProgress` events
240    /// to the parent's event channel, allowing the UI to show live progress.
241    #[allow(clippy::too_many_lines)]
242    async fn run_subagent(
243        &self,
244        task: &str,
245        subagent_id: String,
246        parent_tx: Option<mpsc::Sender<AgentEvent>>,
247    ) -> Result<SubagentResult> {
248        use crate::agent_loop::AgentLoop;
249
250        let start = Instant::now();
251        let thread_id = ThreadId::new();
252
253        // Create stores for this subagent run
254        let message_store = (self.message_store_factory)();
255        let state_store = (self.state_store_factory)();
256
257        // Create agent config
258        let agent_config = AgentConfig {
259            max_turns: self.config.max_turns,
260            system_prompt: self.config.system_prompt.clone(),
261            ..Default::default()
262        };
263
264        // Build the subagent
265        let agent = AgentLoop::new(
266            (*self.provider).clone(),
267            (*self.tools).clone(),
268            (*self.hooks).clone(),
269            message_store,
270            state_store,
271            agent_config,
272        );
273
274        // Create tool context
275        let tool_ctx = ToolContext::new(());
276
277        // Run with optional timeout
278        let (mut rx, _final_state) =
279            agent.run(thread_id, AgentInput::Text(task.to_string()), tool_ctx);
280
281        let mut final_response = String::new();
282        let mut total_turns = 0;
283        let mut tool_count = 0u32;
284        let mut tool_logs: Vec<ToolCallLog> = Vec::new();
285        let mut pending_tools: std::collections::HashMap<String, (String, String)> =
286            std::collections::HashMap::new();
287        let mut total_usage = TokenUsage::default();
288        let mut success = true;
289
290        let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
291
292        loop {
293            let recv_result = if let Some(timeout) = timeout_duration {
294                let remaining = timeout.saturating_sub(start.elapsed());
295                if remaining.is_zero() {
296                    final_response = "Subagent timed out".to_string();
297                    success = false;
298                    break;
299                }
300                tokio::time::timeout(remaining, rx.recv()).await
301            } else {
302                Ok(rx.recv().await)
303            };
304
305            match recv_result {
306                Ok(Some(event)) => match event {
307                    AgentEvent::Text {
308                        message_id: _,
309                        text,
310                    } => {
311                        final_response.push_str(&text);
312                    }
313                    AgentEvent::ToolCallStart {
314                        id, name, input, ..
315                    } => {
316                        // Track tool calls made by the subagent
317                        tool_count += 1;
318                        let context = extract_tool_context(&name, &input);
319                        pending_tools.insert(id, (name.clone(), context.clone()));
320
321                        // Emit progress event to parent
322                        if let Some(ref tx) = parent_tx {
323                            let _ = tx
324                                .send(AgentEvent::SubagentProgress {
325                                    subagent_id: subagent_id.clone(),
326                                    subagent_name: self.config.name.clone(),
327                                    tool_name: name,
328                                    tool_context: context,
329                                    completed: false,
330                                    success: false,
331                                    tool_count,
332                                    total_tokens: u64::from(total_usage.input_tokens)
333                                        + u64::from(total_usage.output_tokens),
334                                })
335                                .await;
336                        }
337                    }
338                    AgentEvent::ToolCallEnd {
339                        id,
340                        name,
341                        display_name,
342                        result,
343                    } => {
344                        // Create log entry when tool completes
345                        let context = pending_tools
346                            .remove(&id)
347                            .map(|(_, ctx)| ctx)
348                            .unwrap_or_default();
349                        let result_summary = summarize_tool_result(&name, &result);
350                        let tool_success = result.success;
351                        tool_logs.push(ToolCallLog {
352                            name: name.clone(),
353                            display_name: display_name.clone(),
354                            context: context.clone(),
355                            result: result_summary,
356                            success: tool_success,
357                            duration_ms: result.duration_ms,
358                        });
359
360                        // Emit progress event to parent
361                        if let Some(ref tx) = parent_tx {
362                            let _ = tx
363                                .send(AgentEvent::SubagentProgress {
364                                    subagent_id: subagent_id.clone(),
365                                    subagent_name: self.config.name.clone(),
366                                    tool_name: name,
367                                    tool_context: context,
368                                    completed: true,
369                                    success: tool_success,
370                                    tool_count,
371                                    total_tokens: u64::from(total_usage.input_tokens)
372                                        + u64::from(total_usage.output_tokens),
373                                })
374                                .await;
375                        }
376                    }
377                    AgentEvent::TurnComplete { turn, usage, .. } => {
378                        total_turns = turn;
379                        total_usage.add(&usage);
380                    }
381                    AgentEvent::Done {
382                        total_turns: turns, ..
383                    } => {
384                        total_turns = turns;
385                        break;
386                    }
387                    AgentEvent::Error { message, .. } => {
388                        final_response = message;
389                        success = false;
390                        break;
391                    }
392                    _ => {}
393                },
394                Ok(None) => break,
395                Err(_) => {
396                    final_response = "Subagent timed out".to_string();
397                    success = false;
398                    break;
399                }
400            }
401        }
402
403        Ok(SubagentResult {
404            name: self.config.name.clone(),
405            final_response,
406            total_turns,
407            tool_count,
408            tool_logs,
409            usage: total_usage,
410            success,
411            duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
412        })
413    }
414}
415
416/// Extracts context information from tool input for display.
417fn extract_tool_context(name: &str, input: &Value) -> String {
418    match name {
419        "read" => input
420            .get("file_path")
421            .or_else(|| input.get("path"))
422            .and_then(Value::as_str)
423            .unwrap_or("")
424            .to_string(),
425        "write" | "edit" => input
426            .get("file_path")
427            .or_else(|| input.get("path"))
428            .and_then(Value::as_str)
429            .unwrap_or("")
430            .to_string(),
431        "bash" => {
432            let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
433            // Truncate long commands
434            if cmd.len() > 60 {
435                format!("{}...", &cmd[..57])
436            } else {
437                cmd.to_string()
438            }
439        }
440        "glob" | "grep" => input
441            .get("pattern")
442            .and_then(Value::as_str)
443            .unwrap_or("")
444            .to_string(),
445        "web_search" => input
446            .get("query")
447            .and_then(Value::as_str)
448            .unwrap_or("")
449            .to_string(),
450        _ => String::new(),
451    }
452}
453
454/// Summarizes tool result for logging.
455fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
456    if !result.success {
457        let first_line = result.output.lines().next().unwrap_or("Error");
458        return if first_line.len() > 50 {
459            format!("{}...", &first_line[..47])
460        } else {
461            first_line.to_string()
462        };
463    }
464
465    match name {
466        "read" => {
467            let line_count = result.output.lines().count();
468            format!("{line_count} lines")
469        }
470        "write" => "wrote file".to_string(),
471        "edit" => "edited".to_string(),
472        "bash" => {
473            let lines: Vec<&str> = result.output.lines().collect();
474            if lines.is_empty() {
475                "done".to_string()
476            } else if lines.len() == 1 {
477                let line = lines[0];
478                if line.len() > 50 {
479                    format!("{}...", &line[..47])
480                } else {
481                    line.to_string()
482                }
483            } else {
484                format!("{} lines", lines.len())
485            }
486        }
487        "glob" => {
488            let count = result.output.lines().count();
489            format!("{count} files")
490        }
491        "grep" => {
492            let count = result.output.lines().count();
493            format!("{count} matches")
494        }
495        _ => {
496            let line_count = result.output.lines().count();
497            if line_count == 0 {
498                "done".to_string()
499            } else {
500                format!("{line_count} lines")
501            }
502        }
503    }
504}
505
506impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
507where
508    P: LlmProvider + Clone + 'static,
509    H: AgentHooks + Clone + 'static,
510    M: MessageStore + 'static,
511    S: StateStore + 'static,
512{
513    type Name = DynamicToolName;
514
515    fn name(&self) -> DynamicToolName {
516        DynamicToolName::new(format!("subagent_{}", self.config.name))
517    }
518
519    fn display_name(&self) -> &'static str {
520        // Leak the name to get 'static lifetime (acceptable for long-lived tools)
521        Box::leak(format!("Subagent: {}", self.config.name).into_boxed_str())
522    }
523
524    fn description(&self) -> &'static str {
525        Box::leak(
526            format!(
527                "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
528                self.config.name
529            )
530            .into_boxed_str(),
531        )
532    }
533
534    fn input_schema(&self) -> Value {
535        json!({
536            "type": "object",
537            "properties": {
538                "task": {
539                    "type": "string",
540                    "description": "The task or question for the subagent to handle"
541                }
542            },
543            "required": ["task"]
544        })
545    }
546
547    fn tier(&self) -> ToolTier {
548        // Subagent spawning requires confirmation
549        ToolTier::Confirm
550    }
551
552    async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
553        let task = input
554            .get("task")
555            .and_then(Value::as_str)
556            .context("Missing 'task' parameter")?;
557
558        // Get event channel from context for progress updates
559        let parent_tx = ctx.event_tx();
560
561        // Generate a unique ID for this subagent execution
562        let subagent_id = format!(
563            "{}_{:x}",
564            self.config.name,
565            std::time::SystemTime::now()
566                .duration_since(std::time::UNIX_EPOCH)
567                .unwrap_or_default()
568                .as_nanos()
569        );
570
571        let result = self.run_subagent(task, subagent_id, parent_tx).await?;
572
573        Ok(ToolResult {
574            success: result.success,
575            output: result.final_response.clone(),
576            data: Some(serde_json::to_value(&result).unwrap_or_default()),
577            duration_ms: Some(result.duration_ms),
578        })
579    }
580}
581
582#[cfg(test)]
583mod tests {
584    use super::*;
585
586    #[test]
587    fn test_subagent_config_builder() {
588        let config = SubagentConfig::new("test")
589            .with_system_prompt("Test prompt")
590            .with_max_turns(5)
591            .with_timeout_ms(30000);
592
593        assert_eq!(config.name, "test");
594        assert_eq!(config.system_prompt, "Test prompt");
595        assert_eq!(config.max_turns, 5);
596        assert_eq!(config.timeout_ms, Some(30000));
597    }
598
599    #[test]
600    fn test_subagent_config_defaults() {
601        let config = SubagentConfig::new("default");
602
603        assert_eq!(config.name, "default");
604        assert!(config.system_prompt.is_empty());
605        assert_eq!(config.max_turns, 10);
606        assert_eq!(config.timeout_ms, None);
607    }
608
609    #[test]
610    fn test_subagent_result_serialization() {
611        let result = SubagentResult {
612            name: "test".to_string(),
613            final_response: "Done".to_string(),
614            total_turns: 3,
615            tool_count: 5,
616            tool_logs: vec![
617                ToolCallLog {
618                    name: "read".to_string(),
619                    display_name: "Read file".to_string(),
620                    context: "/tmp/test.rs".to_string(),
621                    result: "50 lines".to_string(),
622                    success: true,
623                    duration_ms: Some(10),
624                },
625                ToolCallLog {
626                    name: "grep".to_string(),
627                    display_name: "Grep TODO".to_string(),
628                    context: "TODO".to_string(),
629                    result: "3 matches".to_string(),
630                    success: true,
631                    duration_ms: Some(5),
632                },
633            ],
634            usage: TokenUsage::default(),
635            success: true,
636            duration_ms: 1000,
637        };
638
639        let json = serde_json::to_string(&result).expect("serialize");
640        assert!(json.contains("test"));
641        assert!(json.contains("Done"));
642        assert!(json.contains("tool_count"));
643        assert!(json.contains("tool_logs"));
644        assert!(json.contains("/tmp/test.rs"));
645    }
646
647    #[test]
648    fn test_subagent_result_field_extraction() {
649        // Test that verifies the exact JSON structure expected by bip's tui_session.rs
650        let result = SubagentResult {
651            name: "explore".to_string(),
652            final_response: "Found 3 config files".to_string(),
653            total_turns: 2,
654            tool_count: 5,
655            tool_logs: vec![ToolCallLog {
656                name: "glob".to_string(),
657                display_name: "Glob config files".to_string(),
658                context: "**/*.toml".to_string(),
659                result: "3 files".to_string(),
660                success: true,
661                duration_ms: Some(15),
662            }],
663            usage: TokenUsage {
664                input_tokens: 1500,
665                output_tokens: 500,
666            },
667            success: true,
668            duration_ms: 2500,
669        };
670
671        let value = serde_json::to_value(&result).expect("serialize to value");
672
673        // Test tool_count extraction (as_u64 should work for u32)
674        let tool_count = value.get("tool_count").and_then(Value::as_u64);
675        assert_eq!(tool_count, Some(5));
676
677        // Test usage extraction
678        let usage = value.get("usage").expect("usage field");
679        let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
680        let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
681        assert_eq!(input_tokens, Some(1500));
682        assert_eq!(output_tokens, Some(500));
683
684        // Test tool_logs extraction
685        let tool_logs = value.get("tool_logs").and_then(Value::as_array);
686        assert!(tool_logs.is_some());
687        let logs = tool_logs.unwrap();
688        assert_eq!(logs.len(), 1);
689
690        let first_log = &logs[0];
691        assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
692        assert_eq!(
693            first_log.get("context").and_then(Value::as_str),
694            Some("**/*.toml")
695        );
696        assert_eq!(
697            first_log.get("result").and_then(Value::as_str),
698            Some("3 files")
699        );
700        assert_eq!(
701            first_log.get("success").and_then(Value::as_bool),
702            Some(true)
703        );
704    }
705}