Skip to main content

heartbit_core/agent/
voting.rs

1//! Majority-voting workflow agent.
2//!
3//! Runs N voter agents in parallel on the same task, extracts a vote from each
4//! output, and returns the full output of the first voter whose vote matches the
5//! majority. Ties are resolved by an optional `tie_breaker` (defaults to first
6//! alphabetically).
7
8use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::{AgentOutput, AgentRunner};
18
19/// Extracts a vote string from an agent's output text.
20type VoteExtractor = Box<dyn Fn(&str) -> String + Send + Sync>;
21
22/// Resolves ties when multiple votes share the highest count.
23/// Receives the tied vote strings and must return one of them.
24type TieBreaker = Box<dyn Fn(&[String]) -> String + Send + Sync>;
25
26/// The result of a voting round, including the winning vote, the full tally,
27/// and the output from the first voter that cast the winning vote.
28#[derive(Debug)]
29pub struct VoteResult {
30    /// The vote string that won.
31    pub winner: String,
32    /// Vote string → number of voters that cast it.
33    pub tally: HashMap<String, usize>,
34    /// The full `AgentOutput` from the first voter whose vote matched the winner.
35    pub output: AgentOutput,
36}
37
38/// Orchestrates majority voting across N agents running in parallel.
39pub struct VotingAgent<P: LlmProvider + 'static> {
40    voters: Vec<Arc<AgentRunner<P>>>,
41    vote_extractor: VoteExtractor,
42    tie_breaker: TieBreaker,
43}
44
45impl<P: LlmProvider + 'static> std::fmt::Debug for VotingAgent<P> {
46    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47        f.debug_struct("VotingAgent")
48            .field("voter_count", &self.voters.len())
49            .finish()
50    }
51}
52
53/// Builder for [`VotingAgent`].
54pub struct VotingAgentBuilder<P: LlmProvider + 'static> {
55    voters: Vec<Arc<AgentRunner<P>>>,
56    vote_extractor: Option<VoteExtractor>,
57    tie_breaker: Option<TieBreaker>,
58}
59
60impl<P: LlmProvider + 'static> VotingAgent<P> {
61    /// Create a new [`VotingAgentBuilder`].
62    pub fn builder() -> VotingAgentBuilder<P> {
63        VotingAgentBuilder {
64            voters: Vec::new(),
65            vote_extractor: None,
66            tie_breaker: None,
67        }
68    }
69
70    /// Execute all voters in parallel, tally votes, and return the winning result.
71    pub async fn execute(&self, task: &str) -> Result<VoteResult, Error> {
72        let mut set = JoinSet::new();
73
74        for (idx, voter) in self.voters.iter().enumerate() {
75            let voter = Arc::clone(voter);
76            let task = task.to_string();
77            set.spawn(async move {
78                let result = voter.execute(&task).await;
79                (idx, result)
80            });
81        }
82
83        // Collect results in completion order, tracking accumulated usage for
84        // partial-usage-on-error semantics.
85        let mut outputs: Vec<(usize, AgentOutput)> = Vec::with_capacity(self.voters.len());
86        let mut total_usage = TokenUsage::default();
87
88        while let Some(join_result) = set.join_next().await {
89            let (idx, agent_result) = join_result
90                .map_err(|e| Error::Agent(format!("voting agent task panicked: {e}")))?;
91            let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
92            total_usage += output.tokens_used;
93            outputs.push((idx, output));
94        }
95
96        // Sort by original index for deterministic vote ordering.
97        outputs.sort_by_key(|(idx, _)| *idx);
98
99        // Extract votes and build tally.
100        let votes: Vec<String> = outputs
101            .iter()
102            .map(|(_, output)| (self.vote_extractor)(&output.result))
103            .collect();
104
105        let mut tally: HashMap<String, usize> = HashMap::new();
106        for vote in &votes {
107            *tally.entry(vote.clone()).or_insert(0) += 1;
108        }
109
110        // Find the maximum vote count.
111        let max_count = tally.values().copied().max().unwrap_or(0);
112
113        // Collect all votes tied at the maximum.
114        let mut top_votes: Vec<String> = tally
115            .iter()
116            .filter(|&(_, &count)| count == max_count)
117            .map(|(vote, _)| vote.clone())
118            .collect();
119        top_votes.sort();
120
121        let winner = if top_votes.len() == 1 {
122            top_votes.into_iter().next().expect("at least one vote")
123        } else {
124            (self.tie_breaker)(&top_votes)
125        };
126
127        // Find the first voter (by original index) whose vote matches the winner.
128        let winner_idx = votes
129            .iter()
130            .position(|v| *v == winner)
131            .expect("winner must be among votes");
132
133        let (_, mut winning_output) = outputs.remove(winner_idx);
134
135        // Accumulate tool_calls and cost from all voters (usage already tracked
136        // in `total_usage` during JoinSet collection above).
137        let mut total_tool_calls = 0usize;
138        let mut total_cost: Option<f64> = None;
139        for (_, output) in &outputs {
140            total_tool_calls += output.tool_calls_made;
141            if let Some(cost) = output.estimated_cost_usd {
142                *total_cost.get_or_insert(0.0) += cost;
143            }
144        }
145        total_tool_calls += winning_output.tool_calls_made;
146        if let Some(cost) = winning_output.estimated_cost_usd {
147            *total_cost.get_or_insert(0.0) += cost;
148        }
149
150        winning_output.tokens_used = total_usage;
151        winning_output.tool_calls_made = total_tool_calls;
152        winning_output.estimated_cost_usd = total_cost;
153
154        Ok(VoteResult {
155            winner,
156            tally,
157            output: winning_output,
158        })
159    }
160}
161
162impl<P: LlmProvider + 'static> VotingAgentBuilder<P> {
163    /// Add a voter agent. Wraps it in `Arc` for concurrent sharing.
164    pub fn voter(mut self, agent: AgentRunner<P>) -> Self {
165        self.voters.push(Arc::new(agent));
166        self
167    }
168
169    /// Add multiple voter agents.
170    pub fn voters(mut self, agents: Vec<AgentRunner<P>>) -> Self {
171        self.voters.extend(agents.into_iter().map(Arc::new));
172        self
173    }
174
175    /// Set the vote extractor function.
176    pub fn vote_extractor(mut self, f: impl Fn(&str) -> String + Send + Sync + 'static) -> Self {
177        self.vote_extractor = Some(Box::new(f));
178        self
179    }
180
181    /// Set an optional tie-breaker function.
182    pub fn tie_breaker(mut self, f: impl Fn(&[String]) -> String + Send + Sync + 'static) -> Self {
183        self.tie_breaker = Some(Box::new(f));
184        self
185    }
186
187    /// Build the [`VotingAgent`]. Requires at least 2 voters and a vote extractor.
188    pub fn build(self) -> Result<VotingAgent<P>, Error> {
189        if self.voters.len() < 2 {
190            return Err(Error::Config(
191                "VotingAgent requires at least 2 voters".into(),
192            ));
193        }
194        let vote_extractor = self
195            .vote_extractor
196            .ok_or_else(|| Error::Config("VotingAgent requires a vote_extractor".into()))?;
197        let tie_breaker = self.tie_breaker.unwrap_or_else(|| {
198            Box::new(|votes: &[String]| {
199                // Default: first alphabetically (votes are already sorted).
200                votes[0].clone()
201            })
202        });
203        Ok(VotingAgent {
204            voters: self.voters,
205            vote_extractor,
206            tie_breaker,
207        })
208    }
209}
210
211// ===========================================================================
212// Tests
213// ===========================================================================
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use crate::agent::test_helpers::{MockProvider, make_agent};
219
220    fn yes_no_extractor(output: &str) -> String {
221        if output.contains("YES") {
222            "YES".to_string()
223        } else {
224            "NO".to_string()
225        }
226    }
227
228    // -----------------------------------------------------------------------
229    // Builder validation tests
230    // -----------------------------------------------------------------------
231
232    #[test]
233    fn builder_rejects_fewer_than_two_voters() {
234        let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
235            "YES", 10, 5,
236        )]));
237        let result = VotingAgent::builder()
238            .voter(make_agent(provider, "only-one"))
239            .vote_extractor(yes_no_extractor)
240            .build();
241        assert!(result.is_err());
242        assert!(result.unwrap_err().to_string().contains("at least 2"));
243    }
244
245    #[test]
246    fn builder_rejects_zero_voters() {
247        let result = VotingAgent::<MockProvider>::builder()
248            .vote_extractor(yes_no_extractor)
249            .build();
250        assert!(result.is_err());
251        assert!(result.unwrap_err().to_string().contains("at least 2"));
252    }
253
254    #[test]
255    fn builder_rejects_missing_vote_extractor() {
256        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
257            "YES", 10, 5,
258        )]));
259        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
260            "YES", 10, 5,
261        )]));
262        let result = VotingAgent::builder()
263            .voter(make_agent(p1, "a"))
264            .voter(make_agent(p2, "b"))
265            .build();
266        assert!(result.is_err());
267        assert!(result.unwrap_err().to_string().contains("vote_extractor"));
268    }
269
270    #[test]
271    fn builder_accepts_valid_config_without_tie_breaker() {
272        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
273            "YES", 10, 5,
274        )]));
275        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
276            "NO", 10, 5,
277        )]));
278        let result = VotingAgent::builder()
279            .voter(make_agent(p1, "a"))
280            .voter(make_agent(p2, "b"))
281            .vote_extractor(yes_no_extractor)
282            .build();
283        assert!(result.is_ok());
284    }
285
286    #[test]
287    fn builder_accepts_valid_config_with_tie_breaker() {
288        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
289            "YES", 10, 5,
290        )]));
291        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
292            "NO", 10, 5,
293        )]));
294        let result = VotingAgent::builder()
295            .voter(make_agent(p1, "a"))
296            .voter(make_agent(p2, "b"))
297            .vote_extractor(yes_no_extractor)
298            .tie_breaker(|votes| votes.last().unwrap().clone())
299            .build();
300        assert!(result.is_ok());
301    }
302
303    // -----------------------------------------------------------------------
304    // Execution tests
305    // -----------------------------------------------------------------------
306
307    #[test]
308    fn builder_voters_bulk_method() {
309        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
310            "YES", 10, 5,
311        )]));
312        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
313            "NO", 10, 5,
314        )]));
315        let agents = vec![make_agent(p1, "a"), make_agent(p2, "b")];
316        let result = VotingAgent::builder()
317            .voters(agents)
318            .vote_extractor(yes_no_extractor)
319            .build();
320        assert!(result.is_ok());
321    }
322
323    #[tokio::test]
324    async fn unanimous_vote() {
325        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
326            "I vote YES",
327            100,
328            50,
329        )]));
330        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
331            "Definitely YES",
332            200,
333            80,
334        )]));
335        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
336            "YES please",
337            150,
338            60,
339        )]));
340
341        let voting = VotingAgent::builder()
342            .voter(make_agent(p1, "v1"))
343            .voter(make_agent(p2, "v2"))
344            .voter(make_agent(p3, "v3"))
345            .vote_extractor(yes_no_extractor)
346            .build()
347            .unwrap();
348
349        let result = voting.execute("should we?").await.unwrap();
350        assert_eq!(result.winner, "YES");
351        assert_eq!(result.tally["YES"], 3);
352        assert!(!result.tally.contains_key("NO"));
353        // Output should be from one of the YES voters
354        assert!(result.output.result.contains("YES"));
355    }
356
357    #[tokio::test]
358    async fn majority_vote_two_of_three() {
359        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
360            "I say YES",
361            100,
362            50,
363        )]));
364        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
365            "NO way", 200, 80,
366        )]));
367        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
368            "YES definitely",
369            150,
370            60,
371        )]));
372
373        let voting = VotingAgent::builder()
374            .voter(make_agent(p1, "v1"))
375            .voter(make_agent(p2, "v2"))
376            .voter(make_agent(p3, "v3"))
377            .vote_extractor(yes_no_extractor)
378            .build()
379            .unwrap();
380
381        let result = voting.execute("proceed?").await.unwrap();
382        assert_eq!(result.winner, "YES");
383        assert_eq!(result.tally["YES"], 2);
384        assert_eq!(result.tally["NO"], 1);
385    }
386
387    #[tokio::test]
388    async fn tie_broken_by_default_alphabetical() {
389        // 2 voters: one YES, one NO — tie. Default tie-breaker picks alphabetically first.
390        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
391            "NO thanks",
392            100,
393            50,
394        )]));
395        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
396            "YES sure", 200, 80,
397        )]));
398
399        let voting = VotingAgent::builder()
400            .voter(make_agent(p1, "v1"))
401            .voter(make_agent(p2, "v2"))
402            .vote_extractor(yes_no_extractor)
403            .build()
404            .unwrap();
405
406        let result = voting.execute("tie?").await.unwrap();
407        // "NO" < "YES" alphabetically
408        assert_eq!(result.winner, "NO");
409        assert_eq!(result.tally["YES"], 1);
410        assert_eq!(result.tally["NO"], 1);
411    }
412
413    #[tokio::test]
414    async fn tie_broken_by_custom_tie_breaker() {
415        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
416            "NO thanks",
417            100,
418            50,
419        )]));
420        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
421            "YES sure", 200, 80,
422        )]));
423
424        let voting = VotingAgent::builder()
425            .voter(make_agent(p1, "v1"))
426            .voter(make_agent(p2, "v2"))
427            .vote_extractor(yes_no_extractor)
428            .tie_breaker(|votes| votes.last().unwrap().clone()) // pick last alphabetically
429            .build()
430            .unwrap();
431
432        let result = voting.execute("tie?").await.unwrap();
433        // Custom tie-breaker picks last alphabetically: "YES"
434        assert_eq!(result.winner, "YES");
435    }
436
437    #[tokio::test]
438    async fn token_usage_accumulated_across_all_voters() {
439        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
440            "YES", 100, 50,
441        )]));
442        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
443            "YES", 200, 80,
444        )]));
445        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446            "YES", 150, 60,
447        )]));
448
449        let voting = VotingAgent::builder()
450            .voter(make_agent(p1, "v1"))
451            .voter(make_agent(p2, "v2"))
452            .voter(make_agent(p3, "v3"))
453            .vote_extractor(yes_no_extractor)
454            .build()
455            .unwrap();
456
457        let result = voting.execute("go").await.unwrap();
458        assert_eq!(result.output.tokens_used.input_tokens, 450);
459        assert_eq!(result.output.tokens_used.output_tokens, 190);
460    }
461
462    #[tokio::test]
463    async fn error_carries_partial_usage() {
464        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
465            "YES", 100, 50,
466        )]));
467        // Second provider has no responses -> error
468        let p2 = Arc::new(MockProvider::new(vec![]));
469
470        let voting = VotingAgent::builder()
471            .voter(make_agent(p1, "good"))
472            .voter(make_agent(p2, "bad"))
473            .vote_extractor(yes_no_extractor)
474            .build()
475            .unwrap();
476
477        let err = voting.execute("task").await.unwrap_err();
478        let partial = err.partial_usage();
479        // JoinSet ordering is non-deterministic: the successful voter may
480        // or may not finish before the error is collected.
481        assert!(
482            partial.input_tokens == 0 || partial.input_tokens >= 100,
483            "partial usage should be zero or include completed voter"
484        );
485    }
486
487    #[test]
488    fn debug_impl() {
489        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
490            "YES", 10, 5,
491        )]));
492        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
493            "NO", 10, 5,
494        )]));
495
496        let voting = VotingAgent::builder()
497            .voter(make_agent(p1, "a"))
498            .voter(make_agent(p2, "b"))
499            .vote_extractor(yes_no_extractor)
500            .build()
501            .unwrap();
502
503        let debug = format!("{voting:?}");
504        assert!(debug.contains("VotingAgent"));
505        assert!(debug.contains("voter_count"));
506        assert!(debug.contains("2"));
507    }
508
509    #[tokio::test]
510    async fn vote_result_contains_correct_tally() {
511        let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
512            "YES agree",
513            10,
514            5,
515        )]));
516        let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
517            "NO disagree",
518            10,
519            5,
520        )]));
521        let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
522            "YES concur",
523            10,
524            5,
525        )]));
526        let p4 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
527            "NO object",
528            10,
529            5,
530        )]));
531        let p5 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
532            "YES absolutely",
533            10,
534            5,
535        )]));
536
537        let voting = VotingAgent::builder()
538            .voter(make_agent(p1, "v1"))
539            .voter(make_agent(p2, "v2"))
540            .voter(make_agent(p3, "v3"))
541            .voter(make_agent(p4, "v4"))
542            .voter(make_agent(p5, "v5"))
543            .vote_extractor(yes_no_extractor)
544            .build()
545            .unwrap();
546
547        let result = voting.execute("vote").await.unwrap();
548        assert_eq!(result.winner, "YES");
549        assert_eq!(result.tally.len(), 2);
550        assert_eq!(result.tally["YES"], 3);
551        assert_eq!(result.tally["NO"], 2);
552    }
553}