1use std::sync::Arc;
8
9use tokio::task::JoinSet;
10
11use crate::error::Error;
12use crate::llm::LlmProvider;
13use crate::llm::types::TokenUsage;
14
15use super::{AgentOutput, AgentRunner};
16
17type StopCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
20
21pub struct DebateAgent<P: LlmProvider + 'static> {
37 debaters: Vec<Arc<AgentRunner<P>>>,
38 judge: Arc<AgentRunner<P>>,
39 max_rounds: usize,
40 should_stop: Option<StopCondition>,
41}
42
43impl<P: LlmProvider + 'static> std::fmt::Debug for DebateAgent<P> {
44 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45 f.debug_struct("DebateAgent")
46 .field("debater_count", &self.debaters.len())
47 .field("max_rounds", &self.max_rounds)
48 .finish()
49 }
50}
51
52pub struct DebateAgentBuilder<P: LlmProvider + 'static> {
54 debaters: Vec<AgentRunner<P>>,
55 judge: Option<AgentRunner<P>>,
56 max_rounds: Option<usize>,
57 should_stop: Option<StopCondition>,
58}
59
60impl<P: LlmProvider + 'static> DebateAgent<P> {
61 pub fn builder() -> DebateAgentBuilder<P> {
63 DebateAgentBuilder {
64 debaters: Vec::new(),
65 judge: None,
66 max_rounds: None,
67 should_stop: None,
68 }
69 }
70
71 pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
77 let mut total_usage = TokenUsage::default();
78 let mut total_tool_calls = 0usize;
79 let mut total_cost: Option<f64> = None;
80
81 let mut transcript = format!("# Debate Topic\n{task}\n");
82
83 for round in 1..=self.max_rounds {
84 transcript.push_str(&format!("\n### Round {round}\n"));
85
86 let mut set = JoinSet::new();
88 for debater in &self.debaters {
89 let debater = Arc::clone(debater);
90 let input = transcript.clone();
91 set.spawn(async move {
92 let name = debater.name().to_string();
93 let result = debater.execute(&input).await;
94 (name, result)
95 });
96 }
97
98 let mut round_results: Vec<(String, AgentOutput)> =
99 Vec::with_capacity(self.debaters.len());
100
101 while let Some(join_result) = set.join_next().await {
102 let (name, agent_result) = join_result
103 .map_err(|e| Error::Agent(format!("debate agent task panicked: {e}")))?;
104 let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
105 output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
106 round_results.push((name, output));
107 }
108
109 round_results.sort_by(|a, b| a.0.cmp(&b.0));
111
112 for (name, output) in &round_results {
113 transcript.push_str(&format!("\n#### {name}\n{}\n", output.result));
114 }
115
116 if self.should_stop.as_ref().is_some_and(|f| f(&transcript)) {
118 break;
119 }
120 }
121
122 let judge_output = self
124 .judge
125 .execute(&transcript)
126 .await
127 .map_err(|e| e.accumulate_usage(total_usage))?;
128 judge_output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
129
130 Ok(AgentOutput {
131 result: judge_output.result,
132 tool_calls_made: total_tool_calls,
133 tokens_used: total_usage,
134 structured: judge_output.structured,
135 estimated_cost_usd: total_cost,
136 model_name: judge_output.model_name,
137 })
138 }
139}
140
141impl<P: LlmProvider + 'static> DebateAgentBuilder<P> {
142 pub fn debater(mut self, agent: AgentRunner<P>) -> Self {
144 self.debaters.push(agent);
145 self
146 }
147
148 pub fn debaters(mut self, agents: Vec<AgentRunner<P>>) -> Self {
150 self.debaters.extend(agents);
151 self
152 }
153
154 pub fn judge(mut self, agent: AgentRunner<P>) -> Self {
156 self.judge = Some(agent);
157 self
158 }
159
160 pub fn max_rounds(mut self, n: usize) -> Self {
162 self.max_rounds = Some(n);
163 self
164 }
165
166 pub fn should_stop(mut self, f: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
169 self.should_stop = Some(Box::new(f));
170 self
171 }
172
173 pub fn build(self) -> Result<DebateAgent<P>, Error> {
175 if self.debaters.len() < 2 {
176 return Err(Error::Config(
177 "DebateAgent requires at least 2 debaters".into(),
178 ));
179 }
180 let judge = self
181 .judge
182 .ok_or_else(|| Error::Config("DebateAgent requires a judge".into()))?;
183 let max_rounds = self
184 .max_rounds
185 .ok_or_else(|| Error::Config("DebateAgent requires max_rounds".into()))?;
186 if max_rounds == 0 {
187 return Err(Error::Config(
188 "DebateAgent max_rounds must be at least 1".into(),
189 ));
190 }
191 Ok(DebateAgent {
192 debaters: self.debaters.into_iter().map(Arc::new).collect(),
193 judge: Arc::new(judge),
194 max_rounds,
195 should_stop: self.should_stop,
196 })
197 }
198}
199
200#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::agent::test_helpers::{MockProvider, make_agent};
208
209 #[test]
214 fn builder_rejects_fewer_than_two_debaters() {
215 let p = Arc::new(MockProvider::new(vec![]));
216 let judge_p = Arc::new(MockProvider::new(vec![]));
217 let result = DebateAgent::builder()
218 .debater(make_agent(p, "only-one"))
219 .judge(make_agent(judge_p, "judge"))
220 .max_rounds(3)
221 .build();
222 assert!(result.is_err());
223 assert!(
224 result
225 .unwrap_err()
226 .to_string()
227 .contains("at least 2 debaters")
228 );
229 }
230
231 #[test]
232 fn builder_rejects_zero_debaters() {
233 let judge_p = Arc::new(MockProvider::new(vec![]));
234 let result = DebateAgent::<MockProvider>::builder()
235 .judge(make_agent(judge_p, "judge"))
236 .max_rounds(3)
237 .build();
238 assert!(result.is_err());
239 assert!(
240 result
241 .unwrap_err()
242 .to_string()
243 .contains("at least 2 debaters")
244 );
245 }
246
247 #[test]
248 fn builder_rejects_missing_judge() {
249 let p1 = Arc::new(MockProvider::new(vec![]));
250 let p2 = Arc::new(MockProvider::new(vec![]));
251 let result = DebateAgent::builder()
252 .debater(make_agent(p1, "d1"))
253 .debater(make_agent(p2, "d2"))
254 .max_rounds(3)
255 .build();
256 assert!(result.is_err());
257 assert!(result.unwrap_err().to_string().contains("requires a judge"));
258 }
259
260 #[test]
261 fn builder_rejects_missing_max_rounds() {
262 let p1 = Arc::new(MockProvider::new(vec![]));
263 let p2 = Arc::new(MockProvider::new(vec![]));
264 let judge_p = Arc::new(MockProvider::new(vec![]));
265 let result = DebateAgent::builder()
266 .debater(make_agent(p1, "d1"))
267 .debater(make_agent(p2, "d2"))
268 .judge(make_agent(judge_p, "judge"))
269 .build();
270 assert!(result.is_err());
271 assert!(
272 result
273 .unwrap_err()
274 .to_string()
275 .contains("requires max_rounds")
276 );
277 }
278
279 #[test]
280 fn builder_rejects_zero_max_rounds() {
281 let p1 = Arc::new(MockProvider::new(vec![]));
282 let p2 = Arc::new(MockProvider::new(vec![]));
283 let judge_p = Arc::new(MockProvider::new(vec![]));
284 let result = DebateAgent::builder()
285 .debater(make_agent(p1, "d1"))
286 .debater(make_agent(p2, "d2"))
287 .judge(make_agent(judge_p, "judge"))
288 .max_rounds(0)
289 .build();
290 assert!(result.is_err());
291 assert!(result.unwrap_err().to_string().contains("at least 1"));
292 }
293
294 #[test]
295 fn builder_accepts_valid_config_without_should_stop() {
296 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
297 "a", 1, 1,
298 )]));
299 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
300 "b", 1, 1,
301 )]));
302 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
303 "j", 1, 1,
304 )]));
305 let result = DebateAgent::builder()
306 .debater(make_agent(p1, "d1"))
307 .debater(make_agent(p2, "d2"))
308 .judge(make_agent(judge_p, "judge"))
309 .max_rounds(3)
310 .build();
311 assert!(result.is_ok());
312 }
313
314 #[test]
315 fn builder_accepts_valid_config_with_should_stop() {
316 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
317 "a", 1, 1,
318 )]));
319 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
320 "b", 1, 1,
321 )]));
322 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
323 "j", 1, 1,
324 )]));
325 let result = DebateAgent::builder()
326 .debater(make_agent(p1, "d1"))
327 .debater(make_agent(p2, "d2"))
328 .judge(make_agent(judge_p, "judge"))
329 .max_rounds(3)
330 .should_stop(|t| t.contains("CONSENSUS"))
331 .build();
332 assert!(result.is_ok());
333 }
334
335 #[tokio::test]
340 async fn single_round_debate() {
341 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
343 "I argue for A",
344 100,
345 50,
346 )]));
347 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
348 "I argue for B",
349 200,
350 80,
351 )]));
352 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
353 "After deliberation, A wins",
354 150,
355 70,
356 )]));
357
358 let debate = DebateAgent::builder()
359 .debater(make_agent(p1, "debater-a"))
360 .debater(make_agent(p2, "debater-b"))
361 .judge(make_agent(judge_p, "judge"))
362 .max_rounds(1)
363 .build()
364 .unwrap();
365
366 let output = debate.execute("Which is better?").await.unwrap();
367 assert_eq!(output.result, "After deliberation, A wins");
368 assert_eq!(output.tokens_used.input_tokens, 450);
370 assert_eq!(output.tokens_used.output_tokens, 200);
371 }
372
373 #[tokio::test]
374 async fn multi_round_accumulates_usage() {
375 let p1 = Arc::new(MockProvider::new(vec![
378 MockProvider::text_response("round1-d1", 10, 5),
379 MockProvider::text_response("round2-d1", 10, 5),
380 ]));
381 let p2 = Arc::new(MockProvider::new(vec![
382 MockProvider::text_response("round1-d2", 20, 10),
383 MockProvider::text_response("round2-d2", 20, 10),
384 ]));
385 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
386 "final verdict",
387 30,
388 15,
389 )]));
390
391 let debate = DebateAgent::builder()
392 .debater(make_agent(p1, "d1"))
393 .debater(make_agent(p2, "d2"))
394 .judge(make_agent(judge_p, "judge"))
395 .max_rounds(2)
396 .build()
397 .unwrap();
398
399 let output = debate.execute("topic").await.unwrap();
400 assert_eq!(output.result, "final verdict");
401 assert_eq!(output.tokens_used.input_tokens, 90);
403 assert_eq!(output.tokens_used.output_tokens, 45);
404 }
405
406 #[tokio::test]
407 async fn early_stop_via_should_stop() {
408 let p1 = Arc::new(MockProvider::new(vec![
412 MockProvider::text_response("I disagree", 10, 5),
413 MockProvider::text_response("CONSENSUS reached", 10, 5),
414 ]));
415 let p2 = Arc::new(MockProvider::new(vec![
416 MockProvider::text_response("I also disagree", 10, 5),
417 MockProvider::text_response("I concur", 10, 5),
418 ]));
419 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
420 "judge says done",
421 10,
422 5,
423 )]));
424
425 let debate = DebateAgent::builder()
426 .debater(make_agent(p1, "debater-a"))
427 .debater(make_agent(p2, "debater-b"))
428 .judge(make_agent(judge_p, "judge"))
429 .max_rounds(5)
430 .should_stop(|transcript| transcript.contains("CONSENSUS"))
431 .build()
432 .unwrap();
433
434 let output = debate.execute("topic").await.unwrap();
435 assert_eq!(output.result, "judge says done");
436 assert_eq!(output.tokens_used.input_tokens, 50);
438 assert_eq!(output.tokens_used.output_tokens, 25);
440 }
441
442 #[tokio::test]
443 async fn error_carries_partial_usage() {
444 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446 "ok", 100, 50,
447 )]));
448 let p2 = Arc::new(MockProvider::new(vec![])); let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
450 "judge", 10, 5,
451 )]));
452
453 let debate = DebateAgent::builder()
454 .debater(make_agent(p1, "good"))
455 .debater(make_agent(p2, "bad"))
456 .judge(make_agent(judge_p, "judge"))
457 .max_rounds(1)
458 .build()
459 .unwrap();
460
461 let err = debate.execute("topic").await.unwrap_err();
462 let partial = err.partial_usage();
463 assert!(
467 partial.input_tokens == 0 || partial.input_tokens >= 100,
468 "partial usage should be zero or include completed debater"
469 );
470 }
471
472 #[tokio::test]
473 async fn judge_error_carries_debater_usage() {
474 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
476 "arg1", 100, 50,
477 )]));
478 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
479 "arg2", 200, 80,
480 )]));
481 let judge_p = Arc::new(MockProvider::new(vec![])); let debate = DebateAgent::builder()
484 .debater(make_agent(p1, "d1"))
485 .debater(make_agent(p2, "d2"))
486 .judge(make_agent(judge_p, "judge"))
487 .max_rounds(1)
488 .build()
489 .unwrap();
490
491 let err = debate.execute("topic").await.unwrap_err();
492 let partial = err.partial_usage();
493 assert!(partial.input_tokens >= 300);
495 }
496
497 #[test]
502 fn debug_impl() {
503 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
504 "a", 1, 1,
505 )]));
506 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
507 "b", 1, 1,
508 )]));
509 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
510 "j", 1, 1,
511 )]));
512 let debate = DebateAgent::builder()
513 .debater(make_agent(p1, "d1"))
514 .debater(make_agent(p2, "d2"))
515 .judge(make_agent(judge_p, "judge"))
516 .max_rounds(3)
517 .build()
518 .unwrap();
519
520 let debug = format!("{debate:?}");
521 assert!(debug.contains("DebateAgent"));
522 assert!(debug.contains("debater_count"));
523 assert!(debug.contains("2"));
524 assert!(debug.contains("max_rounds"));
525 assert!(debug.contains("3"));
526 }
527
528 #[tokio::test]
529 async fn judge_receives_transcript_with_round_headers_and_names() {
530 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
532 "position-alpha",
533 10,
534 5,
535 )]));
536 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
537 "position-beta",
538 10,
539 5,
540 )]));
541 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
542 "verdict", 10, 5,
543 )]));
544
545 let debate = DebateAgent::builder()
546 .debater(make_agent(Arc::clone(&p1), "alpha"))
547 .debater(make_agent(Arc::clone(&p2), "beta"))
548 .judge(make_agent(Arc::clone(&judge_p), "judge"))
549 .max_rounds(1)
550 .build()
551 .unwrap();
552
553 let output = debate.execute("test topic").await.unwrap();
554 assert_eq!(output.result, "verdict");
555
556 let judge_requests = judge_p.captured_requests.lock().unwrap();
558 assert_eq!(judge_requests.len(), 1);
559 let judge_input = &judge_requests[0].messages[0];
560 let input_text = match &judge_input.content[0] {
561 crate::llm::types::ContentBlock::Text { text } => text.as_str(),
562 _ => panic!("expected text content"),
563 };
564 assert!(
565 input_text.contains("# Debate Topic"),
566 "should have topic header"
567 );
568 assert!(
569 input_text.contains("test topic"),
570 "should have original topic"
571 );
572 assert!(
573 input_text.contains("### Round 1"),
574 "should have round header"
575 );
576 assert!(
577 input_text.contains("#### alpha"),
578 "should have debater name alpha"
579 );
580 assert!(
581 input_text.contains("#### beta"),
582 "should have debater name beta"
583 );
584 assert!(
585 input_text.contains("position-alpha"),
586 "should have alpha's argument"
587 );
588 assert!(
589 input_text.contains("position-beta"),
590 "should have beta's argument"
591 );
592 }
593
594 #[test]
595 fn builder_debaters_bulk_method() {
596 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
597 "a", 1, 1,
598 )]));
599 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
600 "b", 1, 1,
601 )]));
602 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
603 "j", 1, 1,
604 )]));
605 let agents = vec![make_agent(p1, "d1"), make_agent(p2, "d2")];
606 let result = DebateAgent::builder()
607 .debaters(agents)
608 .judge(make_agent(judge_p, "judge"))
609 .max_rounds(1)
610 .build();
611 assert!(result.is_ok());
612 }
613
614 #[tokio::test]
615 async fn three_debaters_single_round() {
616 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
617 "arg-1", 10, 5,
618 )]));
619 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
620 "arg-2", 20, 10,
621 )]));
622 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
623 "arg-3", 30, 15,
624 )]));
625 let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
626 "three-way verdict",
627 40,
628 20,
629 )]));
630
631 let debate = DebateAgent::builder()
632 .debater(make_agent(p1, "d1"))
633 .debater(make_agent(p2, "d2"))
634 .debater(make_agent(p3, "d3"))
635 .judge(make_agent(judge_p, "judge"))
636 .max_rounds(1)
637 .build()
638 .unwrap();
639
640 let output = debate.execute("topic").await.unwrap();
641 assert_eq!(output.result, "three-way verdict");
642 assert_eq!(output.tokens_used.input_tokens, 100);
644 assert_eq!(output.tokens_used.output_tokens, 50);
645 }
646}