#![deny(unused_must_use)]
use std::collections::{BTreeMap, BTreeSet};
use super::{dkg_threshold, KeyGen, PartOutcome};
use crate::mock::PeerId;
use crate::{
dev_utils::{Environment, RngChoice},
serialise,
};
use rand::Rng;
static SEED: RngChoice = RngChoice::Random;
fn test_key_gen_with(threshold: usize, node_num: usize) {
let mut env = Environment::new(SEED);
let peer_ids: Vec<PeerId> = (0..node_num)
.map(|idx| unwrap!(PeerId::from_index(idx)))
.collect();
let pub_keys: BTreeSet<PeerId> = peer_ids.iter().cloned().collect();
let mut nodes = Vec::new();
let mut proposals = Vec::new();
peer_ids.iter().for_each(|peer_id| {
let (key_gen, proposal) = KeyGen::new(peer_id, pub_keys.clone(), threshold, &mut env.rng)
.unwrap_or_else(|_err| panic!("Failed to create `KeyGen` instance {:?}", &peer_id));
nodes.push(key_gen);
proposals.push(proposal);
});
let mut acks = Vec::new();
for (sender_id, proposal) in proposals[..=threshold].iter().enumerate() {
for (node_id, node) in nodes.iter_mut().enumerate() {
let proposal = proposal.clone().expect("proposal");
let ack = match node
.handle_part(&peer_ids[node_id], &peer_ids[sender_id], proposal)
.expect("failed to handle part")
{
PartOutcome::Valid(Some(ack)) => ack,
PartOutcome::Valid(None) => panic!("missing ack message"),
PartOutcome::Invalid(fault) => panic!("invalid proposal: {:?}", fault),
};
if node_id <= 2 * threshold {
acks.push((node_id, ack));
}
}
}
for (sender_id, ack) in acks {
for (node_id, node) in nodes.iter_mut().enumerate() {
assert!(!node.is_ready()); let _ = node
.handle_ack(&peer_ids[node_id], &peer_ids[sender_id], ack.clone())
.expect("error handling ack");
}
}
let msg = "Help I'm trapped in a unit test factory";
let pub_key_set = nodes[0]
.generate()
.expect("Failed to generate `PublicKeySet` for node #0")
.1
.public_key_set;
let sig_shares: BTreeMap<_, _> = nodes
.iter()
.enumerate()
.map(|(idx, node)| {
assert!(node.is_ready());
let dkg_result = node
.generate()
.unwrap_or_else(|_| {
panic!(
"Failed to generate `PublicKeySet` and `SecretKeyShare` for node #{}",
idx
)
})
.1;
let sk = dkg_result.secret_key_share.expect("new secret key");
let pks = dkg_result.public_key_set;
assert_eq!(pks, pub_key_set);
let sig = sk.sign(msg);
assert!(pks.public_key_share(idx).verify(&sig, msg));
(idx, sig)
})
.collect();
let sig = pub_key_set
.combine_signatures(sig_shares.iter().take(threshold + 1))
.expect("signature shares match");
assert!(pub_key_set.public_key().verify(&sig, msg));
let sig2_start_idx = env.rng.gen_range(1, std::cmp::max(2, sig_shares.len()));
let sig2 = pub_key_set
.combine_signatures(
sig_shares
.iter()
.cycle()
.skip(sig2_start_idx)
.take(threshold + 1),
)
.expect("signature shares match");
assert!(pub_key_set.public_key().verify(&sig2, msg));
let sig_ser = serialise(&sig);
let sig2_ser = serialise(&sig2);
assert_eq!(sig_ser, sig2_ser);
}
fn test_key_gen(node_num: usize) {
test_key_gen_with(dkg_threshold(node_num), node_num);
}
#[test]
fn test_key_gen_1() {
test_key_gen(1);
}
#[test]
fn test_key_gen_2() {
test_key_gen(2);
}
#[test]
fn test_key_gen_3() {
test_key_gen(3);
}
#[test]
fn test_key_gen_4() {
test_key_gen(4);
}
#[test]
fn test_key_gen_8() {
test_key_gen(8);
}
#[test]
fn test_key_gen_15() {
test_key_gen(15);
}