hotmint_types/
validator.rs1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::PublicKey;
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "V{}", self.0)
14 }
15}
16
17impl From<u64> for ValidatorId {
18 fn from(v: u64) -> Self {
19 Self(v)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25 pub id: ValidatorId,
26 pub public_key: PublicKey,
27 pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ValidatorSet {
32 validators: Vec<ValidatorInfo>,
33 total_power: u64,
34 #[serde(skip)]
36 index_map: HashMap<ValidatorId, usize>,
37}
38
39impl ValidatorSet {
40 pub fn new(validators: Vec<ValidatorInfo>) -> Self {
41 let total_power = validators.iter().map(|v| v.power).sum();
42 let index_map = validators
43 .iter()
44 .enumerate()
45 .map(|(i, v)| (v.id, i))
46 .collect();
47 Self {
48 validators,
49 total_power,
50 index_map,
51 }
52 }
53
54 pub fn rebuild_index(&mut self) {
56 self.index_map = self
57 .validators
58 .iter()
59 .enumerate()
60 .map(|(i, v)| (v.id, i))
61 .collect();
62 }
63
64 pub fn validators(&self) -> &[ValidatorInfo] {
65 &self.validators
66 }
67
68 pub fn total_power(&self) -> u64 {
69 self.total_power
70 }
71
72 pub fn quorum_threshold(&self) -> u64 {
74 (self.total_power * 2).div_ceil(3)
75 }
76
77 pub fn max_faulty_power(&self) -> u64 {
79 self.total_power - self.quorum_threshold()
80 }
81
82 pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
84 let idx = (view.as_u64() as usize) % self.validators.len();
85 &self.validators[idx]
86 }
87
88 pub fn validator_count(&self) -> usize {
89 self.validators.len()
90 }
91
92 pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
94 self.index_map.get(&id).copied()
95 }
96
97 pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
99 self.index_map.get(&id).map(|&idx| &self.validators[idx])
100 }
101
102 pub fn power_of(&self, id: ValidatorId) -> u64 {
103 self.get(id).map_or(0, |v| v.power)
104 }
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn make_vs(powers: &[u64]) -> ValidatorSet {
112 let validators: Vec<ValidatorInfo> = powers
113 .iter()
114 .enumerate()
115 .map(|(i, &p)| ValidatorInfo {
116 id: ValidatorId(i as u64),
117 public_key: PublicKey(vec![i as u8]),
118 power: p,
119 })
120 .collect();
121 ValidatorSet::new(validators)
122 }
123
124 #[test]
125 fn test_quorum_4_equal() {
126 let vs = make_vs(&[1, 1, 1, 1]);
127 assert_eq!(vs.total_power(), 4);
128 assert_eq!(vs.quorum_threshold(), 3);
129 assert_eq!(vs.max_faulty_power(), 1);
130 }
131
132 #[test]
133 fn test_quorum_3_equal() {
134 let vs = make_vs(&[1, 1, 1]);
135 assert_eq!(vs.quorum_threshold(), 2);
136 assert_eq!(vs.max_faulty_power(), 1);
137 }
138
139 #[test]
140 fn test_quorum_weighted() {
141 let vs = make_vs(&[10, 10, 10, 70]);
142 assert_eq!(vs.quorum_threshold(), 67);
143 assert_eq!(vs.max_faulty_power(), 33);
144 }
145
146 #[test]
147 fn test_quorum_single_validator() {
148 let vs = make_vs(&[1]);
149 assert_eq!(vs.quorum_threshold(), 1);
150 assert_eq!(vs.max_faulty_power(), 0);
151 }
152
153 #[test]
154 fn test_leader_rotation() {
155 let vs = make_vs(&[1, 1, 1, 1]);
156 assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
157 assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
158 assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
159 assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
160 }
161
162 #[test]
163 fn test_index_of_o1() {
164 let vs = make_vs(&[5, 10, 15]);
165 assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
166 assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
167 assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
168 assert_eq!(vs.index_of(ValidatorId(99)), None);
169 }
170
171 #[test]
172 fn test_get_and_power_of() {
173 let vs = make_vs(&[5, 10, 15]);
174 assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
175 assert!(vs.get(ValidatorId(99)).is_none());
176 assert_eq!(vs.power_of(ValidatorId(2)), 15);
177 assert_eq!(vs.power_of(ValidatorId(99)), 0);
178 }
179
180 #[test]
181 fn test_serialization_roundtrip() {
182 let vs = make_vs(&[1, 2, 3]);
183 let bytes = rmp_serde::to_vec(&vs).unwrap();
184 let mut vs2: ValidatorSet = rmp_serde::from_slice(&bytes).unwrap();
185 vs2.rebuild_index();
186 assert_eq!(vs2.validator_count(), 3);
187 assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
188 assert_eq!(vs2.total_power(), 6);
189 }
190}