1use crate::events::AgentEvent;
36use crate::hooks::{AgentHooks, DefaultHooks};
37use crate::llm::LlmProvider;
38use crate::stores::{InMemoryStore, MessageStore, StateStore};
39use crate::tools::{Tool, ToolContext, ToolRegistry};
40use crate::types::{AgentConfig, ThreadId, TokenUsage, ToolResult, ToolTier};
41use anyhow::{Context, Result};
42use async_trait::async_trait;
43use serde::{Deserialize, Serialize};
44use serde_json::{Value, json};
45use std::sync::Arc;
46use std::time::{Duration, Instant};
47use tokio::sync::mpsc;
48
49#[derive(Clone, Debug, Serialize, Deserialize)]
51pub struct SubagentConfig {
52 pub name: String,
54 pub system_prompt: String,
56 pub max_turns: usize,
58 pub timeout_ms: Option<u64>,
60}
61
62impl SubagentConfig {
63 #[must_use]
65 pub fn new(name: impl Into<String>) -> Self {
66 Self {
67 name: name.into(),
68 system_prompt: String::new(),
69 max_turns: 10,
70 timeout_ms: None,
71 }
72 }
73
74 #[must_use]
76 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
77 self.system_prompt = prompt.into();
78 self
79 }
80
81 #[must_use]
83 pub const fn with_max_turns(mut self, max: usize) -> Self {
84 self.max_turns = max;
85 self
86 }
87
88 #[must_use]
90 pub const fn with_timeout_ms(mut self, timeout: u64) -> Self {
91 self.timeout_ms = Some(timeout);
92 self
93 }
94}
95
96#[derive(Clone, Debug, Serialize, Deserialize)]
98pub struct ToolCallLog {
99 pub name: String,
101 pub context: String,
103 pub result: String,
105 pub success: bool,
107 pub duration_ms: Option<u64>,
109}
110
111#[derive(Clone, Debug, Serialize, Deserialize)]
113pub struct SubagentResult {
114 pub name: String,
116 pub final_response: String,
118 pub total_turns: usize,
120 pub tool_count: u32,
122 pub tool_logs: Vec<ToolCallLog>,
124 pub usage: TokenUsage,
126 pub success: bool,
128 pub duration_ms: u64,
130}
131
132pub struct SubagentTool<P, H = DefaultHooks, M = InMemoryStore, S = InMemoryStore>
148where
149 P: LlmProvider,
150 H: AgentHooks,
151 M: MessageStore,
152 S: StateStore,
153{
154 config: SubagentConfig,
155 provider: Arc<P>,
156 tools: Arc<ToolRegistry<()>>,
157 hooks: Arc<H>,
158 message_store_factory: Arc<dyn Fn() -> M + Send + Sync>,
159 state_store_factory: Arc<dyn Fn() -> S + Send + Sync>,
160}
161
162impl<P> SubagentTool<P, DefaultHooks, InMemoryStore, InMemoryStore>
163where
164 P: LlmProvider + 'static,
165{
166 #[must_use]
168 pub fn new(config: SubagentConfig, provider: Arc<P>, tools: Arc<ToolRegistry<()>>) -> Self {
169 Self {
170 config,
171 provider,
172 tools,
173 hooks: Arc::new(DefaultHooks),
174 message_store_factory: Arc::new(InMemoryStore::new),
175 state_store_factory: Arc::new(InMemoryStore::new),
176 }
177 }
178}
179
180impl<P, H, M, S> SubagentTool<P, H, M, S>
181where
182 P: LlmProvider + Clone + 'static,
183 H: AgentHooks + Clone + 'static,
184 M: MessageStore + 'static,
185 S: StateStore + 'static,
186{
187 #[must_use]
189 pub fn with_hooks<H2: AgentHooks + Clone + 'static>(
190 self,
191 hooks: Arc<H2>,
192 ) -> SubagentTool<P, H2, M, S> {
193 SubagentTool {
194 config: self.config,
195 provider: self.provider,
196 tools: self.tools,
197 hooks,
198 message_store_factory: self.message_store_factory,
199 state_store_factory: self.state_store_factory,
200 }
201 }
202
203 #[must_use]
205 pub fn with_stores<M2, S2, MF, SF>(
206 self,
207 message_factory: MF,
208 state_factory: SF,
209 ) -> SubagentTool<P, H, M2, S2>
210 where
211 M2: MessageStore + 'static,
212 S2: StateStore + 'static,
213 MF: Fn() -> M2 + Send + Sync + 'static,
214 SF: Fn() -> S2 + Send + Sync + 'static,
215 {
216 SubagentTool {
217 config: self.config,
218 provider: self.provider,
219 tools: self.tools,
220 hooks: self.hooks,
221 message_store_factory: Arc::new(message_factory),
222 state_store_factory: Arc::new(state_factory),
223 }
224 }
225
226 #[must_use]
228 pub const fn config(&self) -> &SubagentConfig {
229 &self.config
230 }
231
232 #[allow(clippy::too_many_lines)]
237 async fn run_subagent(
238 &self,
239 task: &str,
240 subagent_id: String,
241 parent_tx: Option<mpsc::Sender<AgentEvent>>,
242 ) -> Result<SubagentResult> {
243 use crate::agent_loop::AgentLoop;
244
245 let start = Instant::now();
246 let thread_id = ThreadId::new();
247
248 let message_store = (self.message_store_factory)();
250 let state_store = (self.state_store_factory)();
251
252 let agent_config = AgentConfig {
254 max_turns: self.config.max_turns,
255 system_prompt: self.config.system_prompt.clone(),
256 ..Default::default()
257 };
258
259 let agent = AgentLoop::new(
261 (*self.provider).clone(),
262 (*self.tools).clone(),
263 (*self.hooks).clone(),
264 message_store,
265 state_store,
266 agent_config,
267 );
268
269 let tool_ctx = ToolContext::new(());
271
272 let mut rx = agent.run(thread_id, task.to_string(), tool_ctx);
274
275 let mut final_response = String::new();
276 let mut total_turns = 0;
277 let mut tool_count = 0u32;
278 let mut tool_logs: Vec<ToolCallLog> = Vec::new();
279 let mut pending_tools: std::collections::HashMap<String, (String, String)> =
280 std::collections::HashMap::new();
281 let mut total_usage = TokenUsage::default();
282 let mut success = true;
283
284 let timeout_duration = self.config.timeout_ms.map(Duration::from_millis);
285
286 loop {
287 let recv_result = if let Some(timeout) = timeout_duration {
288 let remaining = timeout.saturating_sub(start.elapsed());
289 if remaining.is_zero() {
290 final_response = "Subagent timed out".to_string();
291 success = false;
292 break;
293 }
294 tokio::time::timeout(remaining, rx.recv()).await
295 } else {
296 Ok(rx.recv().await)
297 };
298
299 match recv_result {
300 Ok(Some(event)) => match event {
301 AgentEvent::Text { text } => {
302 final_response.push_str(&text);
303 }
304 AgentEvent::ToolCallStart {
305 id, name, input, ..
306 } => {
307 tool_count += 1;
309 let context = extract_tool_context(&name, &input);
310 pending_tools.insert(id, (name.clone(), context.clone()));
311
312 if let Some(ref tx) = parent_tx {
314 let _ = tx
315 .send(AgentEvent::SubagentProgress {
316 subagent_id: subagent_id.clone(),
317 subagent_name: self.config.name.clone(),
318 tool_name: name,
319 tool_context: context,
320 completed: false,
321 success: false,
322 tool_count,
323 total_tokens: u64::from(total_usage.input_tokens)
324 + u64::from(total_usage.output_tokens),
325 })
326 .await;
327 }
328 }
329 AgentEvent::ToolCallEnd { id, name, result } => {
330 let context = pending_tools
332 .remove(&id)
333 .map(|(_, ctx)| ctx)
334 .unwrap_or_default();
335 let result_summary = summarize_tool_result(&name, &result);
336 let tool_success = result.success;
337 tool_logs.push(ToolCallLog {
338 name: name.clone(),
339 context: context.clone(),
340 result: result_summary,
341 success: tool_success,
342 duration_ms: result.duration_ms,
343 });
344
345 if let Some(ref tx) = parent_tx {
347 let _ = tx
348 .send(AgentEvent::SubagentProgress {
349 subagent_id: subagent_id.clone(),
350 subagent_name: self.config.name.clone(),
351 tool_name: name,
352 tool_context: context,
353 completed: true,
354 success: tool_success,
355 tool_count,
356 total_tokens: u64::from(total_usage.input_tokens)
357 + u64::from(total_usage.output_tokens),
358 })
359 .await;
360 }
361 }
362 AgentEvent::TurnComplete { turn, usage, .. } => {
363 total_turns = turn;
364 total_usage.add(&usage);
365 }
366 AgentEvent::Done {
367 total_turns: turns, ..
368 } => {
369 total_turns = turns;
370 break;
371 }
372 AgentEvent::Error { message, .. } => {
373 final_response = message;
374 success = false;
375 break;
376 }
377 _ => {}
378 },
379 Ok(None) => break,
380 Err(_) => {
381 final_response = "Subagent timed out".to_string();
382 success = false;
383 break;
384 }
385 }
386 }
387
388 Ok(SubagentResult {
389 name: self.config.name.clone(),
390 final_response,
391 total_turns,
392 tool_count,
393 tool_logs,
394 usage: total_usage,
395 success,
396 duration_ms: u64::try_from(start.elapsed().as_millis()).unwrap_or(u64::MAX),
397 })
398 }
399}
400
401fn extract_tool_context(name: &str, input: &Value) -> String {
403 match name {
404 "read" => input
405 .get("file_path")
406 .or_else(|| input.get("path"))
407 .and_then(Value::as_str)
408 .unwrap_or("")
409 .to_string(),
410 "write" | "edit" => input
411 .get("file_path")
412 .or_else(|| input.get("path"))
413 .and_then(Value::as_str)
414 .unwrap_or("")
415 .to_string(),
416 "bash" => {
417 let cmd = input.get("command").and_then(Value::as_str).unwrap_or("");
418 if cmd.len() > 60 {
420 format!("{}...", &cmd[..57])
421 } else {
422 cmd.to_string()
423 }
424 }
425 "glob" | "grep" => input
426 .get("pattern")
427 .and_then(Value::as_str)
428 .unwrap_or("")
429 .to_string(),
430 "web_search" => input
431 .get("query")
432 .and_then(Value::as_str)
433 .unwrap_or("")
434 .to_string(),
435 _ => String::new(),
436 }
437}
438
439fn summarize_tool_result(name: &str, result: &ToolResult) -> String {
441 if !result.success {
442 let first_line = result.output.lines().next().unwrap_or("Error");
443 return if first_line.len() > 50 {
444 format!("{}...", &first_line[..47])
445 } else {
446 first_line.to_string()
447 };
448 }
449
450 match name {
451 "read" => {
452 let line_count = result.output.lines().count();
453 format!("{line_count} lines")
454 }
455 "write" => "wrote file".to_string(),
456 "edit" => "edited".to_string(),
457 "bash" => {
458 let lines: Vec<&str> = result.output.lines().collect();
459 if lines.is_empty() {
460 "done".to_string()
461 } else if lines.len() == 1 {
462 let line = lines[0];
463 if line.len() > 50 {
464 format!("{}...", &line[..47])
465 } else {
466 line.to_string()
467 }
468 } else {
469 format!("{} lines", lines.len())
470 }
471 }
472 "glob" => {
473 let count = result.output.lines().count();
474 format!("{count} files")
475 }
476 "grep" => {
477 let count = result.output.lines().count();
478 format!("{count} matches")
479 }
480 _ => {
481 let line_count = result.output.lines().count();
482 if line_count == 0 {
483 "done".to_string()
484 } else {
485 format!("{line_count} lines")
486 }
487 }
488 }
489}
490
491#[async_trait]
492impl<P, H, M, S> Tool<()> for SubagentTool<P, H, M, S>
493where
494 P: LlmProvider + Clone + 'static,
495 H: AgentHooks + Clone + 'static,
496 M: MessageStore + 'static,
497 S: StateStore + 'static,
498{
499 fn name(&self) -> &'static str {
500 Box::leak(format!("subagent_{}", self.config.name).into_boxed_str())
502 }
503
504 fn description(&self) -> &'static str {
505 Box::leak(
506 format!(
507 "Spawn a subagent named '{}' to handle a task. The subagent will work independently and return only its final response.",
508 self.config.name
509 )
510 .into_boxed_str(),
511 )
512 }
513
514 fn input_schema(&self) -> Value {
515 json!({
516 "type": "object",
517 "properties": {
518 "task": {
519 "type": "string",
520 "description": "The task or question for the subagent to handle"
521 }
522 },
523 "required": ["task"]
524 })
525 }
526
527 fn tier(&self) -> ToolTier {
528 ToolTier::Confirm
530 }
531
532 async fn execute(&self, ctx: &ToolContext<()>, input: Value) -> Result<ToolResult> {
533 let task = input
534 .get("task")
535 .and_then(Value::as_str)
536 .context("Missing 'task' parameter")?;
537
538 let parent_tx = ctx.event_tx();
540
541 let subagent_id = format!(
543 "{}_{:x}",
544 self.config.name,
545 std::time::SystemTime::now()
546 .duration_since(std::time::UNIX_EPOCH)
547 .unwrap_or_default()
548 .as_nanos()
549 );
550
551 let result = self.run_subagent(task, subagent_id, parent_tx).await?;
552
553 Ok(ToolResult {
554 success: result.success,
555 output: result.final_response.clone(),
556 data: Some(serde_json::to_value(&result).unwrap_or_default()),
557 duration_ms: Some(result.duration_ms),
558 })
559 }
560}
561
562#[cfg(test)]
563mod tests {
564 use super::*;
565
566 #[test]
567 fn test_subagent_config_builder() {
568 let config = SubagentConfig::new("test")
569 .with_system_prompt("Test prompt")
570 .with_max_turns(5)
571 .with_timeout_ms(30000);
572
573 assert_eq!(config.name, "test");
574 assert_eq!(config.system_prompt, "Test prompt");
575 assert_eq!(config.max_turns, 5);
576 assert_eq!(config.timeout_ms, Some(30000));
577 }
578
579 #[test]
580 fn test_subagent_config_defaults() {
581 let config = SubagentConfig::new("default");
582
583 assert_eq!(config.name, "default");
584 assert!(config.system_prompt.is_empty());
585 assert_eq!(config.max_turns, 10);
586 assert_eq!(config.timeout_ms, None);
587 }
588
589 #[test]
590 fn test_subagent_result_serialization() {
591 let result = SubagentResult {
592 name: "test".to_string(),
593 final_response: "Done".to_string(),
594 total_turns: 3,
595 tool_count: 5,
596 tool_logs: vec![
597 ToolCallLog {
598 name: "read".to_string(),
599 context: "/tmp/test.rs".to_string(),
600 result: "50 lines".to_string(),
601 success: true,
602 duration_ms: Some(10),
603 },
604 ToolCallLog {
605 name: "grep".to_string(),
606 context: "TODO".to_string(),
607 result: "3 matches".to_string(),
608 success: true,
609 duration_ms: Some(5),
610 },
611 ],
612 usage: TokenUsage::default(),
613 success: true,
614 duration_ms: 1000,
615 };
616
617 let json = serde_json::to_string(&result).expect("serialize");
618 assert!(json.contains("test"));
619 assert!(json.contains("Done"));
620 assert!(json.contains("tool_count"));
621 assert!(json.contains("tool_logs"));
622 assert!(json.contains("/tmp/test.rs"));
623 }
624
625 #[test]
626 fn test_subagent_result_field_extraction() {
627 let result = SubagentResult {
629 name: "explore".to_string(),
630 final_response: "Found 3 config files".to_string(),
631 total_turns: 2,
632 tool_count: 5,
633 tool_logs: vec![ToolCallLog {
634 name: "glob".to_string(),
635 context: "**/*.toml".to_string(),
636 result: "3 files".to_string(),
637 success: true,
638 duration_ms: Some(15),
639 }],
640 usage: TokenUsage {
641 input_tokens: 1500,
642 output_tokens: 500,
643 },
644 success: true,
645 duration_ms: 2500,
646 };
647
648 let value = serde_json::to_value(&result).expect("serialize to value");
649
650 let tool_count = value.get("tool_count").and_then(Value::as_u64);
652 assert_eq!(tool_count, Some(5));
653
654 let usage = value.get("usage").expect("usage field");
656 let input_tokens = usage.get("input_tokens").and_then(Value::as_u64);
657 let output_tokens = usage.get("output_tokens").and_then(Value::as_u64);
658 assert_eq!(input_tokens, Some(1500));
659 assert_eq!(output_tokens, Some(500));
660
661 let tool_logs = value.get("tool_logs").and_then(Value::as_array);
663 assert!(tool_logs.is_some());
664 let logs = tool_logs.unwrap();
665 assert_eq!(logs.len(), 1);
666
667 let first_log = &logs[0];
668 assert_eq!(first_log.get("name").and_then(Value::as_str), Some("glob"));
669 assert_eq!(
670 first_log.get("context").and_then(Value::as_str),
671 Some("**/*.toml")
672 );
673 assert_eq!(
674 first_log.get("result").and_then(Value::as_str),
675 Some("3 files")
676 );
677 assert_eq!(
678 first_log.get("success").and_then(Value::as_bool),
679 Some(true)
680 );
681 }
682}