Skip to main content

hotmint_types/
validator.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::fmt;
4
5use crate::crypto::PublicKey;
6use crate::view::ViewNumber;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
9pub struct ValidatorId(pub u64);
10
11impl fmt::Display for ValidatorId {
12    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
13        write!(f, "V{}", self.0)
14    }
15}
16
17impl From<u64> for ValidatorId {
18    fn from(v: u64) -> Self {
19        Self(v)
20    }
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ValidatorInfo {
25    pub id: ValidatorId,
26    pub public_key: PublicKey,
27    pub power: u64,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct ValidatorSet {
32    validators: Vec<ValidatorInfo>,
33    total_power: u64,
34    /// O(1) lookup: ValidatorId -> index in validators vec
35    #[serde(skip)]
36    index_map: HashMap<ValidatorId, usize>,
37}
38
39impl ValidatorSet {
40    pub fn new(validators: Vec<ValidatorInfo>) -> Self {
41        let total_power = validators.iter().map(|v| v.power).sum();
42        let index_map = validators
43            .iter()
44            .enumerate()
45            .map(|(i, v)| (v.id, i))
46            .collect();
47        Self {
48            validators,
49            total_power,
50            index_map,
51        }
52    }
53
54    /// Rebuild the index map after deserialization
55    pub fn rebuild_index(&mut self) {
56        self.index_map = self
57            .validators
58            .iter()
59            .enumerate()
60            .map(|(i, v)| (v.id, i))
61            .collect();
62    }
63
64    pub fn validators(&self) -> &[ValidatorInfo] {
65        &self.validators
66    }
67
68    pub fn total_power(&self) -> u64 {
69        self.total_power
70    }
71
72    /// Quorum threshold: ceil(2n/3) where n = total_power
73    pub fn quorum_threshold(&self) -> u64 {
74        (self.total_power * 2).div_ceil(3)
75    }
76
77    /// Maximum faulty power: total_power - quorum_threshold
78    pub fn max_faulty_power(&self) -> u64 {
79        self.total_power - self.quorum_threshold()
80    }
81
82    /// Round-robin leader selection: v mod n
83    pub fn leader_for_view(&self, view: ViewNumber) -> &ValidatorInfo {
84        let idx = (view.as_u64() as usize) % self.validators.len();
85        &self.validators[idx]
86    }
87
88    pub fn validator_count(&self) -> usize {
89        self.validators.len()
90    }
91
92    /// O(1) index lookup
93    pub fn index_of(&self, id: ValidatorId) -> Option<usize> {
94        self.index_map.get(&id).copied()
95    }
96
97    /// O(1) validator info lookup
98    pub fn get(&self, id: ValidatorId) -> Option<&ValidatorInfo> {
99        self.index_map.get(&id).map(|&idx| &self.validators[idx])
100    }
101
102    pub fn power_of(&self, id: ValidatorId) -> u64 {
103        self.get(id).map_or(0, |v| v.power)
104    }
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    fn make_vs(powers: &[u64]) -> ValidatorSet {
112        let validators: Vec<ValidatorInfo> = powers
113            .iter()
114            .enumerate()
115            .map(|(i, &p)| ValidatorInfo {
116                id: ValidatorId(i as u64),
117                public_key: PublicKey(vec![i as u8]),
118                power: p,
119            })
120            .collect();
121        ValidatorSet::new(validators)
122    }
123
124    #[test]
125    fn test_quorum_4_equal() {
126        let vs = make_vs(&[1, 1, 1, 1]);
127        assert_eq!(vs.total_power(), 4);
128        assert_eq!(vs.quorum_threshold(), 3);
129        assert_eq!(vs.max_faulty_power(), 1);
130    }
131
132    #[test]
133    fn test_quorum_3_equal() {
134        let vs = make_vs(&[1, 1, 1]);
135        assert_eq!(vs.quorum_threshold(), 2);
136        assert_eq!(vs.max_faulty_power(), 1);
137    }
138
139    #[test]
140    fn test_quorum_weighted() {
141        let vs = make_vs(&[10, 10, 10, 70]);
142        assert_eq!(vs.quorum_threshold(), 67);
143        assert_eq!(vs.max_faulty_power(), 33);
144    }
145
146    #[test]
147    fn test_quorum_single_validator() {
148        let vs = make_vs(&[1]);
149        assert_eq!(vs.quorum_threshold(), 1);
150        assert_eq!(vs.max_faulty_power(), 0);
151    }
152
153    #[test]
154    fn test_leader_rotation() {
155        let vs = make_vs(&[1, 1, 1, 1]);
156        assert_eq!(vs.leader_for_view(ViewNumber(0)).id, ValidatorId(0));
157        assert_eq!(vs.leader_for_view(ViewNumber(1)).id, ValidatorId(1));
158        assert_eq!(vs.leader_for_view(ViewNumber(4)).id, ValidatorId(0));
159        assert_eq!(vs.leader_for_view(ViewNumber(7)).id, ValidatorId(3));
160    }
161
162    #[test]
163    fn test_index_of_o1() {
164        let vs = make_vs(&[5, 10, 15]);
165        assert_eq!(vs.index_of(ValidatorId(0)), Some(0));
166        assert_eq!(vs.index_of(ValidatorId(1)), Some(1));
167        assert_eq!(vs.index_of(ValidatorId(2)), Some(2));
168        assert_eq!(vs.index_of(ValidatorId(99)), None);
169    }
170
171    #[test]
172    fn test_get_and_power_of() {
173        let vs = make_vs(&[5, 10, 15]);
174        assert_eq!(vs.get(ValidatorId(1)).unwrap().power, 10);
175        assert!(vs.get(ValidatorId(99)).is_none());
176        assert_eq!(vs.power_of(ValidatorId(2)), 15);
177        assert_eq!(vs.power_of(ValidatorId(99)), 0);
178    }
179
180    #[test]
181    fn test_serialization_roundtrip() {
182        let vs = make_vs(&[1, 2, 3]);
183        let bytes = rmp_serde::to_vec(&vs).unwrap();
184        let mut vs2: ValidatorSet = rmp_serde::from_slice(&bytes).unwrap();
185        vs2.rebuild_index();
186        assert_eq!(vs2.validator_count(), 3);
187        assert_eq!(vs2.index_of(ValidatorId(1)), Some(1));
188        assert_eq!(vs2.total_power(), 6);
189    }
190}