use hashbrown::{HashMap, HashSet};
use num_traits::{One, Zero};
use polynomial::Polynomial;
use rand_core::{CryptoRng, RngCore};
use tracing::warn;
use crate::{
common::{
check_public_shares, CheckPrivateShares, Nonce, PolyCommitment, PublicNonce, Signature,
SignatureShare,
},
compute,
curve::{
point::{Point, G},
scalar::Scalar,
},
errors::{AggregatorError, DkgError},
schnorr::ID,
taproot::SchnorrProof,
traits,
vss::VSS,
};
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Party {
pub id: u32,
pub public_key: Point,
pub f: Option<Polynomial<Scalar>>,
num_keys: u32,
threshold: u32,
private_key: Scalar,
pub group_key: Point,
nonce: Nonce,
}
impl Party {
pub fn new<RNG: RngCore + CryptoRng>(id: u32, n: u32, t: u32, rng: &mut RNG) -> Self {
Self {
id,
num_keys: n,
threshold: t,
f: Some(VSS::random_poly(t - 1, rng)),
private_key: Scalar::zero(),
public_key: Point::zero(),
group_key: Point::zero(),
nonce: Nonce::zero(),
}
}
pub fn load(id: u32, n: u32, t: u32, group_key: &Point, state: &traits::PartyState) -> Self {
assert_eq!(state.private_keys.len(), 1);
assert_eq!(state.private_keys[0].0, id);
let private_key = state.private_keys[0].1;
Self {
id,
num_keys: n,
threshold: t,
f: state.polynomial.clone(),
public_key: private_key * G,
private_key,
group_key: *group_key,
nonce: state.nonce.clone(),
}
}
pub fn save(&self) -> traits::PartyState {
traits::PartyState {
private_keys: vec![(self.id, self.private_key)],
polynomial: self.f.clone(),
nonce: self.nonce.clone(),
}
}
pub fn gen_nonce<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) -> PublicNonce {
self.nonce = Nonce::random(rng);
PublicNonce::from(&self.nonce)
}
pub fn get_poly_commitment<RNG: RngCore + CryptoRng>(
&self,
rng: &mut RNG,
) -> Option<PolyCommitment> {
if let Some(poly) = &self.f {
Some(PolyCommitment {
id: ID::new(&self.id(), &poly.data()[0], rng),
poly: (0..poly.data().len())
.map(|i| &poly.data()[i] * G)
.collect(),
})
} else {
warn!("get_poly_commitment called with no polynomial");
None
}
}
pub fn reset_poly<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) {
self.f = Some(VSS::random_poly(self.threshold - 1, rng));
}
pub fn clear_poly(&mut self) {
self.f = None;
}
pub fn get_shares(&self) -> HashMap<u32, Scalar> {
if let Some(poly) = &self.f {
let mut shares = HashMap::new();
for i in 1..self.num_keys + 1 {
shares.insert(i, poly.eval(compute::id(i)));
}
shares
} else {
warn!("get_shares called with no polynomial");
Default::default()
}
}
pub fn compute_secret(
&mut self,
private_shares: HashMap<u32, Scalar>,
public_shares: &HashMap<u32, PolyCommitment>,
) -> Result<(), DkgError> {
self.private_key = Scalar::zero();
self.group_key = Point::zero();
let threshold: usize = self.threshold.try_into()?;
let mut bad_ids = Vec::new(); for (i, comm) in public_shares.iter() {
if !check_public_shares(comm, threshold) {
bad_ids.push(*i);
} else {
self.group_key += comm.poly[0];
}
}
if !bad_ids.is_empty() {
return Err(DkgError::BadPublicShares(bad_ids));
}
let mut missing_shares = Vec::new();
for i in public_shares.keys() {
if private_shares.get(i).is_none() {
missing_shares.push((self.id, *i));
}
}
if !missing_shares.is_empty() {
return Err(DkgError::MissingPrivateShares(missing_shares));
}
let mut check_shares =
CheckPrivateShares::new(self.id(), &private_shares, public_shares.clone());
if Point::multimult_trait(&mut check_shares)? != Point::zero() {
let mut bad_shares = Vec::new();
for (i, s) in private_shares.iter() {
if let Some(comm) = public_shares.get(i) {
if s * G != compute::poly(&self.id(), &comm.poly)? {
bad_shares.push(*i);
}
} else {
warn!("unable to check private share from {}: no corresponding public share, even though we checked for it above", i);
}
}
return Err(DkgError::BadPrivateShares(bad_shares));
}
self.private_key = private_shares.values().sum();
self.public_key = self.private_key * G;
Ok(())
}
fn id(&self) -> Scalar {
compute::id(self.id)
}
pub fn sign(&self, msg: &[u8], signers: &[u32], nonces: &[PublicNonce]) -> SignatureShare {
let (_, aggregate_nonce) = compute::intermediate(msg, signers, nonces);
let mut z = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
z += compute::challenge(&self.group_key, &aggregate_nonce, msg)
* &self.private_key
* compute::lambda(self.id, signers);
SignatureShare {
id: self.id,
z_i: z,
key_ids: vec![self.id],
}
}
pub fn sign_precomputed(
&self,
msg: &[u8],
signers: &[u32],
nonces: &[PublicNonce],
aggregate_nonce: &Point,
) -> SignatureShare {
self.sign_precomputed_with_tweak(msg, signers, nonces, aggregate_nonce, None)
}
pub fn sign_precomputed_with_tweak(
&self,
msg: &[u8],
signers: &[u32],
nonces: &[PublicNonce],
aggregate_nonce: &Point,
tweak: Option<Scalar>,
) -> SignatureShare {
let mut r = &self.nonce.d + &self.nonce.e * compute::binding(&self.id(), nonces, msg);
if tweak.is_some() && !aggregate_nonce.has_even_y() {
r = -r;
}
let mut cx_sign = Scalar::one();
let tweaked_public_key = if let Some(t) = tweak {
if t != Scalar::zero() {
let key = compute::tweaked_public_key_from_tweak(&self.group_key, t);
if key.has_even_y() ^ self.group_key.has_even_y() {
cx_sign = -cx_sign;
}
key
} else {
if !self.group_key.has_even_y() {
cx_sign = -cx_sign;
}
self.group_key
}
} else {
self.group_key
};
let c = compute::challenge(&tweaked_public_key, aggregate_nonce, msg);
let mut cx = c * &self.private_key * compute::lambda(self.id, signers);
cx = cx_sign * cx;
let z = r + cx;
SignatureShare {
id: self.id,
z_i: z,
key_ids: vec![self.id],
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Aggregator {
pub num_keys: u32,
pub threshold: u32,
pub poly: Vec<Point>,
}
impl Aggregator {
#[allow(non_snake_case)]
pub fn sign_with_tweak(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
tweak: Option<Scalar>,
) -> Result<(Point, Signature), AggregatorError> {
if nonces.len() != sig_shares.len() {
return Err(AggregatorError::BadNonceLen(nonces.len(), sig_shares.len()));
}
let signers: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (_Rs, R) = compute::intermediate(msg, &signers, nonces);
let mut z = Scalar::zero();
let mut cx_sign = Scalar::one();
let aggregate_public_key = self.poly[0];
let tweaked_public_key = match tweak {
Some(t) if t != Scalar::zero() => {
let key = compute::tweaked_public_key_from_tweak(&aggregate_public_key, t);
if !key.has_even_y() {
cx_sign = -cx_sign;
}
key
}
_ => aggregate_public_key,
};
let c = compute::challenge(&tweaked_public_key, &R, msg);
for sig_share in sig_shares {
z += sig_share.z_i;
}
if let Some(t) = tweak {
z += cx_sign * c * t;
}
let sig = Signature { R, z };
Ok((tweaked_public_key, sig))
}
#[allow(non_snake_case)]
pub fn check_signature_shares(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
tweak: Option<Scalar>,
) -> AggregatorError {
if nonces.len() != sig_shares.len() {
return AggregatorError::BadNonceLen(nonces.len(), sig_shares.len());
}
let signers: Vec<u32> = sig_shares.iter().map(|ss| ss.id).collect();
let (Rs, R) = compute::intermediate(msg, &signers, nonces);
let mut bad_party_keys = Vec::new();
let mut bad_party_sigs = Vec::new();
let aggregate_public_key = self.poly[0];
let tweaked_public_key = match tweak {
Some(t) if t != Scalar::zero() => {
compute::tweaked_public_key_from_tweak(&aggregate_public_key, t)
}
_ => aggregate_public_key,
};
let c = compute::challenge(&tweaked_public_key, &R, msg);
let mut r_sign = Scalar::one();
let mut cx_sign = Scalar::one();
if let Some(t) = tweak {
if !R.has_even_y() {
r_sign = -Scalar::one();
}
if t != Scalar::zero() {
if !tweaked_public_key.has_even_y() ^ !aggregate_public_key.has_even_y() {
cx_sign = -Scalar::one();
}
} else if !aggregate_public_key.has_even_y() {
cx_sign = -Scalar::one();
}
}
for i in 0..sig_shares.len() {
let id = compute::id(sig_shares[i].id);
let public_key = match compute::poly(&id, &self.poly) {
Ok(p) => p,
Err(_) => {
bad_party_keys.push(sig_shares[i].id);
Point::zero()
}
};
let z_i = sig_shares[i].z_i;
if z_i * G
!= r_sign * Rs[i]
+ cx_sign * (compute::lambda(sig_shares[i].id, &signers) * c * public_key)
{
bad_party_sigs.push(sig_shares[i].id);
}
}
if !bad_party_keys.is_empty() {
AggregatorError::BadPartyKeys(bad_party_keys)
} else if !bad_party_sigs.is_empty() {
AggregatorError::BadPartySigs(bad_party_sigs)
} else {
AggregatorError::BadGroupSig
}
}
}
impl traits::Aggregator for Aggregator {
fn new(num_keys: u32, threshold: u32) -> Self {
Self {
num_keys,
threshold,
poly: Default::default(),
}
}
fn init(&mut self, comms: &HashMap<u32, PolyCommitment>) -> Result<(), AggregatorError> {
let threshold = self.threshold.try_into()?;
let mut poly = Vec::with_capacity(threshold);
for i in 0..poly.capacity() {
poly.push(Point::zero());
for (_, p) in comms {
poly[i] += &p.poly[i];
}
}
self.poly = poly;
Ok(())
}
fn sign(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
_key_ids: &[u32],
) -> Result<Signature, AggregatorError> {
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, None)?;
if sig.verify(&key, msg) {
Ok(sig)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, None))
}
}
fn sign_schnorr(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
_key_ids: &[u32],
) -> Result<SchnorrProof, AggregatorError> {
let tweak = Scalar::from(0);
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, Some(tweak))?;
let proof = SchnorrProof::new(&sig);
if proof.verify(&key.x(), msg) {
Ok(proof)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, Some(tweak)))
}
}
fn sign_taproot(
&mut self,
msg: &[u8],
nonces: &[PublicNonce],
sig_shares: &[SignatureShare],
_key_ids: &[u32],
merkle_root: Option<[u8; 32]>,
) -> Result<SchnorrProof, AggregatorError> {
let tweak = compute::tweak(&self.poly[0], merkle_root);
let (key, sig) = self.sign_with_tweak(msg, nonces, sig_shares, Some(tweak))?;
let proof = SchnorrProof::new(&sig);
if proof.verify(&key.x(), msg) {
Ok(proof)
} else {
Err(self.check_signature_shares(msg, nonces, sig_shares, Some(tweak)))
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct Signer {
id: u32,
num_keys: u32,
threshold: u32,
group_key: Point,
parties: Vec<Party>,
}
impl Signer {
pub fn new<RNG: RngCore + CryptoRng>(
id: u32,
key_ids: &[u32],
num_keys: u32,
threshold: u32,
rng: &mut RNG,
) -> Self {
let parties = key_ids
.iter()
.map(|id| Party::new(*id, num_keys, threshold, rng))
.collect();
Signer {
id,
num_keys,
threshold,
group_key: Point::zero(),
parties,
}
}
}
impl traits::Signer for Signer {
fn new<RNG: RngCore + CryptoRng>(
party_id: u32,
key_ids: &[u32],
_num_signers: u32,
num_keys: u32,
threshold: u32,
rng: &mut RNG,
) -> Self {
Signer::new(party_id, key_ids, num_keys, threshold, rng)
}
fn load(state: &traits::SignerState) -> Self {
let parties = state
.parties
.iter()
.map(|(id, ps)| Party::load(*id, state.num_keys, state.threshold, &state.group_key, ps))
.collect();
Self {
id: state.id,
num_keys: state.num_keys,
threshold: state.threshold,
group_key: state.group_key,
parties,
}
}
fn save(&self) -> traits::SignerState {
let mut key_ids = Vec::new();
let mut parties = Vec::new();
for party in &self.parties {
key_ids.push(party.id);
parties.push((party.id, party.save()));
}
traits::SignerState {
id: self.id,
key_ids,
num_keys: self.num_keys,
num_parties: self.num_keys,
threshold: self.threshold,
group_key: self.group_key,
parties,
}
}
fn get_id(&self) -> u32 {
self.id
}
fn get_key_ids(&self) -> Vec<u32> {
self.parties.iter().map(|p| p.id).collect()
}
fn get_num_parties(&self) -> u32 {
self.num_keys
}
fn get_poly_commitments<RNG: RngCore + CryptoRng>(&self, rng: &mut RNG) -> Vec<PolyCommitment> {
let mut polys = Vec::new();
for party in &self.parties {
let comm = party.get_poly_commitment(rng);
if let Some(poly) = &comm {
polys.push(poly.clone());
}
}
polys
}
fn reset_polys<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) {
for party in self.parties.iter_mut() {
party.reset_poly(rng);
}
}
fn clear_polys(&mut self) {
for party in self.parties.iter_mut() {
party.clear_poly();
}
}
fn get_shares(&self) -> HashMap<u32, HashMap<u32, Scalar>> {
let mut shares = HashMap::new();
for party in &self.parties {
shares.insert(party.id, party.get_shares());
}
shares
}
fn compute_secrets(
&mut self,
private_shares: &HashMap<u32, HashMap<u32, Scalar>>,
polys: &HashMap<u32, PolyCommitment>,
) -> Result<(), HashMap<u32, DkgError>> {
let mut dkg_errors = HashMap::new();
for party in &mut self.parties {
let mut key_shares = HashMap::with_capacity(polys.len());
for (party_id, signer_shares) in private_shares.iter() {
if let Some(share) = signer_shares.get(&party.id) {
key_shares.insert(*party_id, *share);
}
}
if let Err(e) = party.compute_secret(key_shares, polys) {
dkg_errors.insert(party.id, e);
}
self.group_key = party.group_key;
}
if dkg_errors.is_empty() {
Ok(())
} else {
Err(dkg_errors)
}
}
fn gen_nonces<RNG: RngCore + CryptoRng>(&mut self, rng: &mut RNG) -> Vec<PublicNonce> {
self.parties.iter_mut().map(|p| p.gen_nonce(rng)).collect()
}
fn compute_intermediate(
msg: &[u8],
_signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> (Vec<Point>, Point) {
compute::intermediate(msg, key_ids, nonces)
}
fn validate_party_id(
signer_id: u32,
party_id: u32,
signer_key_ids: &HashMap<u32, HashSet<u32>>,
) -> bool {
match signer_key_ids.get(&signer_id) {
Some(key_ids) => key_ids.contains(&party_id),
None => false,
}
}
fn sign(
&self,
msg: &[u8],
_signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
self.parties
.iter()
.map(|p| p.sign_precomputed(msg, key_ids, nonces, &aggregate_nonce))
.collect()
}
fn sign_taproot(
&self,
msg: &[u8],
_signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
merkle_root: Option<[u8; 32]>,
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
let tweak = compute::tweak(&self.parties[0].group_key, merkle_root);
self.parties
.iter()
.map(|p| {
p.sign_precomputed_with_tweak(msg, key_ids, nonces, &aggregate_nonce, Some(tweak))
})
.collect()
}
fn sign_schnorr(
&self,
msg: &[u8],
_signer_ids: &[u32],
key_ids: &[u32],
nonces: &[PublicNonce],
) -> Vec<SignatureShare> {
let aggregate_nonce = compute::aggregate_nonce(msg, key_ids, nonces).unwrap();
self.parties
.iter()
.map(|p| {
p.sign_precomputed_with_tweak(
msg,
key_ids,
nonces,
&aggregate_nonce,
Some(Scalar::from(0)),
)
})
.collect()
}
}
pub mod test_helpers {
use crate::common::{PolyCommitment, PublicNonce};
use crate::errors::DkgError;
use crate::traits::Signer;
use crate::v1;
use hashbrown::HashMap;
use rand_core::{CryptoRng, RngCore};
pub fn dkg<RNG: RngCore + CryptoRng>(
signers: &mut [v1::Signer],
rng: &mut RNG,
) -> Result<HashMap<u32, PolyCommitment>, HashMap<u32, DkgError>> {
let comms: HashMap<u32, PolyCommitment> = signers
.iter()
.flat_map(|s| s.get_poly_commitments(rng))
.map(|comm| (comm.id.id.get_u32(), comm))
.collect();
let mut private_shares = HashMap::new();
for signer in signers.iter() {
for (signer_id, signer_shares) in signer.get_shares() {
private_shares.insert(signer_id, signer_shares);
}
}
let mut secret_errors = HashMap::new();
for signer in signers.iter_mut() {
if let Err(signer_secret_errors) = signer.compute_secrets(&private_shares, &comms) {
secret_errors.extend(signer_secret_errors.into_iter());
}
}
if secret_errors.is_empty() {
Ok(comms)
} else {
Err(secret_errors)
}
}
pub fn sign<RNG: RngCore + CryptoRng>(
msg: &[u8],
signers: &mut [v1::Signer],
rng: &mut RNG,
) -> (Vec<PublicNonce>, Vec<v1::SignatureShare>) {
let ids: Vec<u32> = signers.iter().flat_map(|s| s.get_key_ids()).collect();
let nonces: Vec<PublicNonce> = signers.iter_mut().flat_map(|s| s.gen_nonces(rng)).collect();
let shares = signers
.iter()
.flat_map(|s| s.sign(msg, &ids, &ids, &nonces))
.collect();
(nonces, shares)
}
}
#[cfg(test)]
mod tests {
use crate::traits;
use crate::traits::test_helpers::run_compute_secrets_missing_private_shares;
use crate::traits::{Aggregator, Signer};
use crate::util::create_rng;
use crate::v1;
use hashbrown::{HashMap, HashSet};
use num_traits::Zero;
#[test]
fn signer_new() {
let mut rng = create_rng();
let id = 1;
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let signer = v1::Signer::new(id, &key_ids, n, t, &mut rng);
assert_eq!(signer.parties.len(), key_ids.len());
}
#[test]
fn signer_gen_nonces() {
let mut rng = create_rng();
let id = 1;
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let mut signer = v1::Signer::new(id, &key_ids, n, t, &mut rng);
for party in &signer.parties {
assert!(party.nonce.is_zero());
}
let nonces = signer.gen_nonces(&mut rng);
assert_eq!(nonces.len(), key_ids.len());
for party in &signer.parties {
assert!(!party.nonce.is_zero());
}
}
#[test]
fn signer_save_load() {
let mut rng = create_rng();
let id = 1;
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let signer = v1::Signer::new(id, &key_ids, n, t, &mut rng);
let state = signer.save();
let loaded = v1::Signer::load(&state);
assert_eq!(signer, loaded);
}
#[test]
fn clear_polys() {
let mut rng = create_rng();
let id = 1;
let key_ids = [1, 2, 3];
let n: u32 = 10;
let t: u32 = 7;
let mut signer = v1::Signer::new(id, &key_ids, n, t, &mut rng);
assert_eq!(signer.get_poly_commitments(&mut rng).len(), key_ids.len());
assert_eq!(signer.get_shares().len(), key_ids.len());
for (_id, shares) in signer.get_shares() {
assert_eq!(shares.len(), n.try_into().unwrap());
}
signer.clear_polys();
assert_eq!(signer.get_poly_commitments(&mut rng).len(), 0);
assert_eq!(signer.get_shares().len(), 3);
for (_id, shares) in signer.get_shares() {
assert_eq!(shares.len(), 0);
}
}
#[allow(non_snake_case)]
#[test]
fn aggregator_sign() {
let mut rng = create_rng();
let msg = "It was many and many a year ago".as_bytes();
let N: u32 = 10;
let T: u32 = 7;
let signer_ids: Vec<Vec<u32>> = [
[1, 2, 3].to_vec(),
[4, 5].to_vec(),
[6, 7, 8].to_vec(),
[9, 10].to_vec(),
]
.to_vec();
let mut signers: Vec<v1::Signer> = signer_ids
.iter()
.enumerate()
.map(|(id, ids)| v1::Signer::new(id.try_into().unwrap(), ids, N, T, &mut rng))
.collect();
let comms = match traits::test_helpers::dkg(&mut signers, &mut rng) {
Ok(comms) => comms,
Err(secret_errors) => {
panic!("Got secret errors from DKG: {:?}", secret_errors);
}
};
{
let mut signers = [signers[0].clone(), signers[1].clone(), signers[3].clone()].to_vec();
let mut sig_agg = v1::Aggregator::new(N, T);
sig_agg.init(&comms).expect("aggregator init failed");
let (nonces, sig_shares) = v1::test_helpers::sign(msg, &mut signers, &mut rng);
if let Err(e) = sig_agg.sign(msg, &nonces, &sig_shares, &[]) {
panic!("Aggregator sign failed: {:?}", e);
}
}
}
#[allow(non_snake_case)]
#[test]
pub fn run_compute_secrets_missing_shares() {
run_compute_secrets_missing_private_shares::<v1::Signer>()
}
#[allow(non_snake_case)]
#[test]
pub fn bad_polynomial_length() {
let gt = |t| t + 1;
let lt = |t| t - 1;
traits::test_helpers::bad_polynomial_length::<v1::Signer, _>(gt);
traits::test_helpers::bad_polynomial_length::<v1::Signer, _>(lt);
}
#[test]
pub fn bad_polynomial_commitment() {
traits::test_helpers::bad_polynomial_commitment::<v1::Signer>();
}
#[test]
fn validate_party_id() {
let mut signer_key_ids = HashMap::new();
let mut key_ids = HashSet::new();
key_ids.insert(1);
signer_key_ids.insert(0, key_ids);
assert!(v1::Signer::validate_party_id(0, 1, &signer_key_ids));
assert!(!v1::Signer::validate_party_id(0, 0, &signer_key_ids));
}
}