use hashbrown::{HashMap, HashSet};
use num_traits::{One, Zero};
use polynomial::Polynomial;
use rand_core::{CryptoRng, RngCore};
use tracing::warn;
use crate::{
common::{check_public_shares, Nonce, PolyCommitment, PublicNonce, Signature, SignatureShare},
compute,
curve::{
point::{Point, G},
scalar::Scalar,
},
errors::{AggregatorError, DkgError},
schnorr::ID,
taproot::SchnorrProof,
traits,
vss::VSS,
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Party {
pub party_id: u32,
pub key_ids: Vec<u32>,
num_keys: u32,
num_parties: u32,
threshold: u32,
f: Option<Polynomial<Scalar>>,
private_keys: HashMap<u32, Scalar>,
group_key: Point,
nonce: Nonce,
}
impl Party {
pub fn new<RNG: RngCore + CryptoRng>(
party_id: u32,
key_ids: &[u32],
num_parties: u32,
num_keys: u32,
threshold: u32,
rng: &mut RNG,
) -> Self {
Self {
party_id,
key_ids: key_ids.to_vec(),
num_keys,
num_parties,
threshold,
f: Some(VSS::random_poly(threshold - 1, rng)),
private_keys: Default::default(),
group_key: Point::zero(),
nonce: Nonce::zero(),
}
}
pub fn gen_nonce<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) -> PublicNonce {
self.nonce = Nonce::random(rng);
PublicNonce::from(&self.nonce)
}
pub fn get_poly_commitment<RNG: RngCore + CryptoRng>(
&self,
rng: &mut RNG,
) -> Option<PolyCommitment> {
if let Some(poly) = &self.f {
Some(PolyCommitment {
id: ID::new(&self.id(), &poly.data()[0], rng),
poly: (0..poly.data().len())
.map(|i| &poly.data()[i] * G)
.collect(),
})
} else {
warn!("get_poly_commitment called with no polynomial");
None
}
}
pub fn get_shares(&self) -> HashMap<u32, Scalar> {
let mut shares = HashMap::new();
if let Some(poly) = &self.f {
for i in 1..self.num_keys + 1 {
shares.insert(i, poly.eval(compute::id(i)));
}
} else {
warn!("get_poly_commitment called with no polynomial");
}
shares
}
pub fn compute_secret(
&mut self,
private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
public_shares: &HashMap<u32, PolyCommitment>,
) -> Result<(), DkgError> {
self.private_keys.clear();
self.group_key = Point::zero();
let threshold: usize = self.threshold.try_into()?;
let mut bad_ids = Vec::new();
for (i, comm) in public_shares.iter() {
if !check_public_shares(comm, threshold) {
bad_ids.push(*i);
} else {
self.group_key += comm.poly[0];
}
}
if !bad_ids.is_empty() {
return Err(DkgError::BadPublicShares(bad_ids));
}
let mut missing_shares = Vec::new();
for dst_key_id in &self.key_ids {
for src_key_id in public_shares.keys() {
match private_shares.get(dst_key_id) {
Some(shares) => {
if shares.get(src_key_id).is_none() {
missing_shares.push((*dst_key_id, *src_key_id));
}
}
None => {
missing_shares.push((*dst_key_id, *src_key_id));
}
}
}
}
if !missing_shares.is_empty() {
return Err(DkgError::MissingPrivateShares(missing_shares));
}
let mut bad_shares = Vec::new();
for key_id in &self.key_ids {
if let Some(shares) = private_shares.get(key_id) {
for (sender, s) in shares {
if let Some(comm) = public_shares.get(sender) {
if s * G != compute::poly(&compute::id(*key_id), &comm.poly)? {
bad_shares.push(*sender);
}
} else {
warn!("unable to check private share from {}: no corresponding public share, even though we checked for it above", sender);
}
}
} else {
warn!(
"no private shares for key_id {}, even though we checked for it above",
key_id
);
}
}
if !bad_shares.is_empty() {
return Err(DkgError::BadPrivateShares(bad_shares));
}
for key_id in &self.key_ids {
self.private_keys.insert(*key_id, Scalar::zero());
if let Some(shares) = private_shares.get(key_id) {
let secret = shares.values().sum();
self.private_keys.insert(*key_id, secret);
} else {
warn!(
"no private shares for key_id {}, even though we checked for it above",
key_id
);
}
}
Ok(())
}
pub fn id(&self) -> Scalar {
compute::id(self.party_id)
}
pub fn sign(
&self,
msg: &[u8],
party_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> SignatureShare {
self.sign_with_tweak(msg, party_ids, key_ids, nonces, None)
}
#[allow(non_snake_case)]
pub fn sign_with_tweak(
&self,
msg: &[u8],
party_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
tweak: Option<Scalar>,
) -> SignatureShare {
let mut cx_sign = Scalar::one();
let tweaked_public_key = if let Some(t) = tweak {
if t != Scalar::zero() {
let key = compute::tweaked_public_key_from_tweak(&self.group_key, t);
if key.has_even_y() ^ self.group_key.has_even_y() {
cx_sign = -cx_sign;
}
key
} else {
if !self.group_key.has_even_y() {
cx_sign = -cx_sign;
}
self.group_key
}
} else {
self.group_key
};
let (_, R) = compute::intermediate(msg, party_ids, nonces);
let c = compute::challenge(&tweaked_public_key, &R, msg);
let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
if tweak.is_some() && !R.has_even_y() {
r = -r;
}
let mut cx = Scalar::zero();
for key_id in self.key_ids.iter() {
cx += c * &self.private_keys[key_id] * compute::lambda(*key_id, key_ids);
}
cx = cx_sign * cx;
let z = r + cx;
SignatureShare {
id: self.party_id,
z_i: z,
key_ids: self.key_ids.clone(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Aggregator {
pub num_keys: u32,
pub threshold: u32,
pub poly: Vec<Point>,
}
impl Aggregator {
#[allow(non_snake_case)]
pub fn sign_with_tweak(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
_key_ids: &[u32],
tweak: Option<Scalar>,
) -> Result<(Point, Signature), AggregatorError> {
if nonces.len() != sig_shares.len() {
return Err(AggregatorError::BadNonceLen(nonces.len(), sig_shares.len()));
}
let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (_Rs, R) = compute::intermediate(msg, &party_ids, nonces);
let mut z = Scalar::zero();
let mut cx_sign = Scalar::one();
let aggregate_public_key = self.poly[0];
let tweaked_public_key = if let Some(t) = tweak {
if t != Scalar::zero() {
let key = compute::tweaked_public_key_from_tweak(&aggregate_public_key, t);
if !key.has_even_y() {
cx_sign = -cx_sign;
}
key
} else {
aggregate_public_key
}
} else {
aggregate_public_key
};
let c = compute::challenge(&tweaked_public_key, &R, msg);
for sig_share in sig_shares {
z += sig_share.z_i;
}
if let Some(t) = tweak {
z += cx_sign * c * t;
}
let sig = Signature { R, z };
Ok((tweaked_public_key, sig))
}
#[allow(non_snake_case)]
pub fn check_signature_shares(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
key_ids: &[u32],
tweak: Option<Scalar>,
) -> AggregatorError {
if nonces.len() != sig_shares.len() {
return AggregatorError::BadNonceLen(nonces.len(), sig_shares.len());
}
let party_ids: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (Rs, R) = compute::intermediate(msg, &party_ids, nonces);
let mut bad_party_keys = Vec::new();
let mut bad_party_sigs = Vec::new();
let aggregate_public_key = self.poly[0];
let tweaked_public_key = if let Some(t) = tweak {
if t != Scalar::zero() {
compute::tweaked_public_key_from_tweak(&aggregate_public_key, t)
} else {
aggregate_public_key
}
} else {
aggregate_public_key
};
let c = compute::challenge(&tweaked_public_key, &R, msg);
let mut r_sign = Scalar::one();
let mut cx_sign = Scalar::one();
if let Some(t) = tweak {
if !R.has_even_y() {
r_sign = -Scalar::one();
}
if t != Scalar::zero() {
if !tweaked_public_key.has_even_y() ^ !aggregate_public_key.has_even_y() {
cx_sign = -Scalar::one();
}
} else if !aggregate_public_key.has_even_y() {
cx_sign = -Scalar::one();
}
}
for i in 0..sig_shares.len() {
let z_i = sig_shares[i].z_i;
let mut cx = Point::zero();
for key_id in &sig_shares[i].key_ids {
let kid = compute::id(*key_id);
let public_key = match compute::poly(&kid, &self.poly) {
Ok(p) => p,
Err(_) => {
bad_party_keys.push(sig_shares[i].id);
Point::zero()
}
};
cx += compute::lambda(*key_id, key_ids) * c * public_key;
}
if z_i * G != (r_sign * Rs[i] + cx_sign * cx) {
bad_party_sigs.push(sig_shares[i].id);
}
}
if !bad_party_keys.is_empty() {
AggregatorError::BadPartyKeys(bad_party_keys)
} else if !bad_party_sigs.is_empty() {
AggregatorError::BadPartySigs(bad_party_sigs)
} else {
AggregatorError::BadGroupSig
}
}
}
impl traits::Aggregator for Aggregator {
fn new(num_keys: u32, threshold: u32) -> Self {
Self {
num_keys,
threshold,
poly: Default::default(),
}
}
fn init(&mut self, comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError> {
let threshold: usize = self.threshold.try_into()?;
let mut poly = Vec::with_capacity(threshold);
for i in 0..poly.capacity() {
poly.push(Point::zero());
for (_, comm) in comms {
poly[i] += &comm.poly[i];
}
}
self.poly = poly;
Ok(())
}
fn sign(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
key_ids: &[u32],
) -> Result<Signature, AggregatorError> {
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, None)?;
if sig.verify(&key, msg) {
Ok(sig)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, None))
}
}
fn sign_schnorr(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
key_ids: &[u32],
) -> Result<SchnorrProof, AggregatorError> {
let tweak = Scalar::from(0);
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
let proof = SchnorrProof::new(&sig);
if proof.verify(&key.x(), msg) {
Ok(proof)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
}
}
fn sign_taproot(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
key_ids: &[u32],
merkle_root: Option<[u8; 32]>,
) -> Result<SchnorrProof, AggregatorError> {
let tweak = compute::tweak(&self.poly[0], merkle_root);
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, key_ids, Some(tweak))?;
let proof = SchnorrProof::new(&sig);
if proof.verify(&key.x(), msg) {
Ok(proof)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, key_ids, Some(tweak)))
}
}
}
pub type Signer = Party;
impl traits::Signer for Party {
fn new<RNG: RngCore + CryptoRng>(
party_id: u32,
key_ids: &[u32],
num_signers: u32,
num_keys: u32,
threshold: u32,
rng: &mut RNG,
) -> Self {
Party::new(party_id, key_ids, num_signers, num_keys, threshold, rng)
}
fn load(state: &traits::SignerState) -> Self {
assert_eq!(state.parties.len(), 1);
let party_state = &state.parties[0].1;
Self {
party_id: state.id,
key_ids: state.key_ids.clone(),
num_keys: state.num_keys,
num_parties: state.num_parties,
threshold: state.threshold,
f: party_state.polynomial.clone(),
private_keys: party_state
.private_keys
.iter()
.map(|(k, v)| (*k, *v))
.collect(),
group_key: state.group_key,
nonce: party_state.nonce.clone(),
}
}
fn save(&self) -> traits::SignerState {
let party_state = traits::PartyState {
polynomial: self.f.clone(),
private_keys: self.private_keys.iter().map(|(k, v)| (*k, *v)).collect(),
nonce: self.nonce.clone(),
};
traits::SignerState {
id: self.party_id,
key_ids: self.key_ids.clone(),
num_keys: self.num_keys,
num_parties: self.num_parties,
threshold: self.threshold,
group_key: self.group_key,
parties: vec![(self.party_id, party_state)],
}
}
fn get_id(&self) -> u32 {
self.party_id
}
fn get_key_ids(&self) -> Vec<u32> {
self.key_ids.clone()
}
fn get_num_parties(&self) -> u32 {
self.num_parties
}
fn get_poly_commitments<RNG: RngCore + CryptoRng>(&self, rng: &mut RNG) -> Vec<PolyCommitment> {
if let Some(poly) = self.get_poly_commitment(rng) {
vec![poly.clone()]
} else {
vec![]
}
}
fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) {
self.f = Some(VSS::random_poly(self.threshold - 1, rng));
}
fn clear_polys(&mut self) {
self.f = None;
}
fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>> {
let mut shares = HashMap::new();
shares.insert(self.party_id, self.get_shares());
shares
}
fn compute_secrets(
&mut self,
private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
polys: &HashMap<u32, PolyCommitment>,
) -> Result<(), HashMap<u32, DkgError>> {
let mut key_shares = HashMap::new();
for dest_key_id in self.get_key_ids() {
let mut shares = HashMap::new();
for (src_party_id, signer_shares) in private_shares.iter() {
if let Some(signer_share) = signer_shares.get(&dest_key_id) {
shares.insert(*src_party_id, *signer_share);
}
}
key_shares.insert(dest_key_id, shares);
}
match self.compute_secret(&key_shares, polys) {
Ok(()) => Ok(()),
Err(dkg_error) => {
let mut dkg_errors = HashMap::new();
dkg_errors.insert(self.party_id, dkg_error);
Err(dkg_errors)
}
}
}
fn gen_nonces<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) -> Vec<PublicNonce> {
vec![self.gen_nonce(rng)]
}
fn compute_intermediate(
msg: &[u8],
signer_ids: &[u32],
_key_ids: &[u32],
nonces: &[PublicNonce],
) -> (Vec<Point>, Point) {
compute::intermediate(msg, signer_ids, nonces)
}
fn validate_party_id(
signer_id: u32,
party_id: u32,
_signer_key_ids: &HashMap<u32, HashSet<u32>>,
) -> bool {
signer_id == party_id
}
fn sign(
&self,
msg: &[u8],
signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
vec![self.sign(msg, signer_ids, key_ids, nonces)]
}
fn sign_schnorr(
&self,
msg: &[u8],
signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(Scalar::from(0)))]
}
fn sign_taproot(
&self,
msg: &[u8],
signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
merkle_root: Option<[u8; 32]>,
) -> Vec<SignatureShare> {
let tweak = compute::tweak(&self.group_key, merkle_root);
vec![self.sign_with_tweak(msg, signer_ids, key_ids, nonces, Some(tweak))]
}
}
pub mod test_helpers {
use crate::common::{PolyCommitment, PublicNonce};
use crate::errors::DkgError;
use crate::traits::Signer;
use crate::v2;
use crate::v2::SignatureShare;
use hashbrown::HashMap;
use rand_core::{CryptoRng, RngCore};
pub fn dkg<RNG: RngCore + CryptoRng>(
signers: &mut [v2::Party],
rng: &mut RNG,
) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
let mut polys: HashMap<u32, PolyCommitment> = Default::default();
for signer in signers.iter() {
if let Some(poly) = signer.get_poly_commitment(rng) {
polys.insert(signer.get_id(), poly);
}
}
let mut broadcast_shares = Vec::new();
for party in signers.iter() {
broadcast_shares.push((party.party_id, party.get_shares()));
}
let mut secret_errors = HashMap::new();
for party in signers.iter_mut() {
let mut party_shares = HashMap::new();
for key_id in party.key_ids.clone() {
let mut key_shares = HashMap::new();
for (id, shares) in &broadcast_shares {
if let Some(share) = shares.get(&key_id) {
key_shares.insert(*id, *share);
}
}
party_shares.insert(key_id, key_shares);
}
if let Err(secret_error) = party.compute_secret(&party_shares, &polys) {
secret_errors.insert(party.party_id, secret_error);
}
}
if secret_errors.is_empty() {
Ok(polys)
} else {
Err(secret_errors)
}
}
pub fn sign<RNG: RngCore + CryptoRng>(
msg: &[u8],
signers: &mut [v2::Party],
rng: &mut RNG,
) -> (Vec<PublicNonce>, Vec<SignatureShare>, Vec<u32>) {
let party_ids: Vec<u32> = signers.iter().map(|s| s.party_id).collect();
let key_ids: Vec<u32> = signers.iter().flat_map(|s| s.key_ids.clone()).collect();
let nonces: Vec<PublicNonce> = signers.iter_mut().map(|s| s.gen_nonce(rng)).collect();
let shares = signers
.iter()
.map(|s| s.sign(msg, &party_ids, &key_ids, &nonces))
.collect();
(nonces, shares, key_ids)
}
}
#[cfg(test)]
mod tests {
use hashbrown::{HashMap, HashSet};
use crate::util::create_rng;
use crate::{
traits::{
self, test_helpers::run_compute_secrets_missing_private_shares, Aggregator, Signer,
},
v2,
};
#[test]
fn party_save_load() {
let mut rng = create_rng();
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
let state = signer.save();
let loaded = v2::Party::load(&state);
assert_eq!(signer, loaded);
}
#[test]
fn clear_polys() {
let mut rng = create_rng();
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let mut signer = v2::Party::new(0, &key_ids, 1, n, t, &mut rng);
assert_eq!(signer.get_poly_commitments(&mut rng).len(), 1);
assert_eq!(signer.get_shares().len(), usize::try_from(n).unwrap());
signer.clear_polys();
assert_eq!(signer.get_poly_commitments(&mut rng).len(), 0);
assert_eq!(signer.get_shares().len(), 0);
}
#[test]
fn aggregator_sign() {
let mut rng = create_rng();
let msg = "It was many and many a year ago".as_bytes();
let n_k: u32 = 10;
let t: u32 = 7;
let party_key_ids: Vec<Vec<u32>> = [
[1, 2, 3].to_vec(),
[4, 5].to_vec(),
[6, 7, 8].to_vec(),
[9, 10].to_vec(),
]
.to_vec();
let n_p = party_key_ids.len().try_into().unwrap();
let mut signers: Vec<v2::Party> = party_key_ids
.iter()
.enumerate()
.map(|(pid, pkids)| {
v2::Party::new(pid.try_into().unwrap(), pkids, n_p, n_k, t, &mut rng)
})
.collect();
let comms = match traits::test_helpers::dkg(&mut signers, &mut rng) {
Ok(comms) => comms,
Err(secret_errors) => {
panic!("Got secret errors from DKG: {:?}", secret_errors);
}
};
{
let mut signers = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
let mut sig_agg = v2::Aggregator::new(n_k, t);
sig_agg.init(&comms).expect("aggregator init failed");
let (nonces, sig_shares, key_ids) = v2::test_helpers::sign(msg, &mut signers, &mut rng);
if let Err(e) = sig_agg.sign(msg, &nonces, &sig_shares, &key_ids) {
panic!("Aggregator sign failed: {:?}", e);
}
}
}
#[test]
pub fn run_compute_secrets_missing_shares() {
run_compute_secrets_missing_private_shares::<v2::Signer>()
}
#[test]
pub fn bad_polynomial_length() {
let gt = |t| t + 1;
let lt = |t| t - 1;
traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(gt);
traits::test_helpers::bad_polynomial_length::<v2::Signer, _>(lt);
}
#[test]
pub fn bad_polynomial_commitment() {
traits::test_helpers::bad_polynomial_commitment::<v2::Signer>();
}
#[test]
fn validate_party_id() {
let mut signer_key_ids = HashMap::new();
let mut key_ids = HashSet::new();
key_ids.insert(1);
signer_key_ids.insert(0, key_ids);
assert!(v2::Signer::validate_party_id(0, 0, &signer_key_ids));
assert!(!v2::Signer::validate_party_id(0, 1, &signer_key_ids));
}
}