Skip to main content

heartbit_core/agent/
debate.rs

1//! Debate workflow agent.
2//!
3//! Orchestrates multi-round debates between N debater agents with a judge
4//! agent that synthesizes the final answer from the complete transcript.
5//! Debaters run in parallel each round via `tokio::JoinSet`.
6
7use std::sync::Arc;
8
9use tokio::task::JoinSet;
10
11use crate::error::Error;
12use crate::llm::LlmProvider;
13use crate::llm::types::TokenUsage;
14
15use super::{AgentOutput, AgentRunner};
16
17/// Optional early-stop predicate. Receives the full transcript so far;
18/// returns `true` to end the debate before `max_rounds`.
19type StopCondition = Box<dyn Fn(&str) -> bool + Send + Sync>;
20
21// ---------------------------------------------------------------------------
22// DebateAgent
23// ---------------------------------------------------------------------------
24
25/// Orchestrates a multi-round debate between N debater agents, then asks a
26/// judge agent to synthesize the final answer from the complete transcript.
27///
28/// Each round, all debaters run in parallel and receive the full debate
29/// history as input. After all rounds (or early stop), the judge produces
30/// the final output.
31///
32/// **Note on transcript growth**: The full transcript is passed to every
33/// debater each round, so input tokens grow quadratically with round count
34/// and debater count. Keep `max_rounds` small (2–4) or use `should_stop`
35/// for early termination.
36pub struct DebateAgent<P: LlmProvider + 'static> {
37    debaters: Vec<Arc<AgentRunner<P>>>,
38    judge: Arc<AgentRunner<P>>,
39    max_rounds: usize,
40    should_stop: Option<StopCondition>,
41}
42
43impl<P: LlmProvider + 'static> std::fmt::Debug for DebateAgent<P> {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("DebateAgent")
46            .field("debater_count", &self.debaters.len())
47            .field("max_rounds", &self.max_rounds)
48            .finish()
49    }
50}
51
52/// Builder for [`DebateAgent`].
53pub struct DebateAgentBuilder<P: LlmProvider + 'static> {
54    debaters: Vec<AgentRunner<P>>,
55    judge: Option<AgentRunner<P>>,
56    max_rounds: Option<usize>,
57    should_stop: Option<StopCondition>,
58}
59
60impl<P: LlmProvider + 'static> DebateAgent<P> {
61    /// Create a new [`DebateAgentBuilder`] for constructing a debate agent.
62    pub fn builder() -> DebateAgentBuilder<P> {
63        DebateAgentBuilder {
64            debaters: Vec::new(),
65            judge: None,
66            max_rounds: None,
67            should_stop: None,
68        }
69    }
70
71    /// Execute the debate.
72    ///
73    /// Each round, all debaters receive the full transcript and run in
74    /// parallel. After `max_rounds` (or early stop), the judge synthesizes
75    /// the final answer.
76    pub async fn execute(&self, task: &str) -> Result<AgentOutput, Error> {
77        let mut total_usage = TokenUsage::default();
78        let mut total_tool_calls = 0usize;
79        let mut total_cost: Option<f64> = None;
80
81        let mut transcript = format!("# Debate Topic\n{task}\n");
82
83        for round in 1..=self.max_rounds {
84            transcript.push_str(&format!("\n### Round {round}\n"));
85
86            // Run all debaters in parallel
87            let mut set = JoinSet::new();
88            for debater in &self.debaters {
89                let debater = Arc::clone(debater);
90                let input = transcript.clone();
91                set.spawn(async move {
92                    let name = debater.name().to_string();
93                    let result = debater.execute(&input).await;
94                    (name, result)
95                });
96            }
97
98            let mut round_results: Vec<(String, AgentOutput)> =
99                Vec::with_capacity(self.debaters.len());
100
101            while let Some(join_result) = set.join_next().await {
102                let (name, agent_result) = join_result
103                    .map_err(|e| Error::Agent(format!("debate agent task panicked: {e}")))?;
104                let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
105                output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
106                round_results.push((name, output));
107            }
108
109            // Sort by name for deterministic transcript ordering
110            round_results.sort_by(|a, b| a.0.cmp(&b.0));
111
112            for (name, output) in &round_results {
113                transcript.push_str(&format!("\n#### {name}\n{}\n", output.result));
114            }
115
116            // Check early stop
117            if self.should_stop.as_ref().is_some_and(|f| f(&transcript)) {
118                break;
119            }
120        }
121
122        // Judge synthesizes the final answer
123        let judge_output = self
124            .judge
125            .execute(&transcript)
126            .await
127            .map_err(|e| e.accumulate_usage(total_usage))?;
128        judge_output.accumulate_into(&mut total_usage, &mut total_tool_calls, &mut total_cost);
129
130        Ok(AgentOutput {
131            result: judge_output.result,
132            tool_calls_made: total_tool_calls,
133            tokens_used: total_usage,
134            structured: judge_output.structured,
135            estimated_cost_usd: total_cost,
136            model_name: judge_output.model_name,
137        })
138    }
139}
140
141impl<P: LlmProvider + 'static> DebateAgentBuilder<P> {
142    /// Add a debater agent.
143    pub fn debater(mut self, agent: AgentRunner<P>) -> Self {
144        self.debaters.push(agent);
145        self
146    }
147
148    /// Add multiple debater agents.
149    pub fn debaters(mut self, agents: Vec<AgentRunner<P>>) -> Self {
150        self.debaters.extend(agents);
151        self
152    }
153
154    /// Set the judge agent that synthesizes the final answer.
155    pub fn judge(mut self, agent: AgentRunner<P>) -> Self {
156        self.judge = Some(agent);
157        self
158    }
159
160    /// Set the maximum number of debate rounds (must be >= 1).
161    pub fn max_rounds(mut self, n: usize) -> Self {
162        self.max_rounds = Some(n);
163        self
164    }
165
166    /// Set an optional early-stop predicate. The closure receives the full
167    /// transcript so far and returns `true` to end the debate early.
168    pub fn should_stop(mut self, f: impl Fn(&str) -> bool + Send + Sync + 'static) -> Self {
169        self.should_stop = Some(Box::new(f));
170        self
171    }
172
173    /// Build the [`DebateAgent`].
174    pub fn build(self) -> Result<DebateAgent<P>, Error> {
175        if self.debaters.len() < 2 {
176            return Err(Error::Config(
177                "DebateAgent requires at least 2 debaters".into(),
178            ));
179        }
180        let judge = self
181            .judge
182            .ok_or_else(|| Error::Config("DebateAgent requires a judge".into()))?;
183        let max_rounds = self
184            .max_rounds
185            .ok_or_else(|| Error::Config("DebateAgent requires max_rounds".into()))?;
186        if max_rounds == 0 {
187            return Err(Error::Config(
188                "DebateAgent max_rounds must be at least 1".into(),
189            ));
190        }
191        Ok(DebateAgent {
192            debaters: self.debaters.into_iter().map(Arc::new).collect(),
193            judge: Arc::new(judge),
194            max_rounds,
195            should_stop: self.should_stop,
196        })
197    }
198}
199
200// ===========================================================================
201// Tests
202// ===========================================================================
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207    use crate::agent::test_helpers::{MockProvider, make_agent};
208
209    // -----------------------------------------------------------------------
210    // Builder validation tests
211    // -----------------------------------------------------------------------
212
213    #[test]
214    fn builder_rejects_fewer_than_two_debaters() {
215        let p = Arc::new(MockProvider::new(vec![]));
216        let judge_p = Arc::new(MockProvider::new(vec![]));
217        let result = DebateAgent::builder()
218            .debater(make_agent(p, "only-one"))
219            .judge(make_agent(judge_p, "judge"))
220            .max_rounds(3)
221            .build();
222        assert!(result.is_err());
223        assert!(
224            result
225                .unwrap_err()
226                .to_string()
227                .contains("at least 2 debaters")
228        );
229    }
230
231    #[test]
232    fn builder_rejects_zero_debaters() {
233        let judge_p = Arc::new(MockProvider::new(vec![]));
234        let result = DebateAgent::<MockProvider>::builder()
235            .judge(make_agent(judge_p, "judge"))
236            .max_rounds(3)
237            .build();
238        assert!(result.is_err());
239        assert!(
240            result
241                .unwrap_err()
242                .to_string()
243                .contains("at least 2 debaters")
244        );
245    }
246
247    #[test]
248    fn builder_rejects_missing_judge() {
249        let p1 = Arc::new(MockProvider::new(vec![]));
250        let p2 = Arc::new(MockProvider::new(vec![]));
251        let result = DebateAgent::builder()
252            .debater(make_agent(p1, "d1"))
253            .debater(make_agent(p2, "d2"))
254            .max_rounds(3)
255            .build();
256        assert!(result.is_err());
257        assert!(result.unwrap_err().to_string().contains("requires a judge"));
258    }
259
260    #[test]
261    fn builder_rejects_missing_max_rounds() {
262        let p1 = Arc::new(MockProvider::new(vec![]));
263        let p2 = Arc::new(MockProvider::new(vec![]));
264        let judge_p = Arc::new(MockProvider::new(vec![]));
265        let result = DebateAgent::builder()
266            .debater(make_agent(p1, "d1"))
267            .debater(make_agent(p2, "d2"))
268            .judge(make_agent(judge_p, "judge"))
269            .build();
270        assert!(result.is_err());
271        assert!(
272            result
273                .unwrap_err()
274                .to_string()
275                .contains("requires max_rounds")
276        );
277    }
278
279    #[test]
280    fn builder_rejects_zero_max_rounds() {
281        let p1 = Arc::new(MockProvider::new(vec![]));
282        let p2 = Arc::new(MockProvider::new(vec![]));
283        let judge_p = Arc::new(MockProvider::new(vec![]));
284        let result = DebateAgent::builder()
285            .debater(make_agent(p1, "d1"))
286            .debater(make_agent(p2, "d2"))
287            .judge(make_agent(judge_p, "judge"))
288            .max_rounds(0)
289            .build();
290        assert!(result.is_err());
291        assert!(result.unwrap_err().to_string().contains("at least 1"));
292    }
293
294    #[test]
295    fn builder_accepts_valid_config_without_should_stop() {
296        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
297            "a", 1, 1,
298        )]));
299        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
300            "b", 1, 1,
301        )]));
302        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
303            "j", 1, 1,
304        )]));
305        let result = DebateAgent::builder()
306            .debater(make_agent(p1, "d1"))
307            .debater(make_agent(p2, "d2"))
308            .judge(make_agent(judge_p, "judge"))
309            .max_rounds(3)
310            .build();
311        assert!(result.is_ok());
312    }
313
314    #[test]
315    fn builder_accepts_valid_config_with_should_stop() {
316        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
317            "a", 1, 1,
318        )]));
319        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
320            "b", 1, 1,
321        )]));
322        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
323            "j", 1, 1,
324        )]));
325        let result = DebateAgent::builder()
326            .debater(make_agent(p1, "d1"))
327            .debater(make_agent(p2, "d2"))
328            .judge(make_agent(judge_p, "judge"))
329            .max_rounds(3)
330            .should_stop(|t| t.contains("CONSENSUS"))
331            .build();
332        assert!(result.is_ok());
333    }
334
335    // -----------------------------------------------------------------------
336    // Execution tests
337    // -----------------------------------------------------------------------
338
339    #[tokio::test]
340    async fn single_round_debate() {
341        // 2 debaters + 1 judge, 1 round
342        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
343            "I argue for A",
344            100,
345            50,
346        )]));
347        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
348            "I argue for B",
349            200,
350            80,
351        )]));
352        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
353            "After deliberation, A wins",
354            150,
355            70,
356        )]));
357
358        let debate = DebateAgent::builder()
359            .debater(make_agent(p1, "debater-a"))
360            .debater(make_agent(p2, "debater-b"))
361            .judge(make_agent(judge_p, "judge"))
362            .max_rounds(1)
363            .build()
364            .unwrap();
365
366        let output = debate.execute("Which is better?").await.unwrap();
367        assert_eq!(output.result, "After deliberation, A wins");
368        // 100+200+150 = 450 input, 50+80+70 = 200 output
369        assert_eq!(output.tokens_used.input_tokens, 450);
370        assert_eq!(output.tokens_used.output_tokens, 200);
371    }
372
373    #[tokio::test]
374    async fn multi_round_accumulates_usage() {
375        // 2 debaters, 2 rounds, then judge
376        // Each debater needs 2 responses (one per round)
377        let p1 = Arc::new(MockProvider::new(vec![
378            MockProvider::text_response("round1-d1", 10, 5),
379            MockProvider::text_response("round2-d1", 10, 5),
380        ]));
381        let p2 = Arc::new(MockProvider::new(vec![
382            MockProvider::text_response("round1-d2", 20, 10),
383            MockProvider::text_response("round2-d2", 20, 10),
384        ]));
385        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
386            "final verdict",
387            30,
388            15,
389        )]));
390
391        let debate = DebateAgent::builder()
392            .debater(make_agent(p1, "d1"))
393            .debater(make_agent(p2, "d2"))
394            .judge(make_agent(judge_p, "judge"))
395            .max_rounds(2)
396            .build()
397            .unwrap();
398
399        let output = debate.execute("topic").await.unwrap();
400        assert_eq!(output.result, "final verdict");
401        // (10+20)*2 + 30 = 90 input, (5+10)*2 + 15 = 45 output
402        assert_eq!(output.tokens_used.input_tokens, 90);
403        assert_eq!(output.tokens_used.output_tokens, 45);
404    }
405
406    #[tokio::test]
407    async fn early_stop_via_should_stop() {
408        // 2 debaters, max 5 rounds, but stop early when transcript contains "CONSENSUS"
409        // Round 1: normal responses
410        // Round 2: debater-a says "CONSENSUS reached" -> should stop
411        let p1 = Arc::new(MockProvider::new(vec![
412            MockProvider::text_response("I disagree", 10, 5),
413            MockProvider::text_response("CONSENSUS reached", 10, 5),
414        ]));
415        let p2 = Arc::new(MockProvider::new(vec![
416            MockProvider::text_response("I also disagree", 10, 5),
417            MockProvider::text_response("I concur", 10, 5),
418        ]));
419        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
420            "judge says done",
421            10,
422            5,
423        )]));
424
425        let debate = DebateAgent::builder()
426            .debater(make_agent(p1, "debater-a"))
427            .debater(make_agent(p2, "debater-b"))
428            .judge(make_agent(judge_p, "judge"))
429            .max_rounds(5)
430            .should_stop(|transcript| transcript.contains("CONSENSUS"))
431            .build()
432            .unwrap();
433
434        let output = debate.execute("topic").await.unwrap();
435        assert_eq!(output.result, "judge says done");
436        // 2 rounds * 2 debaters * 10 + judge 10 = 50 input
437        assert_eq!(output.tokens_used.input_tokens, 50);
438        // 2 rounds * 2 debaters * 5 + judge 5 = 25 output
439        assert_eq!(output.tokens_used.output_tokens, 25);
440    }
441
442    #[tokio::test]
443    async fn error_carries_partial_usage() {
444        // Debater 1 succeeds, debater 2 errors in round 1
445        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446            "ok", 100, 50,
447        )]));
448        let p2 = Arc::new(MockProvider::new(vec![])); // will error
449        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
450            "judge", 10, 5,
451        )]));
452
453        let debate = DebateAgent::builder()
454            .debater(make_agent(p1, "good"))
455            .debater(make_agent(p2, "bad"))
456            .judge(make_agent(judge_p, "judge"))
457            .max_rounds(1)
458            .build()
459            .unwrap();
460
461        let err = debate.execute("topic").await.unwrap_err();
462        let partial = err.partial_usage();
463        // At minimum, partial usage is non-negative; exact value depends on
464        // JoinSet completion order (the successful debater may or may not
465        // finish before the error is collected).
466        assert!(
467            partial.input_tokens == 0 || partial.input_tokens >= 100,
468            "partial usage should be zero or include completed debater"
469        );
470    }
471
472    #[tokio::test]
473    async fn judge_error_carries_debater_usage() {
474        // Both debaters succeed, judge errors
475        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
476            "arg1", 100, 50,
477        )]));
478        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
479            "arg2", 200, 80,
480        )]));
481        let judge_p = Arc::new(MockProvider::new(vec![])); // will error
482
483        let debate = DebateAgent::builder()
484            .debater(make_agent(p1, "d1"))
485            .debater(make_agent(p2, "d2"))
486            .judge(make_agent(judge_p, "judge"))
487            .max_rounds(1)
488            .build()
489            .unwrap();
490
491        let err = debate.execute("topic").await.unwrap_err();
492        let partial = err.partial_usage();
493        // Both debaters' usage should be in partial
494        assert!(partial.input_tokens >= 300);
495    }
496
497    // -----------------------------------------------------------------------
498    // Debug and transcript tests
499    // -----------------------------------------------------------------------
500
501    #[test]
502    fn debug_impl() {
503        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
504            "a", 1, 1,
505        )]));
506        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
507            "b", 1, 1,
508        )]));
509        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
510            "j", 1, 1,
511        )]));
512        let debate = DebateAgent::builder()
513            .debater(make_agent(p1, "d1"))
514            .debater(make_agent(p2, "d2"))
515            .judge(make_agent(judge_p, "judge"))
516            .max_rounds(3)
517            .build()
518            .unwrap();
519
520        let debug = format!("{debate:?}");
521        assert!(debug.contains("DebateAgent"));
522        assert!(debug.contains("debater_count"));
523        assert!(debug.contains("2"));
524        assert!(debug.contains("max_rounds"));
525        assert!(debug.contains("3"));
526    }
527
528    #[tokio::test]
529    async fn judge_receives_transcript_with_round_headers_and_names() {
530        // Use captured_requests to inspect what the judge actually received.
531        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
532            "position-alpha",
533            10,
534            5,
535        )]));
536        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
537            "position-beta",
538            10,
539            5,
540        )]));
541        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
542            "verdict", 10, 5,
543        )]));
544
545        let debate = DebateAgent::builder()
546            .debater(make_agent(Arc::clone(&p1), "alpha"))
547            .debater(make_agent(Arc::clone(&p2), "beta"))
548            .judge(make_agent(Arc::clone(&judge_p), "judge"))
549            .max_rounds(1)
550            .build()
551            .unwrap();
552
553        let output = debate.execute("test topic").await.unwrap();
554        assert_eq!(output.result, "verdict");
555
556        // Inspect what the judge received via captured_requests
557        let judge_requests = judge_p.captured_requests.lock().unwrap();
558        assert_eq!(judge_requests.len(), 1);
559        let judge_input = &judge_requests[0].messages[0];
560        let input_text = match &judge_input.content[0] {
561            crate::llm::types::ContentBlock::Text { text } => text.as_str(),
562            _ => panic!("expected text content"),
563        };
564        assert!(
565            input_text.contains("# Debate Topic"),
566            "should have topic header"
567        );
568        assert!(
569            input_text.contains("test topic"),
570            "should have original topic"
571        );
572        assert!(
573            input_text.contains("### Round 1"),
574            "should have round header"
575        );
576        assert!(
577            input_text.contains("#### alpha"),
578            "should have debater name alpha"
579        );
580        assert!(
581            input_text.contains("#### beta"),
582            "should have debater name beta"
583        );
584        assert!(
585            input_text.contains("position-alpha"),
586            "should have alpha's argument"
587        );
588        assert!(
589            input_text.contains("position-beta"),
590            "should have beta's argument"
591        );
592    }
593
594    #[test]
595    fn builder_debaters_bulk_method() {
596        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
597            "a", 1, 1,
598        )]));
599        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
600            "b", 1, 1,
601        )]));
602        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
603            "j", 1, 1,
604        )]));
605        let agents = vec![make_agent(p1, "d1"), make_agent(p2, "d2")];
606        let result = DebateAgent::builder()
607            .debaters(agents)
608            .judge(make_agent(judge_p, "judge"))
609            .max_rounds(1)
610            .build();
611        assert!(result.is_ok());
612    }
613
614    #[tokio::test]
615    async fn three_debaters_single_round() {
616        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
617            "arg-1", 10, 5,
618        )]));
619        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
620            "arg-2", 20, 10,
621        )]));
622        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
623            "arg-3", 30, 15,
624        )]));
625        let judge_p = Arc::new(MockProvider::new(vec![MockProvider::text_response(
626            "three-way verdict",
627            40,
628            20,
629        )]));
630
631        let debate = DebateAgent::builder()
632            .debater(make_agent(p1, "d1"))
633            .debater(make_agent(p2, "d2"))
634            .debater(make_agent(p3, "d3"))
635            .judge(make_agent(judge_p, "judge"))
636            .max_rounds(1)
637            .build()
638            .unwrap();
639
640        let output = debate.execute("topic").await.unwrap();
641        assert_eq!(output.result, "three-way verdict");
642        // 10+20+30+40 = 100 input, 5+10+15+20 = 50 output
643        assert_eq!(output.tokens_used.input_tokens, 100);
644        assert_eq!(output.tokens_used.output_tokens, 50);
645    }
646}