Skip to main content

hotmint_types/
validator.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "V{}", self.0)
14    }
15}
16
17impl From<u64> for ValidatorId {
18    fn from(v: u64) -> Self {
19        Self(v)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25    pub id: ValidatorId,
26    pub public_key: PublicKey,
27    pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32    validators: Vec<ValidatorInfo>,
33    total_power: u64,
34    /// O(1) lookup: ValidatorId -> index in validators vec
35    #[serde(skip)]
36    index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        #[derive(Deserialize)]
45        struct Raw {
46            validators: Vec<ValidatorInfo>,
47            total_power: u64,
48        }
49        let raw = Raw::deserialize(deserializer)?;
50        let index_map = raw
51            .validators
52            .iter()
53            .enumerate()
54            .map(|(i, v)| (v.id, i))
55            .collect();
56        Ok(ValidatorSet {
57            validators: raw.validators,
58            total_power: raw.total_power,
59            index_map,
60        })
61    }
62}
63
64impl ValidatorSet {
65    pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66        let total_power = validators.iter().map(|v| v.power).sum();
67        let index_map = validators
68            .iter()
69            .enumerate()
70            .map(|(i, v)| (v.id, i))
71            .collect();
72        Self {
73            validators,
74            total_power,
75            index_map,
76        }
77    }
78
79    /// Build a `ValidatorSet` from signers with equal power (1 each).
80    pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81        let validators: Vec<ValidatorInfo> = signers
82            .iter()
83            .map(|s| ValidatorInfo {
84                id: s.validator_id(),
85                public_key: s.public_key(),
86                power: 1,
87            })
88            .collect();
89        Self::new(validators)
90    }
91
92    /// Rebuild the index map after deserialization.
93    ///
94    /// NOTE: This is now called automatically during deserialization.
95    /// You only need to call this manually if you modify the validators
96    /// list directly.
97    pub fn rebuild_index(&mut self) {
98        self.index_map = self
99            .validators
100            .iter()
101            .enumerate()
102            .map(|(i, v)| (v.id, i))
103            .collect();
104    }
105
106    pub fn validators(&self) -> &[ValidatorInfo] {
107        &self.validators
108    }
109
110    pub fn total_power(&self) -> u64 {
111        self.total_power
112    }
113
114    /// Quorum threshold: ceil(2n/3) where n = total_power
115    pub fn quorum_threshold(&self) -> u64 {
116        self.total_power
117            .checked_mul(2)
118            .expect("total_power overflow in quorum_threshold")
119            .div_ceil(3)
120    }
121
122    /// Maximum faulty power: total_power - quorum_threshold
123    pub fn max_faulty_power(&self) -> u64 {
124        self.total_power - self.quorum_threshold()
125    }
126
127    /// Power-weighted leader selection.
128    ///
129    /// Maps `view` into the range `[0, total_power)` and picks the validator
130    /// whose cumulative power range contains that slot.  A validator with
131    /// twice the power of another will be selected roughly twice as often.
132    ///
133    /// Returns `None` if the validator set is empty.
134    pub fn leader_for_view(&self, view: ViewNumber) -> Option<&ValidatorInfo> {
135        if self.validators.is_empty() || self.total_power == 0 {
136            return None;
137        }
138        let slot = view.as_u64() % self.total_power;
139        let mut cumulative = 0u64;
140        for vi in &self.validators {
141            cumulative += vi.power;
142            if slot < cumulative {
143                return Some(vi);
144            }
145        }
146        // Fallback (should be unreachable with valid total_power)
147        self.validators.last()
148    }
149
150    pub fn validator_count(&self) -> usize {
151        self.validators.len()
152    }
153
154    /// O(1) index lookup
155    pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
156        self.index_map.get(&id).copied()
157    }
158
159    /// O(1) validator info lookup
160    pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
161        self.index_map.get(&id).map(|&idx| &self.validators[idx])
162    }
163
164    pub fn power_of(&self, id: ValidatorId) -> u64 {
165        self.get(id).map_or(0, |v| v.power)
166    }
167
168    /// Apply validator updates and return a new ValidatorSet.
169    /// - `power > 0`: update existing validator's power/key, or add new validator
170    /// - `power == 0`: remove validator
171    pub fn apply_updates(
172        &self,
173        updates: &[crate::validator_update::ValidatorUpdate],
174    ) -> ValidatorSet {
175        let mut infos: Vec<ValidatorInfo> = self.validators.clone();
176
177        for update in updates {
178            if update.power == 0 {
179                infos.retain(|v| v.id != update.id);
180            } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
181                existing.power = update.power;
182                existing.public_key = update.public_key.clone();
183            } else {
184                infos.push(ValidatorInfo {
185                    id: update.id,
186                    public_key: update.public_key.clone(),
187                    power: update.power,
188                });
189            }
190        }
191
192        ValidatorSet::new(infos)
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199
200    fn make_vs(powers: &[u64]) -> ValidatorSet {
201        let validators: Vec<ValidatorInfo> = powers
202            .iter()
203            .enumerate()
204            .map(|(i, &p)| ValidatorInfo {
205                id: ValidatorId(i as u64),
206                public_key: PublicKey(vec![i as u8]),
207                power: p,
208            })
209            .collect();
210        ValidatorSet::new(validators)
211    }
212
213    #[test]
214    fn test_quorum_4_equal() {
215        let vs = make_vs(&[1, 1, 1, 1]);
216        assert_eq!(vs.total_power(), 4);
217        assert_eq!(vs.quorum_threshold(), 3);
218        assert_eq!(vs.max_faulty_power(), 1);
219    }
220
221    #[test]
222    fn test_quorum_3_equal() {
223        let vs = make_vs(&[1, 1, 1]);
224        assert_eq!(vs.quorum_threshold(), 2);
225        assert_eq!(vs.max_faulty_power(), 1);
226    }
227
228    #[test]
229    fn test_quorum_weighted() {
230        let vs = make_vs(&[10, 10, 10, 70]);
231        assert_eq!(vs.quorum_threshold(), 67);
232        assert_eq!(vs.max_faulty_power(), 33);
233    }
234
235    #[test]
236    fn test_quorum_single_validator() {
237        let vs = make_vs(&[1]);
238        assert_eq!(vs.quorum_threshold(), 1);
239        assert_eq!(vs.max_faulty_power(), 0);
240    }
241
242    #[test]
243    fn test_leader_rotation() {
244        let vs = make_vs(&[1, 1, 1, 1]);
245        assert_eq!(
246            vs.leader_for_view(ViewNumber(0)).unwrap().id,
247            ValidatorId(0)
248        );
249        assert_eq!(
250            vs.leader_for_view(ViewNumber(1)).unwrap().id,
251            ValidatorId(1)
252        );
253        assert_eq!(
254            vs.leader_for_view(ViewNumber(4)).unwrap().id,
255            ValidatorId(0)
256        );
257        assert_eq!(
258            vs.leader_for_view(ViewNumber(7)).unwrap().id,
259            ValidatorId(3)
260        );
261    }
262
263    #[test]
264    fn test_index_of_o1() {
265        let vs = make_vs(&[5, 10, 15]);
266        assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
267        assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
268        assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
269        assert_eq!(vs.index_of(ValidatorId(99)), None);
270    }
271
272    #[test]
273    fn test_get_and_power_of() {
274        let vs = make_vs(&[5, 10, 15]);
275        assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
276        assert!(vs.get(ValidatorId(99)).is_none());
277        assert_eq!(vs.power_of(ValidatorId(2)), 15);
278        assert_eq!(vs.power_of(ValidatorId(99)), 0);
279    }
280
281    #[test]
282    fn test_apply_updates_add_validator() {
283        let vs = make_vs(&[1, 1, 1]);
284        let updates = vec![crate::validator_update::ValidatorUpdate {
285            id: ValidatorId(3),
286            public_key: PublicKey(vec![3]),
287            power: 2,
288        }];
289        let new_vs = vs.apply_updates(&updates);
290        assert_eq!(new_vs.validator_count(), 4);
291        assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
292        assert_eq!(new_vs.total_power(), 5);
293    }
294
295    #[test]
296    fn test_apply_updates_remove_validator() {
297        let vs = make_vs(&[1, 1, 1, 1]);
298        let updates = vec![crate::validator_update::ValidatorUpdate {
299            id: ValidatorId(2),
300            public_key: PublicKey(vec![2]),
301            power: 0,
302        }];
303        let new_vs = vs.apply_updates(&updates);
304        assert_eq!(new_vs.validator_count(), 3);
305        assert!(new_vs.get(ValidatorId(2)).is_none());
306        assert_eq!(new_vs.total_power(), 3);
307    }
308
309    #[test]
310    fn test_apply_updates_change_power() {
311        let vs = make_vs(&[1, 1, 1, 1]);
312        let updates = vec![crate::validator_update::ValidatorUpdate {
313            id: ValidatorId(0),
314            public_key: PublicKey(vec![0]),
315            power: 10,
316        }];
317        let new_vs = vs.apply_updates(&updates);
318        assert_eq!(new_vs.validator_count(), 4);
319        assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
320        assert_eq!(new_vs.total_power(), 13);
321    }
322
323    #[test]
324    fn test_leader_weighted() {
325        // Powers: V0=10, V1=10, V2=10, V3=70 → total=100
326        // V3 should be selected ~70% of the time
327        let vs = make_vs(&[10, 10, 10, 70]);
328        // Slots 0..9 → V0, 10..19 → V1, 20..29 → V2, 30..99 → V3
329        assert_eq!(
330            vs.leader_for_view(ViewNumber(0)).unwrap().id,
331            ValidatorId(0)
332        );
333        assert_eq!(
334            vs.leader_for_view(ViewNumber(9)).unwrap().id,
335            ValidatorId(0)
336        );
337        assert_eq!(
338            vs.leader_for_view(ViewNumber(10)).unwrap().id,
339            ValidatorId(1)
340        );
341        assert_eq!(
342            vs.leader_for_view(ViewNumber(20)).unwrap().id,
343            ValidatorId(2)
344        );
345        assert_eq!(
346            vs.leader_for_view(ViewNumber(30)).unwrap().id,
347            ValidatorId(3)
348        );
349        assert_eq!(
350            vs.leader_for_view(ViewNumber(99)).unwrap().id,
351            ValidatorId(3)
352        );
353        // Wrap around: view 100 → slot 0 → V0
354        assert_eq!(
355            vs.leader_for_view(ViewNumber(100)).unwrap().id,
356            ValidatorId(0)
357        );
358    }
359
360    #[test]
361    fn test_serialization_roundtrip() {
362        let vs = make_vs(&[1, 2, 3]);
363        let bytes = postcard::to_allocvec(&vs).unwrap();
364        let vs2: ValidatorSet = postcard::from_bytes(&bytes).unwrap();
365        // index_map is auto-rebuilt during deserialization
366        assert_eq!(vs2.validator_count(), 3);
367        assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
368        assert_eq!(vs2.total_power(), 6);
369    }
370}