use bls::{PublicKey, PublicKeySet, SecretKey, SecretKeyShare, Signature};
use std::collections::{BTreeMap, BTreeSet};
use crate::error::{Error, Result};
use crate::knowledge::{Knowledge, KnowledgeFault};
use crate::sdkg::{AckOutcome, Part, PartOutcome, SyncKeyGen};
use crate::vote::{DkgSignedVote, DkgVote, IdAck, IdPart, NodeId};
pub struct DkgState {
id: NodeId,
secret_key: SecretKey,
pub_keys: BTreeMap<NodeId, PublicKey>,
keygen: SyncKeyGen<NodeId>,
our_part: Part,
all_votes: BTreeSet<DkgSignedVote>,
reached_termination: bool,
}
pub enum VoteResponse {
WaitingForMoreVotes,
BroadcastVote(Box<DkgSignedVote>),
RequestAntiEntropy,
DkgComplete(PublicKeySet, SecretKeyShare),
}
enum DkgCurrentState {
IncompatibleVotes,
MissingParts,
MissingAcks,
Termination(BTreeMap<IdPart, BTreeSet<IdAck>>),
WaitingForTotalAgreement(BTreeMap<IdPart, BTreeSet<IdAck>>),
GotAllAcks(BTreeMap<IdPart, BTreeSet<IdAck>>),
WaitingForMoreAcks(BTreeSet<IdPart>),
GotAllParts(BTreeSet<IdPart>),
WaitingForMoreParts,
}
impl DkgState {
pub fn new<R: bls::rand::RngCore>(
our_id: NodeId,
secret_key: SecretKey,
pub_keys: BTreeMap<NodeId, PublicKey>,
threshold: usize,
mut rng: R,
) -> Result<Self> {
let (sync_key_gen, opt_part) = SyncKeyGen::new(
our_id,
secret_key.clone(),
pub_keys.clone(),
threshold,
&mut rng,
)?;
Ok(DkgState {
id: our_id,
secret_key,
pub_keys,
keygen: sync_key_gen,
all_votes: BTreeSet::new(),
our_part: opt_part.ok_or(Error::NotInPubKeySet)?,
reached_termination: false,
})
}
pub fn id(&self) -> NodeId {
self.id
}
pub fn first_vote(&mut self) -> Result<DkgSignedVote> {
let vote = DkgVote::SinglePart(self.our_part.clone());
let signed_vote = self.signed_vote(vote)?;
self.all_votes.insert(signed_vote.clone());
Ok(signed_vote)
}
fn get_validated_vote(&self, vote: &DkgSignedVote) -> Result<DkgVote> {
let sender_id = vote.voter;
let sender_pub_key = self.pub_keys.get(&sender_id).ok_or(Error::UnknownSender)?;
let vote = vote.get_validated_vote(sender_pub_key)?;
Ok(vote)
}
fn all_checked_votes(&self) -> Result<Vec<(DkgVote, NodeId)>> {
self.all_votes
.iter()
.map(|v| Ok((self.get_validated_vote(v)?, v.voter)))
.collect()
}
fn current_dkg_state(&self, votes: Vec<(DkgVote, NodeId)>) -> DkgCurrentState {
let knowledge = match Knowledge::from_votes(votes) {
Err(KnowledgeFault::IncompatibleAcks) | Err(KnowledgeFault::IncompatibleParts) => {
return DkgCurrentState::IncompatibleVotes;
}
Err(KnowledgeFault::MissingParts) => {
return DkgCurrentState::MissingParts;
}
Err(KnowledgeFault::MissingAcks) => {
return DkgCurrentState::MissingAcks;
}
Ok(k) => k,
};
let num_participants = self.pub_keys.len();
if knowledge.agreed_with_all_acks.len() == num_participants {
DkgCurrentState::Termination(knowledge.part_acks)
} else if !knowledge.agreed_with_all_acks.is_empty() {
DkgCurrentState::WaitingForTotalAgreement(knowledge.part_acks)
} else if knowledge.got_all_acks(num_participants) {
DkgCurrentState::GotAllAcks(knowledge.part_acks)
} else if !knowledge.part_acks.is_empty() {
DkgCurrentState::WaitingForMoreAcks(knowledge.parts)
} else if knowledge.parts.len() == num_participants {
DkgCurrentState::GotAllParts(knowledge.parts)
} else {
DkgCurrentState::WaitingForMoreParts
}
}
fn we_sent_our_all_acks(&self) -> bool {
let our_id = self.id();
self.all_votes
.iter()
.filter(|v| v.is_all_acks())
.any(|v| v.voter == our_id)
}
fn dkg_state_with_vote(
&self,
votes: Vec<(DkgVote, NodeId)>,
vote: &DkgVote,
) -> DkgCurrentState {
let dkg_state = self.current_dkg_state(votes);
match dkg_state {
DkgCurrentState::WaitingForMoreAcks(parts)
if matches!(vote, DkgVote::SinglePart(_)) =>
{
DkgCurrentState::GotAllParts(parts)
}
DkgCurrentState::WaitingForTotalAgreement(part_acks)
if !self.we_sent_our_all_acks() =>
{
DkgCurrentState::GotAllAcks(part_acks)
}
DkgCurrentState::MissingParts if matches!(vote, DkgVote::SinglePart(_)) => {
DkgCurrentState::WaitingForMoreParts
}
DkgCurrentState::MissingAcks if matches!(vote, DkgVote::SingleAck(_)) => {
DkgCurrentState::WaitingForMoreAcks(Default::default())
}
_ => dkg_state,
}
}
pub fn sign_vote(&self, vote: &DkgVote) -> Result<Signature> {
let sig = self.secret_key.sign(bincode::serialize(vote)?);
Ok(sig)
}
fn signed_vote(&mut self, vote: DkgVote) -> Result<DkgSignedVote> {
let sig = self.sign_vote(&vote)?;
let signed_vote = DkgSignedVote::new(vote, self.id, sig);
Ok(signed_vote)
}
fn handle_all_acks(&mut self, all_acks: BTreeMap<IdPart, BTreeSet<IdAck>>) -> Result<()> {
for ((part_id, _part), acks) in all_acks {
for (sender_id, ack) in acks {
let outcome = self.keygen.handle_ack(&sender_id, ack.clone())?;
if let AckOutcome::Invalid(fault) = outcome {
return Err(Error::FaultyVote(format!(
"Ack fault: {fault:?} by {sender_id:?} for part by {part_id:?}"
)));
}
}
}
Ok(())
}
fn parts_into_acks<R: bls::rand::RngCore>(
&mut self,
parts: BTreeSet<IdPart>,
mut rng: R,
) -> Result<DkgVote> {
let mut acks = BTreeMap::new();
for (sender_id, part) in parts {
match self
.keygen
.handle_part(&sender_id, part.clone(), &mut rng)?
{
PartOutcome::Valid(Some(ack)) => {
acks.insert((sender_id, part), ack);
}
PartOutcome::Invalid(fault) => {
return Err(Error::FaultyVote(format!(
"Part fault: {fault:?} by {sender_id:?}"
)));
}
PartOutcome::Valid(None) => {
return Err(Error::FaultyVote("unexpected part outcome, node is faulty or keygen already handled this part".to_string()));
}
}
}
Ok(DkgVote::SingleAck(acks))
}
pub fn all_votes(&self) -> Vec<DkgSignedVote> {
self.all_votes.iter().cloned().collect()
}
pub fn outcome(&self) -> Result<Option<(PublicKeySet, SecretKeyShare)>> {
let votes = self.all_checked_votes()?;
if let DkgCurrentState::Termination(_) = self.current_dkg_state(votes) {
if let (pubs, Some(sec)) = self.keygen.generate()? {
Ok(Some((pubs, sec)))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
pub fn force_termination(&mut self) -> Result<Option<(PublicKeySet, SecretKeyShare)>> {
let votes = self.all_checked_votes()?;
match self.current_dkg_state(votes) {
DkgCurrentState::Termination(acks)
| DkgCurrentState::WaitingForTotalAgreement(acks)
| DkgCurrentState::GotAllAcks(acks) => {
self.handle_all_acks(acks)?;
self.reached_termination = true;
if let (pubs, Some(sec)) = self.keygen.generate()? {
Ok(Some((pubs, sec)))
} else {
Ok(None)
}
}
_ => Err(Error::FailedForceGenerationBecauseMissingAcks),
}
}
pub fn reached_termination(&self) -> Result<bool> {
Ok(self.reached_termination)
}
pub fn handle_signed_vote<R: bls::rand::RngCore>(
&mut self,
msg: DkgSignedVote,
mut rng: R,
) -> Result<Vec<VoteResponse>> {
if self.all_votes.contains(&msg) {
return Ok(vec![]);
}
let last_vote = self.get_validated_vote(&msg)?;
let _ = self.all_votes.insert(msg);
let votes = self.all_checked_votes()?;
let dkg_state = self.dkg_state_with_vote(votes, &last_vote);
match dkg_state {
DkgCurrentState::MissingParts | DkgCurrentState::MissingAcks => {
Ok(vec![VoteResponse::RequestAntiEntropy])
}
DkgCurrentState::Termination(acks) => {
self.handle_all_acks(acks)?;
if let (pubs, Some(sec)) = self.keygen.generate()? {
self.reached_termination = true;
Ok(vec![VoteResponse::DkgComplete(pubs, sec)])
} else {
Err(Error::FailedToGenerateSecretKeyShare)
}
}
DkgCurrentState::GotAllAcks(acks) => {
let vote = DkgVote::AllAcks(acks);
let signed_vote = self.signed_vote(vote)?;
let mut res = vec![VoteResponse::BroadcastVote(Box::new(signed_vote.clone()))];
let our_vote_res = self.handle_signed_vote(signed_vote, rng)?;
res.extend(
our_vote_res
.into_iter()
.filter(|r| !matches!(r, VoteResponse::WaitingForMoreVotes)),
);
Ok(res)
}
DkgCurrentState::GotAllParts(parts) => {
let vote = self.parts_into_acks(parts, &mut rng)?;
let signed_vote = self.signed_vote(vote)?;
let mut res = vec![VoteResponse::BroadcastVote(Box::new(signed_vote.clone()))];
let our_vote_res = self.handle_signed_vote(signed_vote, rng)?;
res.extend(
our_vote_res
.into_iter()
.filter(|r| !matches!(r, VoteResponse::WaitingForMoreVotes)),
);
Ok(res)
}
DkgCurrentState::WaitingForMoreParts
| DkgCurrentState::WaitingForMoreAcks(_)
| DkgCurrentState::WaitingForTotalAgreement(_) => {
Ok(vec![VoteResponse::WaitingForMoreVotes])
}
DkgCurrentState::IncompatibleVotes => {
Err(Error::FaultyVote("got incompatible votes".to_string()))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{assert_match, sdkg::tests::verify_threshold, vote::test_utils::*};
use bls::{
rand::{rngs::StdRng, seq::IteratorRandom, thread_rng, Rng, RngCore, SeedableRng},
SignatureShare,
};
use eyre::{eyre, Result};
use std::env;
#[test]
fn test_force_termination() {
let mut rng = bls::rand::rngs::OsRng;
let sec_key0: SecretKey = bls::rand::random();
let sec_key1: SecretKey = bls::rand::random();
let sec_key2: SecretKey = bls::rand::random();
let pub_keys: BTreeMap<u8, PublicKey> = BTreeMap::from([
(0, sec_key0.public_key()),
(1, sec_key1.public_key()),
(2, sec_key2.public_key()),
]);
let threshold = 1;
let mut dkg_state0 = DkgState::new(0, sec_key0, pub_keys.clone(), threshold, &mut rng)
.expect("Failed to create DKG state");
let mut dkg_state1 = DkgState::new(1, sec_key1, pub_keys.clone(), threshold, &mut rng)
.expect("Failed to create DKG state");
let mut dkg_state2 = DkgState::new(2, sec_key2, pub_keys, threshold, &mut rng)
.expect("Failed to create DKG state");
let part0 = dkg_state0.first_vote().expect("Failed to get first vote");
let part1 = dkg_state1.first_vote().expect("Failed to get first vote");
let part2 = dkg_state2.first_vote().expect("Failed to get first vote");
let res = &dkg_state0.force_termination();
assert!(matches!(
res,
Err(Error::FailedForceGenerationBecauseMissingAcks)
));
let _ = &dkg_state0.handle_signed_vote(part1.clone(), &mut rng);
let res = &dkg_state0.handle_signed_vote(part2.clone(), &mut rng);
let acks0 =
assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(acks)]) => *acks.clone());
let _ = &dkg_state1.handle_signed_vote(part0.clone(), &mut rng);
let res = &dkg_state1.handle_signed_vote(part2, &mut rng);
let acks1 =
assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(acks)]) => *acks.clone());
let _ = &dkg_state2.handle_signed_vote(part0, &mut rng);
let res = &dkg_state2.handle_signed_vote(part1, &mut rng);
let acks2 =
assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(acks)]) => *acks.clone());
let res = &dkg_state0.force_termination();
assert!(matches!(
res,
Err(Error::FailedForceGenerationBecauseMissingAcks)
));
let _ = &dkg_state0.handle_signed_vote(acks1.clone(), &mut rng);
let res = &dkg_state0.force_termination();
assert!(matches!(
res,
Err(Error::FailedForceGenerationBecauseMissingAcks)
));
let res = &dkg_state0.handle_signed_vote(acks2.clone(), &mut rng);
let all_acks0 = assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(all_acks)]) => *all_acks.clone());
let _ = &dkg_state1.handle_signed_vote(acks0.clone(), &mut rng);
let res = &dkg_state1.handle_signed_vote(acks2, &mut rng);
let all_acks1 = assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(all_acks)]) => *all_acks.clone());
let _ = &dkg_state2.handle_signed_vote(acks0, &mut rng);
let res = &dkg_state2.handle_signed_vote(acks1, &mut rng);
let _all_acks2 = assert_match!(res.as_deref(), Ok([VoteResponse::BroadcastVote(all_acks)]) => *all_acks.clone());
let res = &dkg_state0.force_termination();
let (pubs0, sec0) = assert_match!(res, Ok(Some(keypair)) => keypair);
let res = &dkg_state1.handle_signed_vote(all_acks0.clone(), &mut rng);
assert!(matches!(
res.as_deref(),
Ok([VoteResponse::WaitingForMoreVotes])
));
let res = &dkg_state1.force_termination();
let (pubs1, sec1) = assert_match!(res, Ok(Some(keypair)) => keypair);
let res = &dkg_state2.handle_signed_vote(all_acks0, &mut rng);
assert!(matches!(
res.as_deref(),
Ok([VoteResponse::WaitingForMoreVotes])
));
let _ = &dkg_state2.handle_signed_vote(all_acks1, &mut rng);
let res = &dkg_state2.force_termination();
let (pubs2, sec2) = assert_match!(res, Ok(Some(keypair)) => keypair);
assert_eq!(pubs0, pubs1);
assert_eq!(pubs1, pubs2);
let msg = "signed message";
let sig_shares: BTreeMap<usize, SignatureShare> =
BTreeMap::from([(0, sec0.sign(msg)), (1, sec1.sign(msg))]);
let sig = pubs2
.combine_signatures(&sig_shares)
.expect("Failed to combine signatures");
assert!(pubs2.public_key().verify(&sig, msg));
let sig_shares: BTreeMap<usize, SignatureShare> =
BTreeMap::from([(1, sec1.sign(msg)), (2, sec2.sign(msg))]);
let sig = pubs0
.combine_signatures(&sig_shares)
.expect("Failed to combine signatures");
assert!(pubs0.public_key().verify(&sig, msg));
}
#[test]
fn test_recursive_handle_vote() {
let mut rng = bls::rand::rngs::OsRng;
let sec_key0: SecretKey = bls::rand::random();
let pub_keys: BTreeMap<u8, PublicKey> = BTreeMap::from([(0, sec_key0.public_key())]);
let threshold = 1;
let mut dkg_state0 = DkgState::new(0, sec_key0, pub_keys, threshold, &mut rng)
.expect("Failed to create DKG state");
let part0 = dkg_state0.first_vote().expect("Failed to get first vote");
dkg_state0.all_votes = BTreeSet::new();
let res = dkg_state0
.handle_signed_vote(part0, &mut rng)
.expect("failed to handle vote");
assert!(matches!(res[0], VoteResponse::BroadcastVote(_)));
assert!(matches!(res[1], VoteResponse::BroadcastVote(_)));
assert!(matches!(res[2], VoteResponse::DkgComplete(_, _)));
assert_eq!(res.len(), 3);
}
#[test]
fn fuzz_test() -> Result<()> {
let mut fuzz_count = if let Ok(count) = env::var("FUZZ_TEST_COUNT") {
count.parse::<isize>().map_err(|err| eyre!("{err}"))?
} else {
20
};
let mut rng_for_seed = thread_rng();
let num_nodes = 7;
let threshold = 4;
while fuzz_count != 0 {
let seed = rng_for_seed.gen();
println!(" SEED {seed:?} => count_remaining: {fuzz_count}");
let mut rng = StdRng::seed_from_u64(seed);
let mut nodes = generate_nodes(num_nodes, threshold, &mut rng)?;
let mut parts: BTreeMap<usize, DkgSignedVote> = BTreeMap::new();
let mut acks: BTreeMap<usize, DkgSignedVote> = BTreeMap::new();
let mut all_acks: BTreeMap<usize, DkgSignedVote> = BTreeMap::new();
let mut sk_shares: BTreeMap<usize, SecretKeyShare> = BTreeMap::new();
let mut pk_set: BTreeSet<PublicKeySet> = BTreeSet::new();
for node in nodes.iter_mut() {
parts.insert(node.id() as usize, node.first_vote()?);
}
for cmd in fuzz_commands(num_nodes, seed) {
let (to_nodes, vote) = match cmd {
SendVote::Parts(from, to_nodes) => (to_nodes, parts[&from].clone()),
SendVote::Acks(from, to_nodes) => (to_nodes, acks[&from].clone()),
SendVote::AllAcks(from, to_nodes) => (to_nodes, all_acks[&from].clone()),
};
for (to, expt_resp) in to_nodes {
let actual_resp = nodes[to].handle_signed_vote(vote.clone(), &mut rng)?;
assert_eq!(expt_resp.len(), actual_resp.len());
expt_resp
.into_iter()
.zip(actual_resp.into_iter())
.for_each(|(exp, actual)| {
assert!(exp.match_resp(
actual,
&mut acks,
&mut all_acks,
&mut sk_shares,
&mut pk_set,
to
));
})
}
}
assert_eq!(pk_set.len(), 1);
let pk_set = pk_set.into_iter().collect::<Vec<_>>()[0].clone();
let sk_shares: Vec<_> = sk_shares.into_iter().collect();
assert!(verify_threshold(threshold, &sk_shares, &pk_set).is_ok());
fuzz_count -= 1;
}
Ok(())
}
fn fuzz_commands(num_nodes: usize, seed: u64) -> Vec<SendVote> {
let mut rng = StdRng::seed_from_u64(seed);
let mut nodes = MockNode::new(num_nodes);
let resend_probability = Some((1, 5));
let mut active_nodes = MockNode::active_nodes(&nodes);
let mut commands = Vec::new();
while !active_nodes.is_empty() {
let current_node = active_nodes[rng.gen::<usize>() % active_nodes.len()];
let parts = nodes[current_node].can_send_parts(&nodes, resend_probability, &mut rng);
let acks = nodes[current_node].can_send_acks(&nodes, resend_probability, &mut rng);
let all_acks =
nodes[current_node].can_send_all_acks(&nodes, resend_probability, &mut rng);
if parts.is_empty() && acks.is_empty() && all_acks.is_empty() {
continue;
}
let mut done = false;
while !done {
match rng.gen::<usize>() % 3 {
0 if !parts.is_empty() => {
let to_nodes = MockNode::sample_nodes(&parts, &mut rng);
let to_nodes_resp = to_nodes
.into_iter()
.map(|to| {
let mut resp = Vec::new();
if let Some(val) = nodes[to].handled_parts.get(¤t_node) {
if *val {
return (to, resp);
}
}
if let Some(val) =
nodes[to].handled_parts.insert(current_node, true)
{
if nodes[to].parts_done() {
resp.push(MockVoteResponse::BroadcastVote(
MockDkgVote::SingleAck,
));
if nodes[to].acks_done() {
resp.push(MockVoteResponse::BroadcastVote(
MockDkgVote::AllAcks,
));
}
} else {
if !val {
resp.push(MockVoteResponse::WaitingForMoreVotes)
}
}
}
(to, resp)
})
.collect();
commands.push(SendVote::Parts(current_node, to_nodes_resp));
done = true;
}
1 if !acks.is_empty() => {
let to_nodes = MockNode::sample_nodes(&acks, &mut rng);
let to_nodes_resp = to_nodes
.into_iter()
.map(|to| {
let mut resp = Vec::new();
if let Some(val) = nodes[to].handled_acks.get(¤t_node) {
if *val {
return (to, resp);
}
}
let res = nodes[to].handled_acks.insert(current_node, true);
if !nodes[to].parts_done() {
resp.push(MockVoteResponse::RequestAntiEntropy)
} else if let Some(val) = res {
if nodes[to].acks_done() {
resp.push(MockVoteResponse::BroadcastVote(
MockDkgVote::AllAcks,
));
if nodes[to].all_acks_done() {
resp.push(MockVoteResponse::DkgComplete);
}
} else {
if !val {
resp.push(MockVoteResponse::WaitingForMoreVotes)
}
}
};
(to, resp)
})
.collect();
commands.push(SendVote::Acks(current_node, to_nodes_resp));
done = true
}
2 if !all_acks.is_empty() => {
let to_nodes = MockNode::sample_nodes(&all_acks, &mut rng);
let to_nodes_resp = to_nodes
.into_iter()
.map(|to| {
let mut resp = Vec::new();
if let Some(val) = nodes[to].handled_all_acks.get(¤t_node) {
if *val {
return (to, resp);
}
}
let res = nodes[to].handled_all_acks.insert(current_node, true);
if !nodes[to].acks_done() {
resp.push(MockVoteResponse::RequestAntiEntropy);
} else if let Some(val) = res {
if nodes[to].all_acks_done() {
resp.push(MockVoteResponse::DkgComplete)
} else {
if !val {
resp.push(MockVoteResponse::WaitingForMoreVotes)
}
}
};
(to, resp)
})
.collect();
commands.push(SendVote::AllAcks(current_node, to_nodes_resp));
done = true;
}
_ => {}
}
}
active_nodes = MockNode::active_nodes(&nodes);
}
commands
}
fn generate_nodes<R: RngCore>(
num_nodes: usize,
threshold: usize,
mut rng: &mut R,
) -> Result<Vec<DkgState>> {
let secret_keys: Vec<SecretKey> = (0..num_nodes).map(|_| bls::rand::random()).collect();
let pub_keys: BTreeMap<_, _> = secret_keys
.iter()
.enumerate()
.map(|(id, sk)| (id as u8, sk.public_key()))
.collect();
secret_keys
.iter()
.enumerate()
.map(|(id, sk)| {
DkgState::new(id as u8, sk.clone(), pub_keys.clone(), threshold, &mut rng)
.map_err(|err| eyre!("{err}"))
})
.collect()
}
#[derive(Debug)]
enum SendVote {
Parts(usize, Vec<(usize, Vec<MockVoteResponse>)>),
Acks(usize, Vec<(usize, Vec<MockVoteResponse>)>),
AllAcks(usize, Vec<(usize, Vec<MockVoteResponse>)>),
}
#[derive(Debug)]
enum MockVoteResponse {
WaitingForMoreVotes,
BroadcastVote(MockDkgVote),
RequestAntiEntropy,
DkgComplete,
}
impl PartialEq<VoteResponse> for MockVoteResponse {
fn eq(&self, other: &VoteResponse) -> bool {
match self {
MockVoteResponse::WaitingForMoreVotes
if matches!(other, VoteResponse::WaitingForMoreVotes) =>
{
true
}
MockVoteResponse::BroadcastVote(mock_vote) => {
if let VoteResponse::BroadcastVote(signed_vote) = other {
*mock_vote == **signed_vote
} else {
false
}
}
MockVoteResponse::RequestAntiEntropy
if matches!(other, VoteResponse::RequestAntiEntropy) =>
{
true
}
MockVoteResponse::DkgComplete
if matches!(other, VoteResponse::DkgComplete(_, _)) =>
{
true
}
_ => false,
}
}
}
impl MockVoteResponse {
pub fn match_resp(
&self,
actual_resp: VoteResponse,
update_acks: &mut BTreeMap<usize, DkgSignedVote>,
update_all_acks: &mut BTreeMap<usize, DkgSignedVote>,
update_sk: &mut BTreeMap<usize, SecretKeyShare>,
update_pk: &mut BTreeSet<PublicKeySet>,
id: usize,
) -> bool {
if *self == actual_resp {
match actual_resp {
VoteResponse::BroadcastVote(vote) if MockDkgVote::SingleAck == *vote => {
update_acks.insert(id, *vote);
}
VoteResponse::BroadcastVote(vote) if MockDkgVote::AllAcks == *vote => {
update_all_acks.insert(id, *vote);
}
VoteResponse::DkgComplete(pk, sk) => {
update_pk.insert(pk);
update_sk.insert(id, sk);
}
_ => {}
}
true
} else {
false
}
}
}
#[derive(Debug)]
struct MockNode {
id: usize,
handled_parts: BTreeMap<usize, bool>,
handled_acks: BTreeMap<usize, bool>,
handled_all_acks: BTreeMap<usize, bool>,
}
impl MockNode {
pub fn new(num_nodes: usize) -> Vec<MockNode> {
let mut status: BTreeMap<usize, bool> = BTreeMap::new();
(0..num_nodes).for_each(|id| {
let _ = status.insert(id, false);
});
(0..num_nodes)
.map(|id| {
let mut our_status = status.clone();
our_status.insert(id, true);
MockNode {
id,
handled_parts: our_status.clone(),
handled_acks: our_status.clone(),
handled_all_acks: our_status,
}
})
.collect()
}
pub fn can_send_parts<R: RngCore>(
&self,
nodes: &[MockNode],
resend_probability: Option<(u32, u32)>,
rng: &mut R,
) -> Vec<usize> {
nodes
.iter()
.filter_map(|node| {
if !node.handled_parts[&self.id] {
Some(node.id)
} else {
if let Some((num, den)) = resend_probability {
if rng.gen_ratio(num, den) {
Some(node.id)
} else {
None
}
} else {
None
}
}
})
.collect()
}
pub fn can_send_acks<R: RngCore>(
&self,
nodes: &[MockNode],
resend_probability: Option<(u32, u32)>,
rng: &mut R,
) -> Vec<usize> {
if !self.parts_done() {
return Vec::new();
}
nodes
.iter()
.filter_map(|node| {
if !node.handled_acks[&self.id] {
Some(node.id)
} else {
if let Some((num, den)) = resend_probability {
if rng.gen_ratio(num, den) {
Some(node.id)
} else {
None
}
} else {
None
}
}
})
.collect()
}
pub fn can_send_all_acks<R: RngCore>(
&self,
nodes: &[MockNode],
resend_probability: Option<(u32, u32)>,
rng: &mut R,
) -> Vec<usize> {
if !self.parts_done() {
return Vec::new();
}
if !self.acks_done() {
return Vec::new();
}
nodes
.iter()
.filter_map(|node| {
if !node.handled_all_acks[&self.id] {
Some(node.id)
} else if let Some((num, den)) = resend_probability {
if rng.gen_ratio(num, den) {
Some(node.id)
} else {
None
}
} else {
None
}
})
.collect()
}
pub fn parts_done(&self) -> bool {
self.handled_parts
.iter()
.filter(|(&id, _)| id != self.id)
.all(|(_, &val)| val)
}
pub fn acks_done(&self) -> bool {
self.handled_acks
.iter()
.filter(|(&id, _)| id != self.id)
.all(|(_, &val)| val)
}
pub fn all_acks_done(&self) -> bool {
self.handled_all_acks
.iter()
.filter(|(&id, _)| id != self.id)
.all(|(_, &val)| val)
}
pub fn active_nodes(nodes: &[MockNode]) -> Vec<usize> {
let mut active_nodes = BTreeSet::new();
nodes.iter().for_each(|node| {
node.handled_parts.iter().for_each(|(&id, &val)| {
if id != node.id && !val {
active_nodes.insert(id);
};
});
node.handled_acks.iter().for_each(|(&id, &val)| {
if id != node.id && !val {
active_nodes.insert(id);
};
});
node.handled_all_acks.iter().for_each(|(&id, &val)| {
if id != node.id && !val {
active_nodes.insert(id);
};
});
});
active_nodes.into_iter().collect()
}
pub fn sample_nodes<R: RngCore>(nodes: &Vec<usize>, rng: &mut R) -> Vec<usize> {
let sample_n_nodes = (rng.gen::<usize>() % nodes.len()) + 1;
nodes.iter().cloned().choose_multiple(rng, sample_n_nodes)
}
}
}