Skip to main content

hotmint_types/
validator.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "V{}", self.0)
14    }
15}
16
17impl From<u64> for ValidatorId {
18    fn from(v: u64) -> Self {
19        Self(v)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25    pub id: ValidatorId,
26    pub public_key: PublicKey,
27    pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32    validators: Vec<ValidatorInfo>,
33    total_power: u64,
34    /// O(1) lookup: ValidatorId -> index in validators vec
35    #[serde(skip)]
36    index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        #[derive(Deserialize)]
45        struct Raw {
46            validators: Vec<ValidatorInfo>,
47            total_power: u64,
48        }
49        let raw = Raw::deserialize(deserializer)?;
50        let index_map = raw
51            .validators
52            .iter()
53            .enumerate()
54            .map(|(i, v)| (v.id, i))
55            .collect();
56        Ok(ValidatorSet {
57            validators: raw.validators,
58            total_power: raw.total_power,
59            index_map,
60        })
61    }
62}
63
64impl ValidatorSet {
65    pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66        let total_power = validators.iter().map(|v| v.power).sum();
67        let index_map = validators
68            .iter()
69            .enumerate()
70            .map(|(i, v)| (v.id, i))
71            .collect();
72        Self {
73            validators,
74            total_power,
75            index_map,
76        }
77    }
78
79    /// Build a `ValidatorSet` from signers with equal power (1 each).
80    pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81        let validators: Vec<ValidatorInfo> = signers
82            .iter()
83            .map(|s| ValidatorInfo {
84                id: s.validator_id(),
85                public_key: s.public_key(),
86                power: 1,
87            })
88            .collect();
89        Self::new(validators)
90    }
91
92    /// Rebuild the index map after deserialization.
93    ///
94    /// NOTE: This is now called automatically during deserialization.
95    /// You only need to call this manually if you modify the validators
96    /// list directly.
97    pub fn rebuild_index(&mut self) {
98        self.index_map = self
99            .validators
100            .iter()
101            .enumerate()
102            .map(|(i, v)| (v.id, i))
103            .collect();
104    }
105
106    pub fn validators(&self) -> &[ValidatorInfo] {
107        &self.validators
108    }
109
110    pub fn total_power(&self) -> u64 {
111        self.total_power
112    }
113
114    /// Quorum threshold: ceil(2n/3) where n = total_power
115    pub fn quorum_threshold(&self) -> u64 {
116        self.total_power
117            .checked_mul(2)
118            .expect("total_power overflow in quorum_threshold")
119            .div_ceil(3)
120    }
121
122    /// Maximum faulty power: total_power - quorum_threshold
123    pub fn max_faulty_power(&self) -> u64 {
124        self.total_power - self.quorum_threshold()
125    }
126
127    /// Round-robin leader selection: v mod n.
128    /// Returns `None` if the validator set is empty.
129    pub fn leader_for_view(&self, view: ViewNumber) -> Option<&ValidatorInfo> {
130        if self.validators.is_empty() {
131            return None;
132        }
133        let idx = (view.as_u64() as usize) % self.validators.len();
134        Some(&self.validators[idx])
135    }
136
137    pub fn validator_count(&self) -> usize {
138        self.validators.len()
139    }
140
141    /// O(1) index lookup
142    pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
143        self.index_map.get(&id).copied()
144    }
145
146    /// O(1) validator info lookup
147    pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
148        self.index_map.get(&id).map(|&idx| &self.validators[idx])
149    }
150
151    pub fn power_of(&self, id: ValidatorId) -> u64 {
152        self.get(id).map_or(0, |v| v.power)
153    }
154
155    /// Apply validator updates and return a new ValidatorSet.
156    /// - `power > 0`: update existing validator's power/key, or add new validator
157    /// - `power == 0`: remove validator
158    pub fn apply_updates(
159        &self,
160        updates: &[crate::validator_update::ValidatorUpdate],
161    ) -> ValidatorSet {
162        let mut infos: Vec<ValidatorInfo> = self.validators.clone();
163
164        for update in updates {
165            if update.power == 0 {
166                infos.retain(|v| v.id != update.id);
167            } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
168                existing.power = update.power;
169                existing.public_key = update.public_key.clone();
170            } else {
171                infos.push(ValidatorInfo {
172                    id: update.id,
173                    public_key: update.public_key.clone(),
174                    power: update.power,
175                });
176            }
177        }
178
179        ValidatorSet::new(infos)
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    fn make_vs(powers: &[u64]) -> ValidatorSet {
188        let validators: Vec<ValidatorInfo> = powers
189            .iter()
190            .enumerate()
191            .map(|(i, &p)| ValidatorInfo {
192                id: ValidatorId(i as u64),
193                public_key: PublicKey(vec![i as u8]),
194                power: p,
195            })
196            .collect();
197        ValidatorSet::new(validators)
198    }
199
200    #[test]
201    fn test_quorum_4_equal() {
202        let vs = make_vs(&[1, 1, 1, 1]);
203        assert_eq!(vs.total_power(), 4);
204        assert_eq!(vs.quorum_threshold(), 3);
205        assert_eq!(vs.max_faulty_power(), 1);
206    }
207
208    #[test]
209    fn test_quorum_3_equal() {
210        let vs = make_vs(&[1, 1, 1]);
211        assert_eq!(vs.quorum_threshold(), 2);
212        assert_eq!(vs.max_faulty_power(), 1);
213    }
214
215    #[test]
216    fn test_quorum_weighted() {
217        let vs = make_vs(&[10, 10, 10, 70]);
218        assert_eq!(vs.quorum_threshold(), 67);
219        assert_eq!(vs.max_faulty_power(), 33);
220    }
221
222    #[test]
223    fn test_quorum_single_validator() {
224        let vs = make_vs(&[1]);
225        assert_eq!(vs.quorum_threshold(), 1);
226        assert_eq!(vs.max_faulty_power(), 0);
227    }
228
229    #[test]
230    fn test_leader_rotation() {
231        let vs = make_vs(&[1, 1, 1, 1]);
232        assert_eq!(
233            vs.leader_for_view(ViewNumber(0)).unwrap().id,
234            ValidatorId(0)
235        );
236        assert_eq!(
237            vs.leader_for_view(ViewNumber(1)).unwrap().id,
238            ValidatorId(1)
239        );
240        assert_eq!(
241            vs.leader_for_view(ViewNumber(4)).unwrap().id,
242            ValidatorId(0)
243        );
244        assert_eq!(
245            vs.leader_for_view(ViewNumber(7)).unwrap().id,
246            ValidatorId(3)
247        );
248    }
249
250    #[test]
251    fn test_index_of_o1() {
252        let vs = make_vs(&[5, 10, 15]);
253        assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
254        assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
255        assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
256        assert_eq!(vs.index_of(ValidatorId(99)), None);
257    }
258
259    #[test]
260    fn test_get_and_power_of() {
261        let vs = make_vs(&[5, 10, 15]);
262        assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
263        assert!(vs.get(ValidatorId(99)).is_none());
264        assert_eq!(vs.power_of(ValidatorId(2)), 15);
265        assert_eq!(vs.power_of(ValidatorId(99)), 0);
266    }
267
268    #[test]
269    fn test_apply_updates_add_validator() {
270        let vs = make_vs(&[1, 1, 1]);
271        let updates = vec![crate::validator_update::ValidatorUpdate {
272            id: ValidatorId(3),
273            public_key: PublicKey(vec![3]),
274            power: 2,
275        }];
276        let new_vs = vs.apply_updates(&updates);
277        assert_eq!(new_vs.validator_count(), 4);
278        assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
279        assert_eq!(new_vs.total_power(), 5);
280    }
281
282    #[test]
283    fn test_apply_updates_remove_validator() {
284        let vs = make_vs(&[1, 1, 1, 1]);
285        let updates = vec![crate::validator_update::ValidatorUpdate {
286            id: ValidatorId(2),
287            public_key: PublicKey(vec![2]),
288            power: 0,
289        }];
290        let new_vs = vs.apply_updates(&updates);
291        assert_eq!(new_vs.validator_count(), 3);
292        assert!(new_vs.get(ValidatorId(2)).is_none());
293        assert_eq!(new_vs.total_power(), 3);
294    }
295
296    #[test]
297    fn test_apply_updates_change_power() {
298        let vs = make_vs(&[1, 1, 1, 1]);
299        let updates = vec![crate::validator_update::ValidatorUpdate {
300            id: ValidatorId(0),
301            public_key: PublicKey(vec![0]),
302            power: 10,
303        }];
304        let new_vs = vs.apply_updates(&updates);
305        assert_eq!(new_vs.validator_count(), 4);
306        assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
307        assert_eq!(new_vs.total_power(), 13);
308    }
309
310    #[test]
311    fn test_serialization_roundtrip() {
312        let vs = make_vs(&[1, 2, 3]);
313        let bytes = serde_cbor_2::to_vec(&vs).unwrap();
314        let vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
315        // index_map is auto-rebuilt during deserialization
316        assert_eq!(vs2.validator_count(), 3);
317        assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
318        assert_eq!(vs2.total_power(), 6);
319    }
320}