1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{CandidateRoundVotes, EmbeddedText, EmbeddedTextInput, ElectionResult, RoundSummary};
3use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use std::cmp::Ordering;
6
7#[derive(Debug, Clone)]
9pub struct ElectionConfig {
10 pub report_last_k: usize,
12 pub elimination_fraction: f32,
14 pub random_seed: u64,
16}
17
18impl Default for ElectionConfig {
19 fn default() -> Self {
20 Self {
21 report_last_k: 10,
22 elimination_fraction: 0.05,
23 random_seed: 42,
24 }
25 }
26}
27
28pub fn run_election(
40 inputs: Vec<EmbeddedTextInput>,
41 config: ElectionConfig,
42) -> PopsamResult<ElectionResult> {
43 if inputs.is_empty() {
44 return Err(PopsamError::EmptyInput);
45 }
46 if config.report_last_k == 0 {
47 return Err(PopsamError::InvalidReportK);
48 }
49 if !(0.0 < config.elimination_fraction && config.elimination_fraction <= 1.0) {
50 return Err(PopsamError::InvalidEliminationFraction);
51 }
52
53 let dimension = inputs[0].embedding.len();
54 let mut embeddings = Vec::with_capacity(inputs.len());
55 for input in inputs {
56 if input.embedding.is_empty() {
57 return Err(PopsamError::EmptyEmbedding { id: input.id });
58 }
59 if input.embedding.len() != dimension {
60 return Err(PopsamError::DimensionMismatch {
61 id: input.id,
62 expected: dimension,
63 actual: input.embedding.len(),
64 });
65 }
66 let normalized = normalize(&input.id, input.embedding)?;
67 embeddings.push(EmbeddedText {
68 id: input.id,
69 text: input.text,
70 embedding: normalized,
71 });
72 }
73
74 let report_k = config.report_last_k.min(embeddings.len());
75 let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
76 let mut active: Vec<usize> = (0..embeddings.len()).collect();
77 let mut eliminated_ranked = Vec::with_capacity(embeddings.len().saturating_sub(1));
78 let mut rounds = Vec::new();
79 let mut representative_ids = Vec::new();
80
81 while active.len() > 1 {
82 let vote_rows = tally_votes(&embeddings, &active, &mut rng);
83 let eliminate_count = elimination_count(active.len(), report_k, config.elimination_fraction);
84 let weakest = weakest_candidates(&vote_rows, eliminate_count, &mut rng);
85
86 if active.len() <= report_k {
87 if representative_ids.is_empty() {
88 representative_ids = active
89 .iter()
90 .map(|&idx| embeddings[idx].id.clone())
91 .collect();
92 }
93 rounds.push(build_round_summary(&embeddings, &active, &vote_rows, &weakest, rounds.len() + 1));
94 }
95
96 let mut removed_flags = vec![false; embeddings.len()];
97 for &idx in &weakest {
98 removed_flags[idx] = true;
99 eliminated_ranked.push(embeddings[idx].id.clone());
100 }
101 active.retain(|idx| !removed_flags[*idx]);
102 }
103
104 let winner_idx = active[0];
105 let winner_id = embeddings[winner_idx].id.clone();
106 if representative_ids.is_empty() {
107 representative_ids.push(winner_id.clone());
108 }
109 if report_k > 0 {
110 let final_votes = tally_votes(&embeddings, &active, &mut rng);
111 rounds.push(build_round_summary(
112 &embeddings,
113 &active,
114 &final_votes,
115 &[],
116 rounds.len() + 1,
117 ));
118 }
119 let mut all_ranked_ids = vec![winner_id.clone()];
120 eliminated_ranked.reverse();
121 all_ranked_ids.extend(eliminated_ranked);
122
123 Ok(ElectionResult {
124 winner_id,
125 representative_ids,
126 all_ranked_ids,
127 rounds,
128 embeddings,
129 })
130}
131
132fn normalize(id: &str, embedding: Vec<f32>) -> PopsamResult<Vec<f32>> {
133 let norm = embedding
134 .iter()
135 .map(|value| value * value)
136 .sum::<f32>()
137 .sqrt();
138 if norm == 0.0 {
139 return Err(PopsamError::ZeroNorm {
140 id: id.to_string(),
141 });
142 }
143 Ok(embedding.into_iter().map(|value| value / norm).collect())
144}
145
146fn elimination_count(active: usize, report_k: usize, elimination_fraction: f32) -> usize {
147 if active <= report_k {
148 return 1;
149 }
150 let max_allowed = active - report_k;
151 let batch = ((active as f32) * elimination_fraction).ceil() as usize;
152 batch.max(1).min(max_allowed)
153}
154
155fn tally_votes(
156 embeddings: &[EmbeddedText],
157 active: &[usize],
158 rng: &mut ChaCha8Rng,
159) -> Vec<(usize, CandidateRoundVotes)> {
160 let mut candidate_positions = vec![None; embeddings.len()];
161 for (position, &candidate_idx) in active.iter().enumerate() {
162 candidate_positions[candidate_idx] = Some(position);
163 }
164
165 let mut tallies: Vec<CandidateRoundVotes> = active
166 .iter()
167 .map(|&idx| CandidateRoundVotes {
168 id: embeddings[idx].id.clone(),
169 first_votes: 0,
170 second_votes: 0,
171 third_votes: 0,
172 })
173 .collect();
174
175 for voter in embeddings {
176 let mut ranked = active
177 .iter()
178 .map(|&candidate_idx| {
179 (
180 candidate_idx,
181 cosine_similarity(&voter.embedding, &embeddings[candidate_idx].embedding),
182 rng.random::<u64>(),
183 )
184 })
185 .collect::<Vec<_>>();
186 ranked.sort_by(|left, right| compare_similarity(right, left));
187
188 for (rank, (candidate_idx, _, _)) in ranked.into_iter().take(3).enumerate() {
189 let tally_index = candidate_positions[candidate_idx].expect("candidate must exist");
190 match rank {
191 0 => tallies[tally_index].first_votes += 1,
192 1 => tallies[tally_index].second_votes += 1,
193 2 => tallies[tally_index].third_votes += 1,
194 _ => {}
195 }
196 }
197 }
198
199 active.iter().copied().zip(tallies).collect()
200}
201
202fn weakest_candidates(
203 vote_rows: &[(usize, CandidateRoundVotes)],
204 eliminate_count: usize,
205 rng: &mut ChaCha8Rng,
206) -> Vec<usize> {
207 let mut ranked = vote_rows
208 .iter()
209 .map(|(idx, votes)| (*idx, votes.clone(), rng.random::<u64>()))
210 .collect::<Vec<_>>();
211 ranked.sort_by(|left, right| {
212 left.1
213 .first_votes
214 .cmp(&right.1.first_votes)
215 .then(left.1.second_votes.cmp(&right.1.second_votes))
216 .then(left.1.third_votes.cmp(&right.1.third_votes))
217 .then(left.2.cmp(&right.2))
218 });
219 ranked
220 .into_iter()
221 .take(eliminate_count)
222 .map(|entry| entry.0)
223 .collect()
224}
225
226fn build_round_summary(
227 embeddings: &[EmbeddedText],
228 active: &[usize],
229 vote_rows: &[(usize, CandidateRoundVotes)],
230 eliminated: &[usize],
231 round_index: usize,
232) -> RoundSummary {
233 let mut votes = vote_rows.iter().map(|(_, votes)| votes.clone()).collect::<Vec<_>>();
234 votes.sort_by(|left, right| {
235 right
236 .first_votes
237 .cmp(&left.first_votes)
238 .then(right.second_votes.cmp(&left.second_votes))
239 .then(right.third_votes.cmp(&left.third_votes))
240 .then(left.id.cmp(&right.id))
241 });
242
243 RoundSummary {
244 round_index,
245 active_candidates: active.len(),
246 eliminated_candidate_ids: eliminated
247 .iter()
248 .map(|&idx| embeddings[idx].id.clone())
249 .collect(),
250 votes,
251 }
252}
253
254fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
255 left.iter().zip(right.iter()).map(|(a, b)| a * b).sum()
256}
257
258fn compare_similarity(
259 left: &(usize, f32, u64),
260 right: &(usize, f32, u64),
261) -> Ordering {
262 left.1
263 .total_cmp(&right.1)
264 .then_with(|| left.2.cmp(&right.2))
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270
271 #[test]
272 fn keeps_last_k_rounds_and_winner() {
273 let inputs = vec![
274 EmbeddedTextInput { id: "a".into(), text: None, embedding: vec![1.0, 0.0] },
275 EmbeddedTextInput { id: "b".into(), text: None, embedding: vec![0.9, 0.1] },
276 EmbeddedTextInput { id: "c".into(), text: None, embedding: vec![0.8, 0.2] },
277 EmbeddedTextInput { id: "d".into(), text: None, embedding: vec![0.0, 1.0] },
278 ];
279
280 let result = run_election(
281 inputs,
282 ElectionConfig {
283 report_last_k: 3,
284 elimination_fraction: 0.25,
285 random_seed: 7,
286 },
287 )
288 .expect("election should succeed");
289
290 assert_eq!(result.representative_ids.len(), 3);
291 assert_eq!(result.rounds.len(), 3);
292 assert_eq!(result.rounds[0].active_candidates, 3);
293 assert_eq!(result.rounds[2].active_candidates, 1);
294 assert_eq!(result.winner_id, result.all_ranked_ids[0]);
295 }
296
297 #[test]
298 fn rejects_dimension_mismatch() {
299 let inputs = vec![
300 EmbeddedTextInput { id: "a".into(), text: None, embedding: vec![1.0, 0.0] },
301 EmbeddedTextInput { id: "b".into(), text: None, embedding: vec![1.0] },
302 ];
303
304 let error = run_election(inputs, ElectionConfig::default()).expect_err("must fail");
305 assert!(matches!(error, PopsamError::DimensionMismatch { .. }));
306 }
307}