1use std::collections::HashMap;
8use std::sync::Arc;
9
10use crate::error::Error;
11use crate::llm::LlmProvider;
12use crate::llm::types::TokenUsage;
13use crate::tool::handoff::{
14 HandoffContextMode, HandoffTarget, HandoffTool, parse_handoff_sentinel,
15};
16
17use super::{AgentOutput, AgentRunner};
18
19pub struct HandoffRunner<P: LlmProvider> {
26 agents: HashMap<String, AgentRunner<P>>,
27 initial_agent: String,
28 max_handoffs: usize,
29}
30
31impl<P: LlmProvider> std::fmt::Debug for HandoffRunner<P> {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 f.debug_struct("HandoffRunner")
34 .field("initial_agent", &self.initial_agent)
35 .field("max_handoffs", &self.max_handoffs)
36 .field("agent_count", &self.agents.len())
37 .finish()
38 }
39}
40
41pub struct HandoffRunnerBuilder<P: LlmProvider> {
43 agents: HashMap<String, AgentRunner<P>>,
44 initial_agent: Option<String>,
45 max_handoffs: Option<usize>,
46}
47
48impl<P: LlmProvider> HandoffRunner<P> {
49 pub fn builder() -> HandoffRunnerBuilder<P> {
51 HandoffRunnerBuilder {
52 agents: HashMap::new(),
53 initial_agent: None,
54 max_handoffs: None,
55 }
56 }
57
58 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
64 let mut current_agent = self.initial_agent.clone();
65 let mut total_usage = TokenUsage::default();
66 let mut total_tool_calls = 0usize;
67 let mut total_cost: Option<f64> = None;
68 let mut effective_task = task.to_string();
69 let mut handoff_count = 0;
70
71 loop {
72 let agent = self.agents.get(¤t_agent).ok_or_else(|| {
73 Error::Agent(format!("handoff target agent '{current_agent}' not found"))
74 })?;
75
76 let output = agent
77 .execute(&effective_task)
78 .await
79 .map_err(|e| e.accumulate_usage(total_usage))?;
80 output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
81
82 if let Some((target, context_mode, reason)) = parse_handoff_sentinel(&output.result) {
84 handoff_count += 1;
85 if handoff_count > self.max_handoffs {
86 let mut final_output = output;
88 final_output.tokens_used = total_usage;
89 final_output.tool_calls_made = total_tool_calls;
90 final_output.estimated_cost_usd = total_cost;
91 final_output.result = format!(
92 "[handoff limit reached after {} handoffs]\n{}",
93 self.max_handoffs, final_output.result
94 );
95 return Ok(final_output);
96 }
97
98 if !self.agents.contains_key(&target) {
100 return Err(Error::Agent(format!(
101 "handoff target '{target}' not found. Available: {}",
102 self.agents.keys().cloned().collect::<Vec<_>>().join(", ")
103 )));
104 }
105
106 effective_task = match context_mode {
108 HandoffContextMode::Full => {
109 format!(
110 "## Handoff from {current_agent}\n\
111 Reason: {reason}\n\n\
112 ## Original task\n{task}\n\n\
113 ## Conversation so far\n{result}",
114 result = output.result,
115 )
116 }
117 HandoffContextMode::Summary => {
118 format!(
119 "## Handoff from {current_agent}\n\
120 Reason: {reason}\n\n\
121 ## Original task\n{task}"
122 )
123 }
124 };
125
126 current_agent = target;
127 } else {
128 let mut final_output = output;
130 final_output.tokens_used = total_usage;
131 final_output.tool_calls_made = total_tool_calls;
132 final_output.estimated_cost_usd = total_cost;
133 return Ok(final_output);
134 }
135 }
136 }
137}
138
139impl<P: LlmProvider> HandoffRunnerBuilder<P> {
140 pub fn agent(mut self, name: impl Into<String>, runner: AgentRunner<P>) -> Self {
145 let name = name.into();
146 self.agents.insert(name, runner);
147 self
148 }
149
150 pub fn initial_agent(mut self, name: impl Into<String>) -> Self {
152 self.initial_agent = Some(name.into());
153 self
154 }
155
156 pub fn max_handoffs(mut self, max: usize) -> Self {
158 self.max_handoffs = Some(max);
159 self
160 }
161
162 pub fn build(self) -> Result<HandoffRunner<P>, Error> {
164 if self.agents.is_empty() {
165 return Err(Error::Config(
166 "HandoffRunner requires at least one agent".into(),
167 ));
168 }
169 let initial_agent = self
170 .initial_agent
171 .ok_or_else(|| Error::Config("HandoffRunner requires initial_agent".into()))?;
172 if !self.agents.contains_key(&initial_agent) {
173 return Err(Error::Config(format!(
174 "initial_agent '{initial_agent}' not found in registered agents"
175 )));
176 }
177 let max_handoffs = self
178 .max_handoffs
179 .ok_or_else(|| Error::Config("HandoffRunner requires max_handoffs".into()))?;
180 if max_handoffs == 0 {
181 return Err(Error::Config(
182 "HandoffRunner max_handoffs must be at least 1".into(),
183 ));
184 }
185
186 Ok(HandoffRunner {
187 agents: self.agents,
188 initial_agent,
189 max_handoffs,
190 })
191 }
192}
193
194pub fn make_handoff_tool(targets: Vec<HandoffTarget>) -> Arc<dyn crate::tool::Tool> {
198 Arc::new(HandoffTool::new(targets))
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::agent::test_helpers::{MockProvider, make_agent};
205 use crate::tool::handoff::HANDOFF_SENTINEL;
206
207 #[test]
212 fn builder_rejects_empty_agents() {
213 let result = HandoffRunner::<MockProvider>::builder()
214 .initial_agent("triage")
215 .max_handoffs(3)
216 .build();
217 assert!(result.is_err());
218 assert!(result.unwrap_err().to_string().contains("at least one"));
219 }
220
221 #[test]
222 fn builder_rejects_missing_initial_agent() {
223 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
224 "x", 1, 1,
225 )]));
226 let result = HandoffRunner::builder()
227 .agent("a", make_agent(provider, "a"))
228 .max_handoffs(3)
229 .build();
230 assert!(result.is_err());
231 assert!(
232 result
233 .unwrap_err()
234 .to_string()
235 .contains("requires initial_agent")
236 );
237 }
238
239 #[test]
240 fn builder_rejects_nonexistent_initial_agent() {
241 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
242 "x", 1, 1,
243 )]));
244 let result = HandoffRunner::builder()
245 .agent("a", make_agent(provider, "a"))
246 .initial_agent("nonexistent")
247 .max_handoffs(3)
248 .build();
249 assert!(result.is_err());
250 assert!(result.unwrap_err().to_string().contains("not found"));
251 }
252
253 #[test]
254 fn builder_rejects_zero_max_handoffs() {
255 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
256 "x", 1, 1,
257 )]));
258 let result = HandoffRunner::builder()
259 .agent("a", make_agent(provider, "a"))
260 .initial_agent("a")
261 .max_handoffs(0)
262 .build();
263 assert!(result.is_err());
264 assert!(result.unwrap_err().to_string().contains("at least 1"));
265 }
266
267 #[test]
268 fn builder_rejects_missing_max_handoffs() {
269 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
270 "x", 1, 1,
271 )]));
272 let result = HandoffRunner::builder()
273 .agent("a", make_agent(provider, "a"))
274 .initial_agent("a")
275 .build();
276 assert!(result.is_err());
277 assert!(
278 result
279 .unwrap_err()
280 .to_string()
281 .contains("requires max_handoffs")
282 );
283 }
284
285 #[test]
286 fn builder_accepts_valid_config() {
287 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
288 "done", 10, 5,
289 )]));
290 let result = HandoffRunner::builder()
291 .agent("triage", make_agent(provider, "triage"))
292 .initial_agent("triage")
293 .max_handoffs(5)
294 .build();
295 assert!(result.is_ok());
296 }
297
298 #[tokio::test]
303 async fn execute_no_handoff() {
304 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
305 "Direct answer.",
306 100,
307 50,
308 )]));
309
310 let runner = HandoffRunner::builder()
311 .agent("triage", make_agent(provider, "triage"))
312 .initial_agent("triage")
313 .max_handoffs(5)
314 .build()
315 .unwrap();
316
317 let output = runner.execute("simple question").await.unwrap();
318 assert_eq!(output.result, "Direct answer.");
319 assert_eq!(output.tokens_used.input_tokens, 100);
320 assert_eq!(output.tokens_used.output_tokens, 50);
321 }
322
323 #[tokio::test]
324 async fn execute_single_handoff() {
325 let triage_provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
327 &format!("{HANDOFF_SENTINEL}billing:summary:User has billing question"),
328 50,
329 20,
330 )]));
331 let billing_provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
332 "Your bill is $42.",
333 80,
334 30,
335 )]));
336
337 let runner = HandoffRunner::builder()
338 .agent("triage", make_agent(triage_provider, "triage"))
339 .agent("billing", make_agent(billing_provider, "billing"))
340 .initial_agent("triage")
341 .max_handoffs(5)
342 .build()
343 .unwrap();
344
345 let output = runner.execute("How much do I owe?").await.unwrap();
346 assert_eq!(output.result, "Your bill is $42.");
347 assert_eq!(output.tokens_used.input_tokens, 130);
348 assert_eq!(output.tokens_used.output_tokens, 50);
349 }
350
351 #[tokio::test]
352 async fn execute_max_handoffs_exceeded() {
353 let a_provider = Arc::new(MockProvider::new(vec![
355 MockProvider::text_response(&format!("{HANDOFF_SENTINEL}b:summary:need b"), 10, 5),
356 MockProvider::text_response(
357 &format!("{HANDOFF_SENTINEL}b:summary:need b again"),
358 10,
359 5,
360 ),
361 ]));
362 let b_provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
363 &format!("{HANDOFF_SENTINEL}a:summary:need a"),
364 10,
365 5,
366 )]));
367
368 let runner = HandoffRunner::builder()
369 .agent("a", make_agent(a_provider, "a"))
370 .agent("b", make_agent(b_provider, "b"))
371 .initial_agent("a")
372 .max_handoffs(2)
373 .build()
374 .unwrap();
375
376 let output = runner.execute("ping pong").await.unwrap();
377 assert!(output.result.contains("handoff limit reached"));
378 }
379
380 #[tokio::test]
381 async fn execute_handoff_to_unknown_agent() {
382 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
383 &format!("{HANDOFF_SENTINEL}nonexistent:summary:reason"),
384 10,
385 5,
386 )]));
387
388 let runner = HandoffRunner::builder()
389 .agent("a", make_agent(provider, "a"))
390 .initial_agent("a")
391 .max_handoffs(3)
392 .build()
393 .unwrap();
394
395 let result = runner.execute("test").await;
396 assert!(result.is_err());
397 assert!(result.unwrap_err().to_string().contains("not found"));
398 }
399
400 #[tokio::test]
401 async fn execute_full_context_mode() {
402 let triage_provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
403 &format!("{HANDOFF_SENTINEL}support:full:Complex issue needs full context"),
404 50,
405 20,
406 )]));
407 let support_provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
408 "I can see the full context. Fixed!",
409 80,
410 30,
411 )]));
412
413 let runner = HandoffRunner::builder()
414 .agent("triage", make_agent(triage_provider, "triage"))
415 .agent("support", make_agent(support_provider, "support"))
416 .initial_agent("triage")
417 .max_handoffs(5)
418 .build()
419 .unwrap();
420
421 let output = runner.execute("Complex problem").await.unwrap();
422 assert_eq!(output.result, "I can see the full context. Fixed!");
423 }
424
425 #[test]
426 fn debug_impl() {
427 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
428 "x", 1, 1,
429 )]));
430 let runner = HandoffRunner::builder()
431 .agent("a", make_agent(provider, "a"))
432 .initial_agent("a")
433 .max_handoffs(3)
434 .build()
435 .unwrap();
436
437 let debug = format!("{runner:?}");
438 assert!(debug.contains("HandoffRunner"));
439 assert!(debug.contains("initial_agent"));
440 }
441}