1use std::collections::HashMap;
9use std::sync::Arc;
10
11use tokio::task::JoinSet;
12
13use crate::error::Error;
14use crate::llm::LlmProvider;
15use crate::llm::types::TokenUsage;
16
17use super::{AgentOutput, AgentRunner};
18
19type VoteExtractor = Box<dyn Fn(&str) -> String + Send + Sync>;
21
22type TieBreaker = Box<dyn Fn(&[String]) -> String + Send + Sync>;
25
26#[derive(Debug)]
29pub struct VoteResult {
30 pub winner: String,
32 pub tally: HashMap<String, usize>,
34 pub output: AgentOutput,
36}
37
38pub struct VotingAgent<P: LlmProvider + 'static> {
40 voters: Vec<Arc<AgentRunner<P>>>,
41 vote_extractor: VoteExtractor,
42 tie_breaker: TieBreaker,
43}
44
45impl<P: LlmProvider + 'static> std::fmt::Debug for VotingAgent<P> {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 f.debug_struct("VotingAgent")
48 .field("voter_count", &self.voters.len())
49 .finish()
50 }
51}
52
53pub struct VotingAgentBuilder<P: LlmProvider + 'static> {
55 voters: Vec<Arc<AgentRunner<P>>>,
56 vote_extractor: Option<VoteExtractor>,
57 tie_breaker: Option<TieBreaker>,
58}
59
60impl<P: LlmProvider + 'static> VotingAgent<P> {
61 pub fn builder() -> VotingAgentBuilder<P> {
63 VotingAgentBuilder {
64 voters: Vec::new(),
65 vote_extractor: None,
66 tie_breaker: None,
67 }
68 }
69
70 pub async fn execute(&self, task: &str) -> Result<VoteResult, Error> {
72 let mut set = JoinSet::new();
73
74 for (idx, voter) in self.voters.iter().enumerate() {
75 let voter = Arc::clone(voter);
76 let task = task.to_string();
77 set.spawn(async move {
78 let result = voter.execute(&task).await;
79 (idx, result)
80 });
81 }
82
83 let mut outputs: Vec<(usize, AgentOutput)> = Vec::with_capacity(self.voters.len());
86 let mut total_usage = TokenUsage::default();
87
88 while let Some(join_result) = set.join_next().await {
89 let (idx, agent_result) = join_result
90 .map_err(|e| Error::Agent(format!("voting agent task panicked: {e}")))?;
91 let output = agent_result.map_err(|e| e.accumulate_usage(total_usage))?;
92 total_usage += output.tokens_used;
93 outputs.push((idx, output));
94 }
95
96 outputs.sort_by_key(|(idx, _)| *idx);
98
99 let votes: Vec<String> = outputs
101 .iter()
102 .map(|(_, output)| (self.vote_extractor)(&output.result))
103 .collect();
104
105 let mut tally: HashMap<String, usize> = HashMap::new();
106 for vote in &votes {
107 *tally.entry(vote.clone()).or_insert(0) += 1;
108 }
109
110 let max_count = tally.values().copied().max().unwrap_or(0);
112
113 let mut top_votes: Vec<String> = tally
115 .iter()
116 .filter(|&(_, &count)| count == max_count)
117 .map(|(vote, _)| vote.clone())
118 .collect();
119 top_votes.sort();
120
121 let winner = if top_votes.len() == 1 {
122 top_votes.into_iter().next().expect("at least one vote")
123 } else {
124 (self.tie_breaker)(&top_votes)
125 };
126
127 let winner_idx = votes
129 .iter()
130 .position(|v| *v == winner)
131 .expect("winner must be among votes");
132
133 let (_, mut winning_output) = outputs.remove(winner_idx);
134
135 let mut total_tool_calls = 0usize;
138 let mut total_cost: Option<f64> = None;
139 for (_, output) in &outputs {
140 total_tool_calls += output.tool_calls_made;
141 if let Some(cost) = output.estimated_cost_usd {
142 *total_cost.get_or_insert(0.0) += cost;
143 }
144 }
145 total_tool_calls += winning_output.tool_calls_made;
146 if let Some(cost) = winning_output.estimated_cost_usd {
147 *total_cost.get_or_insert(0.0) += cost;
148 }
149
150 winning_output.tokens_used = total_usage;
151 winning_output.tool_calls_made = total_tool_calls;
152 winning_output.estimated_cost_usd = total_cost;
153
154 Ok(VoteResult {
155 winner,
156 tally,
157 output: winning_output,
158 })
159 }
160}
161
162impl<P: LlmProvider + 'static> VotingAgentBuilder<P> {
163 pub fn voter(mut self, agent: AgentRunner<P>) -> Self {
165 self.voters.push(Arc::new(agent));
166 self
167 }
168
169 pub fn voters(mut self, agents: Vec<AgentRunner<P>>) -> Self {
171 self.voters.extend(agents.into_iter().map(Arc::new));
172 self
173 }
174
175 pub fn vote_extractor(mut self, f: impl Fn(&str) -> String + Send + Sync + 'static) -> Self {
177 self.vote_extractor = Some(Box::new(f));
178 self
179 }
180
181 pub fn tie_breaker(mut self, f: impl Fn(&[String]) -> String + Send + Sync + 'static) -> Self {
183 self.tie_breaker = Some(Box::new(f));
184 self
185 }
186
187 pub fn build(self) -> Result<VotingAgent<P>, Error> {
189 if self.voters.len() < 2 {
190 return Err(Error::Config(
191 "VotingAgent requires at least 2 voters".into(),
192 ));
193 }
194 let vote_extractor = self
195 .vote_extractor
196 .ok_or_else(|| Error::Config("VotingAgent requires a vote_extractor".into()))?;
197 let tie_breaker = self.tie_breaker.unwrap_or_else(|| {
198 Box::new(|votes: &[String]| {
199 votes[0].clone()
201 })
202 });
203 Ok(VotingAgent {
204 voters: self.voters,
205 vote_extractor,
206 tie_breaker,
207 })
208 }
209}
210
211#[cfg(test)]
216mod tests {
217 use super::*;
218 use crate::agent::test_helpers::{MockProvider, make_agent};
219
220 fn yes_no_extractor(output: &str) -> String {
221 if output.contains("YES") {
222 "YES".to_string()
223 } else {
224 "NO".to_string()
225 }
226 }
227
228 #[test]
233 fn builder_rejects_fewer_than_two_voters() {
234 let provider = Arc::new(MockProvider::new(vec![MockProvider::text_response(
235 "YES", 10, 5,
236 )]));
237 let result = VotingAgent::builder()
238 .voter(make_agent(provider, "only-one"))
239 .vote_extractor(yes_no_extractor)
240 .build();
241 assert!(result.is_err());
242 assert!(result.unwrap_err().to_string().contains("at least 2"));
243 }
244
245 #[test]
246 fn builder_rejects_zero_voters() {
247 let result = VotingAgent::<MockProvider>::builder()
248 .vote_extractor(yes_no_extractor)
249 .build();
250 assert!(result.is_err());
251 assert!(result.unwrap_err().to_string().contains("at least 2"));
252 }
253
254 #[test]
255 fn builder_rejects_missing_vote_extractor() {
256 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
257 "YES", 10, 5,
258 )]));
259 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
260 "YES", 10, 5,
261 )]));
262 let result = VotingAgent::builder()
263 .voter(make_agent(p1, "a"))
264 .voter(make_agent(p2, "b"))
265 .build();
266 assert!(result.is_err());
267 assert!(result.unwrap_err().to_string().contains("vote_extractor"));
268 }
269
270 #[test]
271 fn builder_accepts_valid_config_without_tie_breaker() {
272 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
273 "YES", 10, 5,
274 )]));
275 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
276 "NO", 10, 5,
277 )]));
278 let result = VotingAgent::builder()
279 .voter(make_agent(p1, "a"))
280 .voter(make_agent(p2, "b"))
281 .vote_extractor(yes_no_extractor)
282 .build();
283 assert!(result.is_ok());
284 }
285
286 #[test]
287 fn builder_accepts_valid_config_with_tie_breaker() {
288 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
289 "YES", 10, 5,
290 )]));
291 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
292 "NO", 10, 5,
293 )]));
294 let result = VotingAgent::builder()
295 .voter(make_agent(p1, "a"))
296 .voter(make_agent(p2, "b"))
297 .vote_extractor(yes_no_extractor)
298 .tie_breaker(|votes| votes.last().unwrap().clone())
299 .build();
300 assert!(result.is_ok());
301 }
302
303 #[test]
308 fn builder_voters_bulk_method() {
309 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
310 "YES", 10, 5,
311 )]));
312 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
313 "NO", 10, 5,
314 )]));
315 let agents = vec![make_agent(p1, "a"), make_agent(p2, "b")];
316 let result = VotingAgent::builder()
317 .voters(agents)
318 .vote_extractor(yes_no_extractor)
319 .build();
320 assert!(result.is_ok());
321 }
322
323 #[tokio::test]
324 async fn unanimous_vote() {
325 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
326 "I vote YES",
327 100,
328 50,
329 )]));
330 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
331 "Definitely YES",
332 200,
333 80,
334 )]));
335 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
336 "YES please",
337 150,
338 60,
339 )]));
340
341 let voting = VotingAgent::builder()
342 .voter(make_agent(p1, "v1"))
343 .voter(make_agent(p2, "v2"))
344 .voter(make_agent(p3, "v3"))
345 .vote_extractor(yes_no_extractor)
346 .build()
347 .unwrap();
348
349 let result = voting.execute("should we?").await.unwrap();
350 assert_eq!(result.winner, "YES");
351 assert_eq!(result.tally["YES"], 3);
352 assert!(!result.tally.contains_key("NO"));
353 assert!(result.output.result.contains("YES"));
355 }
356
357 #[tokio::test]
358 async fn majority_vote_two_of_three() {
359 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
360 "I say YES",
361 100,
362 50,
363 )]));
364 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
365 "NO way", 200, 80,
366 )]));
367 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
368 "YES definitely",
369 150,
370 60,
371 )]));
372
373 let voting = VotingAgent::builder()
374 .voter(make_agent(p1, "v1"))
375 .voter(make_agent(p2, "v2"))
376 .voter(make_agent(p3, "v3"))
377 .vote_extractor(yes_no_extractor)
378 .build()
379 .unwrap();
380
381 let result = voting.execute("proceed?").await.unwrap();
382 assert_eq!(result.winner, "YES");
383 assert_eq!(result.tally["YES"], 2);
384 assert_eq!(result.tally["NO"], 1);
385 }
386
387 #[tokio::test]
388 async fn tie_broken_by_default_alphabetical() {
389 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
391 "NO thanks",
392 100,
393 50,
394 )]));
395 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
396 "YES sure", 200, 80,
397 )]));
398
399 let voting = VotingAgent::builder()
400 .voter(make_agent(p1, "v1"))
401 .voter(make_agent(p2, "v2"))
402 .vote_extractor(yes_no_extractor)
403 .build()
404 .unwrap();
405
406 let result = voting.execute("tie?").await.unwrap();
407 assert_eq!(result.winner, "NO");
409 assert_eq!(result.tally["YES"], 1);
410 assert_eq!(result.tally["NO"], 1);
411 }
412
413 #[tokio::test]
414 async fn tie_broken_by_custom_tie_breaker() {
415 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
416 "NO thanks",
417 100,
418 50,
419 )]));
420 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
421 "YES sure", 200, 80,
422 )]));
423
424 let voting = VotingAgent::builder()
425 .voter(make_agent(p1, "v1"))
426 .voter(make_agent(p2, "v2"))
427 .vote_extractor(yes_no_extractor)
428 .tie_breaker(|votes| votes.last().unwrap().clone()) .build()
430 .unwrap();
431
432 let result = voting.execute("tie?").await.unwrap();
433 assert_eq!(result.winner, "YES");
435 }
436
437 #[tokio::test]
438 async fn token_usage_accumulated_across_all_voters() {
439 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
440 "YES", 100, 50,
441 )]));
442 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
443 "YES", 200, 80,
444 )]));
445 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
446 "YES", 150, 60,
447 )]));
448
449 let voting = VotingAgent::builder()
450 .voter(make_agent(p1, "v1"))
451 .voter(make_agent(p2, "v2"))
452 .voter(make_agent(p3, "v3"))
453 .vote_extractor(yes_no_extractor)
454 .build()
455 .unwrap();
456
457 let result = voting.execute("go").await.unwrap();
458 assert_eq!(result.output.tokens_used.input_tokens, 450);
459 assert_eq!(result.output.tokens_used.output_tokens, 190);
460 }
461
462 #[tokio::test]
463 async fn error_carries_partial_usage() {
464 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
465 "YES", 100, 50,
466 )]));
467 let p2 = Arc::new(MockProvider::new(vec![]));
469
470 let voting = VotingAgent::builder()
471 .voter(make_agent(p1, "good"))
472 .voter(make_agent(p2, "bad"))
473 .vote_extractor(yes_no_extractor)
474 .build()
475 .unwrap();
476
477 let err = voting.execute("task").await.unwrap_err();
478 let partial = err.partial_usage();
479 assert!(
482 partial.input_tokens == 0 || partial.input_tokens >= 100,
483 "partial usage should be zero or include completed voter"
484 );
485 }
486
487 #[test]
488 fn debug_impl() {
489 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
490 "YES", 10, 5,
491 )]));
492 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
493 "NO", 10, 5,
494 )]));
495
496 let voting = VotingAgent::builder()
497 .voter(make_agent(p1, "a"))
498 .voter(make_agent(p2, "b"))
499 .vote_extractor(yes_no_extractor)
500 .build()
501 .unwrap();
502
503 let debug = format!("{voting:?}");
504 assert!(debug.contains("VotingAgent"));
505 assert!(debug.contains("voter_count"));
506 assert!(debug.contains("2"));
507 }
508
509 #[tokio::test]
510 async fn vote_result_contains_correct_tally() {
511 let p1 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
512 "YES agree",
513 10,
514 5,
515 )]));
516 let p2 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
517 "NO disagree",
518 10,
519 5,
520 )]));
521 let p3 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
522 "YES concur",
523 10,
524 5,
525 )]));
526 let p4 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
527 "NO object",
528 10,
529 5,
530 )]));
531 let p5 = Arc::new(MockProvider::new(vec![MockProvider::text_response(
532 "YES absolutely",
533 10,
534 5,
535 )]));
536
537 let voting = VotingAgent::builder()
538 .voter(make_agent(p1, "v1"))
539 .voter(make_agent(p2, "v2"))
540 .voter(make_agent(p3, "v3"))
541 .voter(make_agent(p4, "v4"))
542 .voter(make_agent(p5, "v5"))
543 .vote_extractor(yes_no_extractor)
544 .build()
545 .unwrap();
546
547 let result = voting.execute("vote").await.unwrap();
548 assert_eq!(result.winner, "YES");
549 assert_eq!(result.tally.len(), 2);
550 assert_eq!(result.tally["YES"], 3);
551 assert_eq!(result.tally["NO"], 2);
552 }
553}