use entropy_protocol::{KeyParams, PartyId, SessionId, SigningSessionInfo, ValidatorInfo};
use futures::future;
use rand_core::OsRng;
use serial_test::serial;
use sp_core::{sr25519, Pair};
use std::{cmp::min, time::Instant};
use subxt::utils::AccountId32;
use synedrion::{ecdsa::VerifyingKey, AuxInfo, KeyShare, ThresholdKeyShare};
use tokio::{net::TcpListener, runtime::Runtime, sync::oneshot};
use x25519_dalek::StaticSecret;
mod helpers;
use helpers::{server, ProtocolOutput};
use std::collections::BTreeSet;
const MAX_THREADS: usize = 16;
#[test]
#[serial]
fn sign_protocol_with_time_logged() {
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_sign_with_parties(num_parties).await;
})
}
#[test]
#[serial]
fn refresh_protocol_with_time_logged() {
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_refresh_with_parties(num_parties).await;
})
}
#[test]
#[serial]
fn dkg_protocol_with_time_logged() {
let num_parties = min(num_cpus::get(), MAX_THREADS);
get_tokio_runtime(num_parties).block_on(async {
test_dkg_with_parties(num_parties).await;
})
}
#[test]
#[serial]
fn t_of_n_dkg_and_sign() {
let cpus = min(num_cpus::get(), MAX_THREADS);
let parties = 3;
get_tokio_runtime(cpus).block_on(async {
test_dkg_and_sign_with_parties(parties).await;
})
}
async fn test_sign_with_parties(num_parties: usize) {
let (pairs, ids) = get_keypairs_and_ids(num_parties);
let keyshares = KeyShare::<KeyParams, PartyId>::new_centralized(&mut OsRng, &ids, None);
let aux_infos = AuxInfo::<KeyParams, PartyId>::new_centralized(&mut OsRng, &ids);
let verifying_key = keyshares[&PartyId::from(pairs[0].public())].verifying_key().unwrap();
let parties: Vec<_> = pairs
.iter()
.map(|pair| ValidatorSecretInfo {
pair: pair.clone(),
keyshare: Some(keyshares[&PartyId::from(pair.public())].clone()),
threshold_keyshare: None,
aux_info: Some(aux_infos[&PartyId::from(pair.public())].clone()),
})
.collect();
let message_hash = [0u8; 32];
let session_id = SessionId::Sign(SigningSessionInfo {
signature_verifying_key: verifying_key.to_encoded_point(true).as_bytes().to_vec(),
message_hash,
request_author: AccountId32([0u8; 32]),
});
let threshold = parties.len();
let mut outputs = test_protocol_with_parties(parties, session_id, threshold).await;
if let ProtocolOutput::Sign(recoverable_signature) = outputs.pop().unwrap() {
let recovery_key_from_sig = VerifyingKey::recover_from_prehash(
&message_hash,
&recoverable_signature.signature,
recoverable_signature.recovery_id,
)
.unwrap();
assert_eq!(verifying_key, recovery_key_from_sig);
} else {
panic!("Unexpected protocol output");
}
}
async fn test_refresh_with_parties(num_parties: usize) {
let (pairs, ids) = get_keypairs_and_ids(num_parties);
let keyshares = KeyShare::<KeyParams, PartyId>::new_centralized(&mut OsRng, &ids, None);
let verifying_key = keyshares[&PartyId::from(pairs[0].public())].verifying_key().unwrap();
let session_id = SessionId::Reshare {
verifying_key: verifying_key.to_encoded_point(true).as_bytes().to_vec(),
block_number: 0,
};
let parties: Vec<_> = pairs
.iter()
.map(|pair| ValidatorSecretInfo {
pair: pair.clone(),
keyshare: None,
threshold_keyshare: Some(ThresholdKeyShare::from_key_share(
&keyshares[&PartyId::from(pair.public())],
)),
aux_info: None,
})
.collect();
let threshold = parties.len();
let mut outputs = test_protocol_with_parties(parties, session_id, threshold).await;
if let ProtocolOutput::Reshare(keyshare) = outputs.pop().unwrap() {
assert!(keyshare.verifying_key() == verifying_key);
} else {
panic!("Unexpected protocol output");
}
}
async fn test_dkg_with_parties(num_parties: usize) {
let (pairs, _ids) = get_keypairs_and_ids(num_parties);
let parties: Vec<_> =
pairs.iter().map(|pair| ValidatorSecretInfo::pair_only(pair.clone())).collect();
let threshold = parties.len();
let session_id = SessionId::Dkg { block_number: 0 };
let mut outputs = test_protocol_with_parties(parties, session_id, threshold).await;
if let ProtocolOutput::Dkg(_keyshare) = outputs.pop().unwrap() {
} else {
panic!("Unexpected protocol output");
}
}
async fn test_dkg_and_sign_with_parties(num_parties: usize) {
let threshold = num_parties - 1;
if threshold < 2 {
panic!("Not enought parties to test threshold signing");
}
let (pairs, ids) = get_keypairs_and_ids(num_parties);
let dkg_parties =
pairs.iter().map(|pair| ValidatorSecretInfo::pair_only(pair.clone())).collect();
let session_id = SessionId::Dkg { block_number: 0 };
let outputs = test_protocol_with_parties(dkg_parties, session_id, threshold).await;
let signing_committee = (0..threshold)
.into_iter()
.map(|i| pairs[i].clone())
.map(|pair| ids.get(&PartyId::from(pair.public())).unwrap())
.cloned()
.collect::<BTreeSet<_>>();
let parties: Vec<ValidatorSecretInfo> = outputs
.clone()
.into_iter()
.filter_map(|output| {
if let ProtocolOutput::Dkg((threshold_keyshare, aux_info)) = output {
let keyshare = threshold_keyshare.to_key_share(&signing_committee);
if signing_committee.contains(keyshare.owner()) {
let pair = pairs
.iter()
.find(|p| keyshare.owner() == &PartyId::new(AccountId32(p.public().0)))
.unwrap();
Some(ValidatorSecretInfo {
pair: pair.clone(),
keyshare: Some(keyshare),
threshold_keyshare: None,
aux_info: Some(aux_info),
})
} else {
None
}
} else {
panic!("Unexpected protocol output");
}
})
.collect();
let verifying_key = parties[0].keyshare.clone().unwrap().verifying_key().unwrap();
let message_hash = [0u8; 32];
let session_id = SessionId::Sign(SigningSessionInfo {
signature_verifying_key: verifying_key.to_encoded_point(true).as_bytes().to_vec(),
message_hash,
request_author: AccountId32([0u8; 32]),
});
let mut outputs =
test_protocol_with_parties(parties[..threshold].to_vec(), session_id, threshold).await;
if let ProtocolOutput::Sign(recoverable_signature) = outputs.pop().unwrap() {
let recovery_key_from_sig = VerifyingKey::recover_from_prehash(
&message_hash,
&recoverable_signature.signature,
recoverable_signature.recovery_id,
)
.unwrap();
assert_eq!(verifying_key, recovery_key_from_sig);
} else {
panic!("Unexpected protocol output");
}
}
async fn test_protocol_with_parties(
parties: Vec<ValidatorSecretInfo>,
session_id: SessionId,
threshold: usize,
) -> Vec<ProtocolOutput> {
let mut validator_secrets = Vec::new();
let mut validators_info = Vec::new();
for i in 0..parties.len() {
let socket = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = socket.local_addr().unwrap();
let x25519_secret_key = StaticSecret::random_from_rng(OsRng);
let x25519_public_key = x25519_dalek::PublicKey::from(&x25519_secret_key).to_bytes();
validator_secrets.push(ValidatorSecretInfoWithSocket::new(
parties[i].clone(),
x25519_secret_key,
socket,
));
validators_info.push(ValidatorInfo {
tss_account: AccountId32(parties[i].pair.public().0),
x25519_public_key,
ip_address: addr.to_string(),
})
}
let now = Instant::now();
let mut results_rx = Vec::new();
for _ in 0..parties.len() {
let (tx, rx) = oneshot::channel();
results_rx.push(rx);
let secret = validator_secrets.pop().unwrap();
let validators_info_clone = validators_info.clone();
let session_id_clone = session_id.clone();
tokio::spawn(async move {
let result = server(
secret.socket,
validators_info_clone,
secret.pair,
secret.x25519_secret_key,
session_id_clone,
secret.keyshare,
secret.threshold_keyshare,
secret.aux_info,
threshold,
)
.await;
if !tx.is_closed() {
tx.send(result).unwrap();
}
});
}
let results =
future::join_all(results_rx).await.into_iter().map(|r| r.unwrap().unwrap()).collect();
println!("Got protocol results with {} parties in {:?}", parties.len(), now.elapsed());
results
}
#[derive(Clone)]
struct ValidatorSecretInfo {
pair: sr25519::Pair,
keyshare: Option<KeyShare<KeyParams, PartyId>>,
threshold_keyshare: Option<ThresholdKeyShare<KeyParams, PartyId>>,
aux_info: Option<AuxInfo<KeyParams, PartyId>>,
}
impl ValidatorSecretInfo {
fn pair_only(pair: sr25519::Pair) -> Self {
ValidatorSecretInfo { pair, keyshare: None, threshold_keyshare: None, aux_info: None }
}
}
struct ValidatorSecretInfoWithSocket {
pair: sr25519::Pair,
keyshare: Option<KeyShare<KeyParams, PartyId>>,
threshold_keyshare: Option<ThresholdKeyShare<KeyParams, PartyId>>,
aux_info: Option<AuxInfo<KeyParams, PartyId>>,
x25519_secret_key: StaticSecret,
socket: TcpListener,
}
impl ValidatorSecretInfoWithSocket {
fn new(
secret_info: ValidatorSecretInfo,
x25519_secret_key: StaticSecret,
socket: TcpListener,
) -> Self {
Self {
pair: secret_info.pair,
keyshare: secret_info.keyshare,
threshold_keyshare: secret_info.threshold_keyshare,
aux_info: secret_info.aux_info,
x25519_secret_key,
socket,
}
}
}
fn get_tokio_runtime(num_cpus: usize) -> Runtime {
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_cpus)
.enable_all()
.build()
.unwrap()
}
fn get_keypairs_and_ids(num_parties: usize) -> (Vec<sr25519::Pair>, BTreeSet<PartyId>) {
let pairs = (0..num_parties).map(|_| sr25519::Pair::generate().0).collect::<Vec<_>>();
let ids = pairs
.iter()
.map(|pair| PartyId::new(AccountId32(pair.public().0)))
.collect::<BTreeSet<_>>();
(pairs, ids)
}