use alloy_signer::{Signature, Signer};
use prost::Message;
use sha2::{Digest, Sha256};
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use uuid::Uuid;
use crate::{
error::ConsensusError,
protos::consensus::v1::{Proposal, Vote},
};
const SIGNATURE_LENGTH: usize = 65;
fn fold_u128_to_u32(n: u128) -> u32 {
((n >> 96) as u32) ^ ((n >> 64) as u32) ^ ((n >> 32) as u32) ^ (n as u32)
}
pub(crate) fn generate_id() -> u32 {
let uuid = Uuid::new_v4();
fold_u128_to_u32(uuid.as_u128())
}
pub fn compute_vote_hash(vote: &Vote) -> Vec<u8> {
let mut hasher = Sha256::new();
hasher.update(vote.vote_id.to_le_bytes());
hasher.update(&vote.vote_owner);
hasher.update(vote.proposal_id.to_le_bytes());
hasher.update(vote.timestamp.to_le_bytes());
hasher.update([vote.vote as u8]);
hasher.update(&vote.parent_hash);
hasher.update(&vote.received_hash);
hasher.finalize().to_vec()
}
pub async fn build_vote<S: Signer + Sync>(
proposal: &Proposal,
user_vote: bool,
signer: S,
) -> Result<Vote, ConsensusError> {
let now = current_timestamp()?;
let voter_address = signer.address().as_slice().to_vec();
let (parent_hash, received_hash) = if let Some(latest_vote) = proposal.votes.last() {
let own_last_vote = proposal
.votes
.iter()
.rfind(|v| v.vote_owner == voter_address);
if let Some(own_vote) = own_last_vote {
(own_vote.vote_hash.clone(), latest_vote.vote_hash.clone())
} else {
(Vec::new(), latest_vote.vote_hash.clone())
}
} else {
(Vec::new(), Vec::new())
};
let vote_id = generate_id();
let mut vote = Vote {
vote_id,
vote_owner: signer.address().as_slice().to_vec(),
proposal_id: proposal.proposal_id,
timestamp: now,
vote: user_vote,
parent_hash,
received_hash,
vote_hash: Vec::new(),
signature: Vec::new(),
};
vote.vote_hash = compute_vote_hash(&vote);
let vote_bytes = vote.encode_to_vec();
let signature = signer.sign_message(&vote_bytes).await?;
vote.signature = signature.as_bytes().to_vec();
Ok(vote)
}
fn verify_vote_hash(
signature: &[u8],
public_key: &[u8],
message: &[u8],
) -> Result<bool, ConsensusError> {
let signature_bytes: [u8; SIGNATURE_LENGTH] =
signature
.try_into()
.map_err(|_| ConsensusError::MismatchedLength {
expect: SIGNATURE_LENGTH,
actual: signature.len(),
})?;
let signature = Signature::from_raw_array(&signature_bytes)?;
let address = signature.recover_address_from_msg(message)?;
let address_bytes = address.as_slice().to_vec();
Ok(address_bytes == public_key)
}
pub fn validate_proposal(proposal: &Proposal) -> Result<(), ConsensusError> {
validate_proposal_timestamp(proposal.expiration_timestamp)?;
for vote in proposal.votes.iter() {
if vote.proposal_id != proposal.proposal_id {
return Err(ConsensusError::VoteProposalIdMismatch);
}
validate_vote(vote, proposal.expiration_timestamp, proposal.timestamp)?;
}
validate_vote_chain(&proposal.votes)?;
Ok(())
}
pub(crate) fn validate_vote(
vote: &Vote,
expiration_timestamp: u64,
creation_time: u64,
) -> Result<(), ConsensusError> {
if vote.vote_owner.is_empty() {
return Err(ConsensusError::EmptyVoteOwner);
}
if vote.vote_hash.is_empty() {
return Err(ConsensusError::EmptyVoteHash);
}
if vote.signature.is_empty() {
return Err(ConsensusError::EmptySignature);
}
if vote.signature.len() != SIGNATURE_LENGTH {
return Err(ConsensusError::MismatchedLength {
expect: SIGNATURE_LENGTH,
actual: vote.signature.len(),
});
}
let expected_hash = compute_vote_hash(vote);
if vote.vote_hash != expected_hash {
return Err(ConsensusError::InvalidVoteHash);
}
let mut vote_copy = vote.clone();
vote_copy.signature = Vec::new();
let vote_copy_bytes = vote_copy.encode_to_vec();
let verified = verify_vote_hash(&vote.signature, &vote.vote_owner, &vote_copy_bytes)?;
if !verified {
return Err(ConsensusError::InvalidVoteSignature);
}
let now = current_timestamp()?;
if vote.timestamp < creation_time {
return Err(ConsensusError::TimestampOlderThanCreationTime);
}
if vote.timestamp > expiration_timestamp || now > expiration_timestamp {
return Err(ConsensusError::VoteExpired);
}
Ok(())
}
pub(crate) fn validate_vote_chain(votes: &[Vote]) -> Result<(), ConsensusError> {
if votes.len() <= 1 {
return Ok(());
}
let mut hash_index: HashMap<&[u8], (&[u8], u64, usize)> = HashMap::new();
for (idx, vote) in votes.iter().enumerate() {
hash_index.insert(&vote.vote_hash, (&vote.vote_owner, vote.timestamp, idx));
}
for (idx, vote) in votes.iter().enumerate() {
if idx > 0 {
let prev_vote = &votes[idx - 1];
if !vote.received_hash.is_empty() {
if vote.received_hash != prev_vote.vote_hash {
return Err(ConsensusError::ReceivedHashMismatch);
}
if prev_vote.timestamp > vote.timestamp {
return Err(ConsensusError::ReceivedHashMismatch);
}
}
}
if !vote.parent_hash.is_empty() {
match hash_index.get(&vote.parent_hash.as_slice()) {
Some((owner, ts, parent_idx))
if *owner == vote.vote_owner.as_slice()
&& *ts <= vote.timestamp
&& *parent_idx < idx => {}
Some(_) => return Err(ConsensusError::ParentHashMismatch),
None => return Err(ConsensusError::ParentHashMismatch),
}
}
}
Ok(())
}
pub fn calculate_consensus_result(
votes: &HashMap<Vec<u8>, Vote>,
expected_voters: u32,
consensus_threshold: f64,
liveness_criteria_yes: bool,
is_timeout: bool,
) -> Option<bool> {
let total_votes = votes.len() as u32;
let yes_votes = votes.values().filter(|v| v.vote).count() as u32;
let no_votes = total_votes.saturating_sub(yes_votes);
let silent_votes = expected_voters.saturating_sub(total_votes);
if expected_voters <= 2 {
if total_votes < expected_voters {
return None;
}
return Some(yes_votes == expected_voters);
}
let required_votes = calculate_required_votes(expected_voters, consensus_threshold);
let effective_total = if is_timeout {
expected_voters
} else {
total_votes
};
if effective_total < required_votes {
return None;
}
let required_choice_votes =
calculate_threshold_based_value(expected_voters, consensus_threshold);
let yes_weight = yes_votes
+ if liveness_criteria_yes {
silent_votes
} else {
0
};
let no_weight = no_votes
+ if liveness_criteria_yes {
0
} else {
silent_votes
};
if yes_weight >= required_choice_votes && yes_weight > no_weight {
return Some(true);
}
if no_weight >= required_choice_votes && no_weight > yes_weight {
return Some(false);
}
if total_votes == expected_voters && yes_weight == no_weight {
return Some(liveness_criteria_yes);
}
None
}
fn calculate_required_votes(expected_voters: u32, consensus_threshold: f64) -> u32 {
if expected_voters <= 2 {
expected_voters
} else {
calculate_threshold_based_value(expected_voters, consensus_threshold)
}
}
pub(crate) fn calculate_max_rounds(expected_voters: u32, consensus_threshold: f64) -> u32 {
calculate_threshold_based_value(expected_voters, consensus_threshold)
}
fn calculate_threshold_based_value(expected_voters: u32, consensus_threshold: f64) -> u32 {
if (consensus_threshold - (2.0 / 3.0)).abs() < f64::EPSILON {
(2 * expected_voters).div_ceil(3)
} else {
((expected_voters as f64) * consensus_threshold).ceil() as u32
}
}
pub(crate) fn current_timestamp() -> Result<u64, ConsensusError> {
let now = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs();
Ok(now)
}
pub(crate) fn validate_proposal_timestamp(expiration_timestamp: u64) -> Result<(), ConsensusError> {
let now = current_timestamp()?;
if now >= expiration_timestamp {
return Err(ConsensusError::ProposalExpired);
}
Ok(())
}
pub(crate) fn validate_threshold(threshold: f64) -> Result<(), ConsensusError> {
if !(0.0..=1.0).contains(&threshold) {
return Err(ConsensusError::InvalidConsensusThreshold);
}
Ok(())
}
pub(crate) fn validate_timeout(timeout: Duration) -> Result<(), ConsensusError> {
if timeout.is_zero() {
return Err(ConsensusError::InvalidTimeout);
}
Ok(())
}
pub(crate) fn validate_expected_voters_count(
expected_voters_count: u32,
) -> Result<(), ConsensusError> {
if expected_voters_count == 0 {
return Err(ConsensusError::InvalidExpectedVotersCount);
}
Ok(())
}
pub fn has_sufficient_votes(
total_votes: u32,
expected_voters: u32,
consensus_threshold: f64,
) -> bool {
let required_votes = calculate_required_votes(expected_voters, consensus_threshold);
total_votes >= required_votes
}
#[cfg(test)]
mod tests {
use uuid::Uuid;
use super::fold_u128_to_u32;
#[test]
fn id_generation_should_not_collapse_distinct_128bit_values() {
let low = 0xDEADBEEFu32;
let high_a = 0x00000001u128;
let high_b = 0xABCDEF01u128;
let value_a = (high_a << 32) | (low as u128);
let value_b = (high_b << 32) | (low as u128);
let uuid_a = Uuid::from_u128(value_a);
let uuid_b = Uuid::from_u128(value_b);
let id_a = fold_u128_to_u32(uuid_a.as_u128());
let id_b = fold_u128_to_u32(uuid_b.as_u128());
assert_ne!(
id_a, id_b,
"distinct 128-bit values should not collapse to the same 32-bit id"
);
}
}