use super::signed::{Signed, SignedShare};
use std::{
collections::HashMap,
fmt::Debug,
time::{Duration, Instant},
};
use thiserror::Error;
use threshold_crypto as bls;
use tiny_keccak::{Hasher, Sha3};
pub const DEFAULT_EXPIRATION: Duration = Duration::from_secs(120);
type Digest256 = [u8; 32];
pub struct SignatureAggregator {
map: HashMap<Digest256, State>,
expiration: Duration,
}
impl SignatureAggregator {
pub fn new() -> Self {
Self::with_expiration(DEFAULT_EXPIRATION)
}
pub fn with_expiration(expiration: Duration) -> Self {
Self {
map: Default::default(),
expiration,
}
}
pub fn add(&mut self, payload: &[u8], signed_share: SignedShare) -> Result<Signed, Error> {
self.remove_expired();
if !signed_share.verify(payload) {
return Err(Error::InvalidShare);
}
let public_key = signed_share.public_key_set.public_key();
let mut hasher = Sha3::v256();
let mut hash = Digest256::default();
hasher.update(payload);
hasher.update(&public_key.to_bytes());
hasher.finalize(&mut hash);
self.map
.entry(hash)
.or_insert_with(State::new)
.add(signed_share)
.map(|signature| Signed {
public_key,
signature,
})
}
fn remove_expired(&mut self) {
let expiration = self.expiration;
self.map
.retain(|_, state| state.modified.elapsed() < expiration)
}
}
impl Default for SignatureAggregator {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Error)]
pub enum Error {
#[error("not enough signature shares")]
NotEnoughShares,
#[error("signature share is invalid")]
InvalidShare,
#[error("failed to combine signature shares: {0}")]
Combine(bls::error::Error),
}
struct State {
shares: HashMap<usize, bls::SignatureShare>,
modified: Instant,
}
impl State {
fn new() -> Self {
Self {
shares: Default::default(),
modified: Instant::now(),
}
}
fn add(&mut self, signed_share: SignedShare) -> Result<bls::Signature, Error> {
if self
.shares
.insert(signed_share.index, signed_share.signature_share)
.is_none()
{
self.modified = Instant::now();
} else {
return Err(Error::NotEnoughShares);
}
if self.shares.len() > signed_share.public_key_set.threshold() {
let signature = signed_share
.public_key_set
.combine_signatures(self.shares.iter().map(|(&index, share)| (index, share)))
.map_err(Error::Combine)?;
self.shares.clear();
Ok(signature)
} else {
Err(Error::NotEnoughShares)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::thread_rng;
use std::thread::sleep;
#[test]
fn smoke() {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::default();
let payload = b"hello";
for index in 0..threshold {
let signed_share = create_signed_share(&sk_set, index, payload);
println!("{:?}", signed_share);
let result = aggregator.add(payload, signed_share);
match result {
Err(Error::NotEnoughShares) => (),
_ => panic!("unexpected result: {:?}", result),
}
}
let signed_share = create_signed_share(&sk_set, threshold, payload);
let signed = aggregator.add(payload, signed_share).unwrap();
assert!(signed.verify(payload));
let signed_share = create_signed_share(&sk_set, threshold + 1, payload);
let result = aggregator.add(payload, signed_share);
match result {
Err(Error::NotEnoughShares) => (),
_ => panic!("unexpected result: {:?}", result),
}
}
#[test]
fn invalid_share() {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::new();
let payload = b"good";
for index in 0..threshold {
let signed_share = create_signed_share(&sk_set, index, payload);
let _ = aggregator.add(payload, signed_share);
}
let invalid_signed_share = create_signed_share(&sk_set, threshold, b"bad");
let result = aggregator.add(payload, invalid_signed_share);
match result {
Err(Error::InvalidShare) => (),
_ => panic!("unexpected result: {:?}", result),
}
let signed_share = create_signed_share(&sk_set, threshold + 1, payload);
let signed = aggregator.add(payload, signed_share).unwrap();
assert!(signed.verify(payload))
}
#[test]
fn expiration() {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::with_expiration(Duration::from_millis(500));
let payload = b"hello";
for index in 0..threshold {
let signed_share = create_signed_share(&sk_set, index, payload);
let _ = aggregator.add(payload, signed_share);
}
sleep(Duration::from_secs(1));
let signed_share = create_signed_share(&sk_set, threshold, payload);
let result = aggregator.add(payload, signed_share);
match result {
Err(Error::NotEnoughShares) => (),
_ => panic!("unexpected result: {:?}", result),
}
}
#[test]
fn repeated_voting() {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::new();
let payload = b"hello";
for index in 0..threshold {
let signed_share = create_signed_share(&sk_set, index, payload);
assert!(aggregator.add(payload, signed_share).is_err());
}
let signed_share = create_signed_share(&sk_set, threshold, payload);
assert!(aggregator.add(payload, signed_share).is_ok());
let offset = 2;
for index in offset..(threshold + offset) {
let signed_share = create_signed_share(&sk_set, index, payload);
assert!(aggregator.add(payload, signed_share).is_err());
}
let signed_share = create_signed_share(&sk_set, threshold + offset + 1, payload);
assert!(aggregator.add(payload, signed_share).is_ok());
}
fn create_signed_share(
sk_set: &bls::SecretKeySet,
index: usize,
payload: &[u8],
) -> SignedShare {
let sk_share = sk_set.secret_key_share(index);
SignedShare::new(sk_set.public_keys(), index, &sk_share, &payload)
}
}