Skip to main content

hotmint_types/
validator.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::PublicKey;
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "V{}", self.0)
14    }
15}
16
17impl From<u64> for ValidatorId {
18    fn from(v: u64) -> Self {
19        Self(v)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25    pub id: ValidatorId,
26    pub public_key: PublicKey,
27    pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ValidatorSet {
32    validators: Vec<ValidatorInfo>,
33    total_power: u64,
34    /// O(1) lookup: ValidatorId -> index in validators vec
35    #[serde(skip)]
36    index_map: HashMap<ValidatorId, usize>,
37}
38
39impl ValidatorSet {
40    pub fn new(validators: Vec<ValidatorInfo>) -> Self {
41        let total_power = validators.iter().map(|v| v.power).sum();
42        let index_map = validators
43            .iter()
44            .enumerate()
45            .map(|(i, v)| (v.id, i))
46            .collect();
47        Self {
48            validators,
49            total_power,
50            index_map,
51        }
52    }
53
54    /// Rebuild the index map after deserialization
55    pub fn rebuild_index(&mut self) {
56        self.index_map = self
57            .validators
58            .iter()
59            .enumerate()
60            .map(|(i, v)| (v.id, i))
61            .collect();
62    }
63
64    pub fn validators(&self) -> &[ValidatorInfo] {
65        &self.validators
66    }
67
68    pub fn total_power(&self) -> u64 {
69        self.total_power
70    }
71
72    /// Quorum threshold: ceil(2n/3) where n = total_power
73    pub fn quorum_threshold(&self) -> u64 {
74        (self.total_power * 2).div_ceil(3)
75    }
76
77    /// Maximum faulty power: total_power - quorum_threshold
78    pub fn max_faulty_power(&self) -> u64 {
79        self.total_power - self.quorum_threshold()
80    }
81
82    /// Round-robin leader selection: v mod n
83    pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
84        let idx = (view.as_u64() as usize) % self.validators.len();
85        &self.validators[idx]
86    }
87
88    pub fn validator_count(&self) -> usize {
89        self.validators.len()
90    }
91
92    /// O(1) index lookup
93    pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
94        self.index_map.get(&id).copied()
95    }
96
97    /// O(1) validator info lookup
98    pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
99        self.index_map.get(&id).map(|&idx| &self.validators[idx])
100    }
101
102    pub fn power_of(&self, id: ValidatorId) -> u64 {
103        self.get(id).map_or(0, |v| v.power)
104    }
105
106    /// Apply validator updates and return a new ValidatorSet.
107    /// - `power > 0`: update existing validator's power/key, or add new validator
108    /// - `power == 0`: remove validator
109    pub fn apply_updates(
110        &self,
111        updates: &[crate::validator_update::ValidatorUpdate],
112    ) -> ValidatorSet {
113        let mut infos: Vec<ValidatorInfo> = self.validators.clone();
114
115        for update in updates {
116            if update.power == 0 {
117                infos.retain(|v| v.id != update.id);
118            } else if let Some(existing) = infos.iter_mut().find(|v| v.id == update.id) {
119                existing.power = update.power;
120                existing.public_key = update.public_key.clone();
121            } else {
122                infos.push(ValidatorInfo {
123                    id: update.id,
124                    public_key: update.public_key.clone(),
125                    power: update.power,
126                });
127            }
128        }
129
130        ValidatorSet::new(infos)
131    }
132}
133
134#[cfg(test)]
135mod tests {
136    use super::*;
137
138    fn make_vs(powers: &[u64]) -> ValidatorSet {
139        let validators: Vec<ValidatorInfo> = powers
140            .iter()
141            .enumerate()
142            .map(|(i, &p)| ValidatorInfo {
143                id: ValidatorId(i as u64),
144                public_key: PublicKey(vec![i as u8]),
145                power: p,
146            })
147            .collect();
148        ValidatorSet::new(validators)
149    }
150
151    #[test]
152    fn test_quorum_4_equal() {
153        let vs = make_vs(&[1, 1, 1, 1]);
154        assert_eq!(vs.total_power(), 4);
155        assert_eq!(vs.quorum_threshold(), 3);
156        assert_eq!(vs.max_faulty_power(), 1);
157    }
158
159    #[test]
160    fn test_quorum_3_equal() {
161        let vs = make_vs(&[1, 1, 1]);
162        assert_eq!(vs.quorum_threshold(), 2);
163        assert_eq!(vs.max_faulty_power(), 1);
164    }
165
166    #[test]
167    fn test_quorum_weighted() {
168        let vs = make_vs(&[10, 10, 10, 70]);
169        assert_eq!(vs.quorum_threshold(), 67);
170        assert_eq!(vs.max_faulty_power(), 33);
171    }
172
173    #[test]
174    fn test_quorum_single_validator() {
175        let vs = make_vs(&[1]);
176        assert_eq!(vs.quorum_threshold(), 1);
177        assert_eq!(vs.max_faulty_power(), 0);
178    }
179
180    #[test]
181    fn test_leader_rotation() {
182        let vs = make_vs(&[1, 1, 1, 1]);
183        assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
184        assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
185        assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
186        assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
187    }
188
189    #[test]
190    fn test_index_of_o1() {
191        let vs = make_vs(&[5, 10, 15]);
192        assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
193        assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
194        assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
195        assert_eq!(vs.index_of(ValidatorId(99)), None);
196    }
197
198    #[test]
199    fn test_get_and_power_of() {
200        let vs = make_vs(&[5, 10, 15]);
201        assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
202        assert!(vs.get(ValidatorId(99)).is_none());
203        assert_eq!(vs.power_of(ValidatorId(2)), 15);
204        assert_eq!(vs.power_of(ValidatorId(99)), 0);
205    }
206
207    #[test]
208    fn test_apply_updates_add_validator() {
209        let vs = make_vs(&[1, 1, 1]);
210        let updates = vec![crate::validator_update::ValidatorUpdate {
211            id: ValidatorId(3),
212            public_key: PublicKey(vec![3]),
213            power: 2,
214        }];
215        let new_vs = vs.apply_updates(&updates);
216        assert_eq!(new_vs.validator_count(), 4);
217        assert_eq!(new_vs.power_of(ValidatorId(3)), 2);
218        assert_eq!(new_vs.total_power(), 5);
219    }
220
221    #[test]
222    fn test_apply_updates_remove_validator() {
223        let vs = make_vs(&[1, 1, 1, 1]);
224        let updates = vec![crate::validator_update::ValidatorUpdate {
225            id: ValidatorId(2),
226            public_key: PublicKey(vec![2]),
227            power: 0,
228        }];
229        let new_vs = vs.apply_updates(&updates);
230        assert_eq!(new_vs.validator_count(), 3);
231        assert!(new_vs.get(ValidatorId(2)).is_none());
232        assert_eq!(new_vs.total_power(), 3);
233    }
234
235    #[test]
236    fn test_apply_updates_change_power() {
237        let vs = make_vs(&[1, 1, 1, 1]);
238        let updates = vec![crate::validator_update::ValidatorUpdate {
239            id: ValidatorId(0),
240            public_key: PublicKey(vec![0]),
241            power: 10,
242        }];
243        let new_vs = vs.apply_updates(&updates);
244        assert_eq!(new_vs.validator_count(), 4);
245        assert_eq!(new_vs.power_of(ValidatorId(0)), 10);
246        assert_eq!(new_vs.total_power(), 13);
247    }
248
249    #[test]
250    fn test_serialization_roundtrip() {
251        let vs = make_vs(&[1, 2, 3]);
252        let bytes = serde_cbor_2::to_vec(&vs).unwrap();
253        let mut vs2: ValidatorSet = serde_cbor_2::from_slice(&bytes).unwrap();
254        vs2.rebuild_index();
255        assert_eq!(vs2.validator_count(), 3);
256        assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
257        assert_eq!(vs2.total_power(), 6);
258    }
259}