use crate::error::MultiError;
use crate::mailbox::Mailbox;
use crate::runner::AgentRunner;
use crate::shared::SharedInfra;
use crate::types::{AgentOutput, AgentSpec};
use crate::patterns::swarm::{Swarm, SwarmMode};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use tracing::instrument;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VoteResult {
pub task: String,
pub votes: Vec<AgentOutput>,
pub winner: String,
pub agreement_ratio: f64,
}
pub struct Vote {
pub agents: Vec<AgentSpec>,
pub synthesizer: Option<AgentSpec>,
}
impl Vote {
pub fn new(agents: Vec<AgentSpec>) -> Self {
Self {
agents,
synthesizer: None,
}
}
pub fn with_synthesizer(mut self, spec: AgentSpec) -> Self {
self.synthesizer = Some(spec);
self
}
#[instrument(name = "multi.vote", skip_all)]
pub async fn run(
&self,
task: &str,
runner: &Arc<dyn AgentRunner>,
infra: &SharedInfra,
) -> Result<VoteResult, MultiError> {
let swarm = Swarm::new(self.agents.clone(), SwarmMode::Parallel);
let swarm_result = swarm.run(task, runner, infra).await?;
let votes = swarm_result.outputs;
let answers: Vec<AgentOutput> = votes.iter().filter(|o| o.succeeded()).cloned().collect();
if answers.is_empty() {
return Ok(VoteResult {
task: task.to_string(),
votes,
winner: String::new(),
agreement_ratio: 0.0,
});
}
if let Some(synth_spec) = &self.synthesizer {
let vote_summary: Vec<String> = answers
.iter()
.map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 200)))
.collect();
let synth_task = format!(
"Task: {}\n\nVotes:\n{}\n\nPick the best answer or synthesize a consensus.",
task,
vote_summary.join("\n")
);
let mailbox = Mailbox::default();
let rt = infra.make_runtime();
let winner = runner
.run(synth_spec, &synth_task, &rt, &mailbox)
.await
.map(|o| o.answer)
.unwrap_or_else(|_| answers[0].answer.clone());
return Ok(VoteResult {
task: task.to_string(),
votes,
winner,
agreement_ratio: 1.0,
});
}
let mut counter: HashMap<String, usize> = HashMap::new();
for a in &answers {
let normalized = a.answer.trim().to_lowercase();
*counter.entry(normalized).or_insert(0) += 1;
}
let (most_common, count) = counter
.iter()
.max_by_key(|(_, c)| *c)
.map(|(k, c)| (k.clone(), *c))
.unwrap_or_default();
let winner = answers
.iter()
.find(|a| a.answer.trim().to_lowercase() == most_common)
.map(|a| a.answer.clone())
.unwrap_or_default();
Ok(VoteResult {
task: task.to_string(),
votes,
winner,
agreement_ratio: count as f64 / answers.len() as f64,
})
}
}
fn truncate(s: &str, max_len: usize) -> &str {
if s.len() <= max_len {
return s;
}
let mut end = max_len;
while end > 0 && !s.is_char_boundary(end) {
end -= 1;
}
&s[..end]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{AgentOutput, AgentSpec};
use car_engine::Runtime;
struct FixedRunner;
#[async_trait::async_trait]
impl crate::runner::AgentRunner for FixedRunner {
async fn run(
&self,
spec: &AgentSpec,
_task: &str,
_runtime: &Runtime,
_mailbox: &Mailbox,
) -> Result<AgentOutput, MultiError> {
let answer = if spec.name.ends_with('1') {
"Paris is the capital"
} else {
"The capital is Paris"
};
Ok(AgentOutput {
name: spec.name.clone(),
answer: answer.to_string(),
turns: 1,
tool_calls: 0,
duration_ms: 5.0,
error: None,
outcome: None,
tokens: None,
})
}
}
#[tokio::test]
async fn test_vote_majority() {
let agents: Vec<AgentSpec> = (0..3)
.map(|i| AgentSpec::new(&format!("voter_{}", i), "Answer the question"))
.collect();
let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(FixedRunner);
let infra = SharedInfra::new();
let result = Vote::new(agents).run("What is the capital of France?", &runner, &infra).await.unwrap();
assert_eq!(result.votes.len(), 3);
assert_eq!(result.winner, "The capital is Paris");
assert!((result.agreement_ratio - 2.0 / 3.0).abs() < 0.01);
}
}