Skip to main content

car_multi/patterns/
vote.rs

1//! Vote — N agents answer independently, majority wins.
2//!
3//! For factual questions: normalized string matching picks the most common answer.
4//! For open-ended questions: an optional synthesizer picks the best.
5
6use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::runner::AgentRunner;
9use crate::shared::SharedInfra;
10use crate::types::{AgentOutput, AgentSpec};
11use crate::patterns::swarm::{Swarm, SwarmMode};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct VoteResult {
18    pub task: String,
19    pub votes: Vec<AgentOutput>,
20    pub winner: String,
21    pub agreement_ratio: f64,
22}
23
24pub struct Vote {
25    pub agents: Vec<AgentSpec>,
26    pub synthesizer: Option<AgentSpec>,
27}
28
29impl Vote {
30    pub fn new(agents: Vec<AgentSpec>) -> Self {
31        Self {
32            agents,
33            synthesizer: None,
34        }
35    }
36
37    pub fn with_synthesizer(mut self, spec: AgentSpec) -> Self {
38        self.synthesizer = Some(spec);
39        self
40    }
41
42    pub async fn run(
43        &self,
44        task: &str,
45        runner: &Arc<dyn AgentRunner>,
46        infra: &SharedInfra,
47    ) -> Result<VoteResult, MultiError> {
48        // All agents answer in parallel
49        let swarm = Swarm::new(self.agents.clone(), SwarmMode::Parallel);
50        let swarm_result = swarm.run(task, runner, infra).await?;
51
52        let votes = swarm_result.outputs;
53        let answers: Vec<AgentOutput> = votes.iter().filter(|o| o.succeeded()).cloned().collect();
54
55        if answers.is_empty() {
56            return Ok(VoteResult {
57                task: task.to_string(),
58                votes,
59                winner: String::new(),
60                agreement_ratio: 0.0,
61            });
62        }
63
64        if let Some(synth_spec) = &self.synthesizer {
65            let vote_summary: Vec<String> = answers
66                .iter()
67                .map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 200)))
68                .collect();
69
70            let synth_task = format!(
71                "Task: {}\n\nVotes:\n{}\n\nPick the best answer or synthesize a consensus.",
72                task,
73                vote_summary.join("\n")
74            );
75
76            let mailbox = Mailbox::default();
77            let rt = infra.make_runtime();
78            let winner = runner
79                .run(synth_spec, &synth_task, &rt, &mailbox)
80                .await
81                .map(|o| o.answer)
82                .unwrap_or_else(|_| answers[0].answer.clone());
83
84            return Ok(VoteResult {
85                task: task.to_string(),
86                votes,
87                winner,
88                agreement_ratio: 1.0,
89            });
90        }
91
92        // Simple majority: pick the most common answer (normalized)
93        let mut counter: HashMap<String, usize> = HashMap::new();
94        for a in &answers {
95            let normalized = a.answer.trim().to_lowercase();
96            *counter.entry(normalized).or_insert(0) += 1;
97        }
98
99        let (most_common, count) = counter
100            .iter()
101            .max_by_key(|(_, c)| *c)
102            .map(|(k, c)| (k.clone(), *c))
103            .unwrap_or_default();
104
105        // Find the original (un-normalized) answer
106        let winner = answers
107            .iter()
108            .find(|a| a.answer.trim().to_lowercase() == most_common)
109            .map(|a| a.answer.clone())
110            .unwrap_or_default();
111
112        Ok(VoteResult {
113            task: task.to_string(),
114            votes,
115            winner,
116            agreement_ratio: count as f64 / answers.len() as f64,
117        })
118    }
119}
120
121fn truncate(s: &str, max_len: usize) -> &str {
122    if s.len() <= max_len {
123        return s;
124    }
125    let mut end = max_len;
126    while end > 0 && !s.is_char_boundary(end) {
127        end -= 1;
128    }
129    &s[..end]
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::types::{AgentOutput, AgentSpec};
136    use car_engine::Runtime;
137
138    struct FixedRunner;
139
140    #[async_trait::async_trait]
141    impl crate::runner::AgentRunner for FixedRunner {
142        async fn run(
143            &self,
144            spec: &AgentSpec,
145            _task: &str,
146            _runtime: &Runtime,
147            _mailbox: &Mailbox,
148        ) -> Result<AgentOutput, MultiError> {
149            // Agents 0 and 2 agree, agent 1 disagrees
150            let answer = if spec.name.ends_with('1') {
151                "Paris is the capital"
152            } else {
153                "The capital is Paris"
154            };
155            Ok(AgentOutput {
156                name: spec.name.clone(),
157                answer: answer.to_string(),
158                turns: 1,
159                tool_calls: 0,
160                duration_ms: 5.0,
161                error: None,
162            })
163        }
164    }
165
166    #[tokio::test]
167    async fn test_vote_majority() {
168        let agents: Vec<AgentSpec> = (0..3)
169            .map(|i| AgentSpec::new(&format!("voter_{}", i), "Answer the question"))
170            .collect();
171
172        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(FixedRunner);
173        let infra = SharedInfra::new();
174
175        let result = Vote::new(agents).run("What is the capital of France?", &runner, &infra).await.unwrap();
176
177        assert_eq!(result.votes.len(), 3);
178        // 2/3 agree on "The capital is Paris"
179        assert_eq!(result.winner, "The capital is Paris");
180        assert!((result.agreement_ratio - 2.0 / 3.0).abs() < 0.01);
181    }
182}