1use blake2::{
2 Blake2b,
3 digest::{Digest, consts::U32},
4};
5use mx_proto::generated::proto::HeaderV3;
6use mx_proto::generated::proto::Message as ConsensusProtoMessage;
7use prost::Message;
8use sha2::Sha256;
9
10type Blake2b256 = Blake2b<U32>;
11
12const MAX_NODES_TO_SWAP_PER_SHARD: usize = 80;
13const NODES_PER_SHARD: usize = 400;
14
15#[derive(Debug, Clone)]
16pub struct EligibleValidator {
17 pub pub_key: Vec<u8>,
18 pub index: u32,
19 pub chances: u32,
20}
21
22pub fn select_leader(
23 rand_seed: &[u8],
24 round: u64,
25 eligible_list: &[EligibleValidator],
26) -> Option<usize> {
27 if eligible_list.is_empty() {
28 return None;
29 }
30
31 let expanded_list = build_expanded_list(eligible_list);
32 if expanded_list.is_empty() {
33 return None;
34 }
35
36 let randomness = build_round_randomness(round, rand_seed);
37 let random_u64 = compute_randomness_as_u64(&randomness, 0);
38 let index = random_u64 % (expanded_list.len() as u64);
39
40 Some(expanded_list[index as usize] as usize)
41}
42
43pub fn select_consensus_group(
44 rand_seed: &[u8],
45 round: u64,
46 eligible_list: &[EligibleValidator],
47 size: usize,
48) -> Vec<usize> {
49 if eligible_list.is_empty() || size == 0 {
50 return Vec::new();
51 }
52
53 let expanded_list = build_expanded_list(eligible_list);
54 let len_expanded = expanded_list.len() as i64;
55 if size as i64 > len_expanded {
56 return Vec::new();
57 }
58
59 let randomness = build_round_randomness(round, rand_seed);
60
61 let mut selected = Vec::with_capacity(size);
62 let mut sorted_entries: Vec<(i64, i64)> = Vec::new();
63 let mut total_selected: i64 = 0;
64
65 for i in 0..size {
66 let random_u64 = compute_randomness_as_u64(&randomness, i);
67 let mut index = random_u64 % ((len_expanded - total_selected) as u64);
68
69 index = adjust_index(index, &sorted_entries);
70
71 let validator_idx = expanded_list[index as usize];
72 selected.push(validator_idx as usize);
73
74 let (start_idx, num_appearances) =
75 compute_start_and_appearances(&expanded_list, index as i64);
76 insert_sorted(&mut sorted_entries, start_idx, num_appearances);
77 total_selected += num_appearances;
78 }
79
80 selected
81}
82
83pub fn epoch_shuffle(
84 eligible: &[EligibleValidator],
85 waiting: &[EligibleValidator],
86 shuffle_randomness: &[u8],
87) -> Vec<EligibleValidator> {
88 let num_to_remove = (eligible.len() + waiting.len()).saturating_sub(NODES_PER_SHARD);
89 let actual_to_remove = num_to_remove.min(MAX_NODES_TO_SWAP_PER_SHARD);
90
91 let shuffled = shuffle_list(eligible, shuffle_randomness);
92 let remaining = &shuffled[actual_to_remove..];
93
94 let num_needed = NODES_PER_SHARD
95 .saturating_sub(remaining.len())
96 .min(waiting.len());
97
98 let mut result = Vec::with_capacity(NODES_PER_SHARD);
99 result.extend_from_slice(remaining);
100 result.extend_from_slice(&waiting[..num_needed]);
101
102 result
103}
104
105#[derive(Debug, Clone, Copy, PartialEq, Eq)]
106pub enum ConsensusMsgType {
107 Unknown,
108 BlockBodyAndHeader,
109 BlockBody,
110 BlockHeader,
111 Signature,
112 BlockHeaderFinalInfo,
113 InvalidSigners,
114 Unrecognized(i64),
115}
116
117impl From<i64> for ConsensusMsgType {
118 fn from(v: i64) -> Self {
119 match v {
120 0 => Self::Unknown,
121 1 => Self::BlockBodyAndHeader,
122 2 => Self::BlockBody,
123 3 => Self::BlockHeader,
124 4 => Self::Signature,
125 5 => Self::BlockHeaderFinalInfo,
126 6 => Self::InvalidSigners,
127 other => Self::Unrecognized(other),
128 }
129 }
130}
131
132#[derive(Debug, Clone, PartialEq, Eq)]
133pub struct ProposalSignal {
134 pub shard_id: u32,
135 pub nonce: u64,
136 pub round: u64,
137 pub epoch: u32,
138 pub rand_seed: Vec<u8>,
139 pub prev_rand_seed: Vec<u8>,
140}
141
142#[derive(Debug, Clone, PartialEq, Eq)]
143pub struct FinalInfoSignal {
144 pub shard_id: u32,
145 pub round: i64,
146 pub block_header_hash: Vec<u8>,
147 pub pub_keys_bitmap: Vec<u8>,
148 pub aggregate_signature: Vec<u8>,
149 pub leader_signature: Vec<u8>,
150}
151
152#[derive(Debug, Clone, PartialEq, Eq)]
153pub enum ConsensusSignal {
154 Proposal(ProposalSignal),
155 FinalInfo(FinalInfoSignal),
156}
157
158pub fn decode_consensus_signal(shard_id: u32, bytes: &[u8]) -> Option<ConsensusSignal> {
159 let message = ConsensusProtoMessage::decode(bytes).ok()?;
160
161 match ConsensusMsgType::from(message.msg_type) {
162 ConsensusMsgType::BlockHeader | ConsensusMsgType::BlockBodyAndHeader => {
163 if message.header.is_empty() {
164 return None;
165 }
166
167 let header = HeaderV3::decode(message.header.as_ref()).ok()?;
168 Some(ConsensusSignal::Proposal(ProposalSignal {
169 shard_id,
170 nonce: header.nonce,
171 round: header.round,
172 epoch: header.epoch,
173 rand_seed: header.rand_seed.to_vec(),
174 prev_rand_seed: header.prev_rand_seed.to_vec(),
175 }))
176 }
177 ConsensusMsgType::BlockHeaderFinalInfo => {
178 Some(ConsensusSignal::FinalInfo(FinalInfoSignal {
179 shard_id,
180 round: message.round_index,
181 block_header_hash: message.block_header_hash.to_vec(),
182 pub_keys_bitmap: message.pub_keys_bitmap.to_vec(),
183 aggregate_signature: message.aggregate_signature.to_vec(),
184 leader_signature: message.leader_signature.to_vec(),
185 }))
186 }
187 ConsensusMsgType::Unknown
188 | ConsensusMsgType::BlockBody
189 | ConsensusMsgType::Signature
190 | ConsensusMsgType::InvalidSigners
191 | ConsensusMsgType::Unrecognized(_) => None,
192 }
193}
194
195fn build_round_randomness(round: u64, rand_seed: &[u8]) -> Vec<u8> {
196 let round_str = round.to_string();
197 let mut randomness = Vec::with_capacity(round_str.len() + 1 + rand_seed.len());
198 randomness.extend_from_slice(round_str.as_bytes());
199 randomness.push(b'-');
200 randomness.extend_from_slice(rand_seed);
201 randomness
202}
203
204fn build_expanded_list(eligible_list: &[EligibleValidator]) -> Vec<u32> {
205 let total: usize = eligible_list
206 .iter()
207 .map(|v| v.chances.max(1) as usize)
208 .sum();
209 let mut expanded = Vec::with_capacity(total);
210 for (i, v) in eligible_list.iter().enumerate() {
211 let chances = v.chances.max(1);
212 for _ in 0..chances {
213 expanded.push(i as u32);
214 }
215 }
216 expanded
217}
218
219fn compute_randomness_as_u64(randomness: &[u8], index: usize) -> u64 {
220 let index_bytes = (index as u64).to_be_bytes();
221 let mut hasher = Blake2b256::new();
222 hasher.update(index_bytes);
223 hasher.update(randomness);
224 let hash = hasher.finalize();
225
226 u64::from_be_bytes(hash[0..8].try_into().unwrap())
227}
228
229fn adjust_index(mut index: u64, sorted_entries: &[(i64, i64)]) -> u64 {
230 for &(start_index, num_appearances) in sorted_entries {
231 if (start_index as u64) > index {
232 break;
233 }
234 index += num_appearances as u64;
235 }
236 index
237}
238
239fn compute_start_and_appearances(expanded_list: &[u32], idx: i64) -> (i64, i64) {
240 let val = expanded_list[idx as usize];
241 let list_len = expanded_list.len() as i64;
242
243 let mut start_idx: i64 = 0;
244 for i in (0..idx).rev() {
245 if expanded_list[i as usize] != val {
246 start_idx = i + 1;
247 break;
248 }
249 }
250
251 let mut end_idx = list_len - 1;
252 for i in (idx + 1)..list_len {
253 if expanded_list[i as usize] != val {
254 end_idx = i - 1;
255 break;
256 }
257 }
258
259 (start_idx, end_idx - start_idx + 1)
260}
261
262fn insert_sorted(sorted_entries: &mut Vec<(i64, i64)>, start_index: i64, num_appearances: i64) {
263 let pos = sorted_entries
264 .iter()
265 .position(|&(si, _)| si >= start_index)
266 .unwrap_or(sorted_entries.len());
267 sorted_entries.insert(pos, (start_index, num_appearances));
268}
269
270fn shuffle_list(validators: &[EligibleValidator], randomness: &[u8]) -> Vec<EligibleValidator> {
271 let mut keyed: Vec<([u8; 32], EligibleValidator)> = validators
272 .iter()
273 .map(|v| {
274 let mut hasher = Sha256::new();
275 hasher.update(&v.pub_key);
276 hasher.update(randomness);
277 let hash: [u8; 32] = hasher.finalize().into();
278 (hash, v.clone())
279 })
280 .collect();
281
282 keyed.sort_by_key(|a| a.0);
283 keyed.into_iter().map(|(_, v)| v).collect()
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289 use mx_proto::generated::proto::Message as ConsensusMsg;
290 use prost::bytes::Bytes;
291
292 fn make_validators(count: usize) -> Vec<EligibleValidator> {
293 (0..count)
294 .map(|i| EligibleValidator {
295 pub_key: vec![i as u8; 96],
296 index: i as u32,
297 chances: 1,
298 })
299 .collect()
300 }
301
302 fn make_validator(idx: u32, pk_byte: u8) -> EligibleValidator {
303 EligibleValidator {
304 pub_key: vec![pk_byte; 96],
305 index: idx,
306 chances: 24,
307 }
308 }
309
310 fn make_header(nonce: u64, round: u64, shard_id: u32) -> Vec<u8> {
311 let header = HeaderV3 {
312 nonce,
313 round,
314 shard_id,
315 epoch: 100,
316 rand_seed: Bytes::from(vec![0xAA; 32]),
317 prev_rand_seed: Bytes::from(vec![0xDD; 32]),
318 leader_signature: Bytes::from(vec![0xBB; 96]),
319 ..Default::default()
320 };
321 header.encode_to_vec()
322 }
323
324 fn make_consensus_msg(msg_type: i64, header_bytes: Vec<u8>) -> Vec<u8> {
325 let msg = ConsensusMsg {
326 msg_type,
327 header: Bytes::from(header_bytes),
328 pub_key: Bytes::from(vec![0xCC; 96]),
329 round_index: 12346,
330 ..Default::default()
331 };
332 msg.encode_to_vec()
333 }
334
335 #[test]
336 fn test_select_leader_deterministic() {
337 let validators = make_validators(400);
338 let rand_seed = vec![0xAA; 32];
339
340 let leader1 = select_leader(&rand_seed, 12345, &validators);
341 let leader2 = select_leader(&rand_seed, 12345, &validators);
342
343 assert_eq!(leader1, leader2);
344 assert!(leader1.is_some());
345 }
346
347 #[test]
348 fn test_different_rounds_produce_different_leaders() {
349 let validators = make_validators(400);
350 let rand_seed = vec![0xAA; 32];
351
352 let mut leaders = std::collections::HashSet::new();
353 for round in 1000..1050 {
354 if let Some(idx) = select_leader(&rand_seed, round, &validators) {
355 leaders.insert(idx);
356 }
357 }
358
359 assert!(leaders.len() >= 2);
360 }
361
362 #[test]
363 fn test_consensus_group_no_duplicates() {
364 let validators = make_validators(100);
365 let rand_seed = vec![0xBB; 32];
366
367 let group = select_consensus_group(&rand_seed, 5000, &validators, 63);
368 assert_eq!(group.len(), 63);
369
370 let mut seen = std::collections::HashSet::new();
371 for &idx in &group {
372 assert!(seen.insert(idx));
373 }
374 }
375
376 #[test]
377 fn test_consensus_group_leader_is_first() {
378 let validators = make_validators(400);
379 let rand_seed = vec![0xCC; 32];
380 let round = 9999;
381
382 let leader = select_leader(&rand_seed, round, &validators).unwrap();
383 let group = select_consensus_group(&rand_seed, round, &validators, 63);
384
385 assert_eq!(group[0], leader);
386 }
387
388 #[test]
389 fn test_epoch_shuffle_preserves_count() {
390 let eligible: Vec<_> = (0..400)
391 .map(|i| make_validator(i, (i % 256) as u8))
392 .collect();
393 let waiting: Vec<_> = (0..80)
394 .map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
395 .collect();
396
397 let result = epoch_shuffle(&eligible, &waiting, &[0xCC; 32]);
398 assert_eq!(result.len(), 400);
399 }
400
401 #[test]
402 fn test_epoch_shuffle_different_randomness() {
403 let eligible: Vec<_> = (0..400)
404 .map(|i| make_validator(i, (i % 256) as u8))
405 .collect();
406 let waiting: Vec<_> = (0..80)
407 .map(|i| make_validator(400 + i, ((400 + i) % 256) as u8))
408 .collect();
409
410 let result1 = epoch_shuffle(&eligible, &waiting, &[0xAA; 32]);
411 let result2 = epoch_shuffle(&eligible, &waiting, &[0xBB; 32]);
412
413 let keys1: Vec<u32> = result1.iter().map(|v| v.index).collect();
414 let keys2: Vec<u32> = result2.iter().map(|v| v.index).collect();
415 assert_ne!(keys1, keys2);
416 }
417
418 #[test]
419 fn test_consensus_msg_type_from_i64() {
420 assert_eq!(ConsensusMsgType::from(0), ConsensusMsgType::Unknown);
421 assert_eq!(
422 ConsensusMsgType::from(1),
423 ConsensusMsgType::BlockBodyAndHeader
424 );
425 assert_eq!(ConsensusMsgType::from(2), ConsensusMsgType::BlockBody);
426 assert_eq!(ConsensusMsgType::from(3), ConsensusMsgType::BlockHeader);
427 assert_eq!(ConsensusMsgType::from(4), ConsensusMsgType::Signature);
428 assert_eq!(
429 ConsensusMsgType::from(5),
430 ConsensusMsgType::BlockHeaderFinalInfo
431 );
432 assert_eq!(ConsensusMsgType::from(6), ConsensusMsgType::InvalidSigners);
433 assert_eq!(
434 ConsensusMsgType::from(99),
435 ConsensusMsgType::Unrecognized(99)
436 );
437 }
438
439 #[test]
440 fn test_decode_consensus_signal_proposal_block_header() {
441 let header_bytes = make_header(12345, 12346, 1);
442 let msg_bytes = make_consensus_msg(3, header_bytes);
443
444 let signal = decode_consensus_signal(1, &msg_bytes).unwrap();
445 let ConsensusSignal::Proposal(proposal) = signal else {
446 panic!("expected proposal signal");
447 };
448
449 assert_eq!(proposal.shard_id, 1);
450 assert_eq!(proposal.nonce, 12345);
451 assert_eq!(proposal.round, 12346);
452 assert_eq!(proposal.rand_seed.len(), 32);
453 }
454
455 #[test]
456 fn test_decode_consensus_signal_body_and_header() {
457 let header_bytes = make_header(500, 501, 2);
458 let msg_bytes = make_consensus_msg(1, header_bytes);
459
460 let signal = decode_consensus_signal(2, &msg_bytes).unwrap();
461 let ConsensusSignal::Proposal(proposal) = signal else {
462 panic!("expected proposal signal");
463 };
464
465 assert_eq!(proposal.shard_id, 2);
466 assert_eq!(proposal.nonce, 500);
467 assert_eq!(proposal.round, 501);
468 }
469
470 #[test]
471 fn test_decode_consensus_signal_final_info() {
472 let msg = ConsensusMsg {
473 msg_type: 5,
474 round_index: 777,
475 block_header_hash: Bytes::from(vec![0x11; 32]),
476 pub_keys_bitmap: Bytes::from(vec![0x22; 50]),
477 aggregate_signature: Bytes::from(vec![0x33; 48]),
478 leader_signature: Bytes::from(vec![0x44; 96]),
479 ..Default::default()
480 };
481
482 let signal = decode_consensus_signal(0, &msg.encode_to_vec()).unwrap();
483 let ConsensusSignal::FinalInfo(final_info) = signal else {
484 panic!("expected final-info signal");
485 };
486
487 assert_eq!(final_info.shard_id, 0);
488 assert_eq!(final_info.round, 777);
489 assert_eq!(final_info.block_header_hash, vec![0x11; 32]);
490 assert_eq!(final_info.pub_keys_bitmap, vec![0x22; 50]);
491 assert_eq!(final_info.aggregate_signature, vec![0x33; 48]);
492 assert_eq!(final_info.leader_signature, vec![0x44; 96]);
493 }
494
495 #[test]
496 fn test_decode_consensus_signal_non_header_returns_none() {
497 let header_bytes = make_header(100, 200, 0);
498 let msg_bytes = make_consensus_msg(4, header_bytes);
499
500 assert!(decode_consensus_signal(0, &msg_bytes).is_none());
501 }
502
503 #[test]
504 fn test_decode_consensus_signal_unrecognized_returns_none() {
505 let header_bytes = make_header(100, 200, 0);
506 let msg_bytes = make_consensus_msg(42, header_bytes);
507
508 assert!(decode_consensus_signal(0, &msg_bytes).is_none());
509 }
510
511 #[test]
512 fn test_decode_consensus_signal_invalid_bytes_returns_none() {
513 assert!(decode_consensus_signal(0, &[0xFF, 0xFF, 0xFF]).is_none());
514 }
515}