1mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56 pub name: String,
58 pub system_prompt: String,
60 pub max_turns: Option<usize>,
62 pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67 #[must_use]
69 pub fn new(name: impl Into<String>) -> Self {
70 Self {
71 name: name.into(),
72 system_prompt: String::new(),
73 max_turns: None,
74 timeout_ms: None,
75 }
76 }
77
78 #[must_use]
80 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81 self.system_prompt = prompt.into();
82 self
83 }
84
85 #[must_use]
87 pub const fn with_max_turns(mut self, max: usize) -> Self {
88 self.max_turns = Some(max);
89 self
90 }
91
92 #[must_use]
94 pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95 self.timeout_ms = Some(timeout);
96 self
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103 pub name: String,
105 pub display_name: String,
107 pub context: String,
109 pub result: String,
111 pub success: bool,
113 pub duration_ms: Option<u64>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize)]
119pub struct SubagentResult {
120 pub name: String,
122 pub final_response: String,
124 pub total_turns: usize,
126 pub tool_count: u32,
128 pub tool_logs: Vec<ToolCallLog>,
130 pub usage: TokenUsage,
132 pub success: bool,
134 pub duration_ms: u64,
136}
137
138pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
154where
155 P: LlmProvider,
156 H: AgentHooks,
157 M: MessageStore,
158 S: StateStore,
159{
160 config: SubagentConfig,
161 provider: Arc<P>,
162 tools: Arc<ToolRegistry<()>>,
163 hooks: Arc<H>,
164 message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
165 state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
166 cached_display_name: &'static str,
168 cached_description: &'static str,
170}
171
172impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
173where
174 P: LlmProvider + 'static,
175{
176 #[must_use]
178 pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
179 let cached_display_name = Box::leak(format!("Subagent: {}", config.name).into_boxed_str());
181 let cached_description = Box::leak(
182 format!(
183 "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
184 config.name
185 )
186 .into_boxed_str(),
187 );
188 Self {
189 config,
190 provider,
191 tools,
192 hooks: Arc::new(DefaultHooks),
193 message_store_factory: Arc::new(InMemoryStore::new),
194 state_store_factory: Arc::new(InMemoryStore::new),
195 cached_display_name,
196 cached_description,
197 }
198 }
199}
200
201impl<P, H, M, S> SubagentTool<P, H, M, S>
202where
203 P: LlmProvider + Clone + 'static,
204 H: AgentHooks + Clone + 'static,
205 M: MessageStore + 'static,
206 S: StateStore + 'static,
207{
208 #[must_use]
210 pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
211 self,
212 hooks: Arc<H2>,
213 ) -> SubagentTool<P, H2, M, S> {
214 SubagentTool {
215 config: self.config,
216 provider: self.provider,
217 tools: self.tools,
218 hooks,
219 message_store_factory: self.message_store_factory,
220 state_store_factory: self.state_store_factory,
221 cached_display_name: self.cached_display_name,
222 cached_description: self.cached_description,
223 }
224 }
225
226 #[must_use]
228 pub fn with_stores<M2, S2, MF, SF>(
229 self,
230 message_factory: MF,
231 state_factory: SF,
232 ) -> SubagentTool<P, H, M2, S2>
233 where
234 M2: MessageStore + 'static,
235 S2: StateStore + 'static,
236 MF: Fn() -> M2 + Send + Sync + 'static,
237 SF: Fn() -> S2 + Send + Sync + 'static,
238 {
239 SubagentTool {
240 config: self.config,
241 provider: self.provider,
242 tools: self.tools,
243 hooks: self.hooks,
244 message_store_factory: Arc::new(message_factory),
245 state_store_factory: Arc::new(state_factory),
246 cached_display_name: self.cached_display_name,
247 cached_description: self.cached_description,
248 }
249 }
250
251 #[must_use]
253 pub const fn config(&self) -> &SubagentConfig {
254 &self.config
255 }
256
257 #[allow(clippy::too_many_lines)]
265 async fn run_subagent(
266 &self,
267 task: &str,
268 subagent_id: String,
269 parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
270 parent_seq: Option<SequenceCounter>,
271 parent_cancel: CancellationToken,
272 ) -> Result<SubagentResult> {
273 use crate::agent_loop::AgentLoop;
274
275 let start = Instant::now();
276 let thread_id = ThreadId::new();
277
278 let message_store = (self.message_store_factory)();
280 let state_store = (self.state_store_factory)();
281
282 let agent_config = AgentConfig {
284 max_turns: Some(self.config.max_turns.unwrap_or(100)),
285 system_prompt: self.config.system_prompt.clone(),
286 ..Default::default()
287 };
288
289 let agent = AgentLoop::new(
291 (*self.provider).clone(),
292 (*self.tools).clone(),
293 (*self.hooks).clone(),
294 message_store,
295 state_store,
296 agent_config,
297 );
298
299 let tool_ctx = ToolContext::new(());
301
302 let cancel_token = parent_cancel.child_token();
304 let timeout_cancel = cancel_token.clone();
305 let (mut rx, _final_state) = agent.run(
306 thread_id,
307 AgentInput::Text(task.to_string()),
308 tool_ctx,
309 cancel_token,
310 );
311
312 let mut final_response = String::new();
313 let mut total_turns = 0;
314 let mut tool_count = 0u32;
315 let mut tool_logs: Vec<ToolCallLog> = Vec::new();
316 let mut pending_tools: std::collections::HashMap<String, (String, String)> =
317 std::collections::HashMap::new();
318 let mut total_usage = TokenUsage::default();
319 let mut success = true;
320
321 let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
322
323 loop {
324 let recv_result = if let Some(timeout) = timeout_duration {
325 let remaining = timeout.saturating_sub(start.elapsed());
326 if remaining.is_zero() {
327 timeout_cancel.cancel(); final_response = "Subagent timed out".to_string();
329 success = false;
330 break;
331 }
332 tokio::time::timeout(remaining, rx.recv()).await
333 } else {
334 Ok(rx.recv().await)
335 };
336
337 match recv_result {
338 Ok(Some(envelope)) => match envelope.event {
339 AgentEvent::Text {
340 message_id: _,
341 text,
342 } => {
343 final_response.push_str(&text);
344 }
345 AgentEvent::ToolCallStart {
346 id, name, input, ..
347 } => {
348 tool_count += 1;
350 let context = extract_tool_context(&name, &input);
351 pending_tools.insert(id, (name.clone(), context.clone()));
352
353 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
355 let event = AgentEvent::SubagentProgress {
356 subagent_id: subagent_id.clone(),
357 subagent_name: self.config.name.clone(),
358 tool_name: name,
359 tool_context: context,
360 completed: false,
361 success: false,
362 tool_count,
363 total_tokens: u64::from(total_usage.input_tokens)
364 + u64::from(total_usage.output_tokens),
365 };
366 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
367 }
368 }
369 AgentEvent::ToolCallEnd {
370 id,
371 name,
372 display_name,
373 result,
374 } => {
375 let context = pending_tools
377 .remove(&id)
378 .map(|(_, ctx)| ctx)
379 .unwrap_or_default();
380 let result_summary = summarize_tool_result(&name, &result);
381 let tool_success = result.success;
382 tool_logs.push(ToolCallLog {
383 name: name.clone(),
384 display_name: display_name.clone(),
385 context: context.clone(),
386 result: result_summary,
387 success: tool_success,
388 duration_ms: result.duration_ms,
389 });
390
391 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
393 let event = AgentEvent::SubagentProgress {
394 subagent_id: subagent_id.clone(),
395 subagent_name: self.config.name.clone(),
396 tool_name: name,
397 tool_context: context,
398 completed: true,
399 success: tool_success,
400 tool_count,
401 total_tokens: u64::from(total_usage.input_tokens)
402 + u64::from(total_usage.output_tokens),
403 };
404 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
405 }
406 }
407 AgentEvent::TurnComplete { turn, usage, .. } => {
408 total_turns = turn;
409 total_usage.add(&usage);
410 }
411 AgentEvent::Done {
412 total_turns: turns, ..
413 } => {
414 total_turns = turns;
415 break;
416 }
417 AgentEvent::Error { message, .. } => {
418 final_response = message;
419 success = false;
420 break;
421 }
422 _ => {}
423 },
424 Ok(None) => break,
425 Err(_) => {
426 timeout_cancel.cancel(); final_response = "Subagent timed out".to_string();
428 success = false;
429 break;
430 }
431 }
432 }
433
434 Ok(SubagentResult {
435 name: self.config.name.clone(),
436 final_response,
437 total_turns,
438 tool_count,
439 tool_logs,
440 usage: total_usage,
441 success,
442 duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
443 })
444 }
445}
446
447fn extract_tool_context(name: &str, input: &Value) -> String {
449 match name {
450 "read" => input
451 .get("file_path")
452 .or_else(|| input.get("path"))
453 .and_then(Value::as_str)
454 .unwrap_or("")
455 .to_string(),
456 "write" | "edit" => input
457 .get("file_path")
458 .or_else(|| input.get("path"))
459 .and_then(Value::as_str)
460 .unwrap_or("")
461 .to_string(),
462 "bash" => {
463 let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
464 if cmd.len() > 60 {
466 format!("{}...", crate::primitive_tools::truncate_str(cmd, 57))
467 } else {
468 cmd.to_string()
469 }
470 }
471 "glob" | "grep" => input
472 .get("pattern")
473 .and_then(Value::as_str)
474 .unwrap_or("")
475 .to_string(),
476 "web_search" => input
477 .get("query")
478 .and_then(Value::as_str)
479 .unwrap_or("")
480 .to_string(),
481 _ => String::new(),
482 }
483}
484
485fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
487 if !result.success {
488 let first_line = result.output.lines().next().unwrap_or("Error");
489 return if first_line.len() > 50 {
490 format!(
491 "{}...",
492 crate::primitive_tools::truncate_str(first_line, 47)
493 )
494 } else {
495 first_line.to_string()
496 };
497 }
498
499 match name {
500 "read" => {
501 let line_count = result.output.lines().count();
502 format!("{line_count} lines")
503 }
504 "write" => "wrote file".to_string(),
505 "edit" => "edited".to_string(),
506 "bash" => {
507 let lines: Vec<&str> = result.output.lines().collect();
508 if lines.is_empty() {
509 "done".to_string()
510 } else if lines.len() == 1 {
511 let line = lines[0];
512 if line.len() > 50 {
513 format!("{}...", crate::primitive_tools::truncate_str(line, 47))
514 } else {
515 line.to_string()
516 }
517 } else {
518 format!("{} lines", lines.len())
519 }
520 }
521 "glob" => {
522 let count = result.output.lines().count();
523 format!("{count} files")
524 }
525 "grep" => {
526 let count = result.output.lines().count();
527 format!("{count} matches")
528 }
529 _ => {
530 let line_count = result.output.lines().count();
531 if line_count == 0 {
532 "done".to_string()
533 } else {
534 format!("{line_count} lines")
535 }
536 }
537 }
538}
539
540impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
541where
542 P: LlmProvider + Clone + 'static,
543 H: AgentHooks + Clone + 'static,
544 M: MessageStore + 'static,
545 S: StateStore + 'static,
546{
547 type Name = DynamicToolName;
548
549 fn name(&self) -> DynamicToolName {
550 DynamicToolName::new(format!("subagent_{}", self.config.name))
551 }
552
553 fn display_name(&self) -> &'static str {
554 self.cached_display_name
555 }
556
557 fn description(&self) -> &'static str {
558 self.cached_description
559 }
560
561 fn input_schema(&self) -> Value {
562 json!({
563 "type": "object",
564 "properties": {
565 "task": {
566 "type": "string",
567 "description": "The task or question for the subagent to handle"
568 }
569 },
570 "required": ["task"]
571 })
572 }
573
574 fn tier(&self) -> ToolTier {
575 ToolTier::Confirm
577 }
578
579 async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
580 let task = input
581 .get("task")
582 .and_then(Value::as_str)
583 .context("Missing 'task' parameter")?;
584
585 let parent_tx = ctx.event_tx();
587 let parent_seq = ctx.event_seq();
588
589 let subagent_id = format!(
591 "{}_{:x}",
592 self.config.name,
593 std::time::SystemTime::now()
594 .duration_since(std::time::UNIX_EPOCH)
595 .unwrap_or_default()
596 .as_nanos()
597 );
598
599 let cancel_token = ctx.cancel_token().unwrap_or_default();
602
603 let result = self
604 .run_subagent(task, subagent_id, parent_tx, parent_seq, cancel_token)
605 .await?;
606
607 Ok(ToolResult {
608 success: result.success,
609 output: result.final_response.clone(),
610 data: Some(serde_json::to_value(&result).unwrap_or_default()),
611 documents: Vec::new(),
612 duration_ms: Some(result.duration_ms),
613 })
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 #[test]
622 fn test_subagent_config_builder() {
623 let config = SubagentConfig::new("test")
624 .with_system_prompt("Test prompt")
625 .with_max_turns(5)
626 .with_timeout_ms(30000);
627
628 assert_eq!(config.name, "test");
629 assert_eq!(config.system_prompt, "Test prompt");
630 assert_eq!(config.max_turns, Some(5));
631 assert_eq!(config.timeout_ms, Some(30000));
632 }
633
634 #[test]
635 fn test_subagent_config_defaults() {
636 let config = SubagentConfig::new("default");
637
638 assert_eq!(config.name, "default");
639 assert!(config.system_prompt.is_empty());
640 assert_eq!(config.max_turns, None);
641 assert_eq!(config.timeout_ms, None);
642 }
643
644 #[test]
645 fn test_subagent_result_serialization() {
646 let result = SubagentResult {
647 name: "test".to_string(),
648 final_response: "Done".to_string(),
649 total_turns: 3,
650 tool_count: 5,
651 tool_logs: vec![
652 ToolCallLog {
653 name: "read".to_string(),
654 display_name: "Read file".to_string(),
655 context: "/tmp/test.rs".to_string(),
656 result: "50 lines".to_string(),
657 success: true,
658 duration_ms: Some(10),
659 },
660 ToolCallLog {
661 name: "grep".to_string(),
662 display_name: "Grep TODO".to_string(),
663 context: "TODO".to_string(),
664 result: "3 matches".to_string(),
665 success: true,
666 duration_ms: Some(5),
667 },
668 ],
669 usage: TokenUsage::default(),
670 success: true,
671 duration_ms: 1000,
672 };
673
674 let json = serde_json::to_string(&result).expect("serialize");
675 assert!(json.contains("test"));
676 assert!(json.contains("Done"));
677 assert!(json.contains("tool_count"));
678 assert!(json.contains("tool_logs"));
679 assert!(json.contains("/tmp/test.rs"));
680 }
681
682 #[test]
683 fn test_subagent_result_field_extraction() {
684 let result = SubagentResult {
686 name: "explore".to_string(),
687 final_response: "Found 3 config files".to_string(),
688 total_turns: 2,
689 tool_count: 5,
690 tool_logs: vec![ToolCallLog {
691 name: "glob".to_string(),
692 display_name: "Glob config files".to_string(),
693 context: "**/*.toml".to_string(),
694 result: "3 files".to_string(),
695 success: true,
696 duration_ms: Some(15),
697 }],
698 usage: TokenUsage {
699 input_tokens: 1500,
700 output_tokens: 500,
701 },
702 success: true,
703 duration_ms: 2500,
704 };
705
706 let value = serde_json::to_value(&result).expect("serialize to value");
707
708 let tool_count = value.get("tool_count").and_then(Value::as_u64);
710 assert_eq!(tool_count, Some(5));
711
712 let usage = value.get("usage").expect("usage field");
714 let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
715 let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
716 assert_eq!(input_tokens, Some(1500));
717 assert_eq!(output_tokens, Some(500));
718
719 let tool_logs = value.get("tool_logs").and_then(Value::as_array);
721 assert!(tool_logs.is_some());
722 let logs = tool_logs.unwrap();
723 assert_eq!(logs.len(), 1);
724
725 let first_log = &logs[0];
726 assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
727 assert_eq!(
728 first_log.get("context").and_then(Value::as_str),
729 Some("**/*.toml")
730 );
731 assert_eq!(
732 first_log.get("result").and_then(Value::as_str),
733 Some("3 files")
734 );
735 assert_eq!(
736 first_log.get("success").and_then(Value::as_bool),
737 Some(true)
738 );
739 }
740}