1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "V{}", self.0)
14 }
15}
16
17impl From<u64> for ValidatorId {
18 fn from(v: u64) -> Self {
19 Self(v)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25 pub id: ValidatorId,
26 pub public_key: PublicKey,
27 pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32 validators: Vec<ValidatorInfo>,
33 total_power: u64,
34 #[serde(skip)]
36 index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 #[derive(Deserialize)]
45 struct Raw {
46 validators: Vec<ValidatorInfo>,
47 total_power: u64,
48 }
49 let raw = Raw::deserialize(deserializer)?;
50 let index_map = raw
51 .validators
52 .iter()
53 .enumerate()
54 .map(|(i, v)| (v.id, i))
55 .collect();
56 Ok(ValidatorSet {
57 validators: raw.validators,
58 total_power: raw.total_power,
59 index_map,
60 })
61 }
62}
63
64impl ValidatorSet {
65 pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66 let total_power = validators.iter().map(|v| v.power).sum();
67 let index_map = validators
68 .iter()
69 .enumerate()
70 .map(|(i, v)| (v.id, i))
71 .collect();
72 Self {
73 validators,
74 total_power,
75 index_map,
76 }
77 }
78
79 pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81 let validators: Vec<ValidatorInfo> = signers
82 .iter()
83 .map(|s| ValidatorInfo {
84 id: s.validator_id(),
85 public_key: s.public_key(),
86 power: 1,
87 })
88 .collect();
89 Self::new(validators)
90 }
91
92 pub fn rebuild_index(&mut self) {
98 self.index_map = self
99 .validators
100 .iter()
101 .enumerate()
102 .map(|(i, v)| (v.id, i))
103 .collect();
104 }
105
106 pub fn validators(&self) -> &[ValidatorInfo] {
107 &self.validators
108 }
109
110 pub fn total_power(&self) -> u64 {
111 self.total_power
112 }
113
114 pub fn quorum_threshold(&self) -> u64 {
116 self.total_power
117 .checked_mul(2)
118 .expect("total_power overflow in quorum_threshold")
119 .div_ceil(3)
120 }
121
122 pub fn max_faulty_power(&self) -> u64 {
124 self.total_power - self.quorum_threshold()
125 }
126
127 pub fn leader_for_view(&self, view: ViewNumber) -> Option<&ValidatorInfo> {
135 if self.validators.is_empty() || self.total_power == 0 {
136 return None;
137 }
138 let slot = view.as_u64() % self.total_power;
139 let mut cumulative = 0u64;
140 for vi in &self.validators {
141 cumulative += vi.power;
142 if slot < cumulative {
143 return Some(vi);
144 }
145 }
146 self.validators.last()
148 }
149
150 pub fn validator_count(&self) -> usize {
151 self.validators.len()
152 }
153
154 pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
156 self.index_map.get(&id).copied()
157 }
158
159 pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
161 self.index_map.get(&id).map(|&idx| &self.validators[idx])
162 }
163
164 pub fn power_of(&self, id: ValidatorId) -> u64 {
165 self.get(id).map_or(0, |v| v.power)
166 }
167
168 pub fn apply_updates(
172 &self,
173 updates: &[crate::validator_update::ValidatorUpdate],
174 ) -> ValidatorSet {
175 let mut infos: Vec<ValidatorInfo> = self.validators.clone();
176
177 for update in updates {
178 if update.power == 0 {
179 infos.retain(|v| v.id != update.id);
180 } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
181 existing.power = update.power;
182 existing.public_key = update.public_key.clone();
183 } else {
184 infos.push(ValidatorInfo {
185 id: update.id,
186 public_key: update.public_key.clone(),
187 power: update.power,
188 });
189 }
190 }
191
192 ValidatorSet::new(infos)
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199
200 fn make_vs(powers: &[u64]) -> ValidatorSet {
201 let validators: Vec<ValidatorInfo> = powers
202 .iter()
203 .enumerate()
204 .map(|(i, &p)| ValidatorInfo {
205 id: ValidatorId(i as u64),
206 public_key: PublicKey(vec![i as u8]),
207 power: p,
208 })
209 .collect();
210 ValidatorSet::new(validators)
211 }
212
213 #[test]
214 fn test_quorum_4_equal() {
215 let vs = make_vs(&[1, 1, 1, 1]);
216 assert_eq!(vs.total_power(), 4);
217 assert_eq!(vs.quorum_threshold(), 3);
218 assert_eq!(vs.max_faulty_power(), 1);
219 }
220
221 #[test]
222 fn test_quorum_3_equal() {
223 let vs = make_vs(&[1, 1, 1]);
224 assert_eq!(vs.quorum_threshold(), 2);
225 assert_eq!(vs.max_faulty_power(), 1);
226 }
227
228 #[test]
229 fn test_quorum_weighted() {
230 let vs = make_vs(&[10, 10, 10, 70]);
231 assert_eq!(vs.quorum_threshold(), 67);
232 assert_eq!(vs.max_faulty_power(), 33);
233 }
234
235 #[test]
236 fn test_quorum_single_validator() {
237 let vs = make_vs(&[1]);
238 assert_eq!(vs.quorum_threshold(), 1);
239 assert_eq!(vs.max_faulty_power(), 0);
240 }
241
242 #[test]
243 fn test_leader_rotation() {
244 let vs = make_vs(&[1, 1, 1, 1]);
245 assert_eq!(
246 vs.leader_for_view(ViewNumber(0)).unwrap().id,
247 ValidatorId(0)
248 );
249 assert_eq!(
250 vs.leader_for_view(ViewNumber(1)).unwrap().id,
251 ValidatorId(1)
252 );
253 assert_eq!(
254 vs.leader_for_view(ViewNumber(4)).unwrap().id,
255 ValidatorId(0)
256 );
257 assert_eq!(
258 vs.leader_for_view(ViewNumber(7)).unwrap().id,
259 ValidatorId(3)
260 );
261 }
262
263 #[test]
264 fn test_index_of_o1() {
265 let vs = make_vs(&[5, 10, 15]);
266 assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
267 assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
268 assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
269 assert_eq!(vs.index_of(ValidatorId(99)), None);
270 }
271
272 #[test]
273 fn test_get_and_power_of() {
274 let vs = make_vs(&[5, 10, 15]);
275 assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
276 assert!(vs.get(ValidatorId(99)).is_none());
277 assert_eq!(vs.power_of(ValidatorId(2)), 15);
278 assert_eq!(vs.power_of(ValidatorId(99)), 0);
279 }
280
281 #[test]
282 fn test_apply_updates_add_validator() {
283 let vs = make_vs(&[1, 1, 1]);
284 let updates = vec![crate::validator_update::ValidatorUpdate {
285 id: ValidatorId(3),
286 public_key: PublicKey(vec![3]),
287 power: 2,
288 }];
289 let new_vs = vs.apply_updates(&updates);
290 assert_eq!(new_vs.validator_count(), 4);
291 assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
292 assert_eq!(new_vs.total_power(), 5);
293 }
294
295 #[test]
296 fn test_apply_updates_remove_validator() {
297 let vs = make_vs(&[1, 1, 1, 1]);
298 let updates = vec![crate::validator_update::ValidatorUpdate {
299 id: ValidatorId(2),
300 public_key: PublicKey(vec![2]),
301 power: 0,
302 }];
303 let new_vs = vs.apply_updates(&updates);
304 assert_eq!(new_vs.validator_count(), 3);
305 assert!(new_vs.get(ValidatorId(2)).is_none());
306 assert_eq!(new_vs.total_power(), 3);
307 }
308
309 #[test]
310 fn test_apply_updates_change_power() {
311 let vs = make_vs(&[1, 1, 1, 1]);
312 let updates = vec![crate::validator_update::ValidatorUpdate {
313 id: ValidatorId(0),
314 public_key: PublicKey(vec![0]),
315 power: 10,
316 }];
317 let new_vs = vs.apply_updates(&updates);
318 assert_eq!(new_vs.validator_count(), 4);
319 assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
320 assert_eq!(new_vs.total_power(), 13);
321 }
322
323 #[test]
324 fn test_leader_weighted() {
325 let vs = make_vs(&[10, 10, 10, 70]);
328 assert_eq!(
330 vs.leader_for_view(ViewNumber(0)).unwrap().id,
331 ValidatorId(0)
332 );
333 assert_eq!(
334 vs.leader_for_view(ViewNumber(9)).unwrap().id,
335 ValidatorId(0)
336 );
337 assert_eq!(
338 vs.leader_for_view(ViewNumber(10)).unwrap().id,
339 ValidatorId(1)
340 );
341 assert_eq!(
342 vs.leader_for_view(ViewNumber(20)).unwrap().id,
343 ValidatorId(2)
344 );
345 assert_eq!(
346 vs.leader_for_view(ViewNumber(30)).unwrap().id,
347 ValidatorId(3)
348 );
349 assert_eq!(
350 vs.leader_for_view(ViewNumber(99)).unwrap().id,
351 ValidatorId(3)
352 );
353 assert_eq!(
355 vs.leader_for_view(ViewNumber(100)).unwrap().id,
356 ValidatorId(0)
357 );
358 }
359
360 #[test]
361 fn test_serialization_roundtrip() {
362 let vs = make_vs(&[1, 2, 3]);
363 let bytes = postcard::to_allocvec(&vs).unwrap();
364 let vs2: ValidatorSet = postcard::from_bytes(&bytes).unwrap();
365 assert_eq!(vs2.validator_count(), 3);
367 assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
368 assert_eq!(vs2.total_power(), 6);
369 }
370}