1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "V{}", self.0)
14 }
15}
16
17impl From<u64> for ValidatorId {
18 fn from(v: u64) -> Self {
19 Self(v)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25 pub id: ValidatorId,
26 pub public_key: PublicKey,
27 pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32 validators: Vec<ValidatorInfo>,
33 total_power: u64,
34 #[serde(skip)]
36 index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41 where
42 D: serde::Deserializer<'de>,
43 {
44 #[derive(Deserialize)]
45 struct Raw {
46 validators: Vec<ValidatorInfo>,
47 total_power: u64,
48 }
49 let raw = Raw::deserialize(deserializer)?;
50 let index_map = raw
51 .validators
52 .iter()
53 .enumerate()
54 .map(|(i, v)| (v.id, i))
55 .collect();
56 Ok(ValidatorSet {
57 validators: raw.validators,
58 total_power: raw.total_power,
59 index_map,
60 })
61 }
62}
63
64impl ValidatorSet {
65 pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66 let total_power = validators.iter().map(|v| v.power).sum();
67 let index_map = validators
68 .iter()
69 .enumerate()
70 .map(|(i, v)| (v.id, i))
71 .collect();
72 Self {
73 validators,
74 total_power,
75 index_map,
76 }
77 }
78
79 pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81 let validators: Vec<ValidatorInfo> = signers
82 .iter()
83 .map(|s| ValidatorInfo {
84 id: s.validator_id(),
85 public_key: s.public_key(),
86 power: 1,
87 })
88 .collect();
89 Self::new(validators)
90 }
91
92 pub fn rebuild_index(&mut self) {
98 self.index_map = self
99 .validators
100 .iter()
101 .enumerate()
102 .map(|(i, v)| (v.id, i))
103 .collect();
104 }
105
106 pub fn validators(&self) -> &[ValidatorInfo] {
107 &self.validators
108 }
109
110 pub fn total_power(&self) -> u64 {
111 self.total_power
112 }
113
114 pub fn quorum_threshold(&self) -> u64 {
116 (self.total_power * 2).div_ceil(3)
117 }
118
119 pub fn max_faulty_power(&self) -> u64 {
121 self.total_power - self.quorum_threshold()
122 }
123
124 pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
126 let idx = (view.as_u64() as usize) % self.validators.len();
127 &self.validators[idx]
128 }
129
130 pub fn validator_count(&self) -> usize {
131 self.validators.len()
132 }
133
134 pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
136 self.index_map.get(&id).copied()
137 }
138
139 pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
141 self.index_map.get(&id).map(|&idx| &self.validators[idx])
142 }
143
144 pub fn power_of(&self, id: ValidatorId) -> u64 {
145 self.get(id).map_or(0, |v| v.power)
146 }
147
148 pub fn apply_updates(
152 &self,
153 updates: &[crate::validator_update::ValidatorUpdate],
154 ) -> ValidatorSet {
155 let mut infos: Vec<ValidatorInfo> = self.validators.clone();
156
157 for update in updates {
158 if update.power == 0 {
159 infos.retain(|v| v.id != update.id);
160 } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
161 existing.power = update.power;
162 existing.public_key = update.public_key.clone();
163 } else {
164 infos.push(ValidatorInfo {
165 id: update.id,
166 public_key: update.public_key.clone(),
167 power: update.power,
168 });
169 }
170 }
171
172 ValidatorSet::new(infos)
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::*;
179
180 fn make_vs(powers: &[u64]) -> ValidatorSet {
181 let validators: Vec<ValidatorInfo> = powers
182 .iter()
183 .enumerate()
184 .map(|(i, &p)| ValidatorInfo {
185 id: ValidatorId(i as u64),
186 public_key: PublicKey(vec![i as u8]),
187 power: p,
188 })
189 .collect();
190 ValidatorSet::new(validators)
191 }
192
193 #[test]
194 fn test_quorum_4_equal() {
195 let vs = make_vs(&[1, 1, 1, 1]);
196 assert_eq!(vs.total_power(), 4);
197 assert_eq!(vs.quorum_threshold(), 3);
198 assert_eq!(vs.max_faulty_power(), 1);
199 }
200
201 #[test]
202 fn test_quorum_3_equal() {
203 let vs = make_vs(&[1, 1, 1]);
204 assert_eq!(vs.quorum_threshold(), 2);
205 assert_eq!(vs.max_faulty_power(), 1);
206 }
207
208 #[test]
209 fn test_quorum_weighted() {
210 let vs = make_vs(&[10, 10, 10, 70]);
211 assert_eq!(vs.quorum_threshold(), 67);
212 assert_eq!(vs.max_faulty_power(), 33);
213 }
214
215 #[test]
216 fn test_quorum_single_validator() {
217 let vs = make_vs(&[1]);
218 assert_eq!(vs.quorum_threshold(), 1);
219 assert_eq!(vs.max_faulty_power(), 0);
220 }
221
222 #[test]
223 fn test_leader_rotation() {
224 let vs = make_vs(&[1, 1, 1, 1]);
225 assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
226 assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
227 assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
228 assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
229 }
230
231 #[test]
232 fn test_index_of_o1() {
233 let vs = make_vs(&[5, 10, 15]);
234 assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
235 assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
236 assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
237 assert_eq!(vs.index_of(ValidatorId(99)), None);
238 }
239
240 #[test]
241 fn test_get_and_power_of() {
242 let vs = make_vs(&[5, 10, 15]);
243 assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
244 assert!(vs.get(ValidatorId(99)).is_none());
245 assert_eq!(vs.power_of(ValidatorId(2)), 15);
246 assert_eq!(vs.power_of(ValidatorId(99)), 0);
247 }
248
249 #[test]
250 fn test_apply_updates_add_validator() {
251 let vs = make_vs(&[1, 1, 1]);
252 let updates = vec![crate::validator_update::ValidatorUpdate {
253 id: ValidatorId(3),
254 public_key: PublicKey(vec![3]),
255 power: 2,
256 }];
257 let new_vs = vs.apply_updates(&updates);
258 assert_eq!(new_vs.validator_count(), 4);
259 assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
260 assert_eq!(new_vs.total_power(), 5);
261 }
262
263 #[test]
264 fn test_apply_updates_remove_validator() {
265 let vs = make_vs(&[1, 1, 1, 1]);
266 let updates = vec![crate::validator_update::ValidatorUpdate {
267 id: ValidatorId(2),
268 public_key: PublicKey(vec![2]),
269 power: 0,
270 }];
271 let new_vs = vs.apply_updates(&updates);
272 assert_eq!(new_vs.validator_count(), 3);
273 assert!(new_vs.get(ValidatorId(2)).is_none());
274 assert_eq!(new_vs.total_power(), 3);
275 }
276
277 #[test]
278 fn test_apply_updates_change_power() {
279 let vs = make_vs(&[1, 1, 1, 1]);
280 let updates = vec![crate::validator_update::ValidatorUpdate {
281 id: ValidatorId(0),
282 public_key: PublicKey(vec![0]),
283 power: 10,
284 }];
285 let new_vs = vs.apply_updates(&updates);
286 assert_eq!(new_vs.validator_count(), 4);
287 assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
288 assert_eq!(new_vs.total_power(), 13);
289 }
290
291 #[test]
292 fn test_serialization_roundtrip() {
293 let vs = make_vs(&[1, 2, 3]);
294 let bytes = serde_cbor_2::to_vec(&vs).unwrap();
295 let vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
296 assert_eq!(vs2.validator_count(), 3);
298 assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
299 assert_eq!(vs2.total_power(), 6);
300 }
301}