Skip to main content

popsam_core/
election.rs

1use crate::error::{PopsamError, PopsamResult};
2use crate::model::{CandidateRoundVotes, EmbeddedText, EmbeddedTextInput, ElectionResult, RoundSummary};
3use rand::{Rng, SeedableRng};
4use rand_chacha::ChaCha8Rng;
5use std::cmp::Ordering;
6
7/// Configuration for the representative-text election.
8#[derive(Debug, Clone)]
9pub struct ElectionConfig {
10    /// Number of final candidates to keep and report.
11    pub report_last_k: usize,
12    /// Fraction of active candidates to eliminate in early rounds.
13    pub elimination_fraction: f32,
14    /// Seed used for all random tie-breaks.
15    pub random_seed: u64,
16}
17
18impl Default for ElectionConfig {
19    fn default() -> Self {
20        Self {
21            report_last_k: 10,
22            elimination_fraction: 0.05,
23            random_seed: 42,
24        }
25    }
26}
27
28/// Runs the elimination process on a collection of embedded texts.
29///
30/// The function normalizes all embeddings, repeatedly assigns first/second/third
31/// preference votes by cosine similarity, and eliminates the weakest candidates
32/// until a single winner remains.
33///
34/// The returned [`ElectionResult`] includes:
35/// - the final winner
36/// - the last `k` surviving candidates
37/// - round-by-round vote totals for the reported suffix
38/// - normalized embeddings for all processed inputs
39pub fn run_election(
40    inputs: Vec<EmbeddedTextInput>,
41    config: ElectionConfig,
42) -> PopsamResult<ElectionResult> {
43    if inputs.is_empty() {
44        return Err(PopsamError::EmptyInput);
45    }
46    if config.report_last_k == 0 {
47        return Err(PopsamError::InvalidReportK);
48    }
49    if !(0.0 < config.elimination_fraction && config.elimination_fraction <= 1.0) {
50        return Err(PopsamError::InvalidEliminationFraction);
51    }
52
53    let dimension = inputs[0].embedding.len();
54    let mut embeddings = Vec::with_capacity(inputs.len());
55    for input in inputs {
56        if input.embedding.is_empty() {
57            return Err(PopsamError::EmptyEmbedding { id: input.id });
58        }
59        if input.embedding.len() != dimension {
60            return Err(PopsamError::DimensionMismatch {
61                id: input.id,
62                expected: dimension,
63                actual: input.embedding.len(),
64            });
65        }
66        let normalized = normalize(&input.id, input.embedding)?;
67        embeddings.push(EmbeddedText {
68            id: input.id,
69            text: input.text,
70            embedding: normalized,
71        });
72    }
73
74    let report_k = config.report_last_k.min(embeddings.len());
75    let mut rng = ChaCha8Rng::seed_from_u64(config.random_seed);
76    let mut active: Vec<usize> = (0..embeddings.len()).collect();
77    let mut eliminated_ranked = Vec::with_capacity(embeddings.len().saturating_sub(1));
78    let mut rounds = Vec::new();
79    let mut representative_ids = Vec::new();
80
81    while active.len() > 1 {
82        let vote_rows = tally_votes(&embeddings, &active, &mut rng);
83        let eliminate_count = elimination_count(active.len(), report_k, config.elimination_fraction);
84        let weakest = weakest_candidates(&vote_rows, eliminate_count, &mut rng);
85
86        if active.len() <= report_k {
87            if representative_ids.is_empty() {
88                representative_ids = active
89                    .iter()
90                    .map(|&idx| embeddings[idx].id.clone())
91                    .collect();
92            }
93            rounds.push(build_round_summary(&embeddings, &active, &vote_rows, &weakest, rounds.len() + 1));
94        }
95
96        let mut removed_flags = vec![false; embeddings.len()];
97        for &idx in &weakest {
98            removed_flags[idx] = true;
99            eliminated_ranked.push(embeddings[idx].id.clone());
100        }
101        active.retain(|idx| !removed_flags[*idx]);
102    }
103
104    let winner_idx = active[0];
105    let winner_id = embeddings[winner_idx].id.clone();
106    if representative_ids.is_empty() {
107        representative_ids.push(winner_id.clone());
108    }
109    if report_k > 0 {
110        let final_votes = tally_votes(&embeddings, &active, &mut rng);
111        rounds.push(build_round_summary(
112            &embeddings,
113            &active,
114            &final_votes,
115            &[],
116            rounds.len() + 1,
117        ));
118    }
119    let mut all_ranked_ids = vec![winner_id.clone()];
120    eliminated_ranked.reverse();
121    all_ranked_ids.extend(eliminated_ranked);
122
123    Ok(ElectionResult {
124        winner_id,
125        representative_ids,
126        all_ranked_ids,
127        rounds,
128        embeddings,
129    })
130}
131
132fn normalize(id: &str, embedding: Vec<f32>) -> PopsamResult<Vec<f32>> {
133    let norm = embedding
134        .iter()
135        .map(|value| value * value)
136        .sum::<f32>()
137        .sqrt();
138    if norm == 0.0 {
139        return Err(PopsamError::ZeroNorm {
140            id: id.to_string(),
141        });
142    }
143    Ok(embedding.into_iter().map(|value| value / norm).collect())
144}
145
146fn elimination_count(active: usize, report_k: usize, elimination_fraction: f32) -> usize {
147    if active <= report_k {
148        return 1;
149    }
150    let max_allowed = active - report_k;
151    let batch = ((active as f32) * elimination_fraction).ceil() as usize;
152    batch.max(1).min(max_allowed)
153}
154
155fn tally_votes(
156    embeddings: &[EmbeddedText],
157    active: &[usize],
158    rng: &mut ChaCha8Rng,
159) -> Vec<(usize, CandidateRoundVotes)> {
160    let mut candidate_positions = vec![None; embeddings.len()];
161    for (position, &candidate_idx) in active.iter().enumerate() {
162        candidate_positions[candidate_idx] = Some(position);
163    }
164
165    let mut tallies: Vec<CandidateRoundVotes> = active
166        .iter()
167        .map(|&idx| CandidateRoundVotes {
168            id: embeddings[idx].id.clone(),
169            first_votes: 0,
170            second_votes: 0,
171            third_votes: 0,
172        })
173        .collect();
174
175    for voter in embeddings {
176        let mut ranked = active
177            .iter()
178            .map(|&candidate_idx| {
179                (
180                    candidate_idx,
181                    cosine_similarity(&voter.embedding, &embeddings[candidate_idx].embedding),
182                    rng.random::<u64>(),
183                )
184            })
185            .collect::<Vec<_>>();
186        ranked.sort_by(|left, right| compare_similarity(right, left));
187
188        for (rank, (candidate_idx, _, _)) in ranked.into_iter().take(3).enumerate() {
189            let tally_index = candidate_positions[candidate_idx].expect("candidate must exist");
190            match rank {
191                0 => tallies[tally_index].first_votes += 1,
192                1 => tallies[tally_index].second_votes += 1,
193                2 => tallies[tally_index].third_votes += 1,
194                _ => {}
195            }
196        }
197    }
198
199    active.iter().copied().zip(tallies).collect()
200}
201
202fn weakest_candidates(
203    vote_rows: &[(usize, CandidateRoundVotes)],
204    eliminate_count: usize,
205    rng: &mut ChaCha8Rng,
206) -> Vec<usize> {
207    let mut ranked = vote_rows
208        .iter()
209        .map(|(idx, votes)| (*idx, votes.clone(), rng.random::<u64>()))
210        .collect::<Vec<_>>();
211    ranked.sort_by(|left, right| {
212        left.1
213            .first_votes
214            .cmp(&right.1.first_votes)
215            .then(left.1.second_votes.cmp(&right.1.second_votes))
216            .then(left.1.third_votes.cmp(&right.1.third_votes))
217            .then(left.2.cmp(&right.2))
218    });
219    ranked
220        .into_iter()
221        .take(eliminate_count)
222        .map(|entry| entry.0)
223        .collect()
224}
225
226fn build_round_summary(
227    embeddings: &[EmbeddedText],
228    active: &[usize],
229    vote_rows: &[(usize, CandidateRoundVotes)],
230    eliminated: &[usize],
231    round_index: usize,
232) -> RoundSummary {
233    let mut votes = vote_rows.iter().map(|(_, votes)| votes.clone()).collect::<Vec<_>>();
234    votes.sort_by(|left, right| {
235        right
236            .first_votes
237            .cmp(&left.first_votes)
238            .then(right.second_votes.cmp(&left.second_votes))
239            .then(right.third_votes.cmp(&left.third_votes))
240            .then(left.id.cmp(&right.id))
241    });
242
243    RoundSummary {
244        round_index,
245        active_candidates: active.len(),
246        eliminated_candidate_ids: eliminated
247            .iter()
248            .map(|&idx| embeddings[idx].id.clone())
249            .collect(),
250        votes,
251    }
252}
253
254fn cosine_similarity(left: &[f32], right: &[f32]) -> f32 {
255    left.iter().zip(right.iter()).map(|(a, b)| a * b).sum()
256}
257
258fn compare_similarity(
259    left: &(usize, f32, u64),
260    right: &(usize, f32, u64),
261) -> Ordering {
262    left.1
263        .total_cmp(&right.1)
264        .then_with(|| left.2.cmp(&right.2))
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270
271    #[test]
272    fn keeps_last_k_rounds_and_winner() {
273        let inputs = vec![
274            EmbeddedTextInput { id: "a".into(), text: None, embedding: vec![1.0, 0.0] },
275            EmbeddedTextInput { id: "b".into(), text: None, embedding: vec![0.9, 0.1] },
276            EmbeddedTextInput { id: "c".into(), text: None, embedding: vec![0.8, 0.2] },
277            EmbeddedTextInput { id: "d".into(), text: None, embedding: vec![0.0, 1.0] },
278        ];
279
280        let result = run_election(
281            inputs,
282            ElectionConfig {
283                report_last_k: 3,
284                elimination_fraction: 0.25,
285                random_seed: 7,
286            },
287        )
288        .expect("election should succeed");
289
290        assert_eq!(result.representative_ids.len(), 3);
291        assert_eq!(result.rounds.len(), 3);
292        assert_eq!(result.rounds[0].active_candidates, 3);
293        assert_eq!(result.rounds[2].active_candidates, 1);
294        assert_eq!(result.winner_id, result.all_ranked_ids[0]);
295    }
296
297    #[test]
298    fn rejects_dimension_mismatch() {
299        let inputs = vec![
300            EmbeddedTextInput { id: "a".into(), text: None, embedding: vec![1.0, 0.0] },
301            EmbeddedTextInput { id: "b".into(), text: None, embedding: vec![1.0] },
302        ];
303
304        let error = run_election(inputs, ElectionConfig::default()).expect_err("must fail");
305        assert!(matches!(error, PopsamError::DimensionMismatch { .. }));
306    }
307}