use std::collections::HashMap;
use std::sync::Arc;
use tokio::task::JoinSet;
use crate::error::Error;
use crate::llm::LlmProvider;
use crate::llm::types::TokenUsage;
use super::{AgentOutput, AgentRunner};
type VoteExtractor = Box<dyn Fn(&str) -> String + Send + Sync>;
type TieBreaker = Box<dyn Fn(&[String]) -> String + Send + Sync>;
#[derive(Debug)]
pub struct VoteResult {
pub winner: String,
pub tally: HashMap<String, usize>,
pub output: AgentOutput,
}
pub struct VotingAgent<P: LlmProvider + 'static> {
voters: Vec<Arc<AgentRunner<P>>>,
vote_extractor: VoteExtractor,
tie_breaker: TieBreaker,
}
impl<P: LlmProvider + 'static> std::fmt::Debug for VotingAgent<P> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("VotingAgent")
.field("voter_count", &self.voters.len())
.finish()
}
}
pub struct VotingAgentBuilder<P: LlmProvider + 'static> {
voters: Vec<Arc<AgentRunner<P>>>,
vote_extractor: Option<VoteExtractor>,
tie_breaker: Option<TieBreaker>,
}
impl<P: LlmProvider + 'static> VotingAgent<P> {
pub fn builder() -> VotingAgentBuilder<P> {
VotingAgentBuilder {
voters: Vec::new(),
vote_extractor: None,
tie_breaker: None,
}
}
pub async fn execute(&self, task: &str) -> Result<VoteResult, Error> {
let mut set = JoinSet::new();
for (idx, voter) in self.voters.iter().enumerate() {
let voter = Arc::clone(voter);
let task = task.to_string();
set.spawn(async move {
let result = voter.execute(&task).await;
(idx, result)
});
}
let mut outputs: Vec<(usize, AgentOutput)> = Vec::with_capacity(self.voters.len());
let mut total_usage = TokenUsage::default();
while let Some(join_result) = set.join_next().await {
let (idx, agent_result) = join_result
.map_err(|e| Error::Agent(format!("voting agent task panicked: {e}")))?;
let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
total_usage += output.tokens_used;
outputs.push((idx, output));
}
outputs.sort_by_key(|(idx, _)| *idx);
let votes: Vec<String> = outputs
.iter()
.map(|(_, output)| (self.vote_extractor)(&output.result))
.collect();
let mut tally: HashMap<String, usize> = HashMap::new();
for vote in &votes {
*tally.entry(vote.clone()).or_insert(0) += 1;
}
let max_count = tally.values().copied().max().unwrap_or(0);
let mut top_votes: Vec<String> = tally
.iter()
.filter(|&(_, &count)| count == max_count)
.map(|(vote, _)| vote.clone())
.collect();
top_votes.sort();
let winner = if top_votes.len() == 1 {
top_votes.into_iter().next().expect("at least one vote")
} else {
(self.tie_breaker)(&top_votes)
};
let winner_idx = votes
.iter()
.position(|v| *v == winner)
.expect("winner must be among votes");
let (_, mut winning_output) = outputs.remove(winner_idx);
let mut total_tool_calls = 0usize;
let mut total_cost: Option<f64> = None;
for (_, output) in &outputs {
total_tool_calls += output.tool_calls_made;
if let Some(cost) = output.estimated_cost_usd {
*total_cost.get_or_insert(0.0) += cost;
}
}
total_tool_calls += winning_output.tool_calls_made;
if let Some(cost) = winning_output.estimated_cost_usd {
*total_cost.get_or_insert(0.0) += cost;
}
winning_output.tokens_used = total_usage;
winning_output.tool_calls_made = total_tool_calls;
winning_output.estimated_cost_usd = total_cost;
Ok(VoteResult {
winner,
tally,
output: winning_output,
})
}
}
impl<P: LlmProvider + 'static> VotingAgentBuilder<P> {
pub fn voter(mut self, agent: AgentRunner<P>) -> Self {
self.voters.push(Arc::new(agent));
self
}
pub fn voters(mut self, agents: Vec<AgentRunner<P>>) -> Self {
self.voters.extend(agents.into_iter().map(Arc::new));
self
}
pub fn vote_extractor(mut self, f: impl Fn(&str) -> String + Send + Sync + 'static) -> Self {
self.vote_extractor = Some(Box::new(f));
self
}
pub fn tie_breaker(mut self, f: impl Fn(&[String]) -> String + Send + Sync + 'static) -> Self {
self.tie_breaker = Some(Box::new(f));
self
}
pub fn build(self) -> Result<VotingAgent<P>, Error> {
if self.voters.len() < 2 {
return Err(Error::Config(
"VotingAgent requires at least 2 voters".into(),
));
}
let vote_extractor = self
.vote_extractor
.ok_or_else(|| Error::Config("VotingAgent requires a vote_extractor".into()))?;
let tie_breaker = self.tie_breaker.unwrap_or_else(|| {
Box::new(|votes: &[String]| {
votes[0].clone()
})
});
Ok(VotingAgent {
voters: self.voters,
vote_extractor,
tie_breaker,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agent::test_helpers::{MockProvider, make_agent};
fn yes_no_extractor(output: &str) -> String {
if output.contains("YES") {
"YES".to_string()
} else {
"NO".to_string()
}
}
#[test]
fn builder_rejects_fewer_than_two_voters() {
let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let result = VotingAgent::builder()
.voter(make_agent(provider, "only-one"))
.vote_extractor(yes_no_extractor)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("at least 2"));
}
#[test]
fn builder_rejects_zero_voters() {
let result = VotingAgent::<MockProvider>::builder()
.vote_extractor(yes_no_extractor)
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("at least 2"));
}
#[test]
fn builder_rejects_missing_vote_extractor() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let result = VotingAgent::builder()
.voter(make_agent(p1, "a"))
.voter(make_agent(p2, "b"))
.build();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("vote_extractor"));
}
#[test]
fn builder_accepts_valid_config_without_tie_breaker() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO", 10, 5,
)]));
let result = VotingAgent::builder()
.voter(make_agent(p1, "a"))
.voter(make_agent(p2, "b"))
.vote_extractor(yes_no_extractor)
.build();
assert!(result.is_ok());
}
#[test]
fn builder_accepts_valid_config_with_tie_breaker() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO", 10, 5,
)]));
let result = VotingAgent::builder()
.voter(make_agent(p1, "a"))
.voter(make_agent(p2, "b"))
.vote_extractor(yes_no_extractor)
.tie_breaker(|votes| votes.last().unwrap().clone())
.build();
assert!(result.is_ok());
}
#[test]
fn builder_voters_bulk_method() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO", 10, 5,
)]));
let agents = vec![make_agent(p1, "a"), make_agent(p2, "b")];
let result = VotingAgent::builder()
.voters(agents)
.vote_extractor(yes_no_extractor)
.build();
assert!(result.is_ok());
}
#[tokio::test]
async fn unanimous_vote() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"I vote YES",
100,
50,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"Definitely YES",
200,
80,
)]));
let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES please",
150,
60,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.voter(make_agent(p3, "v3"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let result = voting.execute("should we?").await.unwrap();
assert_eq!(result.winner, "YES");
assert_eq!(result.tally["YES"], 3);
assert!(!result.tally.contains_key("NO"));
assert!(result.output.result.contains("YES"));
}
#[tokio::test]
async fn majority_vote_two_of_three() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"I say YES",
100,
50,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO way", 200, 80,
)]));
let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES definitely",
150,
60,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.voter(make_agent(p3, "v3"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let result = voting.execute("proceed?").await.unwrap();
assert_eq!(result.winner, "YES");
assert_eq!(result.tally["YES"], 2);
assert_eq!(result.tally["NO"], 1);
}
#[tokio::test]
async fn tie_broken_by_default_alphabetical() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO thanks",
100,
50,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES sure", 200, 80,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let result = voting.execute("tie?").await.unwrap();
assert_eq!(result.winner, "NO");
assert_eq!(result.tally["YES"], 1);
assert_eq!(result.tally["NO"], 1);
}
#[tokio::test]
async fn tie_broken_by_custom_tie_breaker() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO thanks",
100,
50,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES sure", 200, 80,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.vote_extractor(yes_no_extractor)
.tie_breaker(|votes| votes.last().unwrap().clone()) .build()
.unwrap();
let result = voting.execute("tie?").await.unwrap();
assert_eq!(result.winner, "YES");
}
#[tokio::test]
async fn token_usage_accumulated_across_all_voters() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 100, 50,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 200, 80,
)]));
let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 150, 60,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.voter(make_agent(p3, "v3"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let result = voting.execute("go").await.unwrap();
assert_eq!(result.output.tokens_used.input_tokens, 450);
assert_eq!(result.output.tokens_used.output_tokens, 190);
}
#[tokio::test]
async fn error_carries_partial_usage() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 100, 50,
)]));
let p2 = Arc::new(MockProvider::new(vec![]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "good"))
.voter(make_agent(p2, "bad"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let err = voting.execute("task").await.unwrap_err();
let partial = err.partial_usage();
assert!(
partial.input_tokens == 0 || partial.input_tokens >= 100,
"partial usage should be zero or include completed voter"
);
}
#[test]
fn debug_impl() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES", 10, 5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO", 10, 5,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "a"))
.voter(make_agent(p2, "b"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let debug = format!("{voting:?}");
assert!(debug.contains("VotingAgent"));
assert!(debug.contains("voter_count"));
assert!(debug.contains("2"));
}
#[tokio::test]
async fn vote_result_contains_correct_tally() {
let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES agree",
10,
5,
)]));
let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO disagree",
10,
5,
)]));
let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES concur",
10,
5,
)]));
let p4 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"NO object",
10,
5,
)]));
let p5 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
"YES absolutely",
10,
5,
)]));
let voting = VotingAgent::builder()
.voter(make_agent(p1, "v1"))
.voter(make_agent(p2, "v2"))
.voter(make_agent(p3, "v3"))
.voter(make_agent(p4, "v4"))
.voter(make_agent(p5, "v5"))
.vote_extractor(yes_no_extractor)
.build()
.unwrap();
let result = voting.execute("vote").await.unwrap();
assert_eq!(result.winner, "YES");
assert_eq!(result.tally.len(), 2);
assert_eq!(result.tally["YES"], 3);
assert_eq!(result.tally["NO"], 2);
}
}