1mod factory;
36
37pub use factory::SubagentFactory;
38
39use crate::events::{AgentEvent, AgentEventEnvelope, SequenceCounter};
40use crate::hooks::{AgentHooks, DefaultHooks};
41use crate::llm::LlmProvider;
42use crate::stores::{InMemoryStore, MessageStore, StateStore};
43use crate::tools::{DynamicToolName, Tool, ToolContext, ToolRegistry};
44use crate::types::{AgentConfig, AgentInput, ThreadId, TokenUsage, ToolResult, ToolTier};
45use anyhow::{Context, Result};
46use serde::{Deserialize, Serialize};
47use serde_json::{Value, json};
48use std::sync::Arc;
49use std::time::{Duration, Instant};
50use tokio::sync::mpsc;
51use tokio_util::sync::CancellationToken;
52
53#[derive(Clone, Debug, Serialize, Deserialize)]
55pub struct SubagentConfig {
56 pub name: String,
58 pub system_prompt: String,
60 pub max_turns: Option<usize>,
62 pub timeout_ms: Option<u64>,
64}
65
66impl SubagentConfig {
67 #[must_use]
69 pub fn new(name: impl Into<String>) -> Self {
70 Self {
71 name: name.into(),
72 system_prompt: String::new(),
73 max_turns: None,
74 timeout_ms: None,
75 }
76 }
77
78 #[must_use]
80 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
81 self.system_prompt = prompt.into();
82 self
83 }
84
85 #[must_use]
87 pub const fn with_max_turns(mut self, max: usize) -> Self {
88 self.max_turns = Some(max);
89 self
90 }
91
92 #[must_use]
94 pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
95 self.timeout_ms = Some(timeout);
96 self
97 }
98}
99
100#[derive(Clone, Debug, Serialize, Deserialize)]
102pub struct ToolCallLog {
103 pub name: String,
105 pub display_name: String,
107 pub context: String,
109 pub result: String,
111 pub success: bool,
113 pub duration_ms: Option<u64>,
115}
116
117#[derive(Clone, Debug, Serialize, Deserialize)]
119pub struct SubagentResult {
120 pub name: String,
122 pub final_response: String,
124 pub total_turns: usize,
126 pub tool_count: u32,
128 pub tool_logs: Vec<ToolCallLog>,
130 pub usage: TokenUsage,
132 pub success: bool,
134 pub duration_ms: u64,
136}
137
138pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
154where
155 P: LlmProvider,
156 H: AgentHooks,
157 M: MessageStore,
158 S: StateStore,
159{
160 config: SubagentConfig,
161 provider: Arc<P>,
162 tools: Arc<ToolRegistry<()>>,
163 hooks: Arc<H>,
164 message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
165 state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
166}
167
168impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
169where
170 P: LlmProvider + 'static,
171{
172 #[must_use]
174 pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
175 Self {
176 config,
177 provider,
178 tools,
179 hooks: Arc::new(DefaultHooks),
180 message_store_factory: Arc::new(InMemoryStore::new),
181 state_store_factory: Arc::new(InMemoryStore::new),
182 }
183 }
184}
185
186impl<P, H, M, S> SubagentTool<P, H, M, S>
187where
188 P: LlmProvider + Clone + 'static,
189 H: AgentHooks + Clone + 'static,
190 M: MessageStore + 'static,
191 S: StateStore + 'static,
192{
193 #[must_use]
195 pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
196 self,
197 hooks: Arc<H2>,
198 ) -> SubagentTool<P, H2, M, S> {
199 SubagentTool {
200 config: self.config,
201 provider: self.provider,
202 tools: self.tools,
203 hooks,
204 message_store_factory: self.message_store_factory,
205 state_store_factory: self.state_store_factory,
206 }
207 }
208
209 #[must_use]
211 pub fn with_stores<M2, S2, MF, SF>(
212 self,
213 message_factory: MF,
214 state_factory: SF,
215 ) -> SubagentTool<P, H, M2, S2>
216 where
217 M2: MessageStore + 'static,
218 S2: StateStore + 'static,
219 MF: Fn() -> M2 + Send + Sync + 'static,
220 SF: Fn() -> S2 + Send + Sync + 'static,
221 {
222 SubagentTool {
223 config: self.config,
224 provider: self.provider,
225 tools: self.tools,
226 hooks: self.hooks,
227 message_store_factory: Arc::new(message_factory),
228 state_store_factory: Arc::new(state_factory),
229 }
230 }
231
232 #[must_use]
234 pub const fn config(&self) -> &SubagentConfig {
235 &self.config
236 }
237
238 #[allow(clippy::too_many_lines)]
243 async fn run_subagent(
244 &self,
245 task: &str,
246 subagent_id: String,
247 parent_tx: Option<mpsc::Sender<AgentEventEnvelope>>,
248 parent_seq: Option<SequenceCounter>,
249 ) -> Result<SubagentResult> {
250 use crate::agent_loop::AgentLoop;
251
252 let start = Instant::now();
253 let thread_id = ThreadId::new();
254
255 let message_store = (self.message_store_factory)();
257 let state_store = (self.state_store_factory)();
258
259 let agent_config = AgentConfig {
261 max_turns: self.config.max_turns,
262 system_prompt: self.config.system_prompt.clone(),
263 ..Default::default()
264 };
265
266 let agent = AgentLoop::new(
268 (*self.provider).clone(),
269 (*self.tools).clone(),
270 (*self.hooks).clone(),
271 message_store,
272 state_store,
273 agent_config,
274 );
275
276 let tool_ctx = ToolContext::new(());
278
279 let (mut rx, _final_state) = agent.run(
281 thread_id,
282 AgentInput::Text(task.to_string()),
283 tool_ctx,
284 CancellationToken::new(),
285 );
286
287 let mut final_response = String::new();
288 let mut total_turns = 0;
289 let mut tool_count = 0u32;
290 let mut tool_logs: Vec<ToolCallLog> = Vec::new();
291 let mut pending_tools: std::collections::HashMap<String, (String, String)> =
292 std::collections::HashMap::new();
293 let mut total_usage = TokenUsage::default();
294 let mut success = true;
295
296 let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
297
298 loop {
299 let recv_result = if let Some(timeout) = timeout_duration {
300 let remaining = timeout.saturating_sub(start.elapsed());
301 if remaining.is_zero() {
302 final_response = "Subagent timed out".to_string();
303 success = false;
304 break;
305 }
306 tokio::time::timeout(remaining, rx.recv()).await
307 } else {
308 Ok(rx.recv().await)
309 };
310
311 match recv_result {
312 Ok(Some(envelope)) => match envelope.event {
313 AgentEvent::Text {
314 message_id: _,
315 text,
316 } => {
317 final_response.push_str(&text);
318 }
319 AgentEvent::ToolCallStart {
320 id, name, input, ..
321 } => {
322 tool_count += 1;
324 let context = extract_tool_context(&name, &input);
325 pending_tools.insert(id, (name.clone(), context.clone()));
326
327 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
329 let event = AgentEvent::SubagentProgress {
330 subagent_id: subagent_id.clone(),
331 subagent_name: self.config.name.clone(),
332 tool_name: name,
333 tool_context: context,
334 completed: false,
335 success: false,
336 tool_count,
337 total_tokens: u64::from(total_usage.input_tokens)
338 + u64::from(total_usage.output_tokens),
339 };
340 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
341 }
342 }
343 AgentEvent::ToolCallEnd {
344 id,
345 name,
346 display_name,
347 result,
348 } => {
349 let context = pending_tools
351 .remove(&id)
352 .map(|(_, ctx)| ctx)
353 .unwrap_or_default();
354 let result_summary = summarize_tool_result(&name, &result);
355 let tool_success = result.success;
356 tool_logs.push(ToolCallLog {
357 name: name.clone(),
358 display_name: display_name.clone(),
359 context: context.clone(),
360 result: result_summary,
361 success: tool_success,
362 duration_ms: result.duration_ms,
363 });
364
365 if let (Some(tx), Some(seq)) = (&parent_tx, &parent_seq) {
367 let event = AgentEvent::SubagentProgress {
368 subagent_id: subagent_id.clone(),
369 subagent_name: self.config.name.clone(),
370 tool_name: name,
371 tool_context: context,
372 completed: true,
373 success: tool_success,
374 tool_count,
375 total_tokens: u64::from(total_usage.input_tokens)
376 + u64::from(total_usage.output_tokens),
377 };
378 let _ = tx.send(AgentEventEnvelope::wrap(event, seq)).await;
379 }
380 }
381 AgentEvent::TurnComplete { turn, usage, .. } => {
382 total_turns = turn;
383 total_usage.add(&usage);
384 }
385 AgentEvent::Done {
386 total_turns: turns, ..
387 } => {
388 total_turns = turns;
389 break;
390 }
391 AgentEvent::Error { message, .. } => {
392 final_response = message;
393 success = false;
394 break;
395 }
396 _ => {}
397 },
398 Ok(None) => break,
399 Err(_) => {
400 final_response = "Subagent timed out".to_string();
401 success = false;
402 break;
403 }
404 }
405 }
406
407 Ok(SubagentResult {
408 name: self.config.name.clone(),
409 final_response,
410 total_turns,
411 tool_count,
412 tool_logs,
413 usage: total_usage,
414 success,
415 duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
416 })
417 }
418}
419
420fn extract_tool_context(name: &str, input: &Value) -> String {
422 match name {
423 "read" => input
424 .get("file_path")
425 .or_else(|| input.get("path"))
426 .and_then(Value::as_str)
427 .unwrap_or("")
428 .to_string(),
429 "write" | "edit" => input
430 .get("file_path")
431 .or_else(|| input.get("path"))
432 .and_then(Value::as_str)
433 .unwrap_or("")
434 .to_string(),
435 "bash" => {
436 let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
437 if cmd.len() > 60 {
439 format!("{}...", &cmd[..57])
440 } else {
441 cmd.to_string()
442 }
443 }
444 "glob" | "grep" => input
445 .get("pattern")
446 .and_then(Value::as_str)
447 .unwrap_or("")
448 .to_string(),
449 "web_search" => input
450 .get("query")
451 .and_then(Value::as_str)
452 .unwrap_or("")
453 .to_string(),
454 _ => String::new(),
455 }
456}
457
458fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
460 if !result.success {
461 let first_line = result.output.lines().next().unwrap_or("Error");
462 return if first_line.len() > 50 {
463 format!("{}...", &first_line[..47])
464 } else {
465 first_line.to_string()
466 };
467 }
468
469 match name {
470 "read" => {
471 let line_count = result.output.lines().count();
472 format!("{line_count} lines")
473 }
474 "write" => "wrote file".to_string(),
475 "edit" => "edited".to_string(),
476 "bash" => {
477 let lines: Vec<&str> = result.output.lines().collect();
478 if lines.is_empty() {
479 "done".to_string()
480 } else if lines.len() == 1 {
481 let line = lines[0];
482 if line.len() > 50 {
483 format!("{}...", &line[..47])
484 } else {
485 line.to_string()
486 }
487 } else {
488 format!("{} lines", lines.len())
489 }
490 }
491 "glob" => {
492 let count = result.output.lines().count();
493 format!("{count} files")
494 }
495 "grep" => {
496 let count = result.output.lines().count();
497 format!("{count} matches")
498 }
499 _ => {
500 let line_count = result.output.lines().count();
501 if line_count == 0 {
502 "done".to_string()
503 } else {
504 format!("{line_count} lines")
505 }
506 }
507 }
508}
509
510impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
511where
512 P: LlmProvider + Clone + 'static,
513 H: AgentHooks + Clone + 'static,
514 M: MessageStore + 'static,
515 S: StateStore + 'static,
516{
517 type Name = DynamicToolName;
518
519 fn name(&self) -> DynamicToolName {
520 DynamicToolName::new(format!("subagent_{}", self.config.name))
521 }
522
523 fn display_name(&self) -> &'static str {
524 Box::leak(format!("Subagent: {}", self.config.name).into_boxed_str())
526 }
527
528 fn description(&self) -> &'static str {
529 Box::leak(
530 format!(
531 "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
532 self.config.name
533 )
534 .into_boxed_str(),
535 )
536 }
537
538 fn input_schema(&self) -> Value {
539 json!({
540 "type": "object",
541 "properties": {
542 "task": {
543 "type": "string",
544 "description": "The task or question for the subagent to handle"
545 }
546 },
547 "required": ["task"]
548 })
549 }
550
551 fn tier(&self) -> ToolTier {
552 ToolTier::Confirm
554 }
555
556 async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
557 let task = input
558 .get("task")
559 .and_then(Value::as_str)
560 .context("Missing 'task' parameter")?;
561
562 let parent_tx = ctx.event_tx();
564 let parent_seq = ctx.event_seq();
565
566 let subagent_id = format!(
568 "{}_{:x}",
569 self.config.name,
570 std::time::SystemTime::now()
571 .duration_since(std::time::UNIX_EPOCH)
572 .unwrap_or_default()
573 .as_nanos()
574 );
575
576 let result = self
577 .run_subagent(task, subagent_id, parent_tx, parent_seq)
578 .await?;
579
580 Ok(ToolResult {
581 success: result.success,
582 output: result.final_response.clone(),
583 data: Some(serde_json::to_value(&result).unwrap_or_default()),
584 documents: Vec::new(),
585 duration_ms: Some(result.duration_ms),
586 })
587 }
588}
589
590#[cfg(test)]
591mod tests {
592 use super::*;
593
594 #[test]
595 fn test_subagent_config_builder() {
596 let config = SubagentConfig::new("test")
597 .with_system_prompt("Test prompt")
598 .with_max_turns(5)
599 .with_timeout_ms(30000);
600
601 assert_eq!(config.name, "test");
602 assert_eq!(config.system_prompt, "Test prompt");
603 assert_eq!(config.max_turns, Some(5));
604 assert_eq!(config.timeout_ms, Some(30000));
605 }
606
607 #[test]
608 fn test_subagent_config_defaults() {
609 let config = SubagentConfig::new("default");
610
611 assert_eq!(config.name, "default");
612 assert!(config.system_prompt.is_empty());
613 assert_eq!(config.max_turns, None);
614 assert_eq!(config.timeout_ms, None);
615 }
616
617 #[test]
618 fn test_subagent_result_serialization() {
619 let result = SubagentResult {
620 name: "test".to_string(),
621 final_response: "Done".to_string(),
622 total_turns: 3,
623 tool_count: 5,
624 tool_logs: vec![
625 ToolCallLog {
626 name: "read".to_string(),
627 display_name: "Read file".to_string(),
628 context: "/tmp/test.rs".to_string(),
629 result: "50 lines".to_string(),
630 success: true,
631 duration_ms: Some(10),
632 },
633 ToolCallLog {
634 name: "grep".to_string(),
635 display_name: "Grep TODO".to_string(),
636 context: "TODO".to_string(),
637 result: "3 matches".to_string(),
638 success: true,
639 duration_ms: Some(5),
640 },
641 ],
642 usage: TokenUsage::default(),
643 success: true,
644 duration_ms: 1000,
645 };
646
647 let json = serde_json::to_string(&result).expect("serialize");
648 assert!(json.contains("test"));
649 assert!(json.contains("Done"));
650 assert!(json.contains("tool_count"));
651 assert!(json.contains("tool_logs"));
652 assert!(json.contains("/tmp/test.rs"));
653 }
654
655 #[test]
656 fn test_subagent_result_field_extraction() {
657 let result = SubagentResult {
659 name: "explore".to_string(),
660 final_response: "Found 3 config files".to_string(),
661 total_turns: 2,
662 tool_count: 5,
663 tool_logs: vec![ToolCallLog {
664 name: "glob".to_string(),
665 display_name: "Glob config files".to_string(),
666 context: "**/*.toml".to_string(),
667 result: "3 files".to_string(),
668 success: true,
669 duration_ms: Some(15),
670 }],
671 usage: TokenUsage {
672 input_tokens: 1500,
673 output_tokens: 500,
674 },
675 success: true,
676 duration_ms: 2500,
677 };
678
679 let value = serde_json::to_value(&result).expect("serialize to value");
680
681 let tool_count = value.get("tool_count").and_then(Value::as_u64);
683 assert_eq!(tool_count, Some(5));
684
685 let usage = value.get("usage").expect("usage field");
687 let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
688 let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
689 assert_eq!(input_tokens, Some(1500));
690 assert_eq!(output_tokens, Some(500));
691
692 let tool_logs = value.get("tool_logs").and_then(Value::as_array);
694 assert!(tool_logs.is_some());
695 let logs = tool_logs.unwrap();
696 assert_eq!(logs.len(), 1);
697
698 let first_log = &logs[0];
699 assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
700 assert_eq!(
701 first_log.get("context").and_then(Value::as_str),
702 Some("**/*.toml")
703 );
704 assert_eq!(
705 first_log.get("result").and_then(Value::as_str),
706 Some("3 files")
707 );
708 assert_eq!(
709 first_log.get("success").and_then(Value::as_bool),
710 Some(true)
711 );
712 }
713}