use std::hash::{Hash as StdHash, Hasher};
use serde::{Deserialize, Serialize};
use crate::coroutine::Value;
use crate::instr::Endpoint;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
pub struct Hash(pub [u8; 32]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum HashTag {
Value,
SignedValue,
MerkleLeaf,
MerkleNode,
Commitment,
Nullifier,
SigningKey,
}
impl HashTag {
fn domain_byte(self) -> u8 {
match self {
Self::Value => 0x01,
Self::SignedValue => 0x02,
Self::MerkleLeaf => 0x03,
Self::MerkleNode => 0x04,
Self::Commitment => 0x05,
Self::Nullifier => 0x06,
Self::SigningKey => 0x07,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct SigningKey(pub [u8; 32]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct VerifyingKey(pub [u8; 32]);
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub struct Signature {
pub signer: VerifyingKey,
pub digest: Hash,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Commitment(pub Hash);
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
pub struct Nullifier(pub Hash);
pub trait VerificationModel {
type Hash;
type SigningKey;
type VerifyingKey;
type Signature;
type Commitment;
type Nullifier;
fn hash(tag: HashTag, bytes: &[u8]) -> Self::Hash;
fn deriving(signing: &Self::SigningKey) -> Self::VerifyingKey;
fn sign_value(payload: &Value, key: &Self::SigningKey) -> Self::Signature;
fn verify_signed_value(
payload: &Value,
signature: &Self::Signature,
key: &Self::VerifyingKey,
) -> bool;
fn commitment(payload: &Value) -> Self::Commitment;
fn nullifier(payload: &Value) -> Self::Nullifier;
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
pub struct DefaultVerificationModel;
fn hash_bytes_with_tag(tag: HashTag, bytes: &[u8]) -> Hash {
let mut out = [0_u8; 32];
for block in 0_u64..4 {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
tag.domain_byte().hash(&mut hasher);
block.hash(&mut hasher);
bytes.hash(&mut hasher);
let digest = hasher.finish().to_le_bytes();
let Ok(block_usize) = usize::try_from(block) else {
return Hash(out);
};
let start = block_usize * 8;
out[start..start + 8].copy_from_slice(&digest);
}
Hash(out)
}
fn encode_value(value: &Value) -> Vec<u8> {
serde_json::to_vec(value).unwrap_or_else(|_| format!("{value:?}").into_bytes())
}
impl VerificationModel for DefaultVerificationModel {
type Hash = Hash;
type SigningKey = SigningKey;
type VerifyingKey = VerifyingKey;
type Signature = Signature;
type Commitment = Commitment;
type Nullifier = Nullifier;
fn hash(tag: HashTag, bytes: &[u8]) -> Self::Hash {
hash_bytes_with_tag(tag, bytes)
}
fn deriving(signing: &Self::SigningKey) -> Self::VerifyingKey {
let digest = hash_bytes_with_tag(HashTag::SigningKey, &signing.0);
VerifyingKey(digest.0)
}
fn sign_value(payload: &Value, key: &Self::SigningKey) -> Self::Signature {
crate::verification::sign_value(payload, key)
}
fn verify_signed_value(
payload: &Value,
signature: &Self::Signature,
key: &Self::VerifyingKey,
) -> bool {
verify_signed_value(payload, signature, key)
}
fn commitment(payload: &Value) -> Self::Commitment {
Commitment(hash_bytes_with_tag(
HashTag::Commitment,
&encode_value(payload),
))
}
fn nullifier(payload: &Value) -> Self::Nullifier {
Nullifier(hash_bytes_with_tag(
HashTag::Nullifier,
&encode_value(payload),
))
}
}
#[must_use]
pub fn signing_key_for_endpoint(endpoint: &Endpoint) -> SigningKey {
let mut bytes = endpoint.sid.to_le_bytes().to_vec();
bytes.extend_from_slice(endpoint.role.as_bytes());
let digest = hash_bytes_with_tag(HashTag::SigningKey, &bytes);
SigningKey(digest.0)
}
#[must_use]
pub fn verifying_key_for_endpoint(endpoint: &Endpoint) -> VerifyingKey {
DefaultVerificationModel::deriving(&signing_key_for_endpoint(endpoint))
}
#[must_use]
pub fn sign_value(payload: &Value, key: &SigningKey) -> Signature {
let verifying = DefaultVerificationModel::deriving(key);
let mut bytes = verifying.0.to_vec();
bytes.extend_from_slice(&encode_value(payload));
let digest = hash_bytes_with_tag(HashTag::SignedValue, &bytes);
Signature {
signer: verifying,
digest,
}
}
#[must_use]
pub fn verify_signed_value(payload: &Value, signature: &Signature, key: &VerifyingKey) -> bool {
if signature.signer != *key {
return false;
}
let mut bytes = key.0.to_vec();
bytes.extend_from_slice(&encode_value(payload));
let expected = hash_bytes_with_tag(HashTag::SignedValue, &bytes);
expected == signature.digest
}
fn merge_hash_pair(left: Hash, right: Hash) -> Hash {
let mut bytes = left.0.to_vec();
bytes.extend_from_slice(&right.0);
hash_bytes_with_tag(HashTag::MerkleNode, &bytes)
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthProof {
pub index: usize,
pub siblings: Vec<Hash>,
pub sibling_on_left: Vec<bool>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AuthTree {
leaves: Vec<Hash>,
levels: Vec<Vec<Hash>>,
}
impl AuthTree {
#[must_use]
pub fn new(leaves: Vec<Hash>) -> Self {
if leaves.is_empty() {
return Self {
leaves,
levels: vec![vec![hash_bytes_with_tag(HashTag::MerkleLeaf, &[])]],
};
}
let mut levels = vec![leaves.clone()];
let mut level = leaves.clone();
while level.len() > 1 {
let mut next = Vec::with_capacity(level.len().div_ceil(2));
for chunk in level.chunks(2) {
let left = chunk[0];
let right = if chunk.len() == 2 { chunk[1] } else { chunk[0] };
next.push(merge_hash_pair(left, right));
}
levels.push(next.clone());
level = next;
}
Self { leaves, levels }
}
pub fn append_leaf(&mut self, leaf: Hash) {
if self.leaves.is_empty() {
*self = Self::new(vec![leaf]);
return;
}
self.leaves.push(leaf);
self.levels[0].push(leaf);
let mut idx = self.levels[0].len() - 1;
let mut level_idx = 0;
loop {
let level = &self.levels[level_idx];
let pair_start = idx & !1;
let left = level[pair_start];
let right = if pair_start + 1 < level.len() {
level[pair_start + 1]
} else {
left
};
let parent = merge_hash_pair(left, right);
let parent_idx = pair_start / 2;
if self.levels.len() == level_idx + 1 {
self.levels.push(Vec::new());
}
let next = &mut self.levels[level_idx + 1];
if parent_idx < next.len() {
next[parent_idx] = parent;
} else {
next.push(parent);
}
if parent_idx == 0 && next.len() == 1 {
break;
}
idx = parent_idx;
level_idx += 1;
}
}
#[must_use]
pub fn root(&self) -> Hash {
self.levels
.last()
.and_then(|level| level.first().copied())
.unwrap_or_else(|| hash_bytes_with_tag(HashTag::MerkleLeaf, &[]))
}
#[must_use]
pub fn prove(&self, index: usize) -> Option<AuthProof> {
if index >= self.leaves.len() {
return None;
}
let mut idx = index;
let mut siblings = Vec::new();
let mut sibling_on_left = Vec::new();
for level in &self.levels {
if level.len() <= 1 {
break;
}
let pair_index = idx ^ 1;
let sibling = if pair_index < level.len() {
level[pair_index]
} else {
level[idx]
};
siblings.push(sibling);
sibling_on_left.push(pair_index < idx);
idx /= 2;
}
Some(AuthProof {
index,
siblings,
sibling_on_left,
})
}
#[must_use]
pub fn verify(root: Hash, leaf: Hash, proof: &AuthProof) -> bool {
if proof.siblings.len() != proof.sibling_on_left.len() {
return false;
}
let mut current = leaf;
let mut index = proof.index;
for (sibling, on_left) in proof.siblings.iter().zip(proof.sibling_on_left.iter()) {
let expected_on_left = index % 2 == 1;
if *on_left != expected_on_left {
return false;
}
current = if *on_left {
merge_hash_pair(*sibling, current)
} else {
merge_hash_pair(current, *sibling)
};
index /= 2;
}
current == root
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn signature_roundtrip() {
let ep = Endpoint {
sid: 9,
role: "Alice".to_string(),
};
let sk = signing_key_for_endpoint(&ep);
let vk = verifying_key_for_endpoint(&ep);
let payload = Value::Nat(42);
let sig = sign_value(&payload, &sk);
assert!(verify_signed_value(&payload, &sig, &vk));
assert!(!verify_signed_value(&Value::Nat(7), &sig, &vk));
}
#[test]
fn auth_tree_proof_roundtrip() {
let leaves = vec![
hash_bytes_with_tag(HashTag::MerkleLeaf, b"a"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"b"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"c"),
];
let tree = AuthTree::new(leaves.clone());
let proof = tree.prove(1).expect("proof for valid index");
assert!(AuthTree::verify(tree.root(), leaves[1], &proof));
}
#[test]
fn auth_tree_incremental_append_matches_rebuild() {
let leaves = vec![
hash_bytes_with_tag(HashTag::MerkleLeaf, b"a"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"b"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"c"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"d"),
hash_bytes_with_tag(HashTag::MerkleLeaf, b"e"),
];
let mut incremental = AuthTree::new(vec![leaves[0]]);
for leaf in leaves.iter().skip(1) {
incremental.append_leaf(*leaf);
}
let rebuilt = AuthTree::new(leaves);
assert_eq!(incremental.root(), rebuilt.root());
}
}