use std::borrow::Borrow;
use thiserror::Error;
use crate::cbor::{DecodeError, decode_cbor, encode_cbor};
use crate::extensions::{Extension, Extensions};
use crate::hash::Hash;
use crate::identity::{Signature, SigningKey, VerifyingKey};
use crate::logs::SeqNum;
use crate::timestamp::Timestamp;
use crate::traits::Digest;
pub type RawOperation = (Vec<u8>, Option<Vec<u8>>);
#[derive(Clone, Debug)]
pub struct Operation<E = ()> {
pub hash: Hash,
pub header: Header<E>,
pub body: Option<Body>,
}
impl<E> Operation<E>
where
E: Extensions,
{
pub fn header(&self) -> &Header<E> {
&self.header
}
pub fn body(&self) -> Option<&Body> {
self.body.as_ref()
}
}
impl<E> PartialEq for Operation<E> {
fn eq(&self, other: &Self) -> bool {
self.hash.eq(&other.hash)
}
}
impl<E> Eq for Operation<E> {}
impl<E> Borrow<Header<E>> for Operation<E> {
fn borrow(&self) -> &Header<E> {
&self.header
}
}
#[allow(clippy::non_canonical_partial_ord_impl)]
impl<E> PartialOrd for Operation<E> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.hash.cmp(&other.hash))
}
}
impl<E> Ord for Operation<E> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.hash.cmp(&other.hash)
}
}
impl<E> Digest<Hash> for Operation<E> {
fn hash(&self) -> Hash {
self.hash
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
pub struct Header<E = ()> {
pub version: u64,
pub verifying_key: VerifyingKey,
pub signature: Option<Signature>,
pub payload_size: u64,
pub payload_hash: Option<Hash>,
pub timestamp: Timestamp,
pub seq_num: SeqNum,
pub backlink: Option<Hash>,
pub extensions: E,
}
impl<E: Default> Default for Header<E> {
fn default() -> Self {
Self {
version: 1,
verifying_key: VerifyingKey::default(),
signature: None,
payload_size: 0,
payload_hash: None,
timestamp: Timestamp::now(),
seq_num: 0,
backlink: None,
extensions: E::default(),
}
}
}
impl<E> Header<E>
where
E: Extensions,
{
pub fn to_bytes(&self) -> Vec<u8> {
encode_cbor(self)
.expect("CBOR encoder failed due to an critical IO error")
}
pub fn sign(&mut self, signing_key: &SigningKey) {
self.signature = None;
let bytes = self.to_bytes();
self.signature = Some(signing_key.sign(&bytes));
}
pub fn verify(&self) -> bool {
match self.signature {
Some(claimed_signature) => {
let mut unsigned_header = self.clone();
unsigned_header.signature = None;
let unsigned_bytes = unsigned_header.to_bytes();
self.verifying_key
.verify(&unsigned_bytes, &claimed_signature)
}
None => false,
}
}
pub fn hash(&self) -> Hash {
Hash::digest(self.to_bytes())
}
pub fn extension<T>(&self) -> Option<T>
where
E: Extension<T>,
{
E::extract(self)
}
}
impl<E> Header<E> {
pub(crate) fn field_count(&self) -> usize {
let mut count = 6;
if self.signature.is_some() {
count += 1;
}
if self.payload_hash.is_some() {
count += 1;
}
if self.backlink.is_some() {
count += 1;
}
count
}
}
impl TryFrom<&[u8]> for Header {
type Error = DecodeError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
decode_cbor(value)
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Body(pub(super) Vec<u8>);
impl Body {
pub fn new(bytes: &[u8]) -> Self {
Self(bytes.to_vec())
}
pub fn to_bytes(&self) -> Vec<u8> {
self.0.clone()
}
pub fn as_bytes(&self) -> &[u8] {
&self.0
}
pub fn hash(&self) -> Hash {
Hash::digest(&self.0)
}
pub fn size(&self) -> u64 {
self.0.len() as u64
}
}
impl From<&[u8]> for Body {
fn from(value: &[u8]) -> Self {
Body::new(value)
}
}
impl From<Vec<u8>> for Body {
fn from(value: Vec<u8>) -> Self {
Body(value)
}
}
#[derive(Clone, Debug, Error)]
pub enum OperationError {
#[error("operation version {0} is not supported, needs to be <= {1}")]
UnsupportedVersion(u64, u64),
#[error("operation needs to be signed")]
MissingSignature,
#[error("signature does not match claimed public key")]
SignatureMismatch,
#[error("sequence number can't be 0 when backlink is given")]
SeqNumMismatch,
#[error("payload hash and -size need to be defined together")]
InconsistentPayloadInfo,
#[error("needs payload hash in header when body is given")]
MissingPayloadHash,
#[error("payload hash and size do not match given body")]
PayloadMismatch,
#[error("logs can not contain operations of different authors")]
TooManyAuthors,
#[error("expected sequence number {0} but found {1}")]
SeqNumNonIncremental(u64, u64),
#[error("expected backlink but none was given")]
BacklinkMissing,
#[error("given backlink did not match previous operation")]
BacklinkMismatch,
}
pub fn validate_operation<E>(operation: impl Borrow<Operation<E>>) -> Result<(), OperationError>
where
E: Extensions,
{
let operation = operation.borrow();
validate_header(&operation.header)?;
let claimed_payload_size = operation.header.payload_size;
let claimed_payload_hash: Option<Hash> = match claimed_payload_size {
0 => None,
_ => {
let hash = operation
.header
.payload_hash
.ok_or(OperationError::MissingPayloadHash)?;
Some(hash)
}
};
if let Some(body) = &operation.body
&& (claimed_payload_hash != Some(body.hash()) || claimed_payload_size != body.size())
{
return Err(OperationError::PayloadMismatch);
}
Ok(())
}
pub fn validate_header<E>(header: &Header<E>) -> Result<(), OperationError>
where
E: Extensions,
{
if !header.verify() {
return Err(OperationError::SignatureMismatch);
}
if header.version != 1 {
return Err(OperationError::UnsupportedVersion(header.version, 1));
}
if (header.payload_hash.is_some() && header.payload_size == 0)
|| (header.payload_hash.is_none() && header.payload_size > 0)
{
return Err(OperationError::InconsistentPayloadInfo);
}
if header.backlink.is_some() && header.seq_num == 0 {
return Err(OperationError::SeqNumMismatch);
}
if header.backlink.is_none() && header.seq_num > 0 {
return Err(OperationError::BacklinkMissing);
}
Ok(())
}
pub fn validate_backlink<E>(
past_header: impl Borrow<Header<E>>,
header: impl Borrow<Header<E>>,
) -> Result<(), OperationError>
where
E: Extensions,
{
let past_header = past_header.borrow();
let header = header.borrow();
if past_header.verifying_key != header.verifying_key {
return Err(OperationError::TooManyAuthors);
}
if past_header.seq_num + 1 != header.seq_num {
return Err(OperationError::SeqNumNonIncremental(
past_header.seq_num + 1,
header.seq_num,
));
}
match header.backlink {
Some(backlink) => {
if past_header.hash() != backlink {
return Err(OperationError::BacklinkMismatch);
}
}
None => {
return Err(OperationError::BacklinkMissing);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use serde::{Deserialize, Serialize};
use crate::{Extension, SigningKey};
use super::*;
#[test]
fn simple_extension_type_parameter() {
let signing_key = SigningKey::generate();
let body = Body::new("Hello, Sloth!".as_bytes());
let mut header = Header {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: body.size(),
payload_hash: Some(body.hash()),
timestamp: Timestamp::now(),
seq_num: 0,
backlink: None,
extensions: (),
};
header.sign(&signing_key);
}
#[test]
fn sign_and_verify() {
let signing_key = SigningKey::generate();
let body = Body::new("Hello, Sloth!".as_bytes());
type CustomExtensions = ();
let mut header = Header {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: body.size(),
payload_hash: Some(body.hash()),
timestamp: Timestamp::now(),
seq_num: 0,
backlink: None,
extensions: None::<CustomExtensions>,
};
assert!(!header.verify());
header.sign(&signing_key);
assert!(header.verify());
let operation = Operation {
hash: header.hash(),
header,
body: Some(body),
};
assert!(validate_operation(&operation).is_ok());
}
#[test]
fn valid_backlink_header() {
let signing_key = SigningKey::generate();
let mut header_0 = Header::<()> {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: 0,
payload_hash: None,
timestamp: Timestamp::now(),
seq_num: 0,
backlink: None,
extensions: (),
};
header_0.sign(&signing_key);
assert!(validate_header(&header_0).is_ok());
let mut header_1 = Header::<()> {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: 0,
payload_hash: None,
timestamp: Timestamp::now(),
seq_num: 1,
backlink: Some(header_0.hash()),
extensions: (),
};
header_1.sign(&signing_key);
assert!(validate_header(&header_1).is_ok());
assert!(validate_backlink(&header_0, &header_1).is_ok());
}
#[test]
fn invalid_operations() {
let signing_key = SigningKey::generate();
let body: Body = Body::new("Hello, Sloth!".as_bytes());
let header_base = Header::<()> {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: body.size(),
payload_hash: Some(body.hash()),
timestamp: 0.into(),
seq_num: 0,
backlink: None,
extensions: (),
};
let mut header = header_base.clone();
header.version = 0;
header.sign(&signing_key);
assert!(matches!(
validate_header(&header),
Err(OperationError::UnsupportedVersion(0, 1))
));
let mut header = header_base.clone();
header.verifying_key = SigningKey::generate().verifying_key();
header.sign(&signing_key);
assert!(matches!(
validate_header(&header),
Err(OperationError::SignatureMismatch)
));
let mut header = header_base.clone();
header.seq_num = 1;
header.sign(&signing_key);
assert!(matches!(
validate_header(&header),
Err(OperationError::BacklinkMissing)
));
let mut header = header_base.clone();
header.backlink = Some(Hash::digest(vec![4, 5, 6]));
header.sign(&signing_key);
assert!(matches!(
validate_header(&header),
Err(OperationError::SeqNumMismatch)
));
let mut header = header_base.clone();
header.payload_size = 11;
header.sign(&signing_key);
assert!(matches!(
validate_operation(&Operation {
hash: header.hash(),
header,
body: Some(body.clone()),
}),
Err(OperationError::PayloadMismatch)
));
let mut header = header_base.clone();
header.payload_hash = Some(Hash::digest(vec![4, 5, 6]));
header.sign(&signing_key);
assert!(matches!(
validate_operation(&Operation {
hash: header.hash(),
header,
body: Some(body.clone()),
}),
Err(OperationError::PayloadMismatch)
));
}
#[test]
fn extensions() {
#[derive(Clone, Debug, Serialize, Deserialize)]
struct LogId(Hash);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct Expiry(u64);
#[derive(Clone, Debug, Serialize, Deserialize)]
struct CustomExtensions {
log_id: Option<LogId>,
expires: Expiry,
}
impl Extension<LogId> for CustomExtensions {
fn extract(header: &Header<Self>) -> Option<LogId> {
if header.seq_num == 0 {
return Some(LogId(header.hash()));
};
header.extensions.log_id.clone()
}
}
impl Extension<Expiry> for CustomExtensions {
fn extract(header: &Header<Self>) -> Option<Expiry> {
Some(header.extensions.expires.clone())
}
}
let extensions = CustomExtensions {
log_id: None,
expires: Expiry(0123456),
};
let signing_key = SigningKey::generate();
let body: Body = Body::new("Hello, Sloth!".as_bytes());
let mut header = Header {
version: 1,
verifying_key: signing_key.verifying_key(),
signature: None,
payload_size: body.size(),
payload_hash: Some(body.hash()),
timestamp: 0.into(),
seq_num: 0,
backlink: None,
extensions: extensions.clone(),
};
header.sign(&signing_key);
let log_id: LogId = header.extension().unwrap();
let expiry: Expiry = header.extension().unwrap();
assert_eq!(header.hash(), log_id.0);
assert_eq!(extensions.expires.0, expiry.0);
}
}