mod util;
use std::{env, error::Error, process::exit, sync::Arc};
use console::style;
use fhe::{
bfv::{self, Ciphertext, Encoding, Plaintext, PublicKey, SecretKey},
mbfv::{AggregateIter, CommonRandomPoly, DecryptionShare, PublicKeyShare},
};
use fhe_traits::{FheDecoder, FheEncoder, FheEncrypter};
use rand::{
distr::{Distribution, Uniform},
rng,
};
use util::timeit::{timeit, timeit_n};
fn print_notice_and_exit(error: Option<String>) {
println!(
"{} Voting with fhe.rs",
style(" overview:").magenta().bold()
);
println!(
"{} voting [-h] [--help] [--num_voters=<value>] [--num_parties=<value>]",
style(" usage:").magenta().bold()
);
println!(
"{} {} and {} must be at least 1",
style("constraints:").magenta().bold(),
style("num_voters").blue(),
style("num_parties").blue(),
);
if let Some(error) = error {
println!("{} {}", style(" error:").red().bold(), error);
}
exit(0);
}
fn main() -> Result<(), Box<dyn Error>> {
let degree = 4096;
let plaintext_modulus: u64 = 4096;
let moduli = vec![0xffffee001, 0xffffc4001, 0x1ffffe0001];
let args: Vec<String> = env::args().skip(1).collect();
if args.contains(&"-h".to_string()) || args.contains(&"--help".to_string()) {
print_notice_and_exit(None)
}
let mut num_voters = 1000;
let mut num_parties = 10;
for arg in &args {
if arg.starts_with("--num_voters") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(Some("Invalid `--num_voters` argument".to_string()))
} else {
num_voters = a[0].parse::<usize>()?
}
} else if arg.starts_with("--num_parties") {
let a: Vec<&str> = arg.rsplit('=').collect();
if a.len() != 2 || a[0].parse::<usize>().is_err() {
print_notice_and_exit(Some("Invalid `--num_parties` argument".to_string()))
} else {
num_parties = a[0].parse::<usize>()?
}
} else {
print_notice_and_exit(Some(format!("Unrecognized argument: {arg}")))
}
}
if num_parties == 0 || num_voters == 0 {
print_notice_and_exit(Some("Voter and party sizes must be nonzero".to_string()))
}
println!("# Voting with fhe.rs");
println!("\tnum_voters = {num_voters}");
println!("\tnum_parties = {num_parties}");
let params = timeit!(
"Parameters generation",
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli(&moduli)
.build_arc()?
);
let mut rng = rng();
let crp = CommonRandomPoly::new(¶ms, &mut rng)?;
struct Party {
sk_share: SecretKey,
pk_share: PublicKeyShare,
}
let mut parties = Vec::with_capacity(num_parties);
timeit_n!("Party setup (per party)", num_parties as u32, {
let sk_share = SecretKey::random(¶ms, &mut rng);
let pk_share = PublicKeyShare::new(&sk_share, crp.clone(), &mut rng)?;
parties.push(Party { sk_share, pk_share });
});
let pk = timeit!("Public key aggregation", {
let pk: PublicKey = parties.iter().map(|p| p.pk_share.clone()).aggregate()?;
pk
});
let dist = Uniform::new_inclusive(0u64, 1u64).unwrap();
let votes: Vec<u64> = dist.sample_iter(&mut rng).take(num_voters).collect();
let mut votes_encrypted = Vec::with_capacity(num_voters);
let mut _i = 0;
timeit_n!("Vote casting (per voter)", num_voters as u32, {
#[allow(unused_assignments)]
let pt = Plaintext::try_encode(&[votes[_i]], Encoding::poly(), ¶ms)?;
let ct = pk.try_encrypt(&pt, &mut rng)?;
votes_encrypted.push(ct);
_i += 1;
});
let tally = timeit!("Vote tallying", {
let mut sum = Ciphertext::zero(¶ms);
for ct in &votes_encrypted {
sum += ct;
}
Arc::new(sum)
});
let mut decryption_shares = Vec::with_capacity(num_parties);
let mut _i = 0;
timeit_n!("Decryption (per party)", num_parties as u32, {
let sh = DecryptionShare::new(&parties[_i].sk_share, &tally, &mut rng)?;
decryption_shares.push(sh);
_i += 1;
});
let tally_pt = timeit!("Decryption share aggregation", {
let pt: Plaintext = decryption_shares.into_iter().aggregate()?;
pt
});
let tally_vec = Vec::<u64>::try_decode(&tally_pt, Encoding::poly())?;
let tally_result = tally_vec[0];
println!("Vote result = {tally_result} / {num_voters}");
let expected_tally = votes.iter().sum();
assert_eq!(tally_result, expected_tally);
Ok(())
}