Skip to main content

agent_sdk/
subagent.rs

1//! Subagent support for spawning child agents.
2//!
3//! This module provides the ability to spawn subagents from within an agent.
4//! Subagents are isolated agent instances that run to completion and return
5//! only their final response to the parent agent.
6//!
7//! # Overview
8//!
9//! Subagents are useful for:
10//! - Delegating complex subtasks to specialized agents
11//! - Running parallel investigations
12//! - Isolating context for specific operations
13//!
14//! # Example
15//!
16//! ```ignore
17//! use agent_sdk::subagent::{SubagentTool, SubagentConfig};
18//!
19//! let config = SubagentConfig::new("researcher")
20//!     .with_system_prompt("You are a research specialist...")
21//!     .with_max_turns(10);
22//!
23//! let tool = SubagentTool::new(config, provider, tools);
24//! registry.register(tool);
25//! ```
26//!
27//! # Behavior
28//!
29//! When a subagent runs:
30//! 1. A new isolated thread is created
31//! 2. The subagent runs until completion or max turns
32//! 3. Only the final text response is returned to the parent
33//! 4. The parent does not see the subagent's intermediate tool calls
34
35mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53/// Configuration for a subagent.
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56    /// Name of the subagent (for identification).
57    pub name: String,
58    /// System prompt for the subagent.
59    pub system_prompt: String,
60    /// Maximum number of turns before stopping.
61    pub max_turns: Option<usize>,
62    /// Optional timeout in milliseconds.
63    pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67    /// Create a new subagent configuration.
68    #[must_use]
69    pub fn new(name: impl Into<String>) -> Self {
70        Self {
71            name: name.into(),
72            system_prompt: String::new(),
73            max_turns: None,
74            timeout_ms: None,
75        }
76    }
77
78    /// Set the system prompt.
79    #[must_use]
80    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81        self.system_prompt = prompt.into();
82        self
83    }
84
85    /// Set the maximum number of turns.
86    #[must_use]
87    pub const fn with_max_turns(mut self, max: usize) -> Self {
88        self.max_turns = Some(max);
89        self
90    }
91
92    /// Set the timeout in milliseconds.
93    #[must_use]
94    pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95        self.timeout_ms = Some(timeout);
96        self
97    }
98}
99
100/// Log entry for a single tool call within a subagent.
101#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103    /// Tool name.
104    pub name: String,
105    /// Tool display name.
106    pub display_name: String,
107    /// Brief context/args (e.g., file path, command).
108    pub context: String,
109    /// Brief result summary.
110    pub result: String,
111    /// Whether the tool call succeeded.
112    pub success: bool,
113    /// Duration in milliseconds.
114    pub duration_ms: Option<u64>,
115}
116
117/// Result from a subagent execution.
118#[derive(Clone, Debug, Serialize, Deserialize)]
119pub struct SubagentResult {
120    /// Name of the subagent.
121    pub name: String,
122    /// The final text response (only visible part to parent).
123    pub final_response: String,
124    /// Total number of turns taken.
125    pub total_turns: usize,
126    /// Number of tool calls made by the subagent.
127    pub tool_count: u32,
128    /// Log of tool calls made by the subagent.
129    pub tool_logs: Vec<ToolCallLog>,
130    /// Token usage statistics.
131    pub usage: TokenUsage,
132    /// Whether the subagent completed successfully.
133    pub success: bool,
134    /// Duration in milliseconds.
135    pub duration_ms: u64,
136}
137
138/// Tool for spawning subagents.
139///
140/// This tool allows an agent to spawn a child agent that runs independently
141/// and returns only its final response.
142///
143/// # Example
144///
145/// ```ignore
146/// use agent_sdk::subagent::{SubagentTool, SubagentConfig};
147///
148/// let config = SubagentConfig::new("analyzer")
149///     .with_system_prompt("You analyze code...");
150///
151/// let tool = SubagentTool::new(config, provider.clone(), tools.clone());
152/// ```
153pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
154where
155    P: LlmProvider,
156    H: AgentHooks,
157    M: MessageStore,
158    S: StateStore,
159{
160    config: SubagentConfig,
161    provider: Arc<P>,
162    tools: Arc<ToolRegistry<()>>,
163    hooks: Arc<H>,
164    message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
165    state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
166}
167
168impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
169where
170    P: LlmProvider + 'static,
171{
172    /// Create a new subagent tool with default hooks and in-memory stores.
173    #[must_use]
174    pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
175        Self {
176            config,
177            provider,
178            tools,
179            hooks: Arc::new(DefaultHooks),
180            message_store_factory: Arc::new(InMemoryStore::new),
181            state_store_factory: Arc::new(InMemoryStore::new),
182        }
183    }
184}
185
186impl<P, H, M, S> SubagentTool<P, H, M, S>
187where
188    P: LlmProvider + Clone + 'static,
189    H: AgentHooks + Clone + 'static,
190    M: MessageStore + 'static,
191    S: StateStore + 'static,
192{
193    /// Create with custom hooks.
194    #[must_use]
195    pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
196        self,
197        hooks: Arc<H2>,
198    ) -> SubagentTool<P, H2, M, S> {
199        SubagentTool {
200            config: self.config,
201            provider: self.provider,
202            tools: self.tools,
203            hooks,
204            message_store_factory: self.message_store_factory,
205            state_store_factory: self.state_store_factory,
206        }
207    }
208
209    /// Create with custom store factories.
210    #[must_use]
211    pub fn with_stores<M2, S2, MF, SF>(
212        self,
213        message_factory: MF,
214        state_factory: SF,
215    ) -> SubagentTool<P, H, M2, S2>
216    where
217        M2: MessageStore + 'static,
218        S2: StateStore + 'static,
219        MF: Fn() -> M2 + Send + Sync + 'static,
220        SF: Fn() -> S2 + Send + Sync + 'static,
221    {
222        SubagentTool {
223            config: self.config,
224            provider: self.provider,
225            tools: self.tools,
226            hooks: self.hooks,
227            message_store_factory: Arc::new(message_factory),
228            state_store_factory: Arc::new(state_factory),
229        }
230    }
231
232    /// Get the subagent configuration.
233    #[must_use]
234    pub const fn config(&self) -> &SubagentConfig {
235        &self.config
236    }
237
238    /// Run the subagent with a task.
239    ///
240    /// If `parent_tx` is provided, the subagent will emit `SubagentProgress` events
241    /// to the parent's event channel, allowing the UI to show live progress.
242    #[allow(clippy::too_many_lines)]
243    async fn run_subagent(
244        &self,
245        task: &str,
246        subagent_id: String,
247        parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
248        parent_seq: Option<SequenceCounter>,
249    ) -> Result<SubagentResult> {
250        use crate::agent_loop::AgentLoop;
251
252        let start = Instant::now();
253        let thread_id = ThreadId::new();
254
255        // Create stores for this subagent run
256        let message_store = (self.message_store_factory)();
257        let state_store = (self.state_store_factory)();
258
259        // Create agent config
260        let agent_config = AgentConfig {
261            max_turns: self.config.max_turns,
262            system_prompt: self.config.system_prompt.clone(),
263            ..Default::default()
264        };
265
266        // Build the subagent
267        let agent = AgentLoop::new(
268            (*self.provider).clone(),
269            (*self.tools).clone(),
270            (*self.hooks).clone(),
271            message_store,
272            state_store,
273            agent_config,
274        );
275
276        // Create tool context
277        let tool_ctx = ToolContext::new(());
278
279        // Run with optional timeout
280        let (mut rx, _final_state) = agent.run(
281            thread_id,
282            AgentInput::Text(task.to_string()),
283            tool_ctx,
284            CancellationToken::new(),
285        );
286
287        let mut final_response = String::new();
288        let mut total_turns = 0;
289        let mut tool_count = 0u32;
290        let mut tool_logs: Vec<ToolCallLog> = Vec::new();
291        let mut pending_tools: std::collections::HashMap<String, (String, String)> =
292            std::collections::HashMap::new();
293        let mut total_usage = TokenUsage::default();
294        let mut success = true;
295
296        let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
297
298        loop {
299            let recv_result = if let Some(timeout) = timeout_duration {
300                let remaining = timeout.saturating_sub(start.elapsed());
301                if remaining.is_zero() {
302                    final_response = "Subagent timed out".to_string();
303                    success = false;
304                    break;
305                }
306                tokio::time::timeout(remaining, rx.recv()).await
307            } else {
308                Ok(rx.recv().await)
309            };
310
311            match recv_result {
312                Ok(Some(envelope)) => match envelope.event {
313                    AgentEvent::Text {
314                        message_id: _,
315                        text,
316                    } => {
317                        final_response.push_str(&text);
318                    }
319                    AgentEvent::ToolCallStart {
320                        id, name, input, ..
321                    } => {
322                        // Track tool calls made by the subagent
323                        tool_count += 1;
324                        let context = extract_tool_context(&name, &input);
325                        pending_tools.insert(id, (name.clone(), context.clone()));
326
327                        // Emit progress event to parent
328                        if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
329                            let event = AgentEvent::SubagentProgress {
330                                subagent_id: subagent_id.clone(),
331                                subagent_name: self.config.name.clone(),
332                                tool_name: name,
333                                tool_context: context,
334                                completed: false,
335                                success: false,
336                                tool_count,
337                                total_tokens: u64::from(total_usage.input_tokens)
338                                    + u64::from(total_usage.output_tokens),
339                            };
340                            let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
341                        }
342                    }
343                    AgentEvent::ToolCallEnd {
344                        id,
345                        name,
346                        display_name,
347                        result,
348                    } => {
349                        // Create log entry when tool completes
350                        let context = pending_tools
351                            .remove(&id)
352                            .map(|(_, ctx)| ctx)
353                            .unwrap_or_default();
354                        let result_summary = summarize_tool_result(&name, &result);
355                        let tool_success = result.success;
356                        tool_logs.push(ToolCallLog {
357                            name: name.clone(),
358                            display_name: display_name.clone(),
359                            context: context.clone(),
360                            result: result_summary,
361                            success: tool_success,
362                            duration_ms: result.duration_ms,
363                        });
364
365                        // Emit progress event to parent
366                        if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
367                            let event = AgentEvent::SubagentProgress {
368                                subagent_id: subagent_id.clone(),
369                                subagent_name: self.config.name.clone(),
370                                tool_name: name,
371                                tool_context: context,
372                                completed: true,
373                                success: tool_success,
374                                tool_count,
375                                total_tokens: u64::from(total_usage.input_tokens)
376                                    + u64::from(total_usage.output_tokens),
377                            };
378                            let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
379                        }
380                    }
381                    AgentEvent::TurnComplete { turn, usage, .. } => {
382                        total_turns = turn;
383                        total_usage.add(&usage);
384                    }
385                    AgentEvent::Done {
386                        total_turns: turns, ..
387                    } => {
388                        total_turns = turns;
389                        break;
390                    }
391                    AgentEvent::Error { message, .. } => {
392                        final_response = message;
393                        success = false;
394                        break;
395                    }
396                    _ => {}
397                },
398                Ok(None) => break,
399                Err(_) => {
400                    final_response = "Subagent timed out".to_string();
401                    success = false;
402                    break;
403                }
404            }
405        }
406
407        Ok(SubagentResult {
408            name: self.config.name.clone(),
409            final_response,
410            total_turns,
411            tool_count,
412            tool_logs,
413            usage: total_usage,
414            success,
415            duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
416        })
417    }
418}
419
420/// Extracts context information from tool input for display.
421fn extract_tool_context(name: &str, input: &Value) -> String {
422    match name {
423        "read" => input
424            .get("file_path")
425            .or_else(|| input.get("path"))
426            .and_then(Value::as_str)
427            .unwrap_or("")
428            .to_string(),
429        "write" | "edit" => input
430            .get("file_path")
431            .or_else(|| input.get("path"))
432            .and_then(Value::as_str)
433            .unwrap_or("")
434            .to_string(),
435        "bash" => {
436            let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
437            // Truncate long commands
438            if cmd.len() > 60 {
439                format!("{}...", &cmd[..57])
440            } else {
441                cmd.to_string()
442            }
443        }
444        "glob" | "grep" => input
445            .get("pattern")
446            .and_then(Value::as_str)
447            .unwrap_or("")
448            .to_string(),
449        "web_search" => input
450            .get("query")
451            .and_then(Value::as_str)
452            .unwrap_or("")
453            .to_string(),
454        _ => String::new(),
455    }
456}
457
458/// Summarizes tool result for logging.
459fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
460    if !result.success {
461        let first_line = result.output.lines().next().unwrap_or("Error");
462        return if first_line.len() > 50 {
463            format!("{}...", &first_line[..47])
464        } else {
465            first_line.to_string()
466        };
467    }
468
469    match name {
470        "read" => {
471            let line_count = result.output.lines().count();
472            format!("{line_count} lines")
473        }
474        "write" => "wrote file".to_string(),
475        "edit" => "edited".to_string(),
476        "bash" => {
477            let lines: Vec<&str> = result.output.lines().collect();
478            if lines.is_empty() {
479                "done".to_string()
480            } else if lines.len() == 1 {
481                let line = lines[0];
482                if line.len() > 50 {
483                    format!("{}...", &line[..47])
484                } else {
485                    line.to_string()
486                }
487            } else {
488                format!("{} lines", lines.len())
489            }
490        }
491        "glob" => {
492            let count = result.output.lines().count();
493            format!("{count} files")
494        }
495        "grep" => {
496            let count = result.output.lines().count();
497            format!("{count} matches")
498        }
499        _ => {
500            let line_count = result.output.lines().count();
501            if line_count == 0 {
502                "done".to_string()
503            } else {
504                format!("{line_count} lines")
505            }
506        }
507    }
508}
509
510impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
511where
512    P: LlmProvider + Clone + 'static,
513    H: AgentHooks + Clone + 'static,
514    M: MessageStore + 'static,
515    S: StateStore + 'static,
516{
517    type Name = DynamicToolName;
518
519    fn name(&self) -> DynamicToolName {
520        DynamicToolName::new(format!("subagent_{}", self.config.name))
521    }
522
523    fn display_name(&self) -> &'static str {
524        // Leak the name to get 'static lifetime (acceptable for long-lived tools)
525        Box::leak(format!("Subagent: {}", self.config.name).into_boxed_str())
526    }
527
528    fn description(&self) -> &'static str {
529        Box::leak(
530            format!(
531                "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
532                self.config.name
533            )
534            .into_boxed_str(),
535        )
536    }
537
538    fn input_schema(&self) -> Value {
539        json!({
540            "type": "object",
541            "properties": {
542                "task": {
543                    "type": "string",
544                    "description": "The task or question for the subagent to handle"
545                }
546            },
547            "required": ["task"]
548        })
549    }
550
551    fn tier(&self) -> ToolTier {
552        // Subagent spawning requires confirmation
553        ToolTier::Confirm
554    }
555
556    async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
557        let task = input
558            .get("task")
559            .and_then(Value::as_str)
560            .context("Missing 'task' parameter")?;
561
562        // Get event channel and sequence counter from context for progress updates
563        let parent_tx = ctx.event_tx();
564        let parent_seq = ctx.event_seq();
565
566        // Generate a unique ID for this subagent execution
567        let subagent_id = format!(
568            "{}_{:x}",
569            self.config.name,
570            std::time::SystemTime::now()
571                .duration_since(std::time::UNIX_EPOCH)
572                .unwrap_or_default()
573                .as_nanos()
574        );
575
576        let result = self
577            .run_subagent(task, subagent_id, parent_tx, parent_seq)
578            .await?;
579
580        Ok(ToolResult {
581            success: result.success,
582            output: result.final_response.clone(),
583            data: Some(serde_json::to_value(&result).unwrap_or_default()),
584            documents: Vec::new(),
585            duration_ms: Some(result.duration_ms),
586        })
587    }
588}
589
590#[cfg(test)]
591mod tests {
592    use super::*;
593
594    #[test]
595    fn test_subagent_config_builder() {
596        let config = SubagentConfig::new("test")
597            .with_system_prompt("Test prompt")
598            .with_max_turns(5)
599            .with_timeout_ms(30000);
600
601        assert_eq!(config.name, "test");
602        assert_eq!(config.system_prompt, "Test prompt");
603        assert_eq!(config.max_turns, Some(5));
604        assert_eq!(config.timeout_ms, Some(30000));
605    }
606
607    #[test]
608    fn test_subagent_config_defaults() {
609        let config = SubagentConfig::new("default");
610
611        assert_eq!(config.name, "default");
612        assert!(config.system_prompt.is_empty());
613        assert_eq!(config.max_turns, None);
614        assert_eq!(config.timeout_ms, None);
615    }
616
617    #[test]
618    fn test_subagent_result_serialization() {
619        let result = SubagentResult {
620            name: "test".to_string(),
621            final_response: "Done".to_string(),
622            total_turns: 3,
623            tool_count: 5,
624            tool_logs: vec![
625                ToolCallLog {
626                    name: "read".to_string(),
627                    display_name: "Read file".to_string(),
628                    context: "/tmp/test.rs".to_string(),
629                    result: "50 lines".to_string(),
630                    success: true,
631                    duration_ms: Some(10),
632                },
633                ToolCallLog {
634                    name: "grep".to_string(),
635                    display_name: "Grep TODO".to_string(),
636                    context: "TODO".to_string(),
637                    result: "3 matches".to_string(),
638                    success: true,
639                    duration_ms: Some(5),
640                },
641            ],
642            usage: TokenUsage::default(),
643            success: true,
644            duration_ms: 1000,
645        };
646
647        let json = serde_json::to_string(&result).expect("serialize");
648        assert!(json.contains("test"));
649        assert!(json.contains("Done"));
650        assert!(json.contains("tool_count"));
651        assert!(json.contains("tool_logs"));
652        assert!(json.contains("/tmp/test.rs"));
653    }
654
655    #[test]
656    fn test_subagent_result_field_extraction() {
657        // Test that verifies the exact JSON structure expected by bip's tui_session.rs
658        let result = SubagentResult {
659            name: "explore".to_string(),
660            final_response: "Found 3 config files".to_string(),
661            total_turns: 2,
662            tool_count: 5,
663            tool_logs: vec![ToolCallLog {
664                name: "glob".to_string(),
665                display_name: "Glob config files".to_string(),
666                context: "**/*.toml".to_string(),
667                result: "3 files".to_string(),
668                success: true,
669                duration_ms: Some(15),
670            }],
671            usage: TokenUsage {
672                input_tokens: 1500,
673                output_tokens: 500,
674            },
675            success: true,
676            duration_ms: 2500,
677        };
678
679        let value = serde_json::to_value(&result).expect("serialize to value");
680
681        // Test tool_count extraction (as_u64 should work for u32)
682        let tool_count = value.get("tool_count").and_then(Value::as_u64);
683        assert_eq!(tool_count, Some(5));
684
685        // Test usage extraction
686        let usage = value.get("usage").expect("usage field");
687        let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
688        let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
689        assert_eq!(input_tokens, Some(1500));
690        assert_eq!(output_tokens, Some(500));
691
692        // Test tool_logs extraction
693        let tool_logs = value.get("tool_logs").and_then(Value::as_array);
694        assert!(tool_logs.is_some());
695        let logs = tool_logs.unwrap();
696        assert_eq!(logs.len(), 1);
697
698        let first_log = &logs[0];
699        assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
700        assert_eq!(
701            first_log.get("context").and_then(Value::as_str),
702            Some("**/*.toml")
703        );
704        assert_eq!(
705            first_log.get("result").and_then(Value::as_str),
706            Some("3 files")
707        );
708        assert_eq!(
709            first_log.get("success").and_then(Value::as_bool),
710            Some(true)
711        );
712    }
713}