Skip to main content

agent_sdk/
subagent.rs

1//! Subagent support for spawning child agents.
2//!
3//! This module provides the ability to spawn subagents from within an agent.
4//! Subagents are isolated agent instances that run to completion and return
5//! only their final response to the parent agent.
6//!
7//! # Overview
8//!
9//! Subagents are useful for:
10//! - Delegating complex subtasks to specialized agents
11//! - Running parallel investigations
12//! - Isolating context for specific operations
13//!
14//! # Example
15//!
16//! ```ignore
17//! use agent_sdk::subagent::{SubagentTool, SubagentConfig};
18//!
19//! let config = SubagentConfig::new("researcher")
20//!     .with_system_prompt("You are a research specialist...")
21//!     .with_max_turns(10);
22//!
23//! let tool = SubagentTool::new(config, provider, tools);
24//! registry.register(tool);
25//! ```
26//!
27//! # Behavior
28//!
29//! When a subagent runs:
30//! 1. A new isolated thread is created
31//! 2. The subagent runs until completion or max turns
32//! 3. Only the final text response is returned to the parent
33//! 4. The parent does not see the subagent's intermediate tool calls
34
35mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53/// Configuration for a subagent.
54#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56    /// Name of the subagent (for identification).
57    pub name: String,
58    /// System prompt for the subagent.
59    pub system_prompt: String,
60    /// Maximum number of turns before stopping.
61    pub max_turns: Option<usize>,
62    /// Optional timeout in milliseconds.
63    pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67    /// Create a new subagent configuration.
68    #[must_use]
69    pub fn new(name: impl Into<String>) -> Self {
70        Self {
71            name: name.into(),
72            system_prompt: String::new(),
73            max_turns: None,
74            timeout_ms: None,
75        }
76    }
77
78    /// Set the system prompt.
79    #[must_use]
80    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81        self.system_prompt = prompt.into();
82        self
83    }
84
85    /// Set the maximum number of turns.
86    #[must_use]
87    pub const fn with_max_turns(mut self, max: usize) -> Self {
88        self.max_turns = Some(max);
89        self
90    }
91
92    /// Set the timeout in milliseconds.
93    #[must_use]
94    pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95        self.timeout_ms = Some(timeout);
96        self
97    }
98}
99
100/// Log entry for a single tool call within a subagent.
101#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103    /// Tool name.
104    pub name: String,
105    /// Tool display name.
106    pub display_name: String,
107    /// Brief context/args (e.g., file path, command).
108    pub context: String,
109    /// Brief result summary.
110    pub result: String,
111    /// Whether the tool call succeeded.
112    pub success: bool,
113    /// Duration in milliseconds.
114    pub duration_ms: Option<u64>,
115}
116
117/// Result from a subagent execution.
118#[derive(Clone, Debug, Serialize, Deserialize)]
119pub struct SubagentResult {
120    /// Name of the subagent.
121    pub name: String,
122    /// The final text response (only visible part to parent).
123    pub final_response: String,
124    /// Total number of turns taken.
125    pub total_turns: usize,
126    /// Number of tool calls made by the subagent.
127    pub tool_count: u32,
128    /// Log of tool calls made by the subagent.
129    pub tool_logs: Vec<ToolCallLog>,
130    /// Token usage statistics.
131    pub usage: TokenUsage,
132    /// Whether the subagent completed successfully.
133    pub success: bool,
134    /// Duration in milliseconds.
135    pub duration_ms: u64,
136}
137
138/// Tool for spawning subagents.
139///
140/// This tool allows an agent to spawn a child agent that runs independently
141/// and returns only its final response.
142///
143/// # Example
144///
145/// ```ignore
146/// use agent_sdk::subagent::{SubagentTool, SubagentConfig};
147///
148/// let config = SubagentConfig::new("analyzer")
149///     .with_system_prompt("You analyze code...");
150///
151/// let tool = SubagentTool::new(config, provider.clone(), tools.clone());
152/// ```
153pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
154where
155    P: LlmProvider,
156    H: AgentHooks,
157    M: MessageStore,
158    S: StateStore,
159{
160    config: SubagentConfig,
161    provider: Arc<P>,
162    tools: Arc<ToolRegistry<()>>,
163    hooks: Arc<H>,
164    message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
165    state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
166    /// Cached display name to avoid `Box::leak` on every call.
167    cached_display_name: &'static str,
168    /// Cached description to avoid `Box::leak` on every call.
169    cached_description: &'static str,
170}
171
172impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
173where
174    P: LlmProvider + 'static,
175{
176    /// Create a new subagent tool with default hooks and in-memory stores.
177    #[must_use]
178    pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
179        // Cache leaked strings at construction time (bounded by number of tools)
180        let cached_display_name = Box::leak(format!("Subagent: {}", config.name).into_boxed_str());
181        let cached_description = Box::leak(
182            format!(
183                "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
184                config.name
185            )
186            .into_boxed_str(),
187        );
188        Self {
189            config,
190            provider,
191            tools,
192            hooks: Arc::new(DefaultHooks),
193            message_store_factory: Arc::new(InMemoryStore::new),
194            state_store_factory: Arc::new(InMemoryStore::new),
195            cached_display_name,
196            cached_description,
197        }
198    }
199}
200
201impl<P, H, M, S> SubagentTool<P, H, M, S>
202where
203    P: LlmProvider + Clone + 'static,
204    H: AgentHooks + Clone + 'static,
205    M: MessageStore + 'static,
206    S: StateStore + 'static,
207{
208    /// Create with custom hooks.
209    #[must_use]
210    pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
211        self,
212        hooks: Arc<H2>,
213    ) -> SubagentTool<P, H2, M, S> {
214        SubagentTool {
215            config: self.config,
216            provider: self.provider,
217            tools: self.tools,
218            hooks,
219            message_store_factory: self.message_store_factory,
220            state_store_factory: self.state_store_factory,
221            cached_display_name: self.cached_display_name,
222            cached_description: self.cached_description,
223        }
224    }
225
226    /// Create with custom store factories.
227    #[must_use]
228    pub fn with_stores<M2, S2, MF, SF>(
229        self,
230        message_factory: MF,
231        state_factory: SF,
232    ) -> SubagentTool<P, H, M2, S2>
233    where
234        M2: MessageStore + 'static,
235        S2: StateStore + 'static,
236        MF: Fn() -> M2 + Send + Sync + 'static,
237        SF: Fn() -> S2 + Send + Sync + 'static,
238    {
239        SubagentTool {
240            config: self.config,
241            provider: self.provider,
242            tools: self.tools,
243            hooks: self.hooks,
244            message_store_factory: Arc::new(message_factory),
245            state_store_factory: Arc::new(state_factory),
246            cached_display_name: self.cached_display_name,
247            cached_description: self.cached_description,
248        }
249    }
250
251    /// Get the subagent configuration.
252    #[must_use]
253    pub const fn config(&self) -> &SubagentConfig {
254        &self.config
255    }
256
257    /// Run the subagent with a task.
258    ///
259    /// If `parent_tx` is provided, the subagent will emit `SubagentProgress` events
260    /// to the parent's event channel, allowing the UI to show live progress.
261    ///
262    /// The `parent_cancel` token links the subagent's lifecycle to its parent.
263    /// Cancelling the parent token will also cancel the subagent.
264    #[allow(clippy::too_many_lines)]
265    async fn run_subagent(
266        &self,
267        task: &str,
268        subagent_id: String,
269        parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
270        parent_seq: Option<SequenceCounter>,
271        parent_cancel: CancellationToken,
272    ) -> Result<SubagentResult> {
273        use crate::agent_loop::AgentLoop;
274
275        let start = Instant::now();
276        let thread_id = ThreadId::new();
277
278        // Create stores for this subagent run
279        let message_store = (self.message_store_factory)();
280        let state_store = (self.state_store_factory)();
281
282        // Create agent config with a default max_turns to prevent unbounded execution
283        let agent_config = AgentConfig {
284            max_turns: Some(self.config.max_turns.unwrap_or(100)),
285            system_prompt: self.config.system_prompt.clone(),
286            ..Default::default()
287        };
288
289        // Build the subagent
290        let agent = AgentLoop::new(
291            (*self.provider).clone(),
292            (*self.tools).clone(),
293            (*self.hooks).clone(),
294            message_store,
295            state_store,
296            agent_config,
297        );
298
299        // Create tool context
300        let tool_ctx = ToolContext::new(());
301
302        // Run with a child cancellation token so parent cancellation propagates
303        let cancel_token = parent_cancel.child_token();
304        let timeout_cancel = cancel_token.clone();
305        let (mut rx, _final_state) = agent.run(
306            thread_id,
307            AgentInput::Text(task.to_string()),
308            tool_ctx,
309            cancel_token,
310        );
311
312        let mut final_response = String::new();
313        let mut total_turns = 0;
314        let mut tool_count = 0u32;
315        let mut tool_logs: Vec<ToolCallLog> = Vec::new();
316        let mut pending_tools: std::collections::HashMap<String, (String, String)> =
317            std::collections::HashMap::new();
318        let mut total_usage = TokenUsage::default();
319        let mut success = true;
320
321        let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
322
323        loop {
324            let recv_result = if let Some(timeout) = timeout_duration {
325                let remaining = timeout.saturating_sub(start.elapsed());
326                if remaining.is_zero() {
327                    timeout_cancel.cancel(); // Cancel the child agent on timeout
328                    final_response = "Subagent timed out".to_string();
329                    success = false;
330                    break;
331                }
332                tokio::time::timeout(remaining, rx.recv()).await
333            } else {
334                Ok(rx.recv().await)
335            };
336
337            match recv_result {
338                Ok(Some(envelope)) => match envelope.event {
339                    AgentEvent::Text {
340                        message_id: _,
341                        text,
342                    } => {
343                        final_response.push_str(&text);
344                    }
345                    AgentEvent::ToolCallStart {
346                        id, name, input, ..
347                    } => {
348                        // Track tool calls made by the subagent
349                        tool_count += 1;
350                        let context = extract_tool_context(&name, &input);
351                        pending_tools.insert(id, (name.clone(), context.clone()));
352
353                        // Emit progress event to parent
354                        if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
355                            let event = AgentEvent::SubagentProgress {
356                                subagent_id: subagent_id.clone(),
357                                subagent_name: self.config.name.clone(),
358                                tool_name: name,
359                                tool_context: context,
360                                completed: false,
361                                success: false,
362                                tool_count,
363                                total_tokens: u64::from(total_usage.input_tokens)
364                                    + u64::from(total_usage.output_tokens),
365                            };
366                            let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
367                        }
368                    }
369                    AgentEvent::ToolCallEnd {
370                        id,
371                        name,
372                        display_name,
373                        result,
374                    } => {
375                        // Create log entry when tool completes
376                        let context = pending_tools
377                            .remove(&id)
378                            .map(|(_, ctx)| ctx)
379                            .unwrap_or_default();
380                        let result_summary = summarize_tool_result(&name, &result);
381                        let tool_success = result.success;
382                        tool_logs.push(ToolCallLog {
383                            name: name.clone(),
384                            display_name: display_name.clone(),
385                            context: context.clone(),
386                            result: result_summary,
387                            success: tool_success,
388                            duration_ms: result.duration_ms,
389                        });
390
391                        // Emit progress event to parent
392                        if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
393                            let event = AgentEvent::SubagentProgress {
394                                subagent_id: subagent_id.clone(),
395                                subagent_name: self.config.name.clone(),
396                                tool_name: name,
397                                tool_context: context,
398                                completed: true,
399                                success: tool_success,
400                                tool_count,
401                                total_tokens: u64::from(total_usage.input_tokens)
402                                    + u64::from(total_usage.output_tokens),
403                            };
404                            let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
405                        }
406                    }
407                    AgentEvent::TurnComplete { turn, usage, .. } => {
408                        total_turns = turn;
409                        total_usage.add(&usage);
410                    }
411                    AgentEvent::Done {
412                        total_turns: turns, ..
413                    } => {
414                        total_turns = turns;
415                        break;
416                    }
417                    AgentEvent::Error { message, .. } => {
418                        final_response = message;
419                        success = false;
420                        break;
421                    }
422                    _ => {}
423                },
424                Ok(None) => break,
425                Err(_) => {
426                    timeout_cancel.cancel(); // Cancel the child agent on timeout
427                    final_response = "Subagent timed out".to_string();
428                    success = false;
429                    break;
430                }
431            }
432        }
433
434        Ok(SubagentResult {
435            name: self.config.name.clone(),
436            final_response,
437            total_turns,
438            tool_count,
439            tool_logs,
440            usage: total_usage,
441            success,
442            duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
443        })
444    }
445}
446
447/// Extracts context information from tool input for display.
448fn extract_tool_context(name: &str, input: &Value) -> String {
449    match name {
450        "read" => input
451            .get("file_path")
452            .or_else(|| input.get("path"))
453            .and_then(Value::as_str)
454            .unwrap_or("")
455            .to_string(),
456        "write" | "edit" => input
457            .get("file_path")
458            .or_else(|| input.get("path"))
459            .and_then(Value::as_str)
460            .unwrap_or("")
461            .to_string(),
462        "bash" => {
463            let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
464            // Truncate long commands (UTF-8 safe)
465            if cmd.len() > 60 {
466                format!("{}...", crate::primitive_tools::truncate_str(cmd, 57))
467            } else {
468                cmd.to_string()
469            }
470        }
471        "glob" | "grep" => input
472            .get("pattern")
473            .and_then(Value::as_str)
474            .unwrap_or("")
475            .to_string(),
476        "web_search" => input
477            .get("query")
478            .and_then(Value::as_str)
479            .unwrap_or("")
480            .to_string(),
481        _ => String::new(),
482    }
483}
484
485/// Summarizes tool result for logging.
486fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
487    if !result.success {
488        let first_line = result.output.lines().next().unwrap_or("Error");
489        return if first_line.len() > 50 {
490            format!(
491                "{}...",
492                crate::primitive_tools::truncate_str(first_line, 47)
493            )
494        } else {
495            first_line.to_string()
496        };
497    }
498
499    match name {
500        "read" => {
501            let line_count = result.output.lines().count();
502            format!("{line_count} lines")
503        }
504        "write" => "wrote file".to_string(),
505        "edit" => "edited".to_string(),
506        "bash" => {
507            let lines: Vec<&str> = result.output.lines().collect();
508            if lines.is_empty() {
509                "done".to_string()
510            } else if lines.len() == 1 {
511                let line = lines[0];
512                if line.len() > 50 {
513                    format!("{}...", crate::primitive_tools::truncate_str(line, 47))
514                } else {
515                    line.to_string()
516                }
517            } else {
518                format!("{} lines", lines.len())
519            }
520        }
521        "glob" => {
522            let count = result.output.lines().count();
523            format!("{count} files")
524        }
525        "grep" => {
526            let count = result.output.lines().count();
527            format!("{count} matches")
528        }
529        _ => {
530            let line_count = result.output.lines().count();
531            if line_count == 0 {
532                "done".to_string()
533            } else {
534                format!("{line_count} lines")
535            }
536        }
537    }
538}
539
540impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
541where
542    P: LlmProvider + Clone + 'static,
543    H: AgentHooks + Clone + 'static,
544    M: MessageStore + 'static,
545    S: StateStore + 'static,
546{
547    type Name = DynamicToolName;
548
549    fn name(&self) -> DynamicToolName {
550        DynamicToolName::new(format!("subagent_{}", self.config.name))
551    }
552
553    fn display_name(&self) -> &'static str {
554        self.cached_display_name
555    }
556
557    fn description(&self) -> &'static str {
558        self.cached_description
559    }
560
561    fn input_schema(&self) -> Value {
562        json!({
563            "type": "object",
564            "properties": {
565                "task": {
566                    "type": "string",
567                    "description": "The task or question for the subagent to handle"
568                }
569            },
570            "required": ["task"]
571        })
572    }
573
574    fn tier(&self) -> ToolTier {
575        // Subagent spawning requires confirmation
576        ToolTier::Confirm
577    }
578
579    async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
580        let task = input
581            .get("task")
582            .and_then(Value::as_str)
583            .context("Missing 'task' parameter")?;
584
585        // Get event channel and sequence counter from context for progress updates
586        let parent_tx = ctx.event_tx();
587        let parent_seq = ctx.event_seq();
588
589        // Generate a unique ID for this subagent execution
590        let subagent_id = format!(
591            "{}_{:x}",
592            self.config.name,
593            std::time::SystemTime::now()
594                .duration_since(std::time::UNIX_EPOCH)
595                .unwrap_or_default()
596                .as_nanos()
597        );
598
599        // Use the context's cancellation token if available, otherwise create a standalone one.
600        // This ensures that when a parent agent is cancelled, subagents are also cancelled.
601        let cancel_token = ctx.cancel_token().unwrap_or_default();
602
603        let result = self
604            .run_subagent(task, subagent_id, parent_tx, parent_seq, cancel_token)
605            .await?;
606
607        Ok(ToolResult {
608            success: result.success,
609            output: result.final_response.clone(),
610            data: Some(serde_json::to_value(&result).unwrap_or_default()),
611            documents: Vec::new(),
612            duration_ms: Some(result.duration_ms),
613        })
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    #[test]
622    fn test_subagent_config_builder() {
623        let config = SubagentConfig::new("test")
624            .with_system_prompt("Test prompt")
625            .with_max_turns(5)
626            .with_timeout_ms(30000);
627
628        assert_eq!(config.name, "test");
629        assert_eq!(config.system_prompt, "Test prompt");
630        assert_eq!(config.max_turns, Some(5));
631        assert_eq!(config.timeout_ms, Some(30000));
632    }
633
634    #[test]
635    fn test_subagent_config_defaults() {
636        let config = SubagentConfig::new("default");
637
638        assert_eq!(config.name, "default");
639        assert!(config.system_prompt.is_empty());
640        assert_eq!(config.max_turns, None);
641        assert_eq!(config.timeout_ms, None);
642    }
643
644    #[test]
645    fn test_subagent_result_serialization() {
646        let result = SubagentResult {
647            name: "test".to_string(),
648            final_response: "Done".to_string(),
649            total_turns: 3,
650            tool_count: 5,
651            tool_logs: vec![
652                ToolCallLog {
653                    name: "read".to_string(),
654                    display_name: "Read file".to_string(),
655                    context: "/tmp/test.rs".to_string(),
656                    result: "50 lines".to_string(),
657                    success: true,
658                    duration_ms: Some(10),
659                },
660                ToolCallLog {
661                    name: "grep".to_string(),
662                    display_name: "Grep TODO".to_string(),
663                    context: "TODO".to_string(),
664                    result: "3 matches".to_string(),
665                    success: true,
666                    duration_ms: Some(5),
667                },
668            ],
669            usage: TokenUsage::default(),
670            success: true,
671            duration_ms: 1000,
672        };
673
674        let json = serde_json::to_string(&result).expect("serialize");
675        assert!(json.contains("test"));
676        assert!(json.contains("Done"));
677        assert!(json.contains("tool_count"));
678        assert!(json.contains("tool_logs"));
679        assert!(json.contains("/tmp/test.rs"));
680    }
681
682    #[test]
683    fn test_subagent_result_field_extraction() {
684        // Test that verifies the exact JSON structure expected by bip's tui_session.rs
685        let result = SubagentResult {
686            name: "explore".to_string(),
687            final_response: "Found 3 config files".to_string(),
688            total_turns: 2,
689            tool_count: 5,
690            tool_logs: vec![ToolCallLog {
691                name: "glob".to_string(),
692                display_name: "Glob config files".to_string(),
693                context: "**/*.toml".to_string(),
694                result: "3 files".to_string(),
695                success: true,
696                duration_ms: Some(15),
697            }],
698            usage: TokenUsage {
699                input_tokens: 1500,
700                output_tokens: 500,
701            },
702            success: true,
703            duration_ms: 2500,
704        };
705
706        let value = serde_json::to_value(&result).expect("serialize to value");
707
708        // Test tool_count extraction (as_u64 should work for u32)
709        let tool_count = value.get("tool_count").and_then(Value::as_u64);
710        assert_eq!(tool_count, Some(5));
711
712        // Test usage extraction
713        let usage = value.get("usage").expect("usage field");
714        let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
715        let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
716        assert_eq!(input_tokens, Some(1500));
717        assert_eq!(output_tokens, Some(500));
718
719        // Test tool_logs extraction
720        let tool_logs = value.get("tool_logs").and_then(Value::as_array);
721        assert!(tool_logs.is_some());
722        let logs = tool_logs.unwrap();
723        assert_eq!(logs.len(), 1);
724
725        let first_log = &logs[0];
726        assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
727        assert_eq!(
728            first_log.get("context").and_then(Value::as_str),
729            Some("**/*.toml")
730        );
731        assert_eq!(
732            first_log.get("result").and_then(Value::as_str),
733            Some("3 files")
734        );
735        assert_eq!(
736            first_log.get("success").and_then(Value::as_bool),
737            Some(true)
738        );
739    }
740}