Skip to main content

car_multi/patterns/
vote.rs

1//! Vote — N agents answer independently, majority wins.
2//!
3//! For factual questions: normalized string matching picks the most common answer.
4//! For open-ended questions: an optional synthesizer picks the best.
5
6use crate::error::MultiError;
7use crate::mailbox::Mailbox;
8use crate::patterns::swarm::{Swarm, SwarmMode};
9use crate::runner::AgentRunner;
10use crate::shared::SharedInfra;
11use crate::types::{AgentOutput, AgentSpec};
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::sync::Arc;
15use tracing::instrument;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct VoteResult {
19    pub task: String,
20    pub votes: Vec<AgentOutput>,
21    pub winner: String,
22    pub agreement_ratio: f64,
23}
24
25pub struct Vote {
26    pub agents: Vec<AgentSpec>,
27    pub synthesizer: Option<AgentSpec>,
28}
29
30impl Vote {
31    pub fn new(agents: Vec<AgentSpec>) -> Self {
32        Self {
33            agents,
34            synthesizer: None,
35        }
36    }
37
38    pub fn with_synthesizer(mut self, spec: AgentSpec) -> Self {
39        self.synthesizer = Some(spec);
40        self
41    }
42
43    #[instrument(name = "multi.vote", skip_all)]
44    pub async fn run(
45        &self,
46        task: &str,
47        runner: &Arc<dyn AgentRunner>,
48        infra: &SharedInfra,
49    ) -> Result<VoteResult, MultiError> {
50        // All agents answer in parallel
51        let swarm = Swarm::new(self.agents.clone(), SwarmMode::Parallel);
52        let swarm_result = swarm.run(task, runner, infra).await?;
53
54        let votes = swarm_result.outputs;
55        let answers: Vec<AgentOutput> = votes.iter().filter(|o| o.succeeded()).cloned().collect();
56
57        if answers.is_empty() {
58            return Ok(VoteResult {
59                task: task.to_string(),
60                votes,
61                winner: String::new(),
62                agreement_ratio: 0.0,
63            });
64        }
65
66        if let Some(synth_spec) = &self.synthesizer {
67            let vote_summary: Vec<String> = answers
68                .iter()
69                .map(|o| format!("- {}: {}", o.name, truncate(&o.answer, 200)))
70                .collect();
71
72            let synth_task = format!(
73                "Task: {}\n\nVotes:\n{}\n\nPick the best answer or synthesize a consensus.",
74                task,
75                vote_summary.join("\n")
76            );
77
78            // Gate the synthesizer on the budget; on denial fall back to the
79            // first successful vote rather than failing.
80            let winner = if infra.begin_agent().is_ok() {
81                let mailbox = Mailbox::default();
82                let rt = infra.make_runtime();
83                match runner.run(synth_spec, &synth_task, &rt, &mailbox).await {
84                    Ok(o) => {
85                        infra.record_output(&o);
86                        o.answer
87                    }
88                    Err(_) => answers[0].answer.clone(),
89                }
90            } else {
91                answers[0].answer.clone()
92            };
93
94            return Ok(VoteResult {
95                task: task.to_string(),
96                votes,
97                winner,
98                agreement_ratio: 1.0,
99            });
100        }
101
102        // Simple majority: pick the most common answer (normalized)
103        let mut counter: HashMap<String, usize> = HashMap::new();
104        for a in &answers {
105            let normalized = a.answer.trim().to_lowercase();
106            *counter.entry(normalized).or_insert(0) += 1;
107        }
108
109        let (most_common, count) = counter
110            .iter()
111            .max_by_key(|(_, c)| *c)
112            .map(|(k, c)| (k.clone(), *c))
113            .unwrap_or_default();
114
115        // Find the original (un-normalized) answer
116        let winner = answers
117            .iter()
118            .find(|a| a.answer.trim().to_lowercase() == most_common)
119            .map(|a| a.answer.clone())
120            .unwrap_or_default();
121
122        Ok(VoteResult {
123            task: task.to_string(),
124            votes,
125            winner,
126            agreement_ratio: count as f64 / answers.len() as f64,
127        })
128    }
129}
130
131fn truncate(s: &str, max_len: usize) -> &str {
132    if s.len() <= max_len {
133        return s;
134    }
135    let mut end = max_len;
136    while end > 0 && !s.is_char_boundary(end) {
137        end -= 1;
138    }
139    &s[..end]
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use crate::types::{AgentOutput, AgentSpec};
146    use car_engine::Runtime;
147
148    struct FixedRunner;
149
150    #[async_trait::async_trait]
151    impl crate::runner::AgentRunner for FixedRunner {
152        async fn run(
153            &self,
154            spec: &AgentSpec,
155            _task: &str,
156            _runtime: &Runtime,
157            _mailbox: &Mailbox,
158        ) -> Result<AgentOutput, MultiError> {
159            // Agents 0 and 2 agree, agent 1 disagrees
160            let answer = if spec.name.ends_with('1') {
161                "Paris is the capital"
162            } else {
163                "The capital is Paris"
164            };
165            Ok(AgentOutput {
166                name: spec.name.clone(),
167                answer: answer.to_string(),
168                turns: 1,
169                tool_calls: 0,
170                duration_ms: 5.0,
171                error: None,
172                outcome: None,
173                tokens: None,
174            })
175        }
176    }
177
178    #[tokio::test]
179    async fn test_vote_majority() {
180        let agents: Vec<AgentSpec> = (0..3)
181            .map(|i| AgentSpec::new(&format!("voter_{}", i), "Answer the question"))
182            .collect();
183
184        let runner: Arc<dyn crate::runner::AgentRunner> = Arc::new(FixedRunner);
185        let infra = SharedInfra::new();
186
187        let result = Vote::new(agents)
188            .run("What is the capital of France?", &runner, &infra)
189            .await
190            .unwrap();
191
192        assert_eq!(result.votes.len(), 3);
193        // 2/3 agree on "The capital is Paris"
194        assert_eq!(result.winner, "The capital is Paris");
195        assert!((result.agreement_ratio - 2.0 / 3.0).abs() < 0.01);
196    }
197}