1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "V{}", self.0)
14 }
15}
16
17impl From<u64> for ValidatorId {
18 fn from(v: u64) -> Self {
19 Self(v)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25 pub id: ValidatorId,
26 pub public_key: PublicKey,
27 pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32 validators: Vec<ValidatorInfo>,
33 total_power: u64,
34 #[serde(skip)]
36 index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 #[derive(Deserialize)]
45 struct Raw {
46 validators: Vec<ValidatorInfo>,
47 total_power: u64,
48 }
49 let raw = Raw::deserialize(deserializer)?;
50 let index_map = raw
51 .validators
52 .iter()
53 .enumerate()
54 .map(|(i, v)| (v.id, i))
55 .collect();
56 Ok(ValidatorSet {
57 validators: raw.validators,
58 total_power: raw.total_power,
59 index_map,
60 })
61 }
62}
63
64impl ValidatorSet {
65 pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66 let total_power = validators.iter().map(|v| v.power).sum();
67 let index_map = validators
68 .iter()
69 .enumerate()
70 .map(|(i, v)| (v.id, i))
71 .collect();
72 Self {
73 validators,
74 total_power,
75 index_map,
76 }
77 }
78
79 pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81 let validators: Vec<ValidatorInfo> = signers
82 .iter()
83 .map(|s| ValidatorInfo {
84 id: s.validator_id(),
85 public_key: s.public_key(),
86 power: 1,
87 })
88 .collect();
89 Self::new(validators)
90 }
91
92 pub fn rebuild_index(&mut self) {
98 self.index_map = self
99 .validators
100 .iter()
101 .enumerate()
102 .map(|(i, v)| (v.id, i))
103 .collect();
104 }
105
106 pub fn validators(&self) -> &[ValidatorInfo] {
107 &self.validators
108 }
109
110 pub fn total_power(&self) -> u64 {
111 self.total_power
112 }
113
114 pub fn quorum_threshold(&self) -> u64 {
116 self.total_power
117 .checked_mul(2)
118 .expect("total_power overflow in quorum_threshold")
119 .div_ceil(3)
120 }
121
122 pub fn max_faulty_power(&self) -> u64 {
124 self.total_power - self.quorum_threshold()
125 }
126
127 pub fn leader_for_view(&self, view: ViewNumber) -> Option<&ValidatorInfo> {
130 if self.validators.is_empty() {
131 return None;
132 }
133 let idx = (view.as_u64() as usize) % self.validators.len();
134 Some(&self.validators[idx])
135 }
136
137 pub fn validator_count(&self) -> usize {
138 self.validators.len()
139 }
140
141 pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
143 self.index_map.get(&id).copied()
144 }
145
146 pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
148 self.index_map.get(&id).map(|&idx| &self.validators[idx])
149 }
150
151 pub fn power_of(&self, id: ValidatorId) -> u64 {
152 self.get(id).map_or(0, |v| v.power)
153 }
154
155 pub fn apply_updates(
159 &self,
160 updates: &[crate::validator_update::ValidatorUpdate],
161 ) -> ValidatorSet {
162 let mut infos: Vec<ValidatorInfo> = self.validators.clone();
163
164 for update in updates {
165 if update.power == 0 {
166 infos.retain(|v| v.id != update.id);
167 } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
168 existing.power = update.power;
169 existing.public_key = update.public_key.clone();
170 } else {
171 infos.push(ValidatorInfo {
172 id: update.id,
173 public_key: update.public_key.clone(),
174 power: update.power,
175 });
176 }
177 }
178
179 ValidatorSet::new(infos)
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 fn make_vs(powers: &[u64]) -> ValidatorSet {
188 let validators: Vec<ValidatorInfo> = powers
189 .iter()
190 .enumerate()
191 .map(|(i, &p)| ValidatorInfo {
192 id: ValidatorId(i as u64),
193 public_key: PublicKey(vec![i as u8]),
194 power: p,
195 })
196 .collect();
197 ValidatorSet::new(validators)
198 }
199
200 #[test]
201 fn test_quorum_4_equal() {
202 let vs = make_vs(&[1, 1, 1, 1]);
203 assert_eq!(vs.total_power(), 4);
204 assert_eq!(vs.quorum_threshold(), 3);
205 assert_eq!(vs.max_faulty_power(), 1);
206 }
207
208 #[test]
209 fn test_quorum_3_equal() {
210 let vs = make_vs(&[1, 1, 1]);
211 assert_eq!(vs.quorum_threshold(), 2);
212 assert_eq!(vs.max_faulty_power(), 1);
213 }
214
215 #[test]
216 fn test_quorum_weighted() {
217 let vs = make_vs(&[10, 10, 10, 70]);
218 assert_eq!(vs.quorum_threshold(), 67);
219 assert_eq!(vs.max_faulty_power(), 33);
220 }
221
222 #[test]
223 fn test_quorum_single_validator() {
224 let vs = make_vs(&[1]);
225 assert_eq!(vs.quorum_threshold(), 1);
226 assert_eq!(vs.max_faulty_power(), 0);
227 }
228
229 #[test]
230 fn test_leader_rotation() {
231 let vs = make_vs(&[1, 1, 1, 1]);
232 assert_eq!(
233 vs.leader_for_view(ViewNumber(0)).unwrap().id,
234 ValidatorId(0)
235 );
236 assert_eq!(
237 vs.leader_for_view(ViewNumber(1)).unwrap().id,
238 ValidatorId(1)
239 );
240 assert_eq!(
241 vs.leader_for_view(ViewNumber(4)).unwrap().id,
242 ValidatorId(0)
243 );
244 assert_eq!(
245 vs.leader_for_view(ViewNumber(7)).unwrap().id,
246 ValidatorId(3)
247 );
248 }
249
250 #[test]
251 fn test_index_of_o1() {
252 let vs = make_vs(&[5, 10, 15]);
253 assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
254 assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
255 assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
256 assert_eq!(vs.index_of(ValidatorId(99)), None);
257 }
258
259 #[test]
260 fn test_get_and_power_of() {
261 let vs = make_vs(&[5, 10, 15]);
262 assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
263 assert!(vs.get(ValidatorId(99)).is_none());
264 assert_eq!(vs.power_of(ValidatorId(2)), 15);
265 assert_eq!(vs.power_of(ValidatorId(99)), 0);
266 }
267
268 #[test]
269 fn test_apply_updates_add_validator() {
270 let vs = make_vs(&[1, 1, 1]);
271 let updates = vec![crate::validator_update::ValidatorUpdate {
272 id: ValidatorId(3),
273 public_key: PublicKey(vec![3]),
274 power: 2,
275 }];
276 let new_vs = vs.apply_updates(&updates);
277 assert_eq!(new_vs.validator_count(), 4);
278 assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
279 assert_eq!(new_vs.total_power(), 5);
280 }
281
282 #[test]
283 fn test_apply_updates_remove_validator() {
284 let vs = make_vs(&[1, 1, 1, 1]);
285 let updates = vec![crate::validator_update::ValidatorUpdate {
286 id: ValidatorId(2),
287 public_key: PublicKey(vec![2]),
288 power: 0,
289 }];
290 let new_vs = vs.apply_updates(&updates);
291 assert_eq!(new_vs.validator_count(), 3);
292 assert!(new_vs.get(ValidatorId(2)).is_none());
293 assert_eq!(new_vs.total_power(), 3);
294 }
295
296 #[test]
297 fn test_apply_updates_change_power() {
298 let vs = make_vs(&[1, 1, 1, 1]);
299 let updates = vec![crate::validator_update::ValidatorUpdate {
300 id: ValidatorId(0),
301 public_key: PublicKey(vec![0]),
302 power: 10,
303 }];
304 let new_vs = vs.apply_updates(&updates);
305 assert_eq!(new_vs.validator_count(), 4);
306 assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
307 assert_eq!(new_vs.total_power(), 13);
308 }
309
310 #[test]
311 fn test_serialization_roundtrip() {
312 let vs = make_vs(&[1, 2, 3]);
313 let bytes = serde_cbor_2::to_vec(&vs).unwrap();
314 let vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
315 assert_eq!(vs2.validator_count(), 3);
317 assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
318 assert_eq!(vs2.total_power(), 6);
319 }
320}