use blake2::{
Blake2b,
digest::{Digest, consts::U32},
};
use mx_proto::generated::proto::HeaderV3;
use mx_proto::generated::proto::Message as ConsensusProtoMessage;
use prost::Message;
use sha2::Sha256;
type Blake2b256 = Blake2b<U32>;
const MAX_NODES_TO_SWAP_PER_SHARD: usize = 80;
const NODES_PER_SHARD: usize = 400;
#[derive(Debug, Clone)]
pub struct EligibleValidator {
pub pub_key: Vec<u8>,
pub index: u32,
pub chances: u32,
}
pub fn select_leader(
rand_seed: &[u8],
round: u64,
eligible_list: &[EligibleValidator],
) -> Option<usize> {
if eligible_list.is_empty() {
return None;
}
let expanded_list = build_expanded_list(eligible_list);
if expanded_list.is_empty() {
return None;
}
let randomness = build_round_randomness(round, rand_seed);
let random_u64 = compute_randomness_as_u64(&randomness, 0);
let index = random_u64 % (expanded_list.len() as u64);
Some(expanded_list[index as usize] as usize)
}
pub fn select_consensus_group(
rand_seed: &[u8],
round: u64,
eligible_list: &[EligibleValidator],
size: usize,
) -> Vec<usize> {
if eligible_list.is_empty() || size == 0 {
return Vec::new();
}
let expanded_list = build_expanded_list(eligible_list);
let len_expanded = expanded_list.len() as i64;
if size as i64 > len_expanded {
return Vec::new();
}
let randomness = build_round_randomness(round, rand_seed);
let mut selected = Vec::with_capacity(size);
let mut sorted_entries: Vec<(i64, i64)> = Vec::new();
let mut total_selected: i64 = 0;
for i in 0..size {
let random_u64 = compute_randomness_as_u64(&randomness, i);
let mut index = random_u64 % ((len_expanded - total_selected) as u64);
index = adjust_index(index, &sorted_entries);
let validator_idx = expanded_list[index as usize];
selected.push(validator_idx as usize);
let (start_idx, num_appearances) =
compute_start_and_appearances(&expanded_list, index as i64);
insert_sorted(&mut sorted_entries, start_idx, num_appearances);
total_selected += num_appearances;
}
selected
}
pub fn epoch_shuffle(
eligible: &[EligibleValidator],
waiting: &[EligibleValidator],
shuffle_randomness: &[u8],
) -> Vec<EligibleValidator> {
let num_to_remove = (eligible.len() + waiting.len()).saturating_sub(NODES_PER_SHARD);
let actual_to_remove = num_to_remove.min(MAX_NODES_TO_SWAP_PER_SHARD);
let shuffled = shuffle_list(eligible, shuffle_randomness);
let remaining = &shuffled[actual_to_remove..];
let num_needed = NODES_PER_SHARD
.saturating_sub(remaining.len())
.min(waiting.len());
let mut result = Vec::with_capacity(NODES_PER_SHARD);
result.extend_from_slice(remaining);
result.extend_from_slice(&waiting[..num_needed]);
result
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConsensusMsgType {
Unknown,
BlockBodyAndHeader,
BlockBody,
BlockHeader,
Signature,
BlockHeaderFinalInfo,
InvalidSigners,
Unrecognized(i64),
}
impl From<i64> for ConsensusMsgType {
fn from(v: i64) -> Self {
match v {
0 => Self::Unknown,
1 => Self::BlockBodyAndHeader,
2 => Self::BlockBody,
3 => Self::BlockHeader,
4 => Self::Signature,
5 => Self::BlockHeaderFinalInfo,
6 => Self::InvalidSigners,
other => Self::Unrecognized(other),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ProposalSignal {
pub shard_id: u32,
pub nonce: u64,
pub round: u64,
pub epoch: u32,
pub rand_seed: Vec<u8>,
pub prev_rand_seed: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct FinalInfoSignal {
pub shard_id: u32,
pub round: i64,
pub block_header_hash: Vec<u8>,
pub pub_keys_bitmap: Vec<u8>,
pub aggregate_signature: Vec<u8>,
pub leader_signature: Vec<u8>,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ConsensusSignal {
Proposal(ProposalSignal),
FinalInfo(FinalInfoSignal),
}
pub fn decode_consensus_signal(shard_id: u32, bytes: &[u8]) -> Option<ConsensusSignal> {
let message = ConsensusProtoMessage::decode(bytes).ok()?;
match ConsensusMsgType::from(message.msg_type) {
ConsensusMsgType::BlockHeader | ConsensusMsgType::BlockBodyAndHeader => {
if message.header.is_empty() {
return None;
}
let header = HeaderV3::decode(message.header.as_ref()).ok()?;
Some(ConsensusSignal::Proposal(ProposalSignal {
shard_id,
nonce: header.nonce,
round: header.round,
epoch: header.epoch,
rand_seed: header.rand_seed.to_vec(),
prev_rand_seed: header.prev_rand_seed.to_vec(),
}))
}
ConsensusMsgType::BlockHeaderFinalInfo => {
Some(ConsensusSignal::FinalInfo(FinalInfoSignal {
shard_id,
round: message.round_index,
block_header_hash: message.block_header_hash.to_vec(),
pub_keys_bitmap: message.pub_keys_bitmap.to_vec(),
aggregate_signature: message.aggregate_signature.to_vec(),
leader_signature: message.leader_signature.to_vec(),
}))
}
ConsensusMsgType::Unknown
| ConsensusMsgType::BlockBody
| ConsensusMsgType::Signature
| ConsensusMsgType::InvalidSigners
| ConsensusMsgType::Unrecognized(_) => None,
}
}
fn build_round_randomness(round: u64, rand_seed: &[u8]) -> Vec<u8> {
let round_str = round.to_string();
let mut randomness = Vec::with_capacity(round_str.len() + 1 + rand_seed.len());
randomness.extend_from_slice(round_str.as_bytes());
randomness.push(b'-');
randomness.extend_from_slice(rand_seed);
randomness
}
fn build_expanded_list(eligible_list: &[EligibleValidator]) -> Vec<u32> {
let total: usize = eligible_list
.iter()
.map(|v| v.chances.max(1) as usize)
.sum();
let mut expanded = Vec::with_capacity(total);
for (i, v) in eligible_list.iter().enumerate() {
let chances = v.chances.max(1);
for _ in 0..chances {
expanded.push(i as u32);
}
}
expanded
}
fn compute_randomness_as_u64(randomness: &[u8], index: usize) -> u64 {
let index_bytes = (index as u64).to_be_bytes();
let mut hasher = Blake2b256::new();
hasher.update(index_bytes);
hasher.update(randomness);
let hash = hasher.finalize();
u64::from_be_bytes(hash[0..8].try_into().unwrap())
}
fn adjust_index(mut index: u64, sorted_entries: &[(i64, i64)]) -> u64 {
for &(start_index, num_appearances) in sorted_entries {
if (start_index as u64) > index {
break;
}
index += num_appearances as u64;
}
index
}
fn compute_start_and_appearances(expanded_list: &[u32], idx: i64) -> (i64, i64) {
let val = expanded_list[idx as usize];
let list_len = expanded_list.len() as i64;
let mut start_idx: i64 = 0;
for i in (0..idx).rev() {
if expanded_list[i as usize] != val {
start_idx = i + 1;
break;
}
}
let mut end_idx = list_len - 1;
for i in (idx + 1)..list_len {
if expanded_list[i as usize] != val {
end_idx = i - 1;
break;
}
}
(start_idx, end_idx - start_idx + 1)
}
fn insert_sorted(sorted_entries: &mut Vec<(i64, i64)>, start_index: i64, num_appearances: i64) {
let pos = sorted_entries
.iter()
.position(|&(si, _)| si >= start_index)
.unwrap_or(sorted_entries.len());
sorted_entries.insert(pos, (start_index, num_appearances));
}
fn shuffle_list(validators: &[EligibleValidator], randomness: &[u8]) -> Vec<EligibleValidator> {
let mut keyed: Vec<([u8; 32], EligibleValidator)> = validators
.iter()
.map(|v| {
let mut hasher = Sha256::new();
hasher.update(&v.pub_key);
hasher.update(randomness);
let hash: [u8; 32] = hasher.finalize().into();
(hash, v.clone())
})
.collect();
keyed.sort_by_key(|a| a.0);
keyed.into_iter().map(|(_, v)| v).collect()
}
#[cfg(test)]
mod tests {
use super::*;
use mx_proto::generated::proto::Message as ConsensusMsg;
use prost::bytes::Bytes;
fn make_validators(count: usize) -> Vec<EligibleValidator> {
(0..count)
.map(|i| EligibleValidator {
pub_key: vec![i as u8; 96],
index: i as u32,
chances: 1,
})
.collect()
}
fn make_validator(idx: u32, pk_byte: u8) -> EligibleValidator {
EligibleValidator {
pub_key: vec![pk_byte; 96],
index: idx,
chances: 24,
}
}
fn make_header(nonce: u64, round: u64, shard_id: u32) -> Vec<u8> {
let header = HeaderV3 {
nonce,
round,
shard_id,
epoch: 100,
rand_seed: Bytes::from(vec![0xAA; 32]),
prev_rand_seed: Bytes::from(vec![0xDD; 32]),
leader_signature: Bytes::from(vec![0xBB; 96]),
..Default::default()
};
header.encode_to_vec()
}
fn make_consensus_msg(msg_type: i64, header_bytes: Vec<u8>) -> Vec<u8> {
let msg = ConsensusMsg {
msg_type,
header: Bytes::from(header_bytes),
pub_key: Bytes::from(vec![0xCC; 96]),
round_index: 12346,
..Default::default()
};
msg.encode_to_vec()
}
#[test]
fn test_select_leader_deterministic() {
let validators = make_validators(400);
let rand_seed = vec![0xAA; 32];
let leader1 = select_leader(&rand_seed, 12345, &validators);
let leader2 = select_leader(&rand_seed, 12345, &validators);
assert_eq!(leader1, leader2);
assert!(leader1.is_some());
}
#[test]
fn test_different_rounds_produce_different_leaders() {
let validators = make_validators(400);
let rand_seed = vec![0xAA; 32];
let mut leaders = std::collections::HashSet::new();
for round in 1000..1050 {
if let Some(idx) = select_leader(&rand_seed, round, &validators) {
leaders.insert(idx);
}
}
assert!(leaders.len() >= 2);
}
#[test]
fn test_consensus_group_no_duplicates() {
let validators = make_validators(100);
let rand_seed = vec![0xBB; 32];
let group = select_consensus_group(&rand_seed, 5000, &validators, 63);
assert_eq!(group.len(), 63);
let mut seen = std::collections::HashSet::new();
for &idx in &group {
assert!(seen.insert(idx));
}
}
#[test]
fn test_consensus_group_leader_is_first() {
let validators = make_validators(400);
let rand_seed = vec![0xCC; 32];
let round = 9999;
let leader = select_leader(&rand_seed, round, &validators).unwrap();
let group = select_consensus_group(&rand_seed, round, &validators, 63);
assert_eq!(group[0], leader);
}
#[test]
fn test_epoch_shuffle_preserves_count() {
let eligible: Vec<_> = (0..400)
.map(|i| make_validator(i, (i % 256) as u8))
.collect();
let waiting: Vec<_> = (0..80)
.map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
.collect();
let result = epoch_shuffle(&eligible, &waiting, &[0xCC; 32]);
assert_eq!(result.len(), 400);
}
#[test]
fn test_epoch_shuffle_different_randomness() {
let eligible: Vec<_> = (0..400)
.map(|i| make_validator(i, (i % 256) as u8))
.collect();
let waiting: Vec<_> = (0..80)
.map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
.collect();
let result1 = epoch_shuffle(&eligible, &waiting, &[0xAA; 32]);
let result2 = epoch_shuffle(&eligible, &waiting, &[0xBB; 32]);
let keys1: Vec<u32> = result1.iter().map(|v| v.index).collect();
let keys2: Vec<u32> = result2.iter().map(|v| v.index).collect();
assert_ne!(keys1, keys2);
}
#[test]
fn test_consensus_msg_type_from_i64() {
assert_eq!(ConsensusMsgType::from(0), ConsensusMsgType::Unknown);
assert_eq!(
ConsensusMsgType::from(1),
ConsensusMsgType::BlockBodyAndHeader
);
assert_eq!(ConsensusMsgType::from(2), ConsensusMsgType::BlockBody);
assert_eq!(ConsensusMsgType::from(3), ConsensusMsgType::BlockHeader);
assert_eq!(ConsensusMsgType::from(4), ConsensusMsgType::Signature);
assert_eq!(
ConsensusMsgType::from(5),
ConsensusMsgType::BlockHeaderFinalInfo
);
assert_eq!(ConsensusMsgType::from(6), ConsensusMsgType::InvalidSigners);
assert_eq!(
ConsensusMsgType::from(99),
ConsensusMsgType::Unrecognized(99)
);
}
#[test]
fn test_decode_consensus_signal_proposal_block_header() {
let header_bytes = make_header(12345, 12346, 1);
let msg_bytes = make_consensus_msg(3, header_bytes);
let signal = decode_consensus_signal(1, &msg_bytes).unwrap();
let ConsensusSignal::Proposal(proposal) = signal else {
panic!("expected proposal signal");
};
assert_eq!(proposal.shard_id, 1);
assert_eq!(proposal.nonce, 12345);
assert_eq!(proposal.round, 12346);
assert_eq!(proposal.rand_seed.len(), 32);
}
#[test]
fn test_decode_consensus_signal_body_and_header() {
let header_bytes = make_header(500, 501, 2);
let msg_bytes = make_consensus_msg(1, header_bytes);
let signal = decode_consensus_signal(2, &msg_bytes).unwrap();
let ConsensusSignal::Proposal(proposal) = signal else {
panic!("expected proposal signal");
};
assert_eq!(proposal.shard_id, 2);
assert_eq!(proposal.nonce, 500);
assert_eq!(proposal.round, 501);
}
#[test]
fn test_decode_consensus_signal_final_info() {
let msg = ConsensusMsg {
msg_type: 5,
round_index: 777,
block_header_hash: Bytes::from(vec![0x11; 32]),
pub_keys_bitmap: Bytes::from(vec![0x22; 50]),
aggregate_signature: Bytes::from(vec![0x33; 48]),
leader_signature: Bytes::from(vec![0x44; 96]),
..Default::default()
};
let signal = decode_consensus_signal(0, &msg.encode_to_vec()).unwrap();
let ConsensusSignal::FinalInfo(final_info) = signal else {
panic!("expected final-info signal");
};
assert_eq!(final_info.shard_id, 0);
assert_eq!(final_info.round, 777);
assert_eq!(final_info.block_header_hash, vec![0x11; 32]);
assert_eq!(final_info.pub_keys_bitmap, vec![0x22; 50]);
assert_eq!(final_info.aggregate_signature, vec![0x33; 48]);
assert_eq!(final_info.leader_signature, vec![0x44; 96]);
}
#[test]
fn test_decode_consensus_signal_non_header_returns_none() {
let header_bytes = make_header(100, 200, 0);
let msg_bytes = make_consensus_msg(4, header_bytes);
assert!(decode_consensus_signal(0, &msg_bytes).is_none());
}
#[test]
fn test_decode_consensus_signal_unrecognized_returns_none() {
let header_bytes = make_header(100, 200, 0);
let msg_bytes = make_consensus_msg(42, header_bytes);
assert!(decode_consensus_signal(0, &msg_bytes).is_none());
}
#[test]
fn test_decode_consensus_signal_invalid_bytes_returns_none() {
assert!(decode_consensus_signal(0, &[0xFF, 0xFF, 0xFF]).is_none());
}
}