1mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51
52#[derive(Clone, Debug, Serialize, Deserialize)]
54pub struct SubagentConfig {
55 pub name: String,
57 pub system_prompt: String,
59 pub max_turns: usize,
61 pub timeout_ms: Option<u64>,
63}
64
65impl SubagentConfig {
66 #[must_use]
68 pub fn new(name: impl Into<String>) -> Self {
69 Self {
70 name: name.into(),
71 system_prompt: String::new(),
72 max_turns: 10,
73 timeout_ms: None,
74 }
75 }
76
77 #[must_use]
79 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
80 self.system_prompt = prompt.into();
81 self
82 }
83
84 #[must_use]
86 pub const fn with_max_turns(mut self, max: usize) -> Self {
87 self.max_turns = max;
88 self
89 }
90
91 #[must_use]
93 pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
94 self.timeout_ms = Some(timeout);
95 self
96 }
97}
98
99#[derive(Clone, Debug, Serialize, Deserialize)]
101pub struct ToolCallLog {
102 pub name: String,
104 pub display_name: String,
106 pub context: String,
108 pub result: String,
110 pub success: bool,
112 pub duration_ms: Option<u64>,
114}
115
116#[derive(Clone, Debug, Serialize, Deserialize)]
118pub struct SubagentResult {
119 pub name: String,
121 pub final_response: String,
123 pub total_turns: usize,
125 pub tool_count: u32,
127 pub tool_logs: Vec<ToolCallLog>,
129 pub usage: TokenUsage,
131 pub success: bool,
133 pub duration_ms: u64,
135}
136
137pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
153where
154 P: LlmProvider,
155 H: AgentHooks,
156 M: MessageStore,
157 S: StateStore,
158{
159 config: SubagentConfig,
160 provider: Arc<P>,
161 tools: Arc<ToolRegistry<()>>,
162 hooks: Arc<H>,
163 message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
164 state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
165}
166
167impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
168where
169 P: LlmProvider + 'static,
170{
171 #[must_use]
173 pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
174 Self {
175 config,
176 provider,
177 tools,
178 hooks: Arc::new(DefaultHooks),
179 message_store_factory: Arc::new(InMemoryStore::new),
180 state_store_factory: Arc::new(InMemoryStore::new),
181 }
182 }
183}
184
185impl<P, H, M, S> SubagentTool<P, H, M, S>
186where
187 P: LlmProvider + Clone + 'static,
188 H: AgentHooks + Clone + 'static,
189 M: MessageStore + 'static,
190 S: StateStore + 'static,
191{
192 #[must_use]
194 pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
195 self,
196 hooks: Arc<H2>,
197 ) -> SubagentTool<P, H2, M, S> {
198 SubagentTool {
199 config: self.config,
200 provider: self.provider,
201 tools: self.tools,
202 hooks,
203 message_store_factory: self.message_store_factory,
204 state_store_factory: self.state_store_factory,
205 }
206 }
207
208 #[must_use]
210 pub fn with_stores<M2, S2, MF, SF>(
211 self,
212 message_factory: MF,
213 state_factory: SF,
214 ) -> SubagentTool<P, H, M2, S2>
215 where
216 M2: MessageStore + 'static,
217 S2: StateStore + 'static,
218 MF: Fn() -> M2 + Send + Sync + 'static,
219 SF: Fn() -> S2 + Send + Sync + 'static,
220 {
221 SubagentTool {
222 config: self.config,
223 provider: self.provider,
224 tools: self.tools,
225 hooks: self.hooks,
226 message_store_factory: Arc::new(message_factory),
227 state_store_factory: Arc::new(state_factory),
228 }
229 }
230
231 #[must_use]
233 pub const fn config(&self) -> &SubagentConfig {
234 &self.config
235 }
236
237 #[allow(clippy::too_many_lines)]
242 async fn run_subagent(
243 &self,
244 task: &str,
245 subagent_id: String,
246 parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
247 parent_seq: Option<SequenceCounter>,
248 ) -> Result<SubagentResult> {
249 use crate::agent_loop::AgentLoop;
250
251 let start = Instant::now();
252 let thread_id = ThreadId::new();
253
254 let message_store = (self.message_store_factory)();
256 let state_store = (self.state_store_factory)();
257
258 let agent_config = AgentConfig {
260 max_turns: self.config.max_turns,
261 system_prompt: self.config.system_prompt.clone(),
262 ..Default::default()
263 };
264
265 let agent = AgentLoop::new(
267 (*self.provider).clone(),
268 (*self.tools).clone(),
269 (*self.hooks).clone(),
270 message_store,
271 state_store,
272 agent_config,
273 );
274
275 let tool_ctx = ToolContext::new(());
277
278 let (mut rx, _final_state) =
280 agent.run(thread_id, AgentInput::Text(task.to_string()), tool_ctx);
281
282 let mut final_response = String::new();
283 let mut total_turns = 0;
284 let mut tool_count = 0u32;
285 let mut tool_logs: Vec<ToolCallLog> = Vec::new();
286 let mut pending_tools: std::collections::HashMap<String, (String, String)> =
287 std::collections::HashMap::new();
288 let mut total_usage = TokenUsage::default();
289 let mut success = true;
290
291 let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
292
293 loop {
294 let recv_result = if let Some(timeout) = timeout_duration {
295 let remaining = timeout.saturating_sub(start.elapsed());
296 if remaining.is_zero() {
297 final_response = "Subagent timed out".to_string();
298 success = false;
299 break;
300 }
301 tokio::time::timeout(remaining, rx.recv()).await
302 } else {
303 Ok(rx.recv().await)
304 };
305
306 match recv_result {
307 Ok(Some(envelope)) => match envelope.event {
308 AgentEvent::Text {
309 message_id: _,
310 text,
311 } => {
312 final_response.push_str(&text);
313 }
314 AgentEvent::ToolCallStart {
315 id, name, input, ..
316 } => {
317 tool_count += 1;
319 let context = extract_tool_context(&name, &input);
320 pending_tools.insert(id, (name.clone(), context.clone()));
321
322 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
324 let event = AgentEvent::SubagentProgress {
325 subagent_id: subagent_id.clone(),
326 subagent_name: self.config.name.clone(),
327 tool_name: name,
328 tool_context: context,
329 completed: false,
330 success: false,
331 tool_count,
332 total_tokens: u64::from(total_usage.input_tokens)
333 + u64::from(total_usage.output_tokens),
334 };
335 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
336 }
337 }
338 AgentEvent::ToolCallEnd {
339 id,
340 name,
341 display_name,
342 result,
343 } => {
344 let context = pending_tools
346 .remove(&id)
347 .map(|(_, ctx)| ctx)
348 .unwrap_or_default();
349 let result_summary = summarize_tool_result(&name, &result);
350 let tool_success = result.success;
351 tool_logs.push(ToolCallLog {
352 name: name.clone(),
353 display_name: display_name.clone(),
354 context: context.clone(),
355 result: result_summary,
356 success: tool_success,
357 duration_ms: result.duration_ms,
358 });
359
360 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
362 let event = AgentEvent::SubagentProgress {
363 subagent_id: subagent_id.clone(),
364 subagent_name: self.config.name.clone(),
365 tool_name: name,
366 tool_context: context,
367 completed: true,
368 success: tool_success,
369 tool_count,
370 total_tokens: u64::from(total_usage.input_tokens)
371 + u64::from(total_usage.output_tokens),
372 };
373 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
374 }
375 }
376 AgentEvent::TurnComplete { turn, usage, .. } => {
377 total_turns = turn;
378 total_usage.add(&usage);
379 }
380 AgentEvent::Done {
381 total_turns: turns, ..
382 } => {
383 total_turns = turns;
384 break;
385 }
386 AgentEvent::Error { message, .. } => {
387 final_response = message;
388 success = false;
389 break;
390 }
391 _ => {}
392 },
393 Ok(None) => break,
394 Err(_) => {
395 final_response = "Subagent timed out".to_string();
396 success = false;
397 break;
398 }
399 }
400 }
401
402 Ok(SubagentResult {
403 name: self.config.name.clone(),
404 final_response,
405 total_turns,
406 tool_count,
407 tool_logs,
408 usage: total_usage,
409 success,
410 duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
411 })
412 }
413}
414
415fn extract_tool_context(name: &str, input: &Value) -> String {
417 match name {
418 "read" => input
419 .get("file_path")
420 .or_else(|| input.get("path"))
421 .and_then(Value::as_str)
422 .unwrap_or("")
423 .to_string(),
424 "write" | "edit" => input
425 .get("file_path")
426 .or_else(|| input.get("path"))
427 .and_then(Value::as_str)
428 .unwrap_or("")
429 .to_string(),
430 "bash" => {
431 let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
432 if cmd.len() > 60 {
434 format!("{}...", &cmd[..57])
435 } else {
436 cmd.to_string()
437 }
438 }
439 "glob" | "grep" => input
440 .get("pattern")
441 .and_then(Value::as_str)
442 .unwrap_or("")
443 .to_string(),
444 "web_search" => input
445 .get("query")
446 .and_then(Value::as_str)
447 .unwrap_or("")
448 .to_string(),
449 _ => String::new(),
450 }
451}
452
453fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
455 if !result.success {
456 let first_line = result.output.lines().next().unwrap_or("Error");
457 return if first_line.len() > 50 {
458 format!("{}...", &first_line[..47])
459 } else {
460 first_line.to_string()
461 };
462 }
463
464 match name {
465 "read" => {
466 let line_count = result.output.lines().count();
467 format!("{line_count} lines")
468 }
469 "write" => "wrote file".to_string(),
470 "edit" => "edited".to_string(),
471 "bash" => {
472 let lines: Vec<&str> = result.output.lines().collect();
473 if lines.is_empty() {
474 "done".to_string()
475 } else if lines.len() == 1 {
476 let line = lines[0];
477 if line.len() > 50 {
478 format!("{}...", &line[..47])
479 } else {
480 line.to_string()
481 }
482 } else {
483 format!("{} lines", lines.len())
484 }
485 }
486 "glob" => {
487 let count = result.output.lines().count();
488 format!("{count} files")
489 }
490 "grep" => {
491 let count = result.output.lines().count();
492 format!("{count} matches")
493 }
494 _ => {
495 let line_count = result.output.lines().count();
496 if line_count == 0 {
497 "done".to_string()
498 } else {
499 format!("{line_count} lines")
500 }
501 }
502 }
503}
504
505impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
506where
507 P: LlmProvider + Clone + 'static,
508 H: AgentHooks + Clone + 'static,
509 M: MessageStore + 'static,
510 S: StateStore + 'static,
511{
512 type Name = DynamicToolName;
513
514 fn name(&self) -> DynamicToolName {
515 DynamicToolName::new(format!("subagent_{}", self.config.name))
516 }
517
518 fn display_name(&self) -> &'static str {
519 Box::leak(format!("Subagent: {}", self.config.name).into_boxed_str())
521 }
522
523 fn description(&self) -> &'static str {
524 Box::leak(
525 format!(
526 "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
527 self.config.name
528 )
529 .into_boxed_str(),
530 )
531 }
532
533 fn input_schema(&self) -> Value {
534 json!({
535 "type": "object",
536 "properties": {
537 "task": {
538 "type": "string",
539 "description": "The task or question for the subagent to handle"
540 }
541 },
542 "required": ["task"]
543 })
544 }
545
546 fn tier(&self) -> ToolTier {
547 ToolTier::Confirm
549 }
550
551 async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
552 let task = input
553 .get("task")
554 .and_then(Value::as_str)
555 .context("Missing 'task' parameter")?;
556
557 let parent_tx = ctx.event_tx();
559 let parent_seq = ctx.event_seq();
560
561 let subagent_id = format!(
563 "{}_{:x}",
564 self.config.name,
565 std::time::SystemTime::now()
566 .duration_since(std::time::UNIX_EPOCH)
567 .unwrap_or_default()
568 .as_nanos()
569 );
570
571 let result = self
572 .run_subagent(task, subagent_id, parent_tx, parent_seq)
573 .await?;
574
575 Ok(ToolResult {
576 success: result.success,
577 output: result.final_response.clone(),
578 data: Some(serde_json::to_value(&result).unwrap_or_default()),
579 documents: Vec::new(),
580 duration_ms: Some(result.duration_ms),
581 })
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn test_subagent_config_builder() {
591 let config = SubagentConfig::new("test")
592 .with_system_prompt("Test prompt")
593 .with_max_turns(5)
594 .with_timeout_ms(30000);
595
596 assert_eq!(config.name, "test");
597 assert_eq!(config.system_prompt, "Test prompt");
598 assert_eq!(config.max_turns, 5);
599 assert_eq!(config.timeout_ms, Some(30000));
600 }
601
602 #[test]
603 fn test_subagent_config_defaults() {
604 let config = SubagentConfig::new("default");
605
606 assert_eq!(config.name, "default");
607 assert!(config.system_prompt.is_empty());
608 assert_eq!(config.max_turns, 10);
609 assert_eq!(config.timeout_ms, None);
610 }
611
612 #[test]
613 fn test_subagent_result_serialization() {
614 let result = SubagentResult {
615 name: "test".to_string(),
616 final_response: "Done".to_string(),
617 total_turns: 3,
618 tool_count: 5,
619 tool_logs: vec![
620 ToolCallLog {
621 name: "read".to_string(),
622 display_name: "Read file".to_string(),
623 context: "/tmp/test.rs".to_string(),
624 result: "50 lines".to_string(),
625 success: true,
626 duration_ms: Some(10),
627 },
628 ToolCallLog {
629 name: "grep".to_string(),
630 display_name: "Grep TODO".to_string(),
631 context: "TODO".to_string(),
632 result: "3 matches".to_string(),
633 success: true,
634 duration_ms: Some(5),
635 },
636 ],
637 usage: TokenUsage::default(),
638 success: true,
639 duration_ms: 1000,
640 };
641
642 let json = serde_json::to_string(&result).expect("serialize");
643 assert!(json.contains("test"));
644 assert!(json.contains("Done"));
645 assert!(json.contains("tool_count"));
646 assert!(json.contains("tool_logs"));
647 assert!(json.contains("/tmp/test.rs"));
648 }
649
650 #[test]
651 fn test_subagent_result_field_extraction() {
652 let result = SubagentResult {
654 name: "explore".to_string(),
655 final_response: "Found 3 config files".to_string(),
656 total_turns: 2,
657 tool_count: 5,
658 tool_logs: vec![ToolCallLog {
659 name: "glob".to_string(),
660 display_name: "Glob config files".to_string(),
661 context: "**/*.toml".to_string(),
662 result: "3 files".to_string(),
663 success: true,
664 duration_ms: Some(15),
665 }],
666 usage: TokenUsage {
667 input_tokens: 1500,
668 output_tokens: 500,
669 },
670 success: true,
671 duration_ms: 2500,
672 };
673
674 let value = serde_json::to_value(&result).expect("serialize to value");
675
676 let tool_count = value.get("tool_count").and_then(Value::as_u64);
678 assert_eq!(tool_count, Some(5));
679
680 let usage = value.get("usage").expect("usage field");
682 let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
683 let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
684 assert_eq!(input_tokens, Some(1500));
685 assert_eq!(output_tokens, Some(500));
686
687 let tool_logs = value.get("tool_logs").and_then(Value::as_array);
689 assert!(tool_logs.is_some());
690 let logs = tool_logs.unwrap();
691 assert_eq!(logs.len(), 1);
692
693 let first_log = &logs[0];
694 assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
695 assert_eq!(
696 first_log.get("context").and_then(Value::as_str),
697 Some("**/*.toml")
698 );
699 assert_eq!(
700 first_log.get("result").and_then(Value::as_str),
701 Some("3 files")
702 );
703 assert_eq!(
704 first_log.get("success").and_then(Value::as_bool),
705 Some(true)
706 );
707 }
708}