use crate::messaging::system::{SectionSig, SectionSigShare};
use crate::types::keys::ed25519::Digest256;
use std::collections::{BTreeMap, BTreeSet};
use thiserror::Error;
use tiny_keccak::{Hasher, Sha3};
#[derive(Debug, Default)]
pub struct SignatureAggregator {
map: BTreeMap<Digest256, BTreeSet<SectionSigShare>>,
}
#[derive(Debug, Error)]
pub enum AggregatorError {
#[error("signature share is invalid")]
InvalidSigShare,
#[error("failed to combine signature shares: {0}")]
FailedToCombineSigShares(#[from] bls::error::Error),
}
impl SignatureAggregator {
pub fn try_aggregate(
&mut self,
payload: &[u8],
sig_share: SectionSigShare,
) -> Result<Option<SectionSig>, AggregatorError> {
if !sig_share.verify(payload) {
return Err(AggregatorError::InvalidSigShare);
}
let public_key = sig_share.public_key_set.public_key();
let mut hasher = Sha3::v256();
let mut hash = Digest256::default();
hasher.update(payload);
hasher.update(&public_key.to_bytes());
hasher.finalize(&mut hash);
let current_shares = self.map.entry(hash).or_default();
let _ = current_shares.insert(sig_share.clone());
if current_shares.len() > sig_share.public_key_set.threshold() {
let signature = sig_share
.public_key_set
.combine_signatures(
current_shares
.iter()
.map(|s| (s.index, s.signature_share.clone())),
)
.map_err(AggregatorError::FailedToCombineSigShares)?;
let section_sig = SectionSig {
public_key,
signature,
};
self.map.remove(&hash);
Ok(Some(section_sig))
} else {
Ok(None)
}
}
}
#[derive(Debug, Default)]
pub struct TotalParticipationAggregator {
map: BTreeMap<Digest256, BTreeSet<SectionSigShare>>,
}
impl TotalParticipationAggregator {
pub fn try_aggregate(
&mut self,
payload: &[u8],
sig_share: SectionSigShare,
total_participants: usize,
) -> Result<Option<SectionSig>, AggregatorError> {
if !sig_share.verify(payload) {
return Err(AggregatorError::InvalidSigShare);
}
let public_key = sig_share.public_key_set.public_key();
let mut hasher = Sha3::v256();
let mut hash = Digest256::default();
hasher.update(payload);
hasher.update(&public_key.to_bytes());
hasher.finalize(&mut hash);
let current_shares = self.map.entry(hash).or_default();
let _ = current_shares.insert(sig_share.clone());
if current_shares.len() == total_participants {
let signature = sig_share
.public_key_set
.combine_signatures(
current_shares
.iter()
.map(|s| (s.index, s.signature_share.clone())),
)
.map_err(AggregatorError::FailedToCombineSigShares)?;
let section_sig = SectionSig {
public_key,
signature,
};
self.map.remove(&hash);
Ok(Some(section_sig))
} else {
Ok(None)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::thread_rng;
#[test]
fn test_signature_aggregator() -> Result<(), AggregatorError> {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::default();
let payload = b"hello";
for index in 0..threshold {
let sig_share = create_sig_share(&sk_set, index, payload);
let result = aggregator.try_aggregate(payload, sig_share);
match result {
Ok(None) => (),
_ => panic!("unexpected result: {result:?}"),
}
}
let sig_share = create_sig_share(&sk_set, threshold, payload);
let sig = aggregator.try_aggregate(payload, sig_share)?;
assert!(sig.expect("some key").verify(payload));
let sig_share = create_sig_share(&sk_set, threshold + 1, payload);
let result = aggregator.try_aggregate(payload, sig_share);
match result {
Ok(None) => Ok(()),
_ => panic!("unexpected result: {result:?}"),
}
}
#[test]
fn test_total_participation_aggregator() -> Result<(), AggregatorError> {
let mut rng = thread_rng();
let threshold = 3;
let total_participation = 5;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = TotalParticipationAggregator::default();
let payload = b"hello";
for index in 0..total_participation - 1 {
let sig_share = create_sig_share(&sk_set, index, payload);
let result = aggregator.try_aggregate(payload, sig_share, total_participation);
match result {
Ok(None) => (),
_ => panic!("unexpected result: {result:?}"),
}
}
let sig_share = create_sig_share(&sk_set, total_participation, payload);
let sig = aggregator.try_aggregate(payload, sig_share, total_participation)?;
assert!(sig.expect("some key").verify(payload));
let sig_share = create_sig_share(&sk_set, total_participation + 1, payload);
let result = aggregator.try_aggregate(payload, sig_share, total_participation);
match result {
Ok(None) => Ok(()),
_ => panic!("unexpected result: {result:?}"),
}
}
#[test]
fn invalid_share() -> Result<(), AggregatorError> {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::default();
let payload = b"good";
for index in 0..threshold {
let sig_share = create_sig_share(&sk_set, index, payload);
let _keyed_sig = aggregator.try_aggregate(payload, sig_share);
}
let invalid_sig_share = create_sig_share(&sk_set, threshold, b"bad");
let result = aggregator.try_aggregate(payload, invalid_sig_share);
match result {
Err(AggregatorError::InvalidSigShare) => (),
_ => panic!("unexpected result: {result:?}"),
}
let sig_share = create_sig_share(&sk_set, threshold + 1, payload);
let sig = aggregator.try_aggregate(payload, sig_share)?;
assert!(sig.expect("some key").verify(payload));
Ok(())
}
#[test]
fn repeated_voting() {
let mut rng = thread_rng();
let threshold = 3;
let sk_set = bls::SecretKeySet::random(threshold, &mut rng);
let mut aggregator = SignatureAggregator::default();
let payload = b"hello";
for index in 0..threshold {
let sig_share = create_sig_share(&sk_set, index, payload);
assert!(matches!(
aggregator.try_aggregate(payload, sig_share),
Ok(None)
));
}
let sig_share = create_sig_share(&sk_set, threshold, payload);
assert!(matches!(
aggregator.try_aggregate(payload, sig_share),
Ok(Some(_))
));
let offset = 2;
for index in offset..(threshold + offset) {
let sig_share = create_sig_share(&sk_set, index, payload);
assert!(matches!(
aggregator.try_aggregate(payload, sig_share),
Ok(None)
));
}
let sig_share = create_sig_share(&sk_set, threshold + offset + 1, payload);
assert!(matches!(
aggregator.try_aggregate(payload, sig_share),
Ok(Some(_))
));
}
fn create_sig_share(
sk_set: &bls::SecretKeySet,
index: usize,
payload: &[u8],
) -> SectionSigShare {
let sk_share = sk_set.secret_key_share(index);
SectionSigShare::new(sk_set.public_keys(), index, &sk_share, payload)
}
}