use serde::{Deserialize, Serialize};
use std::{
collections::HashSet,
iter, mem,
ops::{Bound, RangeBounds},
};
use thiserror::Error;
#[derive(Debug, Eq, PartialEq, Clone, Hash, Serialize, Deserialize)]
pub struct SectionProofChain {
head: bls::PublicKey,
tail: Vec<Block>,
}
#[allow(clippy::len_without_is_empty)]
impl SectionProofChain {
pub fn new(first: bls::PublicKey) -> Self {
Self {
head: first,
tail: Vec::new(),
}
}
pub(crate) fn push(&mut self, key: bls::PublicKey, signature: bls::Signature) -> bool {
if self.has_key(&key) {
trace!("already has key {:?}", key);
return false;
}
let valid = bincode::serialize(&key)
.map(|bytes| self.last_key().verify(&signature, &bytes))
.unwrap_or(false);
if valid {
self.tail.push(Block { key, signature });
true
} else {
error!(
"invalid SectionProofChain block signature (new key: {:?}, last key: {:?})",
key,
self.last_key()
);
false
}
}
#[cfg(test)]
pub(crate) fn push_without_validation(
&mut self,
key: bls::PublicKey,
signature: bls::Signature,
) {
self.tail.push(Block { key, signature })
}
pub fn first_key(&self) -> &bls::PublicKey {
&self.head
}
pub fn last_key(&self) -> &bls::PublicKey {
self.tail
.last()
.map(|block| &block.key)
.unwrap_or(&self.head)
}
pub fn keys(&self) -> impl DoubleEndedIterator<Item = &bls::PublicKey> {
iter::once(&self.head).chain(self.tail.iter().map(|block| &block.key))
}
pub fn has_key(&self, key: &bls::PublicKey) -> bool {
self.keys().any(|existing_key| existing_key == key)
}
pub fn index_of(&self, key: &bls::PublicKey) -> Option<u64> {
self.keys()
.position(|existing_key| existing_key == key)
.map(|index| index as u64)
}
pub fn slice<B: RangeBounds<u64>>(&self, range: B) -> Self {
let start = match range.start_bound() {
Bound::Included(index) => *index as usize,
Bound::Excluded(index) => *index as usize + 1,
Bound::Unbounded => 0,
};
let end = match range.end_bound() {
Bound::Included(index) => *index as usize + 1,
Bound::Excluded(index) => *index as usize,
Bound::Unbounded => self.tail.len() + 1,
};
let start = start.min(self.tail.len());
let end = end.min(self.tail.len() + 1).max(start + 1);
if start == 0 {
Self {
head: self.head,
tail: self.tail[0..end - 1].to_vec(),
}
} else {
Self {
head: self.tail[start - 1].key,
tail: self.tail[start..end - 1].to_vec(),
}
}
}
pub fn len(&self) -> usize {
1 + self.tail.len()
}
pub fn last_key_index(&self) -> u64 {
self.tail.len() as u64
}
pub fn self_verify(&self) -> bool {
let mut current_key = &self.head;
for block in &self.tail {
if !block.verify(current_key) {
return false;
}
current_key = &block.key;
}
true
}
pub fn check_trust<'a, I>(&self, trusted_keys: I) -> TrustStatus
where
I: IntoIterator<Item = &'a bls::PublicKey>,
{
if let Some((index, mut trusted_key)) = self.latest_trusted_key(trusted_keys) {
for block in &self.tail[index..] {
if !block.verify(trusted_key) {
return TrustStatus::Invalid;
}
trusted_key = &block.key;
}
TrustStatus::Trusted
} else if self.self_verify() {
TrustStatus::Unknown
} else {
TrustStatus::Invalid
}
}
pub(crate) fn extend(
&mut self,
new_first_key: &bls::PublicKey,
full_chain: &Self,
) -> Result<(), ExtendError> {
if self.has_key(new_first_key) {
return Err(ExtendError::AlreadySufficient);
}
let index_from = full_chain
.index_of(new_first_key)
.ok_or(ExtendError::InvalidFirstKey)?;
let index_to = full_chain
.index_of(self.last_key())
.ok_or(ExtendError::InvalidLastKey)?;
if index_from > index_to {
return Err(ExtendError::InvalidFirstKey);
}
*self = full_chain.slice(index_from..=index_to);
Ok(())
}
pub(crate) fn merge(&mut self, other: Self) -> Result<(), MergeError> {
fn check_same_keys<'a>(
a: impl IntoIterator<Item = &'a bls::PublicKey>,
b: impl IntoIterator<Item = &'a bls::PublicKey>,
) -> Result<(), MergeError> {
if a.into_iter().zip(b).all(|(a, b)| a == b) {
Ok(())
} else {
Err(MergeError)
}
}
if let Some(first) = self.index_of(other.first_key()) {
check_same_keys(self.keys().skip(first as usize + 1), other.keys().skip(1))?;
if self.has_key(other.last_key()) {
Ok(())
} else {
self.tail = mem::take(&mut self.tail)
.into_iter()
.take(first as usize)
.chain(other.tail)
.collect();
Ok(())
}
} else if let Some(first) = other.index_of(self.first_key()) {
check_same_keys(self.keys().skip(1), other.keys().skip(first as usize + 1))?;
if other.has_key(self.last_key()) {
self.head = other.head;
self.tail = other.tail;
Ok(())
} else {
self.head = other.head;
self.tail = other
.tail
.into_iter()
.take(first as usize)
.chain(mem::take(&mut self.tail))
.collect();
Ok(())
}
} else {
Err(MergeError)
}
}
fn latest_trusted_key<'a, 'b, I>(
&'a self,
trusted_keys: I,
) -> Option<(usize, &'a bls::PublicKey)>
where
I: IntoIterator<Item = &'b bls::PublicKey>,
{
let trusted_keys: HashSet<_> = trusted_keys.into_iter().collect();
let last_index = self.len() - 1;
self.keys()
.rev()
.enumerate()
.map(|(rev_index, key)| (last_index - rev_index, key))
.find(|(_, key)| trusted_keys.contains(key))
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum TrustStatus {
Trusted,
Invalid,
Unknown,
}
#[derive(Debug, Error)]
pub enum ExtendError {
#[error("invalid first key")]
InvalidFirstKey,
#[error("invalid last key")]
InvalidLastKey,
#[error("proof chain already sufficient")]
AlreadySufficient,
}
#[derive(Debug, Error, Eq, PartialEq)]
#[error("incompatible chains cannot be merged")]
pub struct MergeError;
#[derive(Debug, Eq, PartialEq, Clone, Hash, Serialize, Deserialize)]
struct Block {
key: bls::PublicKey,
signature: bls::Signature,
}
impl Block {
fn verify(&self, public_key: &bls::PublicKey) -> bool {
bincode::serialize(&self.key)
.map(|bytes| public_key.verify(&self.signature, &bytes))
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::{iter, ops::Range};
#[test]
fn check_trust_trusted() {
let (chain, _) = gen_chain(4);
for key in chain.keys() {
assert_eq!(chain.check_trust(iter::once(key)), TrustStatus::Trusted)
}
}
#[test]
fn check_trust_invalid() {
let (mut chain, _) = gen_chain(2);
let (_, invalid_secret_key) = gen_keys();
let (key, signature, secret_key) = gen_block(&invalid_secret_key);
chain.push_without_validation(key, signature);
let (key, signature, _) = gen_block(&secret_key);
let _ = chain.push(key, signature);
for key in chain.keys().take(2) {
assert_eq!(chain.check_trust(iter::once(key)), TrustStatus::Invalid)
}
for key in chain.keys().skip(2) {
assert_eq!(chain.check_trust(iter::once(key)), TrustStatus::Trusted)
}
}
#[test]
fn check_trust_unknown() {
let (chain, _) = gen_chain(2);
let (trusted_key, _) = gen_keys();
assert_eq!(
chain.check_trust(iter::once(&trusted_key)),
TrustStatus::Unknown
)
}
#[test]
#[allow(clippy::reversed_empty_ranges)]
fn slice() {
let (chain, _) = gen_chain(3);
let keys: Vec<_> = chain.keys().collect();
let assert_keys_eq = |chain: SectionProofChain, expected: &[_]| {
let actual: Vec<_> = chain.keys().collect();
assert_eq!(&actual[..], expected)
};
assert_keys_eq(chain.slice(..), &keys[0..3]);
assert_keys_eq(chain.slice(0..), &keys[0..3]);
assert_keys_eq(chain.slice(1..), &keys[1..3]);
assert_keys_eq(chain.slice(2..), &keys[2..3]);
assert_keys_eq(chain.slice(3..), &keys[2..3]);
assert_keys_eq(chain.slice(..0), &keys[0..1]);
assert_keys_eq(chain.slice(..1), &keys[0..1]);
assert_keys_eq(chain.slice(..2), &keys[0..2]);
assert_keys_eq(chain.slice(..3), &keys[0..3]);
assert_keys_eq(chain.slice(..4), &keys[0..3]);
assert_keys_eq(chain.slice(..=0), &keys[0..1]);
assert_keys_eq(chain.slice(..=1), &keys[0..2]);
assert_keys_eq(chain.slice(..=2), &keys[0..3]);
assert_keys_eq(chain.slice(..=3), &keys[0..3]);
assert_keys_eq(chain.slice(0..0), &keys[0..1]);
assert_keys_eq(chain.slice(0..1), &keys[0..1]);
assert_keys_eq(chain.slice(0..2), &keys[0..2]);
assert_keys_eq(chain.slice(0..3), &keys[0..3]);
assert_keys_eq(chain.slice(0..4), &keys[0..3]);
assert_keys_eq(chain.slice(1..1), &keys[1..2]);
assert_keys_eq(chain.slice(1..2), &keys[1..2]);
assert_keys_eq(chain.slice(1..3), &keys[1..3]);
assert_keys_eq(chain.slice(2..2), &keys[2..3]);
assert_keys_eq(chain.slice(2..3), &keys[2..3]);
assert_keys_eq(chain.slice(0..=0), &keys[0..1]);
assert_keys_eq(chain.slice(0..=1), &keys[0..2]);
assert_keys_eq(chain.slice(0..=2), &keys[0..3]);
assert_keys_eq(chain.slice(0..=3), &keys[0..3]);
}
#[test]
fn merge() {
let (chain, _) = gen_chain(4);
let check = |a: Range<u64>, b: Range<u64>, expected: Result<Range<u64>, MergeError>| {
let mut a = chain.slice(a);
let b = chain.slice(b);
let result = a.merge(b);
match expected {
Ok(range) => {
assert_eq!(result, Ok(()));
assert_eq!(a, chain.slice(range));
}
Err(error) => assert_eq!(result, Err(error)),
}
};
check(0..1, 0..1, Ok(0..1));
check(0..1, 0..2, Ok(0..2));
check(0..2, 0..1, Ok(0..2));
check(1..2, 0..2, Ok(0..2));
check(0..2, 1..2, Ok(0..2));
check(0..2, 0..2, Ok(0..2));
check(0..2, 1..3, Ok(0..3));
check(1..3, 0..2, Ok(0..3));
check(0..3, 1..2, Ok(0..3));
check(1..2, 0..3, Ok(0..3));
check(0..1, 1..2, Err(MergeError));
check(1..2, 0..1, Err(MergeError));
}
#[test]
fn merge_fork() {
let (mut chain0, sk) = gen_chain(3);
let mut chain1 = chain0.clone();
let (c0b0_pk, c0b0_signature, _) = gen_block(&sk);
let _ = chain0.push(c0b0_pk, c0b0_signature);
let (c1b0_pk, c1b0_signature, _) = gen_block(&sk);
let _ = chain1.push(c1b0_pk, c1b0_signature);
assert_eq!(chain0.merge(chain1), Err(MergeError));
}
#[test]
fn merge_fork_join() {
let (mut chain0, sk) = gen_chain(3);
let mut chain1 = chain0.clone();
let (c0b0_pk, c0b0_signature, c0b0_sk) = gen_block(&sk);
let _ = chain0.push(c0b0_pk, c0b0_signature);
let (c1b0_pk, c1b0_signature, c1b0_sk) = gen_block(&sk);
let _ = chain1.push(c1b0_pk, c1b0_signature);
let (b1_pk, _) = gen_keys();
let c0b1_signature = c0b0_sk.sign(&bincode::serialize(&b1_pk).unwrap());
let _ = chain0.push(b1_pk, c0b1_signature);
let c1b1_signature = c1b0_sk.sign(&bincode::serialize(&b1_pk).unwrap());
let _ = chain1.push(b1_pk, c1b1_signature);
assert_eq!(chain0.merge(chain1), Err(MergeError));
}
fn gen_keys() -> (bls::PublicKey, bls::SecretKey) {
let secret_key = bls::SecretKey::random();
(secret_key.public_key(), secret_key)
}
fn gen_block(
prev_secret_key: &bls::SecretKey,
) -> (bls::PublicKey, bls::Signature, bls::SecretKey) {
let (public_key, secret_key) = gen_keys();
let signature = prev_secret_key.sign(&bincode::serialize(&public_key).unwrap());
(public_key, signature, secret_key)
}
fn gen_chain(len: usize) -> (SectionProofChain, bls::SecretKey) {
let (key, mut current_secret_key) = gen_keys();
let mut chain = SectionProofChain::new(key);
for _ in 1..len {
let (new_public_key, new_signature, new_secret_key) = gen_block(¤t_secret_key);
let _ = chain.push(new_public_key, new_signature);
current_secret_key = new_secret_key;
}
(chain, current_secret_key)
}
}