Skip to main content

popsam_core/
election.rs

1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{
3    CandidateBestResult, CandidateRoundVotes, ElectionResult, EmbeddedText, EmbeddedTextInput,
4    RoundSummary,
5};
6use rand::{Rng, SeedableRng};
7use rand_chacha::ChaCha8Rng;
8use std::cmp::Ordering;
9use std::collections::HashMap;
10
11/// Configuration for the representative-text election.
12#[derive(Debug, Clone)]
13pub struct ElectionConfig {
14    /// Number of final candidates to keep and report.
15    pub report_last_k: usize,
16    /// Fraction of active candidates to eliminate in early rounds.
17    pub elimination_fraction: f32,
18    /// Seed used for all random tie-breaks.
19    pub random_seed: u64,
20}
21
22impl Default for ElectionConfig {
23    fn default() -> Self {
24        Self {
25            report_last_k: 10,
26            elimination_fraction: 0.05,
27            random_seed: 42,
28        }
29    }
30}
31
32/// Runs the elimination process on a collection of embedded texts.
33///
34/// The function normalizes all embeddings, repeatedly assigns first/second/third
35/// preference votes by cosine similarity, and eliminates the weakest candidates
36/// until a single winner remains.
37///
38/// The returned [`ElectionResult`] includes:
39/// - the final winner
40/// - the last `k` surviving candidates
41/// - round-by-round vote totals for the reported suffix
42/// - each candidate's best result across the full election
43/// - normalized embeddings for all processed inputs
44pub fn run_election(
45    inputs: Vec<EmbeddedTextInput>,
46    config: ElectionConfig,
47) -> PopsamResult<ElectionResult> {
48    if inputs.is_empty() {
49        return Err(PopsamError::EmptyInput);
50    }
51    if config.report_last_k == 0 {
52        return Err(PopsamError::InvalidReportK);
53    }
54    if !(0.0 < config.elimination_fraction && config.elimination_fraction <= 1.0) {
55        return Err(PopsamError::InvalidEliminationFraction);
56    }
57
58    let dimension = inputs[0].embedding.len();
59    let mut embeddings = Vec::with_capacity(inputs.len());
60    for input in inputs {
61        if input.embedding.is_empty() {
62            return Err(PopsamError::EmptyEmbedding { id: input.id });
63        }
64        if input.embedding.len() != dimension {
65            return Err(PopsamError::DimensionMismatch {
66                id: input.id,
67                expected: dimension,
68                actual: input.embedding.len(),
69            });
70        }
71        let normalized = normalize(&input.id, input.embedding)?;
72        embeddings.push(EmbeddedText {
73            id: input.id,
74            text: input.text,
75            embedding: normalized,
76        });
77    }
78
79    let report_k = config.report_last_k.min(embeddings.len());
80    let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
81    let mut active: Vec<usize> = (0..embeddings.len()).collect();
82    let mut eliminated_ranked = Vec::with_capacity(embeddings.len().saturating_sub(1));
83    let mut rounds = Vec::new();
84    let mut representative_ids = Vec::new();
85    let mut candidate_best_results = vec![None; embeddings.len()];
86    let mut full_round_index = 0;
87
88    while active.len() > 1 {
89        full_round_index += 1;
90        let vote_rows = tally_votes(&embeddings, &active, &mut rng);
91        update_best_results(
92            &mut candidate_best_results,
93            &active,
94            &vote_rows,
95            full_round_index,
96        );
97        let eliminate_count =
98            elimination_count(active.len(), report_k, config.elimination_fraction);
99        let weakest = weakest_candidates(&vote_rows, eliminate_count, &mut rng);
100
101        if active.len() <= report_k {
102            if representative_ids.is_empty() {
103                representative_ids = active
104                    .iter()
105                    .map(|&idx| embeddings[idx].id.clone())
106                    .collect();
107            }
108            rounds.push(build_round_summary(
109                &embeddings,
110                &active,
111                &vote_rows,
112                &weakest,
113                rounds.len() + 1,
114            ));
115        }
116
117        let mut removed_flags = vec![false; embeddings.len()];
118        for &idx in &weakest {
119            removed_flags[idx] = true;
120            eliminated_ranked.push(embeddings[idx].id.clone());
121        }
122        active.retain(|idx| !removed_flags[*idx]);
123    }
124
125    let winner_idx = active[0];
126    let winner_id = embeddings[winner_idx].id.clone();
127    if representative_ids.is_empty() {
128        representative_ids.push(winner_id.clone());
129    }
130    if report_k > 0 {
131        full_round_index += 1;
132        let final_votes = tally_votes(&embeddings, &active, &mut rng);
133        update_best_results(
134            &mut candidate_best_results,
135            &active,
136            &final_votes,
137            full_round_index,
138        );
139        rounds.push(build_round_summary(
140            &embeddings,
141            &active,
142            &final_votes,
143            &[],
144            rounds.len() + 1,
145        ));
146    }
147    let mut all_ranked_ids = vec![winner_id.clone()];
148    eliminated_ranked.reverse();
149    all_ranked_ids.extend(eliminated_ranked);
150    let candidate_best_results =
151        order_best_results(candidate_best_results, &embeddings, &all_ranked_ids);
152
153    Ok(ElectionResult {
154        winner_id,
155        representative_ids,
156        all_ranked_ids,
157        rounds,
158        candidate_best_results,
159        embeddings,
160    })
161}
162
163fn normalize(id: &str, embedding: Vec<f32>) -> PopsamResult<Vec<f32>> {
164    let norm = embedding
165        .iter()
166        .map(|value| value * value)
167        .sum::<f32>()
168        .sqrt();
169    if norm == 0.0 {
170        return Err(PopsamError::ZeroNorm { id: id.to_string() });
171    }
172    Ok(embedding.into_iter().map(|value| value / norm).collect())
173}
174
175fn elimination_count(active: usize, report_k: usize, elimination_fraction: f32) -> usize {
176    if active <= report_k {
177        return 1;
178    }
179    let max_allowed = active - report_k;
180    let batch = ((active as f32) * elimination_fraction).ceil() as usize;
181    batch.max(1).min(max_allowed)
182}
183
184fn tally_votes(
185    embeddings: &[EmbeddedText],
186    active: &[usize],
187    rng: &mut ChaCha8Rng,
188) -> Vec<(usize, CandidateRoundVotes)> {
189    let mut candidate_positions = vec![None; embeddings.len()];
190    for (position, &candidate_idx) in active.iter().enumerate() {
191        candidate_positions[candidate_idx] = Some(position);
192    }
193
194    let mut tallies: Vec<CandidateRoundVotes> = active
195        .iter()
196        .map(|&idx| CandidateRoundVotes {
197            id: embeddings[idx].id.clone(),
198            first_votes: 0,
199            second_votes: 0,
200            third_votes: 0,
201        })
202        .collect();
203
204    for voter in embeddings {
205        let mut ranked = active
206            .iter()
207            .map(|&candidate_idx| {
208                (
209                    candidate_idx,
210                    cosine_similarity(&voter.embedding, &embeddings[candidate_idx].embedding),
211                    rng.random::<u64>(),
212                )
213            })
214            .collect::<Vec<_>>();
215        ranked.sort_by(|left, right| compare_similarity(right, left));
216
217        for (rank, (candidate_idx, _, _)) in ranked.into_iter().take(3).enumerate() {
218            let tally_index = candidate_positions[candidate_idx].expect("candidate must exist");
219            match rank {
220                0 => tallies[tally_index].first_votes += 1,
221                1 => tallies[tally_index].second_votes += 1,
222                2 => tallies[tally_index].third_votes += 1,
223                _ => {}
224            }
225        }
226    }
227
228    active.iter().copied().zip(tallies).collect()
229}
230
231fn weakest_candidates(
232    vote_rows: &[(usize, CandidateRoundVotes)],
233    eliminate_count: usize,
234    rng: &mut ChaCha8Rng,
235) -> Vec<usize> {
236    let mut ranked = vote_rows
237        .iter()
238        .map(|(idx, votes)| (*idx, votes.clone(), rng.random::<u64>()))
239        .collect::<Vec<_>>();
240    ranked.sort_by(|left, right| {
241        left.1
242            .first_votes
243            .cmp(&right.1.first_votes)
244            .then(left.1.second_votes.cmp(&right.1.second_votes))
245            .then(left.1.third_votes.cmp(&right.1.third_votes))
246            .then(left.2.cmp(&right.2))
247    });
248    ranked
249        .into_iter()
250        .take(eliminate_count)
251        .map(|entry| entry.0)
252        .collect()
253}
254
255fn build_round_summary(
256    embeddings: &[EmbeddedText],
257    active: &[usize],
258    vote_rows: &[(usize, CandidateRoundVotes)],
259    eliminated: &[usize],
260    round_index: usize,
261) -> RoundSummary {
262    let votes = sorted_votes(vote_rows);
263
264    RoundSummary {
265        round_index,
266        active_candidates: active.len(),
267        eliminated_candidate_ids: eliminated
268            .iter()
269            .map(|&idx| embeddings[idx].id.clone())
270            .collect(),
271        votes,
272    }
273}
274
275fn update_best_results(
276    candidate_best_results: &mut [Option<CandidateBestResult>],
277    active: &[usize],
278    vote_rows: &[(usize, CandidateRoundVotes)],
279    round_index: usize,
280) {
281    let active_candidates = active.len();
282    for (rank, (candidate_idx, votes)) in sorted_vote_rows(vote_rows).into_iter().enumerate() {
283        let result = CandidateBestResult {
284            id: votes.id.clone(),
285            full_round_index: round_index,
286            active_candidates,
287            rank: rank + 1,
288            first_votes: votes.first_votes,
289            second_votes: votes.second_votes,
290            third_votes: votes.third_votes,
291        };
292
293        let best = &mut candidate_best_results[candidate_idx];
294        if best
295            .as_ref()
296            .is_none_or(|current| is_better_result(&result, current))
297        {
298            *best = Some(result);
299        }
300    }
301}
302
303fn order_best_results(
304    candidate_best_results: Vec<Option<CandidateBestResult>>,
305    embeddings: &[EmbeddedText],
306    all_ranked_ids: &[String],
307) -> Vec<CandidateBestResult> {
308    let mut index_by_id = HashMap::with_capacity(embeddings.len());
309    for (idx, embedding) in embeddings.iter().enumerate() {
310        index_by_id.insert(embedding.id.as_str(), idx);
311    }
312
313    all_ranked_ids
314        .iter()
315        .map(|id| {
316            let idx = index_by_id
317                .get(id.as_str())
318                .expect("ranked candidate must exist in embeddings");
319            candidate_best_results[*idx]
320                .clone()
321                .expect("every candidate is active in at least one round")
322        })
323        .collect()
324}
325
326fn sorted_votes(vote_rows: &[(usize, CandidateRoundVotes)]) -> Vec<CandidateRoundVotes> {
327    sorted_vote_rows(vote_rows)
328        .into_iter()
329        .map(|(_, votes)| votes)
330        .collect()
331}
332
333fn sorted_vote_rows(
334    vote_rows: &[(usize, CandidateRoundVotes)],
335) -> Vec<(usize, CandidateRoundVotes)> {
336    let mut votes = vote_rows
337        .iter()
338        .map(|(idx, votes)| (*idx, votes.clone()))
339        .collect::<Vec<_>>();
340    votes.sort_by(|left, right| compare_round_votes(&left.1, &right.1));
341    votes
342}
343
344fn compare_round_votes(left: &CandidateRoundVotes, right: &CandidateRoundVotes) -> Ordering {
345    right
346        .first_votes
347        .cmp(&left.first_votes)
348        .then(right.second_votes.cmp(&left.second_votes))
349        .then(right.third_votes.cmp(&left.third_votes))
350        .then(left.id.cmp(&right.id))
351}
352
353fn is_better_result(candidate: &CandidateBestResult, current: &CandidateBestResult) -> bool {
354    candidate
355        .rank
356        .cmp(&current.rank)
357        .reverse()
358        .then(candidate.first_votes.cmp(&current.first_votes))
359        .then(candidate.second_votes.cmp(&current.second_votes))
360        .then(candidate.third_votes.cmp(&current.third_votes))
361        .then(candidate.active_candidates.cmp(&current.active_candidates))
362        == Ordering::Greater
363}
364
365fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
366    left.iter().zip(right.iter()).map(|(a, b)| a * b).sum()
367}
368
369fn compare_similarity(left: &(usize, f32, u64), right: &(usize, f32, u64)) -> Ordering {
370    left.1
371        .total_cmp(&right.1)
372        .then_with(|| left.2.cmp(&right.2))
373}
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn keeps_last_k_rounds_and_winner() {
381        let inputs = vec![
382            EmbeddedTextInput {
383                id: "a".into(),
384                text: None,
385                embedding: vec![1.0, 0.0],
386            },
387            EmbeddedTextInput {
388                id: "b".into(),
389                text: None,
390                embedding: vec![0.9, 0.1],
391            },
392            EmbeddedTextInput {
393                id: "c".into(),
394                text: None,
395                embedding: vec![0.8, 0.2],
396            },
397            EmbeddedTextInput {
398                id: "d".into(),
399                text: None,
400                embedding: vec![0.0, 1.0],
401            },
402        ];
403
404        let result = run_election(
405            inputs,
406            ElectionConfig {
407                report_last_k: 3,
408                elimination_fraction: 0.25,
409                random_seed: 7,
410            },
411        )
412        .expect("election should succeed");
413
414        assert_eq!(result.representative_ids.len(), 3);
415        assert_eq!(result.rounds.len(), 3);
416        assert_eq!(result.rounds[0].active_candidates, 3);
417        assert_eq!(result.rounds[2].active_candidates, 1);
418        assert_eq!(result.winner_id, result.all_ranked_ids[0]);
419        assert_eq!(result.candidate_best_results.len(), 4);
420        assert!(result
421            .candidate_best_results
422            .iter()
423            .all(|result| result.rank >= 1 && result.rank <= result.active_candidates));
424        assert!(result
425            .candidate_best_results
426            .iter()
427            .any(|best| best.id == result.winner_id && best.rank == 1));
428    }
429
430    #[test]
431    fn rejects_dimension_mismatch() {
432        let inputs = vec![
433            EmbeddedTextInput {
434                id: "a".into(),
435                text: None,
436                embedding: vec![1.0, 0.0],
437            },
438            EmbeddedTextInput {
439                id: "b".into(),
440                text: None,
441                embedding: vec![1.0],
442            },
443        ];
444
445        let error = run_election(inputs, ElectionConfig::default()).expect_err("must fail");
446        assert!(matches!(error, PopsamError::DimensionMismatch { .. }));
447    }
448
449    #[test]
450    fn tracks_best_result_for_each_candidate_across_full_election() {
451        let inputs = vec![
452            EmbeddedTextInput {
453                id: "a".into(),
454                text: None,
455                embedding: vec![1.0, 0.0],
456            },
457            EmbeddedTextInput {
458                id: "b".into(),
459                text: None,
460                embedding: vec![0.9, 0.1],
461            },
462            EmbeddedTextInput {
463                id: "c".into(),
464                text: None,
465                embedding: vec![0.0, 1.0],
466            },
467            EmbeddedTextInput {
468                id: "d".into(),
469                text: None,
470                embedding: vec![0.1, 0.9],
471            },
472            EmbeddedTextInput {
473                id: "e".into(),
474                text: None,
475                embedding: vec![0.7, 0.3],
476            },
477        ];
478
479        let result = run_election(
480            inputs,
481            ElectionConfig {
482                report_last_k: 2,
483                elimination_fraction: 0.4,
484                random_seed: 11,
485            },
486        )
487        .expect("election should succeed");
488
489        assert_eq!(result.candidate_best_results.len(), 5);
490        assert_eq!(
491            result
492                .candidate_best_results
493                .iter()
494                .map(|best| best.id.as_str())
495                .collect::<Vec<_>>(),
496            result
497                .all_ranked_ids
498                .iter()
499                .map(String::as_str)
500                .collect::<Vec<_>>()
501        );
502        assert!(
503            result.rounds[0].round_index < result.candidate_best_results[0].full_round_index
504                || result
505                    .candidate_best_results
506                    .iter()
507                    .any(|best| best.active_candidates > result.rounds[0].active_candidates)
508        );
509        assert!(result
510            .candidate_best_results
511            .iter()
512            .all(|best| best.full_round_index >= 1 && best.active_candidates >= 1));
513    }
514}