use crate::circuits::halo2_ivc::helpers::merkle_tree::MerkleTreeCommitment;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::fmt::Display;
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, PartialOrd, Ord)]
pub enum ProtocolMessagePartKey {
#[serde(rename = "digest")]
Digest,
#[serde(rename = "snapshot_digest")]
SnapshotDigest,
#[serde(rename = "cardano_transactions_merkle_root")]
CardanoTransactionsMerkleRoot,
#[serde(rename = "next_aggregate_verification_key")]
NextAggregateVerificationKey,
#[serde(rename = "next_protocol_parameters")]
NextProtocolParameters,
#[serde(rename = "current_epoch")]
CurrentEpoch,
#[serde(rename = "latest_block_number")]
LatestBlockNumber,
#[serde(rename = "cardano_stake_distribution_epoch")]
CardanoStakeDistributionEpoch,
#[serde(rename = "cardano_stake_distribution_merkle_root")]
CardanoStakeDistributionMerkleRoot,
#[serde(rename = "cardano_database_merkle_root")]
CardanoDatabaseMerkleRoot,
}
impl Display for ProtocolMessagePartKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Self::Digest => write!(f, "digest"),
Self::SnapshotDigest => write!(f, "snapshot_digest"),
Self::NextAggregateVerificationKey => write!(f, "next_aggregate_verification_key"),
Self::NextProtocolParameters => write!(f, "next_protocol_parameters"),
Self::CurrentEpoch => write!(f, "current_epoch"),
Self::CardanoTransactionsMerkleRoot => write!(f, "cardano_transactions_merkle_root"),
Self::LatestBlockNumber => write!(f, "latest_block_number"),
Self::CardanoStakeDistributionEpoch => write!(f, "cardano_stake_distribution_epoch"),
Self::CardanoStakeDistributionMerkleRoot => {
write!(f, "cardano_stake_distribution_merkle_root")
}
Self::CardanoDatabaseMerkleRoot => write!(f, "cardano_database_merkle_root"),
}
}
}
pub type ProtocolMessagePartValue = Vec<u8>;
#[derive(Clone, Debug, PartialEq, Eq, Default, Serialize, Deserialize)]
pub struct ProtocolMessage {
pub message_parts: BTreeMap<ProtocolMessagePartKey, ProtocolMessagePartValue>,
}
impl ProtocolMessage {
pub fn new() -> ProtocolMessage {
ProtocolMessage {
message_parts: BTreeMap::new(),
}
}
pub fn set_message_part(
&mut self,
key: ProtocolMessagePartKey,
value: ProtocolMessagePartValue,
) -> Option<ProtocolMessagePartValue> {
self.message_parts.insert(key, value)
}
pub fn get_message_part(
&self,
key: &ProtocolMessagePartKey,
) -> Option<&ProtocolMessagePartValue> {
self.message_parts.get(key)
}
pub fn get_preimage(&self) -> Vec<u8> {
let mut preimage = Vec::new();
self.message_parts.iter().for_each(|(k, v)| {
preimage.extend_from_slice(k.to_string().as_bytes());
preimage.extend_from_slice(v);
});
preimage
}
pub fn compute_hash(&self) -> Vec<u8> {
let preimage = self.get_preimage();
Sha256::digest(&preimage).to_vec()
}
}
#[derive(Debug, Clone)]
pub struct AggregateVerificationKey {
mt_commit: MerkleTreeCommitment,
total_stake: u64,
}
impl AggregateVerificationKey {
pub fn new(mt_commit: MerkleTreeCommitment, total_stake: u64) -> Self {
Self {
mt_commit,
total_stake,
}
}
}
impl From<AggregateVerificationKey> for Vec<u8> {
fn from(avk: AggregateVerificationKey) -> Vec<u8> {
let mut bytes = Vec::from(avk.mt_commit);
bytes.extend_from_slice(&avk.total_stake.to_le_bytes());
bytes
}
}
impl TryFrom<&[u8]> for AggregateVerificationKey {
type Error = &'static str;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != 44 {
return Err("Invalid byte length for AggregateVerificationKey");
}
let mt_commit = bytes[0..36].try_into().unwrap();
let total_stake = u64::from_le_bytes(bytes[36..44].try_into().unwrap());
Ok(AggregateVerificationKey {
mt_commit,
total_stake,
})
}
}
#[derive(Debug, Clone)]
pub struct Epoch(pub u64);
impl From<Epoch> for Vec<u8> {
fn from(epoch: Epoch) -> Vec<u8> {
epoch.0.to_le_bytes().to_vec()
}
}
impl TryFrom<&[u8]> for Epoch {
type Error = &'static str;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
if bytes.len() != 8 {
return Err("Invalid byte length for AggregateVerificationKey");
}
let num = u64::from_le_bytes(bytes.try_into().unwrap());
Ok(Epoch(num))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn all_key() -> ([ProtocolMessagePartKey; 10], [usize; 10]) {
let keys = [
ProtocolMessagePartKey::Digest,
ProtocolMessagePartKey::SnapshotDigest,
ProtocolMessagePartKey::CardanoTransactionsMerkleRoot,
ProtocolMessagePartKey::NextAggregateVerificationKey,
ProtocolMessagePartKey::NextProtocolParameters,
ProtocolMessagePartKey::CurrentEpoch,
ProtocolMessagePartKey::LatestBlockNumber,
ProtocolMessagePartKey::CardanoStakeDistributionEpoch,
ProtocolMessagePartKey::CardanoStakeDistributionMerkleRoot,
ProtocolMessagePartKey::CardanoDatabaseMerkleRoot,
];
let lens = [6, 15, 32, 31, 24, 13, 19, 32, 38, 28];
(keys, lens)
}
fn build_protocol_message_reference() -> ProtocolMessage {
let mut protocol_message = ProtocolMessage::new();
protocol_message.set_message_part(ProtocolMessagePartKey::SnapshotDigest, vec![0u8; 32]);
protocol_message.set_message_part(
ProtocolMessagePartKey::NextAggregateVerificationKey,
vec![0u8; 44],
);
protocol_message.set_message_part(
ProtocolMessagePartKey::NextProtocolParameters,
vec![0u8; 32],
);
protocol_message.set_message_part(ProtocolMessagePartKey::CurrentEpoch, vec![0u8; 8]);
protocol_message
}
#[test]
fn test_protocol_message_hash() {
let protocol_message = build_protocol_message_reference();
let hash = protocol_message.compute_hash();
let mut protocol_message_modified = protocol_message.clone();
protocol_message_modified.set_message_part(
ProtocolMessagePartKey::NextAggregateVerificationKey,
vec![3u8; 44],
);
assert_ne!(hash, protocol_message_modified.compute_hash());
}
#[test]
fn test_protocol_message_keys() {
let (keys, lens) = all_key();
assert_eq!(keys.len(), lens.len());
for (key, len) in keys.iter().zip(lens.iter()) {
let s = key.to_string();
assert_eq!(s.len(), *len);
}
}
}