1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::PublicKey;
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13 write!(f, "V{}", self.0)
14 }
15}
16
17impl From<u64> for ValidatorId {
18 fn from(v: u64) -> Self {
19 Self(v)
20 }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25 pub id: ValidatorId,
26 pub public_key: PublicKey,
27 pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ValidatorSet {
32 validators: Vec<ValidatorInfo>,
33 total_power: u64,
34 #[serde(skip)]
36 index_map: HashMap<ValidatorId, usize>,
37}
38
39impl ValidatorSet {
40 pub fn new(validators: Vec<ValidatorInfo>) -> Self {
41 let total_power = validators.iter().map(|v| v.power).sum();
42 let index_map = validators
43 .iter()
44 .enumerate()
45 .map(|(i, v)| (v.id, i))
46 .collect();
47 Self {
48 validators,
49 total_power,
50 index_map,
51 }
52 }
53
54 pub fn rebuild_index(&mut self) {
56 self.index_map = self
57 .validators
58 .iter()
59 .enumerate()
60 .map(|(i, v)| (v.id, i))
61 .collect();
62 }
63
64 pub fn validators(&self) -> &[ValidatorInfo] {
65 &self.validators
66 }
67
68 pub fn total_power(&self) -> u64 {
69 self.total_power
70 }
71
72 pub fn quorum_threshold(&self) -> u64 {
74 (self.total_power * 2).div_ceil(3)
75 }
76
77 pub fn max_faulty_power(&self) -> u64 {
79 self.total_power - self.quorum_threshold()
80 }
81
82 pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
84 let idx = (view.as_u64() as usize) % self.validators.len();
85 &self.validators[idx]
86 }
87
88 pub fn validator_count(&self) -> usize {
89 self.validators.len()
90 }
91
92 pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
94 self.index_map.get(&id).copied()
95 }
96
97 pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
99 self.index_map.get(&id).map(|&idx| &self.validators[idx])
100 }
101
102 pub fn power_of(&self, id: ValidatorId) -> u64 {
103 self.get(id).map_or(0, |v| v.power)
104 }
105
106 pub fn apply_updates(
110 &self,
111 updates: &[crate::validator_update::ValidatorUpdate],
112 ) -> ValidatorSet {
113 let mut infos: Vec<ValidatorInfo> = self.validators.clone();
114
115 for update in updates {
116 if update.power == 0 {
117 infos.retain(|v| v.id != update.id);
118 } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
119 existing.power = update.power;
120 existing.public_key = update.public_key.clone();
121 } else {
122 infos.push(ValidatorInfo {
123 id: update.id,
124 public_key: update.public_key.clone(),
125 power: update.power,
126 });
127 }
128 }
129
130 ValidatorSet::new(infos)
131 }
132}
133
134#[cfg(test)]
135mod tests {
136 use super::*;
137
138 fn make_vs(powers: &[u64]) -> ValidatorSet {
139 let validators: Vec<ValidatorInfo> = powers
140 .iter()
141 .enumerate()
142 .map(|(i, &p)| ValidatorInfo {
143 id: ValidatorId(i as u64),
144 public_key: PublicKey(vec![i as u8]),
145 power: p,
146 })
147 .collect();
148 ValidatorSet::new(validators)
149 }
150
151 #[test]
152 fn test_quorum_4_equal() {
153 let vs = make_vs(&[1, 1, 1, 1]);
154 assert_eq!(vs.total_power(), 4);
155 assert_eq!(vs.quorum_threshold(), 3);
156 assert_eq!(vs.max_faulty_power(), 1);
157 }
158
159 #[test]
160 fn test_quorum_3_equal() {
161 let vs = make_vs(&[1, 1, 1]);
162 assert_eq!(vs.quorum_threshold(), 2);
163 assert_eq!(vs.max_faulty_power(), 1);
164 }
165
166 #[test]
167 fn test_quorum_weighted() {
168 let vs = make_vs(&[10, 10, 10, 70]);
169 assert_eq!(vs.quorum_threshold(), 67);
170 assert_eq!(vs.max_faulty_power(), 33);
171 }
172
173 #[test]
174 fn test_quorum_single_validator() {
175 let vs = make_vs(&[1]);
176 assert_eq!(vs.quorum_threshold(), 1);
177 assert_eq!(vs.max_faulty_power(), 0);
178 }
179
180 #[test]
181 fn test_leader_rotation() {
182 let vs = make_vs(&[1, 1, 1, 1]);
183 assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
184 assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
185 assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
186 assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
187 }
188
189 #[test]
190 fn test_index_of_o1() {
191 let vs = make_vs(&[5, 10, 15]);
192 assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
193 assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
194 assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
195 assert_eq!(vs.index_of(ValidatorId(99)), None);
196 }
197
198 #[test]
199 fn test_get_and_power_of() {
200 let vs = make_vs(&[5, 10, 15]);
201 assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
202 assert!(vs.get(ValidatorId(99)).is_none());
203 assert_eq!(vs.power_of(ValidatorId(2)), 15);
204 assert_eq!(vs.power_of(ValidatorId(99)), 0);
205 }
206
207 #[test]
208 fn test_apply_updates_add_validator() {
209 let vs = make_vs(&[1, 1, 1]);
210 let updates = vec![crate::validator_update::ValidatorUpdate {
211 id: ValidatorId(3),
212 public_key: PublicKey(vec![3]),
213 power: 2,
214 }];
215 let new_vs = vs.apply_updates(&updates);
216 assert_eq!(new_vs.validator_count(), 4);
217 assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
218 assert_eq!(new_vs.total_power(), 5);
219 }
220
221 #[test]
222 fn test_apply_updates_remove_validator() {
223 let vs = make_vs(&[1, 1, 1, 1]);
224 let updates = vec![crate::validator_update::ValidatorUpdate {
225 id: ValidatorId(2),
226 public_key: PublicKey(vec![2]),
227 power: 0,
228 }];
229 let new_vs = vs.apply_updates(&updates);
230 assert_eq!(new_vs.validator_count(), 3);
231 assert!(new_vs.get(ValidatorId(2)).is_none());
232 assert_eq!(new_vs.total_power(), 3);
233 }
234
235 #[test]
236 fn test_apply_updates_change_power() {
237 let vs = make_vs(&[1, 1, 1, 1]);
238 let updates = vec![crate::validator_update::ValidatorUpdate {
239 id: ValidatorId(0),
240 public_key: PublicKey(vec![0]),
241 power: 10,
242 }];
243 let new_vs = vs.apply_updates(&updates);
244 assert_eq!(new_vs.validator_count(), 4);
245 assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
246 assert_eq!(new_vs.total_power(), 13);
247 }
248
249 #[test]
250 fn test_serialization_roundtrip() {
251 let vs = make_vs(&[1, 2, 3]);
252 let bytes = serde_cbor_2::to_vec(&vs).unwrap();
253 let mut vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
254 vs2.rebuild_index();
255 assert_eq!(vs2.validator_count(), 3);
256 assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
257 assert_eq!(vs2.total_power(), 6);
258 }
259}