use crate::error::{blst_err_to_atms, AtmsError};
use blst::min_pk::{
AggregatePublicKey, AggregateSignature, PublicKey as BlstPk, SecretKey as BlstSk,
Signature as BlstSig,
};
use blst::BLST_ERROR;
use rand_core::{CryptoRng, RngCore};
use std::{
cmp::Ordering,
fmt::Debug,
hash::{Hash, Hasher},
iter::Sum,
ops::Sub,
};
#[derive(Debug)]
pub struct SigningKey(BlstSk);
#[derive(Clone, Copy, Debug)]
pub struct PublicKey(pub(crate) BlstPk);
#[derive(Clone, Copy, Debug)]
pub struct ProofOfPossession(BlstSig);
#[derive(Clone, Copy, Debug)]
pub struct PublicKeyPoP(pub(crate) PublicKey, pub(crate) ProofOfPossession);
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct Signature(pub(crate) BlstSig);
impl SigningKey {
pub fn gen<R: CryptoRng + RngCore>(rng: &mut R) -> Self {
let mut ikm = [0u8; 32];
rng.fill_bytes(&mut ikm);
Self(
BlstSk::key_gen(&ikm, &[])
.expect("Error occurs when the length of ikm < 32. This will not happen here."),
)
}
pub fn sign(&self, msg: &[u8]) -> Signature {
Signature(self.0.sign(msg, &[], &[]))
}
pub fn to_bytes(&self) -> [u8; 32] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, AtmsError> {
match BlstSk::from_bytes(&bytes[..32]) {
Ok(sk) => Ok(Self(sk)),
Err(e) => Err(blst_err_to_atms(e)
.expect_err("If deserialisation is not successful, blst returns and error different to SUCCESS."))
}
}
}
impl From<&SigningKey> for PublicKey {
fn from(sk: &SigningKey) -> Self {
Self(sk.0.sk_to_pk())
}
}
impl From<&SigningKey> for ProofOfPossession {
fn from(sk: &SigningKey) -> Self {
ProofOfPossession(sk.0.sign(b"PoP", &[], &[]))
}
}
impl From<&SigningKey> for PublicKeyPoP {
fn from(sk: &SigningKey) -> Self {
Self(PublicKey(sk.0.sk_to_pk()), sk.into())
}
}
impl PublicKeyPoP {
pub fn verify(&self) -> Result<PublicKey, AtmsError> {
if self.1 .0.verify(false, b"PoP", &[], &[], &self.0 .0, false) == BLST_ERROR::BLST_SUCCESS
{
return Ok(self.0);
}
Err(AtmsError::InvalidPoP)
}
pub fn to_bytes(&self) -> [u8; 144] {
let mut pkpop_bytes = [0u8; 144];
pkpop_bytes[..48].copy_from_slice(&self.0.to_bytes());
pkpop_bytes[48..].copy_from_slice(&self.1 .0.to_bytes());
pkpop_bytes
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, AtmsError> {
let pk = match BlstPk::from_bytes(&bytes[..48]) {
Ok(key) => PublicKey(key),
Err(e) => {
return Err(blst_err_to_atms(e)
.expect_err("If it passed, blst returns and error different to SUCCESS."))
}
};
let pop = match BlstSig::from_bytes(&bytes[48..]) {
Ok(proof) => ProofOfPossession(proof),
Err(e) => {
return Err(blst_err_to_atms(e)
.expect_err("If it passed, blst returns and error different to SUCCESS."))
}
};
Ok(Self(pk, pop))
}
}
impl PublicKey {
pub fn to_bytes(&self) -> [u8; 48] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, AtmsError> {
match BlstPk::from_bytes(&bytes[..48]) {
Ok(pk) => Ok(Self(pk)),
Err(e) => Err(blst_err_to_atms(e)
.expect_err("If deserialisation is not successful, blst returns and error different to SUCCESS."))
}
}
fn cmp_msp_mvk(&self, other: &PublicKey) -> Ordering {
let self_bytes = self.to_bytes();
let other_bytes = other.to_bytes();
let mut result = Ordering::Equal;
for (i, j) in self_bytes.iter().zip(other_bytes.iter()) {
result = i.cmp(j);
if result != Ordering::Equal {
return result;
}
}
result
}
}
impl Hash for PublicKey {
fn hash<H: Hasher>(&self, state: &mut H) {
Hash::hash_slice(&self.0.compress(), state)
}
}
impl PartialEq for PublicKey {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for PublicKey {}
impl PartialOrd for PublicKey {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp_msp_mvk(other))
}
}
impl Ord for PublicKey {
fn cmp(&self, other: &Self) -> Ordering {
self.cmp_msp_mvk(other)
}
}
impl<'a> Sum<&'a Self> for PublicKey {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
let mut aggregate_key = BlstPk::default();
let keys: Vec<&BlstPk> = iter.map(|x| &x.0).collect();
if !keys.is_empty() {
aggregate_key = AggregatePublicKey::aggregate(&keys, false)
.expect("It is assumed that public keys are checked. If this aggregation failed is due to invalid keys.")
.to_public_key();
}
Self(aggregate_key)
}
}
impl Sub for PublicKey {
type Output = Self;
fn sub(self, rhs: Self) -> PublicKey {
use blst::{blst_bendian_from_fp, blst_fp, blst_fp_cneg, blst_fp_from_bendian};
let mut rhs_bytes = rhs.0.serialize();
unsafe {
let y_bytes: Vec<u8> = rhs_bytes[48..].to_vec();
let mut y: blst_fp = blst_fp::default();
let mut neg_y: blst_fp = blst_fp::default();
blst_fp_from_bendian(&mut y, &y_bytes[0]);
blst_fp_cneg(&mut neg_y, &y, true);
blst_bendian_from_fp(&mut rhs_bytes[48], &neg_y);
}
let neg_rhs = BlstPk::deserialize(&rhs_bytes)
.expect("The negative of a valid point is also a valid point.");
PublicKey(
AggregatePublicKey::aggregate(&[&neg_rhs, &self.0], false)
.expect("Points are valid")
.to_public_key(),
)
}
}
impl Signature {
pub fn verify(&self, pk: &PublicKey, msg: &[u8]) -> Result<(), AtmsError> {
blst_err_to_atms(self.0.verify(false, msg, &[], &[], &pk.0, false))
}
pub fn to_bytes(&self) -> [u8; 96] {
self.0.to_bytes()
}
pub fn from_bytes(bytes: &[u8]) -> Result<Self, AtmsError> {
match BlstSig::from_bytes(&bytes[..96]) {
Ok(sig) => Ok(Self(sig)),
Err(e) => Err(blst_err_to_atms(e)
.expect_err("If deserialisation is not successful, blst returns and error different to SUCCESS."))
}
}
fn cmp_msp_sig(&self, other: &Self) -> Ordering {
let self_bytes = self.to_bytes();
let other_bytes = other.to_bytes();
let mut result = Ordering::Equal;
for (i, j) in self_bytes.iter().zip(other_bytes.iter()) {
result = i.cmp(j);
if result != Ordering::Equal {
return result;
}
}
result
}
}
impl PartialOrd for Signature {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp_msp_sig(other))
}
}
impl Ord for Signature {
fn cmp(&self, other: &Self) -> Ordering {
self.cmp_msp_sig(other)
}
}
impl<'a> Sum<&'a Self> for Signature {
fn sum<I>(iter: I) -> Self
where
I: Iterator<Item = &'a Self>,
{
let signatures: Vec<&BlstSig> = iter.map(|x| &x.0).collect();
let aggregate = AggregateSignature::aggregate(&signatures, false).expect("Signatures are assumed verified before aggregation. If signatures are invalid, they should not be aggregated.");
Self(aggregate.to_signature())
}
}
#[cfg(test)]
mod tests {
use super::*;
use blst::{
blst_p1, blst_p1_add, blst_p1_add_affine, blst_p1_affine, blst_p1_cneg,
blst_p1_deserialize, blst_p1_from_affine, blst_p1_serialize, blst_p1_uncompress, blst_p2,
blst_p2_add_affine, blst_p2_affine, blst_p2_deserialize, blst_p2_serialize,
blst_p2_uncompress, blst_scalar, blst_scalar_fr_check, blst_scalar_from_bendian,
};
use proptest::prelude::*;
use rand_chacha::ChaCha20Rng;
use rand_core::SeedableRng;
proptest! {
#![proptest_config(ProptestConfig::with_cases(1000))]
#[test]
fn test_gen(seed in any::<[u8; 32]>(),
) {
let mut rng = ChaCha20Rng::from_seed(seed);
let sk = SigningKey::gen(&mut rng);
let pkpop = PublicKeyPoP::from(&sk);
assert!(pkpop.verify().is_ok());
}
#[test]
fn test_sig(
msg in prop::collection::vec(any::<u8>(), 1..128),
seed in any::<[u8;32]>(),
) {
let sk = SigningKey::gen(&mut ChaCha20Rng::from_seed(seed));
let pk = PublicKey::from(&sk);
let sig = sk.sign(&msg);
assert!(sig.verify(&pk, &msg).is_ok());
}
#[test]
fn test_invalid_sig(msg in prop::collection::vec(any::<u8>(), 1..128),
seed in any::<[u8;32]>(),
) {
let mut rng = ChaCha20Rng::from_seed(seed);
let sk = SigningKey::gen(&mut rng);
let pk = PublicKey::from(&sk);
let sig = sk.sign(&msg);
let invalid_sk = SigningKey::gen(&mut rng);
let invalid_sig = invalid_sk.sign(&msg);
assert_eq!(
invalid_sig.verify(&pk, &msg).unwrap_err(),
AtmsError::InvalidSignature
);
assert_eq!(
sig.verify(
&pk,
b"We are just going to take a message long enough to make sure that \
the test is never going to fall in it. Therefore, the test should fail."
).unwrap_err(),
AtmsError::InvalidSignature)
}
#[test]
fn addition_of_pks(nr_parties in 1..10usize,
seed in any::<[u8;32]>())
{
let mut rng = ChaCha20Rng::from_seed(seed);
let mut pks = Vec::with_capacity(nr_parties);
let mut underlying_points = Vec::with_capacity(nr_parties);
for _ in 0..nr_parties {
let sk = SigningKey::gen(&mut rng);
let pk = PublicKey::from(&sk);
pks.push(pk);
underlying_points.push(pk.0.serialize());
}
let aggr_pk: PublicKey = pks.iter().sum();
unsafe {
let mut aggr_point = blst_p1::default();
let mut temp_point = blst_p1_affine::default();
for point in underlying_points.iter() {
blst_p1_deserialize(&mut temp_point, &point[0]);
blst_p1_add_affine(&mut aggr_point, &aggr_point, &temp_point);
}
let mut bytes_res = [0u8; 96];
blst_p1_serialize(&mut bytes_res[0], &aggr_point);
assert_eq!(aggr_pk.0.serialize(), bytes_res);
}
}
#[test]
fn subtraction_of_pks(seed in any::<[u8;32]>())
{
let mut rng = ChaCha20Rng::from_seed(seed);
let sk_1 = SigningKey::gen(&mut rng);
let sk_2 = SigningKey::gen(&mut rng);
let pk_1 = PublicKey::from(&sk_1);
let pk_2 = PublicKey::from(&sk_2);
let point_1 = pk_1.0.serialize();
let point_2 = pk_2.0.serialize();
let negation = pk_1 - pk_2;
unsafe {
let mut raw_negation = blst_p1::default();
let mut raw_point_1 = blst_p1::default();
let mut raw_point_2 = blst_p1::default();
let mut raw_point_1_affine = blst_p1_affine::default();
let mut raw_point_2_affine = blst_p1_affine::default();
blst_p1_deserialize(&mut raw_point_1_affine, &point_1[0]);
blst_p1_deserialize(&mut raw_point_2_affine, &point_2[0]);
blst_p1_from_affine(&mut raw_point_1, &raw_point_1_affine);
blst_p1_from_affine(&mut raw_point_2, &raw_point_2_affine);
blst_p1_cneg(&mut raw_point_2, true);
blst_p1_add(&mut raw_negation, &raw_point_1, &raw_point_2);
let mut bytes_res = [0u8; 96];
blst_p1_serialize(&mut bytes_res[0], &raw_negation);
assert_eq!(negation.0.serialize(), bytes_res);
}
}
#[test]
fn addition_of_sigs(nr_parties in 1..10usize,
seed in any::<[u8;32]>())
{
let mut rng = ChaCha20Rng::from_seed(seed);
let mut sigs = Vec::with_capacity(nr_parties);
let mut underlying_points = Vec::with_capacity(nr_parties);
for _ in 0..nr_parties {
let sk = SigningKey::gen(&mut rng);
let sig = sk.sign(b"dummy message");
sigs.push(sig);
underlying_points.push(sig.0.serialize());
}
let aggr_sig: Signature = sigs.iter().sum();
unsafe {
let mut aggr_point = blst_p2::default();
let mut temp_point = blst_p2_affine::default();
for point in underlying_points.iter() {
blst_p2_deserialize(&mut temp_point, &point[0]);
blst_p2_add_affine(&mut aggr_point, &aggr_point, &temp_point);
}
let mut bytes_res = [0u8; 192];
blst_p2_serialize(&mut bytes_res[0], &aggr_point);
assert_eq!(aggr_sig.0.serialize(), bytes_res);
}
}
#[test]
fn pk_ordering(seed in any::<[u8;32]>()) {
let mut rng = ChaCha20Rng::from_seed(seed);
let sk_1 = SigningKey::gen(&mut rng);
let sk_2 = SigningKey::gen(&mut rng);
let pk_1 = PublicKey::from(&sk_1);
let pk_2 = PublicKey::from(&sk_2);
let pk_1_bytes = pk_1.to_bytes();
let pk_2_bytes = pk_2.to_bytes();
let mut result = Ordering::Equal;
for (i, j) in pk_1_bytes.iter().zip(pk_2_bytes.iter()) {
result = i.cmp(j);
if result != Ordering::Equal {
break;
}
}
assert_eq!(result, pk_1.cmp(&pk_2));
}
#[test]
fn serde_sk(sk in any::<[u8;32]>()) {
let mut raw_scalar = blst_scalar::default();
unsafe {
match SigningKey::from_bytes(&sk) {
Ok(_) => {
blst_scalar_from_bendian(&mut raw_scalar, &sk[0]);
assert!(blst_scalar_fr_check(&raw_scalar));
}
Err(_) => {
blst_scalar_from_bendian(&mut raw_scalar, &sk[0]);
assert!(!blst_scalar_fr_check(&raw_scalar));
},
};
}
}
#[test]
fn serde_pk(seed in any::<[u8; 32]>()) {
let mut random_bytes = [0u8; 48];
ChaCha20Rng::from_seed(seed).fill_bytes(&mut random_bytes);
let mut raw_pk = blst_p1_affine::default();
unsafe{
match PublicKey::from_bytes(&random_bytes) {
Ok(_) => {
assert_eq!(blst_p1_uncompress(&mut raw_pk, &random_bytes[0]), BLST_ERROR::BLST_SUCCESS);
}
Err(_) => {
let error = blst_p1_uncompress(&mut raw_pk, &random_bytes[0]);
assert!(error == BLST_ERROR::BLST_BAD_ENCODING || error == BLST_ERROR::BLST_POINT_NOT_ON_CURVE);
},
};
}
}
#[test]
fn serde_sig(seed in any::<[u8; 32]>()) {
let mut random_bytes = [0u8; 96];
ChaCha20Rng::from_seed(seed).fill_bytes(&mut random_bytes);
let mut raw_sig = blst_p2_affine::default();
unsafe {
match Signature::from_bytes(&random_bytes) {
Ok(_) => {
assert_eq!(blst_p2_uncompress(&mut raw_sig, &random_bytes[0]), BLST_ERROR::BLST_SUCCESS);
}
Err(_) => {
let error = blst_p2_uncompress(&mut raw_sig, &random_bytes[0]);
assert!(error == BLST_ERROR::BLST_BAD_ENCODING || error == BLST_ERROR::BLST_POINT_NOT_ON_CURVE);
},
};
}
}
}
}