1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{
3 CandidateBestResult, CandidateRoundVotes, ElectionResult, EmbeddedText, EmbeddedTextInput,
4 RoundSummary,
5};
6use rand::{Rng, SeedableRng};
7use rand_chacha::ChaCha8Rng;
8use std::cmp::Ordering;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct ElectionConfig {
14 pub report_last_k: usize,
16 pub elimination_fraction: f32,
18 pub random_seed: u64,
20}
21
22impl Default for ElectionConfig {
23 fn default() -> Self {
24 Self {
25 report_last_k: 10,
26 elimination_fraction: 0.05,
27 random_seed: 42,
28 }
29 }
30}
31
32pub fn run_election(
45 inputs: Vec<EmbeddedTextInput>,
46 config: ElectionConfig,
47) -> PopsamResult<ElectionResult> {
48 if inputs.is_empty() {
49 return Err(PopsamError::EmptyInput);
50 }
51 if config.report_last_k == 0 {
52 return Err(PopsamError::InvalidReportK);
53 }
54 if !(0.0 < config.elimination_fraction && config.elimination_fraction <= 1.0) {
55 return Err(PopsamError::InvalidEliminationFraction);
56 }
57
58 let dimension = inputs[0].embedding.len();
59 let mut embeddings = Vec::with_capacity(inputs.len());
60 for input in inputs {
61 if input.embedding.is_empty() {
62 return Err(PopsamError::EmptyEmbedding { id: input.id });
63 }
64 if input.embedding.len() != dimension {
65 return Err(PopsamError::DimensionMismatch {
66 id: input.id,
67 expected: dimension,
68 actual: input.embedding.len(),
69 });
70 }
71 let normalized = normalize(&input.id, input.embedding)?;
72 embeddings.push(EmbeddedText {
73 id: input.id,
74 text: input.text,
75 embedding: normalized,
76 });
77 }
78
79 let report_k = config.report_last_k.min(embeddings.len());
80 let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
81 let mut active: Vec<usize> = (0..embeddings.len()).collect();
82 let mut eliminated_ranked = Vec::with_capacity(embeddings.len().saturating_sub(1));
83 let mut rounds = Vec::new();
84 let mut representative_ids = Vec::new();
85 let mut candidate_best_results = vec![None; embeddings.len()];
86 let mut full_round_index = 0;
87
88 while active.len() > 1 {
89 full_round_index += 1;
90 let vote_rows = tally_votes(&embeddings, &active, &mut rng);
91 update_best_results(
92 &mut candidate_best_results,
93 &active,
94 &vote_rows,
95 full_round_index,
96 );
97 let eliminate_count =
98 elimination_count(active.len(), report_k, config.elimination_fraction);
99 let weakest = weakest_candidates(&vote_rows, eliminate_count, &mut rng);
100
101 if active.len() <= report_k {
102 if representative_ids.is_empty() {
103 representative_ids = active
104 .iter()
105 .map(|&idx| embeddings[idx].id.clone())
106 .collect();
107 }
108 rounds.push(build_round_summary(
109 &embeddings,
110 &active,
111 &vote_rows,
112 &weakest,
113 rounds.len() + 1,
114 ));
115 }
116
117 let mut removed_flags = vec![false; embeddings.len()];
118 for &idx in &weakest {
119 removed_flags[idx] = true;
120 eliminated_ranked.push(embeddings[idx].id.clone());
121 }
122 active.retain(|idx| !removed_flags[*idx]);
123 }
124
125 let winner_idx = active[0];
126 let winner_id = embeddings[winner_idx].id.clone();
127 if representative_ids.is_empty() {
128 representative_ids.push(winner_id.clone());
129 }
130 if report_k > 0 {
131 full_round_index += 1;
132 let final_votes = tally_votes(&embeddings, &active, &mut rng);
133 update_best_results(
134 &mut candidate_best_results,
135 &active,
136 &final_votes,
137 full_round_index,
138 );
139 rounds.push(build_round_summary(
140 &embeddings,
141 &active,
142 &final_votes,
143 &[],
144 rounds.len() + 1,
145 ));
146 }
147 let mut all_ranked_ids = vec![winner_id.clone()];
148 eliminated_ranked.reverse();
149 all_ranked_ids.extend(eliminated_ranked);
150 let candidate_best_results =
151 order_best_results(candidate_best_results, &embeddings, &all_ranked_ids);
152
153 Ok(ElectionResult {
154 winner_id,
155 representative_ids,
156 all_ranked_ids,
157 rounds,
158 candidate_best_results,
159 embeddings,
160 })
161}
162
163fn normalize(id: &str, embedding: Vec<f32>) -> PopsamResult<Vec<f32>> {
164 let norm = embedding
165 .iter()
166 .map(|value| value * value)
167 .sum::<f32>()
168 .sqrt();
169 if norm == 0.0 {
170 return Err(PopsamError::ZeroNorm { id: id.to_string() });
171 }
172 Ok(embedding.into_iter().map(|value| value / norm).collect())
173}
174
175fn elimination_count(active: usize, report_k: usize, elimination_fraction: f32) -> usize {
176 if active <= report_k {
177 return 1;
178 }
179 let max_allowed = active - report_k;
180 let batch = ((active as f32) * elimination_fraction).ceil() as usize;
181 batch.max(1).min(max_allowed)
182}
183
184fn tally_votes(
185 embeddings: &[EmbeddedText],
186 active: &[usize],
187 rng: &mut ChaCha8Rng,
188) -> Vec<(usize, CandidateRoundVotes)> {
189 let mut candidate_positions = vec![None; embeddings.len()];
190 for (position, &candidate_idx) in active.iter().enumerate() {
191 candidate_positions[candidate_idx] = Some(position);
192 }
193
194 let mut tallies: Vec<CandidateRoundVotes> = active
195 .iter()
196 .map(|&idx| CandidateRoundVotes {
197 id: embeddings[idx].id.clone(),
198 first_votes: 0,
199 second_votes: 0,
200 third_votes: 0,
201 })
202 .collect();
203
204 for voter in embeddings {
205 let mut ranked = active
206 .iter()
207 .map(|&candidate_idx| {
208 (
209 candidate_idx,
210 cosine_similarity(&voter.embedding, &embeddings[candidate_idx].embedding),
211 rng.random::<u64>(),
212 )
213 })
214 .collect::<Vec<_>>();
215 ranked.sort_by(|left, right| compare_similarity(right, left));
216
217 for (rank, (candidate_idx, _, _)) in ranked.into_iter().take(3).enumerate() {
218 let tally_index = candidate_positions[candidate_idx].expect("candidate must exist");
219 match rank {
220 0 => tallies[tally_index].first_votes += 1,
221 1 => tallies[tally_index].second_votes += 1,
222 2 => tallies[tally_index].third_votes += 1,
223 _ => {}
224 }
225 }
226 }
227
228 active.iter().copied().zip(tallies).collect()
229}
230
231fn weakest_candidates(
232 vote_rows: &[(usize, CandidateRoundVotes)],
233 eliminate_count: usize,
234 rng: &mut ChaCha8Rng,
235) -> Vec<usize> {
236 let mut ranked = vote_rows
237 .iter()
238 .map(|(idx, votes)| (*idx, votes.clone(), rng.random::<u64>()))
239 .collect::<Vec<_>>();
240 ranked.sort_by(|left, right| {
241 left.1
242 .first_votes
243 .cmp(&right.1.first_votes)
244 .then(left.1.second_votes.cmp(&right.1.second_votes))
245 .then(left.1.third_votes.cmp(&right.1.third_votes))
246 .then(left.2.cmp(&right.2))
247 });
248 ranked
249 .into_iter()
250 .take(eliminate_count)
251 .map(|entry| entry.0)
252 .collect()
253}
254
255fn build_round_summary(
256 embeddings: &[EmbeddedText],
257 active: &[usize],
258 vote_rows: &[(usize, CandidateRoundVotes)],
259 eliminated: &[usize],
260 round_index: usize,
261) -> RoundSummary {
262 let votes = sorted_votes(vote_rows);
263
264 RoundSummary {
265 round_index,
266 active_candidates: active.len(),
267 eliminated_candidate_ids: eliminated
268 .iter()
269 .map(|&idx| embeddings[idx].id.clone())
270 .collect(),
271 votes,
272 }
273}
274
275fn update_best_results(
276 candidate_best_results: &mut [Option<CandidateBestResult>],
277 active: &[usize],
278 vote_rows: &[(usize, CandidateRoundVotes)],
279 round_index: usize,
280) {
281 let active_candidates = active.len();
282 for (rank, (candidate_idx, votes)) in sorted_vote_rows(vote_rows).into_iter().enumerate() {
283 let result = CandidateBestResult {
284 id: votes.id.clone(),
285 full_round_index: round_index,
286 active_candidates,
287 rank: rank + 1,
288 first_votes: votes.first_votes,
289 second_votes: votes.second_votes,
290 third_votes: votes.third_votes,
291 };
292
293 let best = &mut candidate_best_results[candidate_idx];
294 if best
295 .as_ref()
296 .is_none_or(|current| is_better_result(&result, current))
297 {
298 *best = Some(result);
299 }
300 }
301}
302
303fn order_best_results(
304 candidate_best_results: Vec<Option<CandidateBestResult>>,
305 embeddings: &[EmbeddedText],
306 all_ranked_ids: &[String],
307) -> Vec<CandidateBestResult> {
308 let mut index_by_id = HashMap::with_capacity(embeddings.len());
309 for (idx, embedding) in embeddings.iter().enumerate() {
310 index_by_id.insert(embedding.id.as_str(), idx);
311 }
312
313 all_ranked_ids
314 .iter()
315 .map(|id| {
316 let idx = index_by_id
317 .get(id.as_str())
318 .expect("ranked candidate must exist in embeddings");
319 candidate_best_results[*idx]
320 .clone()
321 .expect("every candidate is active in at least one round")
322 })
323 .collect()
324}
325
326fn sorted_votes(vote_rows: &[(usize, CandidateRoundVotes)]) -> Vec<CandidateRoundVotes> {
327 sorted_vote_rows(vote_rows)
328 .into_iter()
329 .map(|(_, votes)| votes)
330 .collect()
331}
332
333fn sorted_vote_rows(
334 vote_rows: &[(usize, CandidateRoundVotes)],
335) -> Vec<(usize, CandidateRoundVotes)> {
336 let mut votes = vote_rows
337 .iter()
338 .map(|(idx, votes)| (*idx, votes.clone()))
339 .collect::<Vec<_>>();
340 votes.sort_by(|left, right| compare_round_votes(&left.1, &right.1));
341 votes
342}
343
344fn compare_round_votes(left: &CandidateRoundVotes, right: &CandidateRoundVotes) -> Ordering {
345 right
346 .first_votes
347 .cmp(&left.first_votes)
348 .then(right.second_votes.cmp(&left.second_votes))
349 .then(right.third_votes.cmp(&left.third_votes))
350 .then(left.id.cmp(&right.id))
351}
352
353fn is_better_result(candidate: &CandidateBestResult, current: &CandidateBestResult) -> bool {
354 candidate
355 .rank
356 .cmp(¤t.rank)
357 .reverse()
358 .then(candidate.first_votes.cmp(¤t.first_votes))
359 .then(candidate.second_votes.cmp(¤t.second_votes))
360 .then(candidate.third_votes.cmp(¤t.third_votes))
361 .then(candidate.active_candidates.cmp(¤t.active_candidates))
362 == Ordering::Greater
363}
364
365fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
366 left.iter().zip(right.iter()).map(|(a, b)| a * b).sum()
367}
368
369fn compare_similarity(left: &(usize, f32, u64), right: &(usize, f32, u64)) -> Ordering {
370 left.1
371 .total_cmp(&right.1)
372 .then_with(|| left.2.cmp(&right.2))
373}
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn keeps_last_k_rounds_and_winner() {
381 let inputs = vec![
382 EmbeddedTextInput {
383 id: "a".into(),
384 text: None,
385 embedding: vec![1.0, 0.0],
386 },
387 EmbeddedTextInput {
388 id: "b".into(),
389 text: None,
390 embedding: vec![0.9, 0.1],
391 },
392 EmbeddedTextInput {
393 id: "c".into(),
394 text: None,
395 embedding: vec![0.8, 0.2],
396 },
397 EmbeddedTextInput {
398 id: "d".into(),
399 text: None,
400 embedding: vec![0.0, 1.0],
401 },
402 ];
403
404 let result = run_election(
405 inputs,
406 ElectionConfig {
407 report_last_k: 3,
408 elimination_fraction: 0.25,
409 random_seed: 7,
410 },
411 )
412 .expect("election should succeed");
413
414 assert_eq!(result.representative_ids.len(), 3);
415 assert_eq!(result.rounds.len(), 3);
416 assert_eq!(result.rounds[0].active_candidates, 3);
417 assert_eq!(result.rounds[2].active_candidates, 1);
418 assert_eq!(result.winner_id, result.all_ranked_ids[0]);
419 assert_eq!(result.candidate_best_results.len(), 4);
420 assert!(result
421 .candidate_best_results
422 .iter()
423 .all(|result| result.rank >= 1 && result.rank <= result.active_candidates));
424 assert!(result
425 .candidate_best_results
426 .iter()
427 .any(|best| best.id == result.winner_id && best.rank == 1));
428 }
429
430 #[test]
431 fn rejects_dimension_mismatch() {
432 let inputs = vec![
433 EmbeddedTextInput {
434 id: "a".into(),
435 text: None,
436 embedding: vec![1.0, 0.0],
437 },
438 EmbeddedTextInput {
439 id: "b".into(),
440 text: None,
441 embedding: vec![1.0],
442 },
443 ];
444
445 let error = run_election(inputs, ElectionConfig::default()).expect_err("must fail");
446 assert!(matches!(error, PopsamError::DimensionMismatch { .. }));
447 }
448
449 #[test]
450 fn tracks_best_result_for_each_candidate_across_full_election() {
451 let inputs = vec![
452 EmbeddedTextInput {
453 id: "a".into(),
454 text: None,
455 embedding: vec![1.0, 0.0],
456 },
457 EmbeddedTextInput {
458 id: "b".into(),
459 text: None,
460 embedding: vec![0.9, 0.1],
461 },
462 EmbeddedTextInput {
463 id: "c".into(),
464 text: None,
465 embedding: vec![0.0, 1.0],
466 },
467 EmbeddedTextInput {
468 id: "d".into(),
469 text: None,
470 embedding: vec![0.1, 0.9],
471 },
472 EmbeddedTextInput {
473 id: "e".into(),
474 text: None,
475 embedding: vec![0.7, 0.3],
476 },
477 ];
478
479 let result = run_election(
480 inputs,
481 ElectionConfig {
482 report_last_k: 2,
483 elimination_fraction: 0.4,
484 random_seed: 11,
485 },
486 )
487 .expect("election should succeed");
488
489 assert_eq!(result.candidate_best_results.len(), 5);
490 assert_eq!(
491 result
492 .candidate_best_results
493 .iter()
494 .map(|best| best.id.as_str())
495 .collect::<Vec<_>>(),
496 result
497 .all_ranked_ids
498 .iter()
499 .map(String::as_str)
500 .collect::<Vec<_>>()
501 );
502 assert!(
503 result.rounds[0].round_index < result.candidate_best_results[0].full_round_index
504 || result
505 .candidate_best_results
506 .iter()
507 .any(|best| best.active_candidates > result.rounds[0].active_candidates)
508 );
509 assert!(result
510 .candidate_best_results
511 .iter()
512 .all(|best| best.full_round_index >= 1 && best.active_candidates >= 1));
513 }
514}