mod factory;
pub use factory::SubagentFactory;
use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
use crate::hooks::{AgentHooks, DefaultHooks};
use crate::llm::LlmProvider;
use crate::stores::{InMemoryStore, MessageStore, StateStore};
use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
use anyhow::{Context, Result};
use serde::{Deserialize, Serialize};
use serde_json::{Value, json};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::mpsc;
use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SubagentConfig {
pub name: String,
pub system_prompt: String,
pub max_turns: Option<usize>,
pub timeout_ms: Option<u64>,
}
impl SubagentConfig {
#[must_use]
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
system_prompt: String::new(),
max_turns: None,
timeout_ms: None,
}
}
#[must_use]
pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.system_prompt = prompt.into();
self
}
#[must_use]
pub const fn with_max_turns(mut self, max: usize) -> Self {
self.max_turns = Some(max);
self
}
#[must_use]
pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
self.timeout_ms = Some(timeout);
self
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ToolCallLog {
pub name: String,
pub display_name: String,
pub context: String,
pub result: String,
pub success: bool,
pub duration_ms: Option<u64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SubagentResult {
pub name: String,
pub final_response: String,
pub total_turns: usize,
pub tool_count: u32,
pub tool_logs: Vec<ToolCallLog>,
pub usage: TokenUsage,
pub success: bool,
pub duration_ms: u64,
}
pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
where
P: LlmProvider,
H: AgentHooks,
M: MessageStore,
S: StateStore,
{
config: SubagentConfig,
provider: Arc<P>,
tools: Arc<ToolRegistry<()>>,
hooks: Arc<H>,
message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
cached_display_name: &'static str,
cached_description: &'static str,
}
impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
where
P: LlmProvider + 'static,
{
#[must_use]
pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
let cached_display_name = Box::leak(format!("Subagent: {}", config.name).into_boxed_str());
let cached_description = Box::leak(
format!(
"Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
config.name
)
.into_boxed_str(),
);
Self {
config,
provider,
tools,
hooks: Arc::new(DefaultHooks),
message_store_factory: Arc::new(InMemoryStore::new),
state_store_factory: Arc::new(InMemoryStore::new),
cached_display_name,
cached_description,
}
}
}
impl<P, H, M, S> SubagentTool<P, H, M, S>
where
P: LlmProvider + Clone + 'static,
H: AgentHooks + Clone + 'static,
M: MessageStore + 'static,
S: StateStore + 'static,
{
#[must_use]
pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
self,
hooks: Arc<H2>,
) -> SubagentTool<P, H2, M, S> {
SubagentTool {
config: self.config,
provider: self.provider,
tools: self.tools,
hooks,
message_store_factory: self.message_store_factory,
state_store_factory: self.state_store_factory,
cached_display_name: self.cached_display_name,
cached_description: self.cached_description,
}
}
#[must_use]
pub fn with_stores<M2, S2, MF, SF>(
self,
message_factory: MF,
state_factory: SF,
) -> SubagentTool<P, H, M2, S2>
where
M2: MessageStore + 'static,
S2: StateStore + 'static,
MF: Fn() -> M2 + Send + Sync + 'static,
SF: Fn() -> S2 + Send + Sync + 'static,
{
SubagentTool {
config: self.config,
provider: self.provider,
tools: self.tools,
hooks: self.hooks,
message_store_factory: Arc::new(message_factory),
state_store_factory: Arc::new(state_factory),
cached_display_name: self.cached_display_name,
cached_description: self.cached_description,
}
}
#[must_use]
pub const fn config(&self) -> &SubagentConfig {
&self.config
}
#[allow(clippy::too_many_lines)]
async fn run_subagent(
&self,
task: &str,
subagent_id: String,
parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
parent_seq: Option<SequenceCounter>,
parent_cancel: CancellationToken,
) -> Result<SubagentResult> {
use crate::agent_loop::AgentLoop;
let start = Instant::now();
let thread_id = ThreadId::new();
let message_store = (self.message_store_factory)();
let state_store = (self.state_store_factory)();
let agent_config = AgentConfig {
max_turns: Some(self.config.max_turns.unwrap_or(100)),
system_prompt: self.config.system_prompt.clone(),
..Default::default()
};
let agent = AgentLoop::new(
(*self.provider).clone(),
(*self.tools).clone(),
(*self.hooks).clone(),
message_store,
state_store,
agent_config,
);
let tool_ctx = ToolContext::new(());
let cancel_token = parent_cancel.child_token();
let timeout_cancel = cancel_token.clone();
let (mut rx, _final_state) = agent.run(
thread_id,
AgentInput::Text(task.to_string()),
tool_ctx,
cancel_token,
);
let mut final_response = String::new();
let mut total_turns = 0;
let mut tool_count = 0u32;
let mut tool_logs: Vec<ToolCallLog> = Vec::new();
let mut pending_tools: std::collections::HashMap<String, (String, String)> =
std::collections::HashMap::new();
let mut total_usage = TokenUsage::default();
let mut success = true;
let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
loop {
let recv_result = if let Some(timeout) = timeout_duration {
let remaining = timeout.saturating_sub(start.elapsed());
if remaining.is_zero() {
timeout_cancel.cancel(); final_response = "Subagent timed out".to_string();
success = false;
break;
}
tokio::time::timeout(remaining, rx.recv()).await
} else {
Ok(rx.recv().await)
};
match recv_result {
Ok(Some(envelope)) => match envelope.event {
AgentEvent::Text {
message_id: _,
text,
} => {
final_response.push_str(&text);
}
AgentEvent::ToolCallStart {
id, name, input, ..
} => {
tool_count += 1;
let context = extract_tool_context(&name, &input);
pending_tools.insert(id, (name.clone(), context.clone()));
if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
let event = AgentEvent::SubagentProgress {
subagent_id: subagent_id.clone(),
subagent_name: self.config.name.clone(),
tool_name: name,
tool_context: context,
completed: false,
success: false,
tool_count,
total_tokens: u64::from(total_usage.input_tokens)
+ u64::from(total_usage.output_tokens),
};
let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
}
}
AgentEvent::ToolCallEnd {
id,
name,
display_name,
result,
} => {
let context = pending_tools
.remove(&id)
.map(|(_, ctx)| ctx)
.unwrap_or_default();
let result_summary = summarize_tool_result(&name, &result);
let tool_success = result.success;
tool_logs.push(ToolCallLog {
name: name.clone(),
display_name: display_name.clone(),
context: context.clone(),
result: result_summary,
success: tool_success,
duration_ms: result.duration_ms,
});
if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
let event = AgentEvent::SubagentProgress {
subagent_id: subagent_id.clone(),
subagent_name: self.config.name.clone(),
tool_name: name,
tool_context: context,
completed: true,
success: tool_success,
tool_count,
total_tokens: u64::from(total_usage.input_tokens)
+ u64::from(total_usage.output_tokens),
};
let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
}
}
AgentEvent::TurnComplete { turn, usage, .. } => {
total_turns = turn;
total_usage.add(&usage);
}
AgentEvent::Done {
total_turns: turns, ..
} => {
total_turns = turns;
break;
}
AgentEvent::Error { message, .. } => {
final_response = message;
success = false;
break;
}
_ => {}
},
Ok(None) => break,
Err(_) => {
timeout_cancel.cancel(); final_response = "Subagent timed out".to_string();
success = false;
break;
}
}
}
Ok(SubagentResult {
name: self.config.name.clone(),
final_response,
total_turns,
tool_count,
tool_logs,
usage: total_usage,
success,
duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
})
}
}
fn extract_tool_context(name: &str, input: &Value) -> String {
match name {
"read" => input
.get("file_path")
.or_else(|| input.get("path"))
.and_then(Value::as_str)
.unwrap_or("")
.to_string(),
"write" | "edit" => input
.get("file_path")
.or_else(|| input.get("path"))
.and_then(Value::as_str)
.unwrap_or("")
.to_string(),
"bash" => {
let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
if cmd.len() > 60 {
format!("{}...", crate::primitive_tools::truncate_str(cmd, 57))
} else {
cmd.to_string()
}
}
"glob" | "grep" => input
.get("pattern")
.and_then(Value::as_str)
.unwrap_or("")
.to_string(),
"web_search" => input
.get("query")
.and_then(Value::as_str)
.unwrap_or("")
.to_string(),
_ => String::new(),
}
}
fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
if !result.success {
let first_line = result.output.lines().next().unwrap_or("Error");
return if first_line.len() > 50 {
format!(
"{}...",
crate::primitive_tools::truncate_str(first_line, 47)
)
} else {
first_line.to_string()
};
}
match name {
"read" => {
let line_count = result.output.lines().count();
format!("{line_count} lines")
}
"write" => "wrote file".to_string(),
"edit" => "edited".to_string(),
"bash" => {
let lines: Vec<&str> = result.output.lines().collect();
if lines.is_empty() {
"done".to_string()
} else if lines.len() == 1 {
let line = lines[0];
if line.len() > 50 {
format!("{}...", crate::primitive_tools::truncate_str(line, 47))
} else {
line.to_string()
}
} else {
format!("{} lines", lines.len())
}
}
"glob" => {
let count = result.output.lines().count();
format!("{count} files")
}
"grep" => {
let count = result.output.lines().count();
format!("{count} matches")
}
_ => {
let line_count = result.output.lines().count();
if line_count == 0 {
"done".to_string()
} else {
format!("{line_count} lines")
}
}
}
}
impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
where
P: LlmProvider + Clone + 'static,
H: AgentHooks + Clone + 'static,
M: MessageStore + 'static,
S: StateStore + 'static,
{
type Name = DynamicToolName;
fn name(&self) -> DynamicToolName {
DynamicToolName::new(format!("subagent_{}", self.config.name))
}
fn display_name(&self) -> &'static str {
self.cached_display_name
}
fn description(&self) -> &'static str {
self.cached_description
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"task": {
"type": "string",
"description": "The task or question for the subagent to handle"
}
},
"required": ["task"]
})
}
fn tier(&self) -> ToolTier {
ToolTier::Confirm
}
async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
let task = input
.get("task")
.and_then(Value::as_str)
.context("Missing 'task' parameter")?;
let parent_tx = ctx.event_tx();
let parent_seq = ctx.event_seq();
let subagent_id = format!(
"{}_{:x}",
self.config.name,
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
);
let cancel_token = ctx.cancel_token().unwrap_or_default();
let result = self
.run_subagent(task, subagent_id, parent_tx, parent_seq, cancel_token)
.await?;
Ok(ToolResult {
success: result.success,
output: result.final_response.clone(),
data: Some(serde_json::to_value(&result).unwrap_or_default()),
documents: Vec::new(),
duration_ms: Some(result.duration_ms),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_subagent_config_builder() {
let config = SubagentConfig::new("test")
.with_system_prompt("Test prompt")
.with_max_turns(5)
.with_timeout_ms(30000);
assert_eq!(config.name, "test");
assert_eq!(config.system_prompt, "Test prompt");
assert_eq!(config.max_turns, Some(5));
assert_eq!(config.timeout_ms, Some(30000));
}
#[test]
fn test_subagent_config_defaults() {
let config = SubagentConfig::new("default");
assert_eq!(config.name, "default");
assert!(config.system_prompt.is_empty());
assert_eq!(config.max_turns, None);
assert_eq!(config.timeout_ms, None);
}
#[test]
fn test_subagent_result_serialization() {
let result = SubagentResult {
name: "test".to_string(),
final_response: "Done".to_string(),
total_turns: 3,
tool_count: 5,
tool_logs: vec![
ToolCallLog {
name: "read".to_string(),
display_name: "Read file".to_string(),
context: "/tmp/test.rs".to_string(),
result: "50 lines".to_string(),
success: true,
duration_ms: Some(10),
},
ToolCallLog {
name: "grep".to_string(),
display_name: "Grep TODO".to_string(),
context: "TODO".to_string(),
result: "3 matches".to_string(),
success: true,
duration_ms: Some(5),
},
],
usage: TokenUsage::default(),
success: true,
duration_ms: 1000,
};
let json = serde_json::to_string(&result).expect("serialize");
assert!(json.contains("test"));
assert!(json.contains("Done"));
assert!(json.contains("tool_count"));
assert!(json.contains("tool_logs"));
assert!(json.contains("/tmp/test.rs"));
}
#[test]
fn test_subagent_result_field_extraction() {
let result = SubagentResult {
name: "explore".to_string(),
final_response: "Found 3 config files".to_string(),
total_turns: 2,
tool_count: 5,
tool_logs: vec![ToolCallLog {
name: "glob".to_string(),
display_name: "Glob config files".to_string(),
context: "**/*.toml".to_string(),
result: "3 files".to_string(),
success: true,
duration_ms: Some(15),
}],
usage: TokenUsage {
input_tokens: 1500,
output_tokens: 500,
},
success: true,
duration_ms: 2500,
};
let value = serde_json::to_value(&result).expect("serialize to value");
let tool_count = value.get("tool_count").and_then(Value::as_u64);
assert_eq!(tool_count, Some(5));
let usage = value.get("usage").expect("usage field");
let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
assert_eq!(input_tokens, Some(1500));
assert_eq!(output_tokens, Some(500));
let tool_logs = value.get("tool_logs").and_then(Value::as_array);
assert!(tool_logs.is_some());
let logs = tool_logs.unwrap();
assert_eq!(logs.len(), 1);
let first_log = &logs[0];
assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
assert_eq!(
first_log.get("context").and_then(Value::as_str),
Some("**/*.toml")
);
assert_eq!(
first_log.get("result").and_then(Value::as_str),
Some("3 files")
);
assert_eq!(
first_log.get("success").and_then(Value::as_bool),
Some(true)
);
}
}