1mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::AgentEvent;
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use async_trait::async_trait;
47use serde::{Deserialize, Serialize};
48use serde_json::{Value, json};
49use std::sync::Arc;
50use std::time::{Duration, Instant};
51use tokio::sync::mpsc;
52
53#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56 pub name: String,
58 pub system_prompt: String,
60 pub max_turns: usize,
62 pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67 #[must_use]
69 pub fn new(name: impl Into<String>) -> Self {
70 Self {
71 name: name.into(),
72 system_prompt: String::new(),
73 max_turns: 10,
74 timeout_ms: None,
75 }
76 }
77
78 #[must_use]
80 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81 self.system_prompt = prompt.into();
82 self
83 }
84
85 #[must_use]
87 pub const fn with_max_turns(mut self, max: usize) -> Self {
88 self.max_turns = max;
89 self
90 }
91
92 #[must_use]
94 pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95 self.timeout_ms = Some(timeout);
96 self
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103 pub name: String,
105 pub context: String,
107 pub result: String,
109 pub success: bool,
111 pub duration_ms: Option<u64>,
113}
114
115#[derive(Clone, Debug, Serialize, Deserialize)]
117pub struct SubagentResult {
118 pub name: String,
120 pub final_response: String,
122 pub total_turns: usize,
124 pub tool_count: u32,
126 pub tool_logs: Vec<ToolCallLog>,
128 pub usage: TokenUsage,
130 pub success: bool,
132 pub duration_ms: u64,
134}
135
136pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
152where
153 P: LlmProvider,
154 H: AgentHooks,
155 M: MessageStore,
156 S: StateStore,
157{
158 config: SubagentConfig,
159 provider: Arc<P>,
160 tools: Arc<ToolRegistry<()>>,
161 hooks: Arc<H>,
162 message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
163 state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
164}
165
166impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
167where
168 P: LlmProvider + 'static,
169{
170 #[must_use]
172 pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
173 Self {
174 config,
175 provider,
176 tools,
177 hooks: Arc::new(DefaultHooks),
178 message_store_factory: Arc::new(InMemoryStore::new),
179 state_store_factory: Arc::new(InMemoryStore::new),
180 }
181 }
182}
183
184impl<P, H, M, S> SubagentTool<P, H, M, S>
185where
186 P: LlmProvider + Clone + 'static,
187 H: AgentHooks + Clone + 'static,
188 M: MessageStore + 'static,
189 S: StateStore + 'static,
190{
191 #[must_use]
193 pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
194 self,
195 hooks: Arc<H2>,
196 ) -> SubagentTool<P, H2, M, S> {
197 SubagentTool {
198 config: self.config,
199 provider: self.provider,
200 tools: self.tools,
201 hooks,
202 message_store_factory: self.message_store_factory,
203 state_store_factory: self.state_store_factory,
204 }
205 }
206
207 #[must_use]
209 pub fn with_stores<M2, S2, MF, SF>(
210 self,
211 message_factory: MF,
212 state_factory: SF,
213 ) -> SubagentTool<P, H, M2, S2>
214 where
215 M2: MessageStore + 'static,
216 S2: StateStore + 'static,
217 MF: Fn() -> M2 + Send + Sync + 'static,
218 SF: Fn() -> S2 + Send + Sync + 'static,
219 {
220 SubagentTool {
221 config: self.config,
222 provider: self.provider,
223 tools: self.tools,
224 hooks: self.hooks,
225 message_store_factory: Arc::new(message_factory),
226 state_store_factory: Arc::new(state_factory),
227 }
228 }
229
230 #[must_use]
232 pub const fn config(&self) -> &SubagentConfig {
233 &self.config
234 }
235
236 #[allow(clippy::too_many_lines)]
241 async fn run_subagent(
242 &self,
243 task: &str,
244 subagent_id: String,
245 parent_tx: Option<mpsc::Sender<AgentEvent>>,
246 ) -> Result<SubagentResult> {
247 use crate::agent_loop::AgentLoop;
248
249 let start = Instant::now();
250 let thread_id = ThreadId::new();
251
252 let message_store = (self.message_store_factory)();
254 let state_store = (self.state_store_factory)();
255
256 let agent_config = AgentConfig {
258 max_turns: self.config.max_turns,
259 system_prompt: self.config.system_prompt.clone(),
260 ..Default::default()
261 };
262
263 let agent = AgentLoop::new(
265 (*self.provider).clone(),
266 (*self.tools).clone(),
267 (*self.hooks).clone(),
268 message_store,
269 state_store,
270 agent_config,
271 );
272
273 let tool_ctx = ToolContext::new(());
275
276 let mut rx = agent.run(thread_id, task.to_string(), tool_ctx);
278
279 let mut final_response = String::new();
280 let mut total_turns = 0;
281 let mut tool_count = 0u32;
282 let mut tool_logs: Vec<ToolCallLog> = Vec::new();
283 let mut pending_tools: std::collections::HashMap<String, (String, String)> =
284 std::collections::HashMap::new();
285 let mut total_usage = TokenUsage::default();
286 let mut success = true;
287
288 let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
289
290 loop {
291 let recv_result = if let Some(timeout) = timeout_duration {
292 let remaining = timeout.saturating_sub(start.elapsed());
293 if remaining.is_zero() {
294 final_response = "Subagent timed out".to_string();
295 success = false;
296 break;
297 }
298 tokio::time::timeout(remaining, rx.recv()).await
299 } else {
300 Ok(rx.recv().await)
301 };
302
303 match recv_result {
304 Ok(Some(event)) => match event {
305 AgentEvent::Text { text } => {
306 final_response.push_str(&text);
307 }
308 AgentEvent::ToolCallStart {
309 id, name, input, ..
310 } => {
311 tool_count += 1;
313 let context = extract_tool_context(&name, &input);
314 pending_tools.insert(id, (name.clone(), context.clone()));
315
316 if let Some(ref tx) = parent_tx {
318 let _ = tx
319 .send(AgentEvent::SubagentProgress {
320 subagent_id: subagent_id.clone(),
321 subagent_name: self.config.name.clone(),
322 tool_name: name,
323 tool_context: context,
324 completed: false,
325 success: false,
326 tool_count,
327 total_tokens: u64::from(total_usage.input_tokens)
328 + u64::from(total_usage.output_tokens),
329 })
330 .await;
331 }
332 }
333 AgentEvent::ToolCallEnd { id, name, result } => {
334 let context = pending_tools
336 .remove(&id)
337 .map(|(_, ctx)| ctx)
338 .unwrap_or_default();
339 let result_summary = summarize_tool_result(&name, &result);
340 let tool_success = result.success;
341 tool_logs.push(ToolCallLog {
342 name: name.clone(),
343 context: context.clone(),
344 result: result_summary,
345 success: tool_success,
346 duration_ms: result.duration_ms,
347 });
348
349 if let Some(ref tx) = parent_tx {
351 let _ = tx
352 .send(AgentEvent::SubagentProgress {
353 subagent_id: subagent_id.clone(),
354 subagent_name: self.config.name.clone(),
355 tool_name: name,
356 tool_context: context,
357 completed: true,
358 success: tool_success,
359 tool_count,
360 total_tokens: u64::from(total_usage.input_tokens)
361 + u64::from(total_usage.output_tokens),
362 })
363 .await;
364 }
365 }
366 AgentEvent::TurnComplete { turn, usage, .. } => {
367 total_turns = turn;
368 total_usage.add(&usage);
369 }
370 AgentEvent::Done {
371 total_turns: turns, ..
372 } => {
373 total_turns = turns;
374 break;
375 }
376 AgentEvent::Error { message, .. } => {
377 final_response = message;
378 success = false;
379 break;
380 }
381 _ => {}
382 },
383 Ok(None) => break,
384 Err(_) => {
385 final_response = "Subagent timed out".to_string();
386 success = false;
387 break;
388 }
389 }
390 }
391
392 Ok(SubagentResult {
393 name: self.config.name.clone(),
394 final_response,
395 total_turns,
396 tool_count,
397 tool_logs,
398 usage: total_usage,
399 success,
400 duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
401 })
402 }
403}
404
405fn extract_tool_context(name: &str, input: &Value) -> String {
407 match name {
408 "read" => input
409 .get("file_path")
410 .or_else(|| input.get("path"))
411 .and_then(Value::as_str)
412 .unwrap_or("")
413 .to_string(),
414 "write" | "edit" => input
415 .get("file_path")
416 .or_else(|| input.get("path"))
417 .and_then(Value::as_str)
418 .unwrap_or("")
419 .to_string(),
420 "bash" => {
421 let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
422 if cmd.len() > 60 {
424 format!("{}...", &cmd[..57])
425 } else {
426 cmd.to_string()
427 }
428 }
429 "glob" | "grep" => input
430 .get("pattern")
431 .and_then(Value::as_str)
432 .unwrap_or("")
433 .to_string(),
434 "web_search" => input
435 .get("query")
436 .and_then(Value::as_str)
437 .unwrap_or("")
438 .to_string(),
439 _ => String::new(),
440 }
441}
442
443fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
445 if !result.success {
446 let first_line = result.output.lines().next().unwrap_or("Error");
447 return if first_line.len() > 50 {
448 format!("{}...", &first_line[..47])
449 } else {
450 first_line.to_string()
451 };
452 }
453
454 match name {
455 "read" => {
456 let line_count = result.output.lines().count();
457 format!("{line_count} lines")
458 }
459 "write" => "wrote file".to_string(),
460 "edit" => "edited".to_string(),
461 "bash" => {
462 let lines: Vec<&str> = result.output.lines().collect();
463 if lines.is_empty() {
464 "done".to_string()
465 } else if lines.len() == 1 {
466 let line = lines[0];
467 if line.len() > 50 {
468 format!("{}...", &line[..47])
469 } else {
470 line.to_string()
471 }
472 } else {
473 format!("{} lines", lines.len())
474 }
475 }
476 "glob" => {
477 let count = result.output.lines().count();
478 format!("{count} files")
479 }
480 "grep" => {
481 let count = result.output.lines().count();
482 format!("{count} matches")
483 }
484 _ => {
485 let line_count = result.output.lines().count();
486 if line_count == 0 {
487 "done".to_string()
488 } else {
489 format!("{line_count} lines")
490 }
491 }
492 }
493}
494
495#[async_trait]
496impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
497where
498 P: LlmProvider + Clone + 'static,
499 H: AgentHooks + Clone + 'static,
500 M: MessageStore + 'static,
501 S: StateStore + 'static,
502{
503 fn name(&self) -> &'static str {
504 Box::leak(format!("subagent_{}", self.config.name).into_boxed_str())
506 }
507
508 fn description(&self) -> &'static str {
509 Box::leak(
510 format!(
511 "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
512 self.config.name
513 )
514 .into_boxed_str(),
515 )
516 }
517
518 fn input_schema(&self) -> Value {
519 json!({
520 "type": "object",
521 "properties": {
522 "task": {
523 "type": "string",
524 "description": "The task or question for the subagent to handle"
525 }
526 },
527 "required": ["task"]
528 })
529 }
530
531 fn tier(&self) -> ToolTier {
532 ToolTier::Confirm
534 }
535
536 async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
537 let task = input
538 .get("task")
539 .and_then(Value::as_str)
540 .context("Missing 'task' parameter")?;
541
542 let parent_tx = ctx.event_tx();
544
545 let subagent_id = format!(
547 "{}_{:x}",
548 self.config.name,
549 std::time::SystemTime::now()
550 .duration_since(std::time::UNIX_EPOCH)
551 .unwrap_or_default()
552 .as_nanos()
553 );
554
555 let result = self.run_subagent(task, subagent_id, parent_tx).await?;
556
557 Ok(ToolResult {
558 success: result.success,
559 output: result.final_response.clone(),
560 data: Some(serde_json::to_value(&result).unwrap_or_default()),
561 duration_ms: Some(result.duration_ms),
562 })
563 }
564}
565
566#[cfg(test)]
567mod tests {
568 use super::*;
569
570 #[test]
571 fn test_subagent_config_builder() {
572 let config = SubagentConfig::new("test")
573 .with_system_prompt("Test prompt")
574 .with_max_turns(5)
575 .with_timeout_ms(30000);
576
577 assert_eq!(config.name, "test");
578 assert_eq!(config.system_prompt, "Test prompt");
579 assert_eq!(config.max_turns, 5);
580 assert_eq!(config.timeout_ms, Some(30000));
581 }
582
583 #[test]
584 fn test_subagent_config_defaults() {
585 let config = SubagentConfig::new("default");
586
587 assert_eq!(config.name, "default");
588 assert!(config.system_prompt.is_empty());
589 assert_eq!(config.max_turns, 10);
590 assert_eq!(config.timeout_ms, None);
591 }
592
593 #[test]
594 fn test_subagent_result_serialization() {
595 let result = SubagentResult {
596 name: "test".to_string(),
597 final_response: "Done".to_string(),
598 total_turns: 3,
599 tool_count: 5,
600 tool_logs: vec![
601 ToolCallLog {
602 name: "read".to_string(),
603 context: "/tmp/test.rs".to_string(),
604 result: "50 lines".to_string(),
605 success: true,
606 duration_ms: Some(10),
607 },
608 ToolCallLog {
609 name: "grep".to_string(),
610 context: "TODO".to_string(),
611 result: "3 matches".to_string(),
612 success: true,
613 duration_ms: Some(5),
614 },
615 ],
616 usage: TokenUsage::default(),
617 success: true,
618 duration_ms: 1000,
619 };
620
621 let json = serde_json::to_string(&result).expect("serialize");
622 assert!(json.contains("test"));
623 assert!(json.contains("Done"));
624 assert!(json.contains("tool_count"));
625 assert!(json.contains("tool_logs"));
626 assert!(json.contains("/tmp/test.rs"));
627 }
628
629 #[test]
630 fn test_subagent_result_field_extraction() {
631 let result = SubagentResult {
633 name: "explore".to_string(),
634 final_response: "Found 3 config files".to_string(),
635 total_turns: 2,
636 tool_count: 5,
637 tool_logs: vec![ToolCallLog {
638 name: "glob".to_string(),
639 context: "**/*.toml".to_string(),
640 result: "3 files".to_string(),
641 success: true,
642 duration_ms: Some(15),
643 }],
644 usage: TokenUsage {
645 input_tokens: 1500,
646 output_tokens: 500,
647 },
648 success: true,
649 duration_ms: 2500,
650 };
651
652 let value = serde_json::to_value(&result).expect("serialize to value");
653
654 let tool_count = value.get("tool_count").and_then(Value::as_u64);
656 assert_eq!(tool_count, Some(5));
657
658 let usage = value.get("usage").expect("usage field");
660 let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
661 let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
662 assert_eq!(input_tokens, Some(1500));
663 assert_eq!(output_tokens, Some(500));
664
665 let tool_logs = value.get("tool_logs").and_then(Value::as_array);
667 assert!(tool_logs.is_some());
668 let logs = tool_logs.unwrap();
669 assert_eq!(logs.len(), 1);
670
671 let first_log = &logs[0];
672 assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
673 assert_eq!(
674 first_log.get("context").and_then(Value::as_str),
675 Some("**/*.toml")
676 );
677 assert_eq!(
678 first_log.get("result").and_then(Value::as_str),
679 Some("3 files")
680 );
681 assert_eq!(
682 first_log.get("success").and_then(Value::as_bool),
683 Some(true)
684 );
685 }
686}