use crate::keys::{PublicKey, SecretKey};
use digest::Digest;
use std::{ops::Mul, prelude::v1::Vec};
use thiserror::Error;
pub const MAX_SIGNATURES: usize = 32768;
#[derive(Clone, Debug, Error, PartialEq, Eq)]
pub enum MuSigError {
#[error("The number of public nonces must match the number of public keys in the joint key")]
MismatchedNonces,
#[error("The number of partial signatures must match the number of public keys in the joint key")]
MismatchedSignatures,
#[error("The aggregate signature did not verify")]
InvalidAggregateSignature,
#[error("A partial signature did not validate: {0}")]
InvalidPartialSignature(usize),
#[error("The participant list must be sorted before making this call")]
NotSorted,
#[error("The participant key is not in the list")]
ParticipantNotFound,
#[error("An attempt was made to perform an invalid MuSig state transition")]
InvalidStateTransition,
#[error("An attempt was made to add a duplicate public key to a MuSig signature")]
DuplicatePubKey,
#[error("There are too many parties in the MuSig signature")]
TooManyParticipants,
#[error("There are too few parties in the MuSig signature")]
NotEnoughParticipants,
#[error("A nonce hash is missing")]
MissingHash,
#[error("The message to be signed can only be set once")]
MessageAlreadySet,
#[error("The message to be signed MUST be set before the final nonce is added to the MuSig ceremony")]
MissingMessage,
#[error("The message to sign is invalid. have you hashed it?")]
InvalidMessage,
#[error("MuSig requires a hash function with a 32 byte digest")]
IncompatibleHashFunction,
}
pub struct JointKey<P, K>
where
K: SecretKey,
P: PublicKey<K = K>,
{
pub_keys: Vec<P>,
musig_scalars: Vec<K>,
common: K,
joint_pub_key: P,
}
pub struct JointKeyBuilder<P, K>
where
K: SecretKey,
P: PublicKey<K = K>,
{
num_signers: usize,
pub_keys: Vec<P>,
}
impl<K, P> JointKeyBuilder<P, K>
where
K: SecretKey + Mul<P, Output = P>,
P: PublicKey<K = K>,
{
pub fn new(n: usize) -> Result<JointKeyBuilder<P, K>, MuSigError> {
if n > MAX_SIGNATURES {
return Err(MuSigError::TooManyParticipants);
}
if n == 0 {
return Err(MuSigError::NotEnoughParticipants);
}
Ok(JointKeyBuilder {
pub_keys: Vec::with_capacity(n),
num_signers: n,
})
}
pub fn num_signers(&self) -> usize {
self.num_signers
}
pub fn add_key(&mut self, pub_key: P) -> Result<usize, MuSigError> {
if self.key_exists(&pub_key) {
return Err(MuSigError::DuplicatePubKey);
}
let n = self.pub_keys.len();
if n >= self.num_signers {
return Err(MuSigError::TooManyParticipants);
}
self.pub_keys.push(pub_key);
Ok(self.pub_keys.len())
}
pub fn key_exists(&self, key: &P) -> bool {
self.pub_keys.iter().any(|v| v == key)
}
pub fn is_full(&self) -> bool {
self.pub_keys.len() == self.num_signers
}
pub fn add_keys<T: IntoIterator<Item = P>>(&mut self, keys: T) -> Result<usize, MuSigError> {
for k in keys {
self.add_key(k)?;
}
Ok(self.pub_keys.len())
}
pub fn build<D: Digest>(mut self) -> Result<JointKey<P, K>, MuSigError> {
if !self.is_full() {
return Err(MuSigError::NotEnoughParticipants);
}
self.sort_keys();
let common = self.calculate_common::<D>();
let musig_scalars = self.calculate_musig_scalars::<D>(&common);
let joint_pub_key = JointKeyBuilder::calculate_joint_key::<D>(&musig_scalars, &self.pub_keys);
Ok(JointKey {
pub_keys: self.pub_keys,
musig_scalars,
joint_pub_key,
common,
})
}
fn calculate_common<D: Digest>(&self) -> K {
let mut common = D::new();
for k in self.pub_keys.iter() {
common = common.chain(k.as_bytes());
}
K::from_bytes(&common.result())
.expect("Could not calculate Scalar from hash value. Your crypto/hash combination might be inconsistent")
}
fn calculate_partial_key<D: Digest>(common: &[u8], pubkey: &P) -> K {
let k = D::new().chain(common).chain(pubkey.as_bytes()).result();
K::from_bytes(&k)
.expect("Could not calculate Scalar from hash value. Your crypto/hash combination might be inconsistent")
}
fn sort_keys(&mut self) {
self.pub_keys.sort_unstable();
}
fn calculate_musig_scalars<D: Digest>(&self, common: &K) -> Vec<K> {
self.pub_keys
.iter()
.map(|p| JointKeyBuilder::calculate_partial_key::<D>(common.as_bytes(), p))
.collect()
}
fn calculate_joint_key<D: Digest>(scalars: &[K], pub_keys: &[P]) -> P {
P::batch_mul(scalars, pub_keys)
}
}
impl<P, K> JointKey<P, K>
where
K: SecretKey,
P: PublicKey<K = K>,
{
pub fn index_of(&self, pubkey: &P) -> Result<usize, MuSigError> {
match self.pub_keys.binary_search(pubkey) {
Ok(i) => Ok(i),
Err(_) => Err(MuSigError::ParticipantNotFound),
}
}
#[inline]
pub fn size(&self) -> usize {
self.pub_keys.len()
}
#[inline]
pub fn get_pub_keys(&self, index: usize) -> &P {
&self.pub_keys[index]
}
#[inline]
pub fn get_musig_scalar(&self, index: usize) -> &K {
&self.musig_scalars[index]
}
#[inline]
pub fn get_common(&self) -> &K {
&self.common
}
#[inline]
pub fn get_joint_pubkey(&self) -> &P {
&self.joint_pub_key
}
}