use super::errors::TallyError;
use super::RankedCandidate;
use super::RankedWinners;
use hashbrown::HashMap;
use num_traits::cast::NumCast;
use num_traits::Num;
use petgraph::algo::tarjan_scc;
use petgraph::graph::NodeIndex;
use petgraph::Graph;
use std::convert::TryInto;
use std::hash::Hash;
use std::ops::AddAssign;
pub type DefaultCondorcetTally<T> = CondorcetTally<T, u64>;
pub struct CondorcetTally<T, C = u64>
where
T: Eq + Clone + Hash, C: Copy + PartialOrd + AddAssign + Num + NumCast, {
pub(crate) running_total: HashMap<(usize, usize), C>,
pub(crate) num_winners: usize,
pub(crate) candidates: HashMap<T, usize>, check_votes: bool,
}
impl<T, C> CondorcetTally<T, C>
where
T: Eq + Clone + Hash, C: Copy + PartialOrd + AddAssign + Num + NumCast, {
pub fn new(num_winners: usize) -> Self {
CondorcetTally {
running_total: HashMap::new(),
num_winners: num_winners,
candidates: HashMap::new(),
check_votes: true,
}
}
pub fn with_candidates(num_winners: usize, candidates: Vec<T>) -> Self {
let mut tally = CondorcetTally {
running_total: HashMap::with_capacity(candidates.len() ^ 2),
num_winners: num_winners,
candidates: HashMap::with_capacity(candidates.len()),
check_votes: true,
};
tally.add_candidates(candidates);
tally
}
pub fn unchecked(mut self) -> Self {
self.check_votes = false;
self
}
pub fn add_candidate(&mut self, candidate: T) {
let candidate_id = self.candidates.len();
self.candidates.insert(candidate, candidate_id);
}
pub fn add_candidates(&mut self, mut candidates: Vec<T>) {
for candidate in candidates.drain(..) {
self.add_candidate(candidate)
}
}
pub fn add(&mut self, vote: &[T]) -> Result<(), TallyError> {
self.add_weighted(vote, C::one())
}
pub fn add_weighted(&mut self, vote: &[T], weight: C) -> Result<(), TallyError> {
if self.check_votes {
self.check_vote(vote)?;
}
let selection = self.unranked_mapped_candidates(&vote);
self.add_ranked_candidate_ids(selection, weight);
Ok(())
}
pub fn ranked_add(&mut self, vote: &[(T, u32)]) -> Result<(), TallyError> {
self.ranked_add_weighted(vote, C::one())
}
pub fn ranked_add_weighted(&mut self, vote: &[(T, u32)], weight: C) -> Result<(), TallyError> {
if self.check_votes {
self.check_ranked_vote(vote)?;
}
let selection = self.ranked_mapped_candidates(&vote);
self.add_ranked_candidate_ids(selection, weight);
Ok(())
}
fn add_ranked_candidate_ids(&mut self, selection: Vec<(usize, u32)>, weight: C) {
for (i, (candidate_1, rank_1)) in selection.iter().enumerate() {
let mut j = i + 1;
while let Some((candidate_2, rank_2)) = selection.get(j) {
if rank_1 < rank_2 {
*self.running_total.entry((*candidate_1, *candidate_2)).or_insert(C::zero()) += weight;
}
if rank_2 < rank_1 {
*self.running_total.entry((*candidate_2, *candidate_1)).or_insert(C::zero()) += weight;
}
j += 1;
}
}
}
pub fn totals(&self) -> Vec<((T, T), C)> {
let mut totals = Vec::<((T, T), C)>::with_capacity(self.running_total.len());
let mut candidates = HashMap::<usize, T>::with_capacity(self.candidates.len());
for (candidate, i) in self.candidates.iter() {
candidates.insert(*i, candidate.clone());
}
for ((candidate1, candidate2), count) in self.running_total.iter() {
let candidate1 = candidates.get(candidate1).unwrap().clone();
let candidate2 = candidates.get(candidate2).unwrap().clone();
totals.push(((candidate1, candidate2), *count));
}
totals
}
pub fn ranked(&self) -> Vec<RankedCandidate<T>> {
let graph = self.build_graph();
let smith_sets = tarjan_scc(&graph);
let mut candidates = HashMap::<usize, T>::with_capacity(self.candidates.len());
for (candidate, i) in self.candidates.iter() {
candidates.insert(*i, candidate.clone());
}
let mut ranked = Vec::<RankedCandidate<T>>::with_capacity(self.candidates.len());
for (rank, smith_set) in smith_sets.iter().enumerate() {
for graph_id in smith_set.iter() {
let candidate = graph.node_weight(*graph_id).unwrap(); ranked.push(RankedCandidate {
candidate: candidate.clone(),
rank,
});
}
}
ranked
}
pub fn winners(&self) -> RankedWinners<T> {
RankedWinners::from_ranked(self.ranked(), self.num_winners)
}
pub fn build_graph(&self) -> Graph<T, (C, C)> {
let mut graph = Graph::<T, (C, C)>::with_capacity(self.candidates.len(), self.candidates.len() ^ 2);
let mut graph_ids = HashMap::<usize, NodeIndex>::new();
for (candidate, candidate_id) in self.candidates.iter() {
graph_ids.insert(*candidate_id, graph.add_node(candidate.clone()));
}
let zero = C::zero();
for ((candidate_1, candidate_2), votecount_1) in self.running_total.iter() {
let votecount_2 = self.running_total.get(&(*candidate_2, *candidate_1)).unwrap_or(&zero);
if votecount_1 >= votecount_2 {
let candidate_1_id = graph_ids.get(candidate_1).unwrap(); let candidate_2_id = graph_ids.get(candidate_2).unwrap();
graph.add_edge(*candidate_2_id, *candidate_1_id, (*votecount_1, *votecount_2));
}
}
graph
}
pub fn candidates(&self) -> Vec<T> {
self.candidates.iter().map(|(k, _v)| k.clone()).collect()
}
pub fn check_vote(&self, vote: &[T]) -> Result<(), TallyError> {
for candidate in vote {
if self.candidates.get(candidate).is_none() {
return Err(TallyError::UnknownCandidate);
}
}
crate::util::check_duplicates_transitive_vote(vote)?;
Ok(())
}
pub fn check_ranked_vote(&self, vote: &[(T, u32)]) -> Result<(), TallyError> {
for (candidate, _rank) in vote {
if self.candidates.get(candidate).is_none() {
return Err(TallyError::UnknownCandidate);
}
}
crate::util::check_duplicates_ranked_vote(vote)?;
Ok(())
}
fn unranked_mapped_candidates(&mut self, selection: &[T]) -> Vec<(usize, u32)> {
let mut mapped = Vec::<(usize, u32)>::new();
for (candidate, candidate_id) in self.candidates.iter() {
let index = selection.iter().position(|ref r| *r == candidate);
let rank = match index {
Some(i) => i,
None => selection.len(),
};
mapped.push((*candidate_id, rank.try_into().unwrap())); }
mapped
}
fn ranked_mapped_candidates(&mut self, selection: &[(T, u32)]) -> Vec<(usize, u32)> {
let mut mapped = Vec::<(usize, u32)>::new();
let mut trailing_candidates = Vec::<usize>::new();
let mut max_rank = 0;
for (candidate, candidate_id) in self.candidates.iter() {
let ranked_candidate = selection.iter().find(|ref r| &(r.0) == candidate);
match ranked_candidate {
Some(rc) => {
max_rank = std::cmp::max(max_rank, rc.1);
mapped.push((*candidate_id, rc.1));
}
None => trailing_candidates.push(*candidate_id),
};
}
for candidate_id in trailing_candidates {
mapped.push((candidate_id, max_rank + 1));
}
mapped
}
}
#[cfg(test)]
mod tests {
use super::*;
use maplit::hashset;
use std::collections::HashSet;
use std::iter::FromIterator;
#[test]
fn condorcet_basic() -> Result<(), TallyError> {
let mut tally = DefaultCondorcetTally::with_candidates(2, vec!["Alice", "Bob", "Carol"]);
tally.add(&vec!["Alice", "Bob", "Carol"])?;
tally.add(&vec!["Alice", "Bob", "Carol"])?;
tally.add(&vec!["Alice", "Bob", "Carol"])?;
let totals = tally.totals();
let totals = HashSet::from_iter(totals.iter().cloned()); assert_eq!(
totals,
hashset![(("Alice", "Bob"), 3), (("Bob", "Carol"), 3), (("Alice", "Carol"), 3)]
);
let winners = tally.winners();
assert_eq!(winners.into_vec(), vec! {("Alice", 0), ("Bob", 1)});
let mut tally = DefaultCondorcetTally::with_candidates(2, vec!["Alice", "Bob", "Carol"]);
tally.add(&vec!["Alice", "Bob", "Carol"])?;
tally.add(&vec!["Bob", "Carol", "Alice"])?;
tally.add(&vec!["Carol", "Alice", "Bob"])?;
let winners = tally.winners();
assert_eq!(winners.is_empty(), false);
assert_eq!(winners.check_overflow(), true);
assert_eq!(winners.all().len(), 3);
assert_eq!(winners.overflow().unwrap().len(), 3);
assert_eq!(winners.rank(&"Alice").unwrap(), 0);
assert_eq!(winners.rank(&"Bob").unwrap(), 0);
assert_eq!(winners.rank(&"Carol").unwrap(), 0);
Ok(())
}
#[test]
fn condorcet_wikipedia() -> Result<(), TallyError> {
let mut tally = DefaultCondorcetTally::with_candidates(4, vec!["Memphis", "Nashville", "Chattanooga", "Knoxville"]);
tally.add_weighted(&vec!["Memphis", "Nashville", "Chattanooga", "Knoxville"], 42)?;
tally.add_weighted(&vec!["Nashville", "Chattanooga", "Knoxville", "Memphis"], 26)?;
tally.add_weighted(&vec!["Chattanooga", "Knoxville", "Nashville", "Memphis"], 15)?;
tally.add_weighted(&vec!["Knoxville", "Chattanooga", "Nashville", "Memphis"], 17)?;
let candidates = tally.candidates();
let candidates = HashSet::from_iter(candidates.iter().cloned()); assert_eq!(candidates, hashset!["Memphis", "Nashville", "Chattanooga", "Knoxville"]);
let totals = tally.totals();
let totals = HashSet::from_iter(totals.iter().cloned()); assert_eq!(
totals,
hashset![
(("Memphis", "Nashville"), 42),
(("Nashville", "Memphis"), 58),
(("Memphis", "Chattanooga"), 42),
(("Chattanooga", "Memphis"), 58),
(("Memphis", "Knoxville"), 42),
(("Knoxville", "Memphis"), 58),
(("Nashville", "Chattanooga"), 68),
(("Chattanooga", "Nashville"), 32),
(("Nashville", "Knoxville"), 68),
(("Knoxville", "Nashville"), 32),
(("Chattanooga", "Knoxville"), 83),
(("Knoxville", "Chattanooga"), 17),
]
);
let winners = tally.winners();
assert_eq!(
winners.into_vec(),
vec! {("Nashville", 0), ("Chattanooga", 1), ("Knoxville", 2), ("Memphis", 3)}
);
Ok(())
}
#[test]
fn condorcet_graph() -> Result<(), TallyError> {
let mut tally = DefaultCondorcetTally::with_candidates(1, vec!["a", "b", "c", "d"]);
tally.add_weighted(&vec!["a", "c", "d", "b"], 8)?;
tally.add_weighted(&vec!["b", "a", "d", "c"], 2)?;
tally.add_weighted(&vec!["c", "d", "b", "a"], 4)?;
tally.add_weighted(&vec!["d", "b", "a", "c"], 4)?;
tally.add_weighted(&vec!["d", "c", "b", "a"], 3)?;
let graph = tally.build_graph();
assert_eq!(graph.node_count(), 4);
assert_eq!(graph.edge_count(), 6);
for index in graph.node_indices() {
let candidate = *graph.node_weight(index).unwrap();
for edge in graph.edges(index).map(|e| e.weight()) {
match candidate {
"a" => assert!(*edge == (13, 8) || *edge == (11, 10) || *edge == (14, 7)),
"b" => assert!(*edge == (13, 8) || *edge == (15, 6) || *edge == (19, 2)),
"c" => assert!(*edge == (12, 9) || *edge == (15, 6) || *edge == (14, 7)),
"d" => assert!(*edge == (12, 9) || *edge == (11, 10) || *edge == (19, 2)),
_ => panic!("Invalid candidate"),
}
}
}
Ok(())
}
}