Skip to main content

hotmint_types/
validator.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::{PublicKey, Signer};
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "V{}", self.0)
14    }
15}
16
17impl From<u64> for ValidatorId {
18    fn from(v: u64) -> Self {
19        Self(v)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25    pub id: ValidatorId,
26    pub public_key: PublicKey,
27    pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize)]
31pub struct ValidatorSet {
32    validators: Vec<ValidatorInfo>,
33    total_power: u64,
34    /// O(1) lookup: ValidatorId -> index in validators vec
35    #[serde(skip)]
36    index_map: HashMap<ValidatorId, usize>,
37}
38
39impl<'de> Deserialize<'de> for ValidatorSet {
40    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
41    where
42        D: serde::Deserializer<'de>,
43    {
44        #[derive(Deserialize)]
45        struct Raw {
46            validators: Vec<ValidatorInfo>,
47            total_power: u64,
48        }
49        let raw = Raw::deserialize(deserializer)?;
50        let index_map = raw
51            .validators
52            .iter()
53            .enumerate()
54            .map(|(i, v)| (v.id, i))
55            .collect();
56        Ok(ValidatorSet {
57            validators: raw.validators,
58            total_power: raw.total_power,
59            index_map,
60        })
61    }
62}
63
64impl ValidatorSet {
65    pub fn new(validators: Vec<ValidatorInfo>) -> Self {
66        let total_power = validators.iter().map(|v| v.power).sum();
67        let index_map = validators
68            .iter()
69            .enumerate()
70            .map(|(i, v)| (v.id, i))
71            .collect();
72        Self {
73            validators,
74            total_power,
75            index_map,
76        }
77    }
78
79    /// Build a `ValidatorSet` from signers with equal power (1 each).
80    pub fn from_signers(signers: &[&dyn Signer]) -> Self {
81        let validators: Vec<ValidatorInfo> = signers
82            .iter()
83            .map(|s| ValidatorInfo {
84                id: s.validator_id(),
85                public_key: s.public_key(),
86                power: 1,
87            })
88            .collect();
89        Self::new(validators)
90    }
91
92    /// Rebuild the index map after deserialization.
93    ///
94    /// NOTE: This is now called automatically during deserialization.
95    /// You only need to call this manually if you modify the validators
96    /// list directly.
97    pub fn rebuild_index(&mut self) {
98        self.index_map = self
99            .validators
100            .iter()
101            .enumerate()
102            .map(|(i, v)| (v.id, i))
103            .collect();
104    }
105
106    pub fn validators(&self) -> &[ValidatorInfo] {
107        &self.validators
108    }
109
110    pub fn total_power(&self) -> u64 {
111        self.total_power
112    }
113
114    /// Quorum threshold: ceil(2n/3) where n = total_power
115    pub fn quorum_threshold(&self) -> u64 {
116        (self.total_power * 2).div_ceil(3)
117    }
118
119    /// Maximum faulty power: total_power - quorum_threshold
120    pub fn max_faulty_power(&self) -> u64 {
121        self.total_power - self.quorum_threshold()
122    }
123
124    /// Round-robin leader selection: v mod n
125    pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
126        let idx = (view.as_u64() as usize) % self.validators.len();
127        &self.validators[idx]
128    }
129
130    pub fn validator_count(&self) -> usize {
131        self.validators.len()
132    }
133
134    /// O(1) index lookup
135    pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
136        self.index_map.get(&id).copied()
137    }
138
139    /// O(1) validator info lookup
140    pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
141        self.index_map.get(&id).map(|&idx| &self.validators[idx])
142    }
143
144    pub fn power_of(&self, id: ValidatorId) -> u64 {
145        self.get(id).map_or(0, |v| v.power)
146    }
147
148    /// Apply validator updates and return a new ValidatorSet.
149    /// - `power > 0`: update existing validator's power/key, or add new validator
150    /// - `power == 0`: remove validator
151    pub fn apply_updates(
152        &self,
153        updates: &[crate::validator_update::ValidatorUpdate],
154    ) -> ValidatorSet {
155        let mut infos: Vec<ValidatorInfo> = self.validators.clone();
156
157        for update in updates {
158            if update.power == 0 {
159                infos.retain(|v| v.id != update.id);
160            } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
161                existing.power = update.power;
162                existing.public_key = update.public_key.clone();
163            } else {
164                infos.push(ValidatorInfo {
165                    id: update.id,
166                    public_key: update.public_key.clone(),
167                    power: update.power,
168                });
169            }
170        }
171
172        ValidatorSet::new(infos)
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    fn make_vs(powers: &[u64]) -> ValidatorSet {
181        let validators: Vec<ValidatorInfo> = powers
182            .iter()
183            .enumerate()
184            .map(|(i, &p)| ValidatorInfo {
185                id: ValidatorId(i as u64),
186                public_key: PublicKey(vec![i as u8]),
187                power: p,
188            })
189            .collect();
190        ValidatorSet::new(validators)
191    }
192
193    #[test]
194    fn test_quorum_4_equal() {
195        let vs = make_vs(&[1, 1, 1, 1]);
196        assert_eq!(vs.total_power(), 4);
197        assert_eq!(vs.quorum_threshold(), 3);
198        assert_eq!(vs.max_faulty_power(), 1);
199    }
200
201    #[test]
202    fn test_quorum_3_equal() {
203        let vs = make_vs(&[1, 1, 1]);
204        assert_eq!(vs.quorum_threshold(), 2);
205        assert_eq!(vs.max_faulty_power(), 1);
206    }
207
208    #[test]
209    fn test_quorum_weighted() {
210        let vs = make_vs(&[10, 10, 10, 70]);
211        assert_eq!(vs.quorum_threshold(), 67);
212        assert_eq!(vs.max_faulty_power(), 33);
213    }
214
215    #[test]
216    fn test_quorum_single_validator() {
217        let vs = make_vs(&[1]);
218        assert_eq!(vs.quorum_threshold(), 1);
219        assert_eq!(vs.max_faulty_power(), 0);
220    }
221
222    #[test]
223    fn test_leader_rotation() {
224        let vs = make_vs(&[1, 1, 1, 1]);
225        assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
226        assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
227        assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
228        assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
229    }
230
231    #[test]
232    fn test_index_of_o1() {
233        let vs = make_vs(&[5, 10, 15]);
234        assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
235        assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
236        assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
237        assert_eq!(vs.index_of(ValidatorId(99)), None);
238    }
239
240    #[test]
241    fn test_get_and_power_of() {
242        let vs = make_vs(&[5, 10, 15]);
243        assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
244        assert!(vs.get(ValidatorId(99)).is_none());
245        assert_eq!(vs.power_of(ValidatorId(2)), 15);
246        assert_eq!(vs.power_of(ValidatorId(99)), 0);
247    }
248
249    #[test]
250    fn test_apply_updates_add_validator() {
251        let vs = make_vs(&[1, 1, 1]);
252        let updates = vec![crate::validator_update::ValidatorUpdate {
253            id: ValidatorId(3),
254            public_key: PublicKey(vec![3]),
255            power: 2,
256        }];
257        let new_vs = vs.apply_updates(&updates);
258        assert_eq!(new_vs.validator_count(), 4);
259        assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
260        assert_eq!(new_vs.total_power(), 5);
261    }
262
263    #[test]
264    fn test_apply_updates_remove_validator() {
265        let vs = make_vs(&[1, 1, 1, 1]);
266        let updates = vec![crate::validator_update::ValidatorUpdate {
267            id: ValidatorId(2),
268            public_key: PublicKey(vec![2]),
269            power: 0,
270        }];
271        let new_vs = vs.apply_updates(&updates);
272        assert_eq!(new_vs.validator_count(), 3);
273        assert!(new_vs.get(ValidatorId(2)).is_none());
274        assert_eq!(new_vs.total_power(), 3);
275    }
276
277    #[test]
278    fn test_apply_updates_change_power() {
279        let vs = make_vs(&[1, 1, 1, 1]);
280        let updates = vec![crate::validator_update::ValidatorUpdate {
281            id: ValidatorId(0),
282            public_key: PublicKey(vec![0]),
283            power: 10,
284        }];
285        let new_vs = vs.apply_updates(&updates);
286        assert_eq!(new_vs.validator_count(), 4);
287        assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
288        assert_eq!(new_vs.total_power(), 13);
289    }
290
291    #[test]
292    fn test_serialization_roundtrip() {
293        let vs = make_vs(&[1, 2, 3]);
294        let bytes = serde_cbor_2::to_vec(&vs).unwrap();
295        let vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
296        // index_map is auto-rebuilt during deserialization
297        assert_eq!(vs2.validator_count(), 3);
298        assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
299        assert_eq!(vs2.total_power(), 6);
300    }
301}