Skip to main content

mx_core/
leader.rs

1use blake2::{
2    Blake2b,
3    digest::{Digest, consts::U32},
4};
5use mx_proto::generated::proto::HeaderV3;
6use mx_proto::generated::proto::Message as ConsensusProtoMessage;
7use prost::Message;
8use sha2::Sha256;
9
10type Blake2b256 = Blake2b<U32>;
11
12const MAX_NODES_TO_SWAP_PER_SHARD: usize = 80;
13const NODES_PER_SHARD: usize = 400;
14
15#[derive(Debug, Clone)]
16pub struct EligibleValidator {
17    pub pub_key: Vec<u8>,
18    pub index: u32,
19    pub chances: u32,
20}
21
22pub fn select_leader(
23    rand_seed: &[u8],
24    round: u64,
25    eligible_list: &[EligibleValidator],
26) -> Option<usize> {
27    if eligible_list.is_empty() {
28        return None;
29    }
30
31    let expanded_list = build_expanded_list(eligible_list);
32    if expanded_list.is_empty() {
33        return None;
34    }
35
36    let randomness = build_round_randomness(round, rand_seed);
37    let random_u64 = compute_randomness_as_u64(&randomness, 0);
38    let index = random_u64 % (expanded_list.len() as u64);
39
40    Some(expanded_list[index as usize] as usize)
41}
42
43pub fn select_consensus_group(
44    rand_seed: &[u8],
45    round: u64,
46    eligible_list: &[EligibleValidator],
47    size: usize,
48) -> Vec<usize> {
49    if eligible_list.is_empty() || size == 0 {
50        return Vec::new();
51    }
52
53    let expanded_list = build_expanded_list(eligible_list);
54    let len_expanded = expanded_list.len() as i64;
55    if size as i64 > len_expanded {
56        return Vec::new();
57    }
58
59    let randomness = build_round_randomness(round, rand_seed);
60
61    let mut selected = Vec::with_capacity(size);
62    let mut sorted_entries: Vec<(i64, i64)> = Vec::new();
63    let mut total_selected: i64 = 0;
64
65    for i in 0..size {
66        let random_u64 = compute_randomness_as_u64(&randomness, i);
67        let mut index = random_u64 % ((len_expanded - total_selected) as u64);
68
69        index = adjust_index(index, &sorted_entries);
70
71        let validator_idx = expanded_list[index as usize];
72        selected.push(validator_idx as usize);
73
74        let (start_idx, num_appearances) =
75            compute_start_and_appearances(&expanded_list, index as i64);
76        insert_sorted(&mut sorted_entries, start_idx, num_appearances);
77        total_selected += num_appearances;
78    }
79
80    selected
81}
82
83pub fn epoch_shuffle(
84    eligible: &[EligibleValidator],
85    waiting: &[EligibleValidator],
86    shuffle_randomness: &[u8],
87) -> Vec<EligibleValidator> {
88    let num_to_remove = (eligible.len() + waiting.len()).saturating_sub(NODES_PER_SHARD);
89    let actual_to_remove = num_to_remove.min(MAX_NODES_TO_SWAP_PER_SHARD);
90
91    let shuffled = shuffle_list(eligible, shuffle_randomness);
92    let remaining = &shuffled[actual_to_remove..];
93
94    let num_needed = NODES_PER_SHARD
95        .saturating_sub(remaining.len())
96        .min(waiting.len());
97
98    let mut result = Vec::with_capacity(NODES_PER_SHARD);
99    result.extend_from_slice(remaining);
100    result.extend_from_slice(&waiting[..num_needed]);
101
102    result
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum ConsensusMsgType {
107    Unknown,
108    BlockBodyAndHeader,
109    BlockBody,
110    BlockHeader,
111    Signature,
112    BlockHeaderFinalInfo,
113    InvalidSigners,
114    Unrecognized(i64),
115}
116
117impl From<i64> for ConsensusMsgType {
118    fn from(v: i64) -> Self {
119        match v {
120            0 => Self::Unknown,
121            1 => Self::BlockBodyAndHeader,
122            2 => Self::BlockBody,
123            3 => Self::BlockHeader,
124            4 => Self::Signature,
125            5 => Self::BlockHeaderFinalInfo,
126            6 => Self::InvalidSigners,
127            other => Self::Unrecognized(other),
128        }
129    }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct ProposalSignal {
134    pub shard_id: u32,
135    pub nonce: u64,
136    pub round: u64,
137    pub epoch: u32,
138    pub rand_seed: Vec<u8>,
139    pub prev_rand_seed: Vec<u8>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub struct FinalInfoSignal {
144    pub shard_id: u32,
145    pub round: i64,
146    pub block_header_hash: Vec<u8>,
147    pub pub_keys_bitmap: Vec<u8>,
148    pub aggregate_signature: Vec<u8>,
149    pub leader_signature: Vec<u8>,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
153pub enum ConsensusSignal {
154    Proposal(ProposalSignal),
155    FinalInfo(FinalInfoSignal),
156}
157
158pub fn decode_consensus_signal(shard_id: u32, bytes: &[u8]) -> Option<ConsensusSignal> {
159    let message = ConsensusProtoMessage::decode(bytes).ok()?;
160
161    match ConsensusMsgType::from(message.msg_type) {
162        ConsensusMsgType::BlockHeader | ConsensusMsgType::BlockBodyAndHeader => {
163            if message.header.is_empty() {
164                return None;
165            }
166
167            let header = HeaderV3::decode(message.header.as_ref()).ok()?;
168            Some(ConsensusSignal::Proposal(ProposalSignal {
169                shard_id,
170                nonce: header.nonce,
171                round: header.round,
172                epoch: header.epoch,
173                rand_seed: header.rand_seed.to_vec(),
174                prev_rand_seed: header.prev_rand_seed.to_vec(),
175            }))
176        }
177        ConsensusMsgType::BlockHeaderFinalInfo => {
178            Some(ConsensusSignal::FinalInfo(FinalInfoSignal {
179                shard_id,
180                round: message.round_index,
181                block_header_hash: message.block_header_hash.to_vec(),
182                pub_keys_bitmap: message.pub_keys_bitmap.to_vec(),
183                aggregate_signature: message.aggregate_signature.to_vec(),
184                leader_signature: message.leader_signature.to_vec(),
185            }))
186        }
187        ConsensusMsgType::Unknown
188        | ConsensusMsgType::BlockBody
189        | ConsensusMsgType::Signature
190        | ConsensusMsgType::InvalidSigners
191        | ConsensusMsgType::Unrecognized(_) => None,
192    }
193}
194
195fn build_round_randomness(round: u64, rand_seed: &[u8]) -> Vec<u8> {
196    let round_str = round.to_string();
197    let mut randomness = Vec::with_capacity(round_str.len() + 1 + rand_seed.len());
198    randomness.extend_from_slice(round_str.as_bytes());
199    randomness.push(b'-');
200    randomness.extend_from_slice(rand_seed);
201    randomness
202}
203
204fn build_expanded_list(eligible_list: &[EligibleValidator]) -> Vec<u32> {
205    let total: usize = eligible_list
206        .iter()
207        .map(|v| v.chances.max(1) as usize)
208        .sum();
209    let mut expanded = Vec::with_capacity(total);
210    for (i, v) in eligible_list.iter().enumerate() {
211        let chances = v.chances.max(1);
212        for _ in 0..chances {
213            expanded.push(i as u32);
214        }
215    }
216    expanded
217}
218
219fn compute_randomness_as_u64(randomness: &[u8], index: usize) -> u64 {
220    let index_bytes = (index as u64).to_be_bytes();
221    let mut hasher = Blake2b256::new();
222    hasher.update(index_bytes);
223    hasher.update(randomness);
224    let hash = hasher.finalize();
225
226    u64::from_be_bytes(hash[0..8].try_into().unwrap())
227}
228
229fn adjust_index(mut index: u64, sorted_entries: &[(i64, i64)]) -> u64 {
230    for &(start_index, num_appearances) in sorted_entries {
231        if (start_index as u64) > index {
232            break;
233        }
234        index += num_appearances as u64;
235    }
236    index
237}
238
239fn compute_start_and_appearances(expanded_list: &[u32], idx: i64) -> (i64, i64) {
240    let val = expanded_list[idx as usize];
241    let list_len = expanded_list.len() as i64;
242
243    let mut start_idx: i64 = 0;
244    for i in (0..idx).rev() {
245        if expanded_list[i as usize] != val {
246            start_idx = i + 1;
247            break;
248        }
249    }
250
251    let mut end_idx = list_len - 1;
252    for i in (idx + 1)..list_len {
253        if expanded_list[i as usize] != val {
254            end_idx = i - 1;
255            break;
256        }
257    }
258
259    (start_idx, end_idx - start_idx + 1)
260}
261
262fn insert_sorted(sorted_entries: &mut Vec<(i64, i64)>, start_index: i64, num_appearances: i64) {
263    let pos = sorted_entries
264        .iter()
265        .position(|&(si, _)| si >= start_index)
266        .unwrap_or(sorted_entries.len());
267    sorted_entries.insert(pos, (start_index, num_appearances));
268}
269
270fn shuffle_list(validators: &[EligibleValidator], randomness: &[u8]) -> Vec<EligibleValidator> {
271    let mut keyed: Vec<([u8; 32], EligibleValidator)> = validators
272        .iter()
273        .map(|v| {
274            let mut hasher = Sha256::new();
275            hasher.update(&v.pub_key);
276            hasher.update(randomness);
277            let hash: [u8; 32] = hasher.finalize().into();
278            (hash, v.clone())
279        })
280        .collect();
281
282    keyed.sort_by_key(|a| a.0);
283    keyed.into_iter().map(|(_, v)| v).collect()
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use mx_proto::generated::proto::Message as ConsensusMsg;
290    use prost::bytes::Bytes;
291
292    fn make_validators(count: usize) -> Vec<EligibleValidator> {
293        (0..count)
294            .map(|i| EligibleValidator {
295                pub_key: vec![i as u8; 96],
296                index: i as u32,
297                chances: 1,
298            })
299            .collect()
300    }
301
302    fn make_validator(idx: u32, pk_byte: u8) -> EligibleValidator {
303        EligibleValidator {
304            pub_key: vec![pk_byte; 96],
305            index: idx,
306            chances: 24,
307        }
308    }
309
310    fn make_header(nonce: u64, round: u64, shard_id: u32) -> Vec<u8> {
311        let header = HeaderV3 {
312            nonce,
313            round,
314            shard_id,
315            epoch: 100,
316            rand_seed: Bytes::from(vec![0xAA; 32]),
317            prev_rand_seed: Bytes::from(vec![0xDD; 32]),
318            leader_signature: Bytes::from(vec![0xBB; 96]),
319            ..Default::default()
320        };
321        header.encode_to_vec()
322    }
323
324    fn make_consensus_msg(msg_type: i64, header_bytes: Vec<u8>) -> Vec<u8> {
325        let msg = ConsensusMsg {
326            msg_type,
327            header: Bytes::from(header_bytes),
328            pub_key: Bytes::from(vec![0xCC; 96]),
329            round_index: 12346,
330            ..Default::default()
331        };
332        msg.encode_to_vec()
333    }
334
335    #[test]
336    fn test_select_leader_deterministic() {
337        let validators = make_validators(400);
338        let rand_seed = vec![0xAA; 32];
339
340        let leader1 = select_leader(&rand_seed, 12345, &validators);
341        let leader2 = select_leader(&rand_seed, 12345, &validators);
342
343        assert_eq!(leader1, leader2);
344        assert!(leader1.is_some());
345    }
346
347    #[test]
348    fn test_different_rounds_produce_different_leaders() {
349        let validators = make_validators(400);
350        let rand_seed = vec![0xAA; 32];
351
352        let mut leaders = std::collections::HashSet::new();
353        for round in 1000..1050 {
354            if let Some(idx) = select_leader(&rand_seed, round, &validators) {
355                leaders.insert(idx);
356            }
357        }
358
359        assert!(leaders.len() >= 2);
360    }
361
362    #[test]
363    fn test_consensus_group_no_duplicates() {
364        let validators = make_validators(100);
365        let rand_seed = vec![0xBB; 32];
366
367        let group = select_consensus_group(&rand_seed, 5000, &validators, 63);
368        assert_eq!(group.len(), 63);
369
370        let mut seen = std::collections::HashSet::new();
371        for &idx in &group {
372            assert!(seen.insert(idx));
373        }
374    }
375
376    #[test]
377    fn test_consensus_group_leader_is_first() {
378        let validators = make_validators(400);
379        let rand_seed = vec![0xCC; 32];
380        let round = 9999;
381
382        let leader = select_leader(&rand_seed, round, &validators).unwrap();
383        let group = select_consensus_group(&rand_seed, round, &validators, 63);
384
385        assert_eq!(group[0], leader);
386    }
387
388    #[test]
389    fn test_epoch_shuffle_preserves_count() {
390        let eligible: Vec<_> = (0..400)
391            .map(|i| make_validator(i, (i % 256) as u8))
392            .collect();
393        let waiting: Vec<_> = (0..80)
394            .map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
395            .collect();
396
397        let result = epoch_shuffle(&eligible, &waiting, &[0xCC; 32]);
398        assert_eq!(result.len(), 400);
399    }
400
401    #[test]
402    fn test_epoch_shuffle_different_randomness() {
403        let eligible: Vec<_> = (0..400)
404            .map(|i| make_validator(i, (i % 256) as u8))
405            .collect();
406        let waiting: Vec<_> = (0..80)
407            .map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
408            .collect();
409
410        let result1 = epoch_shuffle(&eligible, &waiting, &[0xAA; 32]);
411        let result2 = epoch_shuffle(&eligible, &waiting, &[0xBB; 32]);
412
413        let keys1: Vec<u32> = result1.iter().map(|v| v.index).collect();
414        let keys2: Vec<u32> = result2.iter().map(|v| v.index).collect();
415        assert_ne!(keys1, keys2);
416    }
417
418    #[test]
419    fn test_consensus_msg_type_from_i64() {
420        assert_eq!(ConsensusMsgType::from(0), ConsensusMsgType::Unknown);
421        assert_eq!(
422            ConsensusMsgType::from(1),
423            ConsensusMsgType::BlockBodyAndHeader
424        );
425        assert_eq!(ConsensusMsgType::from(2), ConsensusMsgType::BlockBody);
426        assert_eq!(ConsensusMsgType::from(3), ConsensusMsgType::BlockHeader);
427        assert_eq!(ConsensusMsgType::from(4), ConsensusMsgType::Signature);
428        assert_eq!(
429            ConsensusMsgType::from(5),
430            ConsensusMsgType::BlockHeaderFinalInfo
431        );
432        assert_eq!(ConsensusMsgType::from(6), ConsensusMsgType::InvalidSigners);
433        assert_eq!(
434            ConsensusMsgType::from(99),
435            ConsensusMsgType::Unrecognized(99)
436        );
437    }
438
439    #[test]
440    fn test_decode_consensus_signal_proposal_block_header() {
441        let header_bytes = make_header(12345, 12346, 1);
442        let msg_bytes = make_consensus_msg(3, header_bytes);
443
444        let signal = decode_consensus_signal(1, &msg_bytes).unwrap();
445        let ConsensusSignal::Proposal(proposal) = signal else {
446            panic!("expected proposal signal");
447        };
448
449        assert_eq!(proposal.shard_id, 1);
450        assert_eq!(proposal.nonce, 12345);
451        assert_eq!(proposal.round, 12346);
452        assert_eq!(proposal.rand_seed.len(), 32);
453    }
454
455    #[test]
456    fn test_decode_consensus_signal_body_and_header() {
457        let header_bytes = make_header(500, 501, 2);
458        let msg_bytes = make_consensus_msg(1, header_bytes);
459
460        let signal = decode_consensus_signal(2, &msg_bytes).unwrap();
461        let ConsensusSignal::Proposal(proposal) = signal else {
462            panic!("expected proposal signal");
463        };
464
465        assert_eq!(proposal.shard_id, 2);
466        assert_eq!(proposal.nonce, 500);
467        assert_eq!(proposal.round, 501);
468    }
469
470    #[test]
471    fn test_decode_consensus_signal_final_info() {
472        let msg = ConsensusMsg {
473            msg_type: 5,
474            round_index: 777,
475            block_header_hash: Bytes::from(vec![0x11; 32]),
476            pub_keys_bitmap: Bytes::from(vec![0x22; 50]),
477            aggregate_signature: Bytes::from(vec![0x33; 48]),
478            leader_signature: Bytes::from(vec![0x44; 96]),
479            ..Default::default()
480        };
481
482        let signal = decode_consensus_signal(0, &msg.encode_to_vec()).unwrap();
483        let ConsensusSignal::FinalInfo(final_info) = signal else {
484            panic!("expected final-info signal");
485        };
486
487        assert_eq!(final_info.shard_id, 0);
488        assert_eq!(final_info.round, 777);
489        assert_eq!(final_info.block_header_hash, vec![0x11; 32]);
490        assert_eq!(final_info.pub_keys_bitmap, vec![0x22; 50]);
491        assert_eq!(final_info.aggregate_signature, vec![0x33; 48]);
492        assert_eq!(final_info.leader_signature, vec![0x44; 96]);
493    }
494
495    #[test]
496    fn test_decode_consensus_signal_non_header_returns_none() {
497        let header_bytes = make_header(100, 200, 0);
498        let msg_bytes = make_consensus_msg(4, header_bytes);
499
500        assert!(decode_consensus_signal(0, &msg_bytes).is_none());
501    }
502
503    #[test]
504    fn test_decode_consensus_signal_unrecognized_returns_none() {
505        let header_bytes = make_header(100, 200, 0);
506        let msg_bytes = make_consensus_msg(42, header_bytes);
507
508        assert!(decode_consensus_signal(0, &msg_bytes).is_none());
509    }
510
511    #[test]
512    fn test_decode_consensus_signal_invalid_bytes_returns_none() {
513        assert!(decode_consensus_signal(0, &[0xFF, 0xFF, 0xFF]).is_none());
514    }
515}