aleph_bft_crypto/
node.rs

1use codec::{Decode, Encode, Error, Input, Output};
2use derive_more::{Add, AddAssign, From, Into, Sub, SubAssign, Sum};
3use std::{
4    collections::HashMap,
5    fmt,
6    hash::Hash,
7    ops::{Div, Index as StdIndex, Mul},
8    vec,
9};
10
11/// The index of a node
12#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, From, Into)]
13pub struct NodeIndex(pub usize);
14
15impl Encode for NodeIndex {
16    fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
17        let val = self.0 as u64;
18        let bytes = val.to_le_bytes();
19        dest.write(&bytes);
20    }
21}
22
23impl Decode for NodeIndex {
24    fn decode<I: Input>(value: &mut I) -> Result<Self, Error> {
25        let mut arr = [0u8; 8];
26        value.read(&mut arr)?;
27        let val: u64 = u64::from_le_bytes(arr);
28        Ok(NodeIndex(val as usize))
29    }
30}
31
32/// Indicates that an implementor has been assigned some index.
33pub trait Index {
34    fn index(&self) -> NodeIndex;
35}
36
37/// Node count. Right now it doubles as node weight in many places in the code, in the future we
38/// might need a new type for that.
39#[derive(
40    Copy,
41    Clone,
42    Eq,
43    PartialEq,
44    Ord,
45    PartialOrd,
46    Hash,
47    Debug,
48    Default,
49    Add,
50    AddAssign,
51    From,
52    Into,
53    Sub,
54    SubAssign,
55    Sum,
56)]
57pub struct NodeCount(pub usize);
58
59// deriving Mul and Div is somehow cumbersome
60impl Mul<usize> for NodeCount {
61    type Output = Self;
62    fn mul(self, rhs: usize) -> Self::Output {
63        NodeCount(self.0 * rhs)
64    }
65}
66
67impl Div<usize> for NodeCount {
68    type Output = Self;
69    fn div(self, rhs: usize) -> Self::Output {
70        NodeCount(self.0 / rhs)
71    }
72}
73
74impl NodeCount {
75    pub fn into_range(self) -> core::ops::Range<NodeIndex> {
76        core::ops::Range {
77            start: 0.into(),
78            end: self.0.into(),
79        }
80    }
81
82    pub fn into_iterator(self) -> impl Iterator<Item = NodeIndex> {
83        (0..self.0).map(NodeIndex)
84    }
85
86    /// If this is the total node count, what number of nodes is required for secure consensus.
87    pub fn consensus_threshold(&self) -> NodeCount {
88        (*self * 2) / 3 + NodeCount(1)
89    }
90}
91
92/// A container keeping items indexed by NodeIndex.
93#[derive(Clone, Eq, PartialEq, Hash, Debug, Default, Decode, Encode, From)]
94pub struct NodeMap<T>(Vec<Option<T>>);
95
96impl<T> NodeMap<T> {
97    /// Constructs a new node map with a given length.
98    pub fn with_size(len: NodeCount) -> Self
99    where
100        T: Clone,
101    {
102        let v = vec![None; len.into()];
103        NodeMap(v)
104    }
105
106    pub fn from_hashmap(len: NodeCount, hashmap: HashMap<NodeIndex, T>) -> Self
107    where
108        T: Clone,
109    {
110        let v = vec![None; len.into()];
111        let mut nm = NodeMap(v);
112        for (id, item) in hashmap.into_iter() {
113            nm.insert(id, item);
114        }
115        nm
116    }
117
118    pub fn size(&self) -> NodeCount {
119        self.0.len().into()
120    }
121
122    pub fn iter(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
123        self.0
124            .iter()
125            .enumerate()
126            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_ref()?)))
127    }
128
129    pub fn iter_mut(&mut self) -> impl Iterator<Item = (NodeIndex, &mut T)> {
130        self.0
131            .iter_mut()
132            .enumerate()
133            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_mut()?)))
134    }
135
136    fn into_iter(self) -> impl Iterator<Item = (NodeIndex, T)>
137    where
138        T: 'static,
139    {
140        self.0
141            .into_iter()
142            .enumerate()
143            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value?)))
144    }
145
146    pub fn values(&self) -> impl Iterator<Item = &T> {
147        self.iter().map(|(_, value)| value)
148    }
149
150    pub fn into_values(self) -> impl Iterator<Item = T>
151    where
152        T: 'static,
153    {
154        self.into_iter().map(|(_, value)| value)
155    }
156
157    pub fn get(&self, node_id: NodeIndex) -> Option<&T> {
158        self.0[node_id.0].as_ref()
159    }
160
161    pub fn get_mut(&mut self, node_id: NodeIndex) -> Option<&mut T> {
162        self.0[node_id.0].as_mut()
163    }
164
165    pub fn insert(&mut self, node_id: NodeIndex, value: T) {
166        self.0[node_id.0] = Some(value)
167    }
168
169    pub fn to_subset(&self) -> NodeSubset {
170        NodeSubset(self.0.iter().map(Option::is_some).collect())
171    }
172
173    pub fn item_count(&self) -> usize {
174        self.iter().count()
175    }
176}
177
178impl<T: 'static> IntoIterator for NodeMap<T> {
179    type Item = (NodeIndex, T);
180    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, T)>>;
181    fn into_iter(self) -> Self::IntoIter {
182        Box::new(self.into_iter())
183    }
184}
185
186impl<'a, T> IntoIterator for &'a NodeMap<T> {
187    type Item = (NodeIndex, &'a T);
188    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a T)> + 'a>;
189    fn into_iter(self) -> Self::IntoIter {
190        Box::new(self.iter())
191    }
192}
193
194impl<'a, T> IntoIterator for &'a mut NodeMap<T> {
195    type Item = (NodeIndex, &'a mut T);
196    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a mut T)> + 'a>;
197    fn into_iter(self) -> Self::IntoIter {
198        Box::new(self.iter_mut())
199    }
200}
201
202impl<T: fmt::Display> fmt::Display for NodeMap<T> {
203    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204        write!(f, "[")?;
205        let mut it = self.iter().peekable();
206        while let Some((id, item)) = it.next() {
207            write!(f, "({}, {})", id.0, item)?;
208            if it.peek().is_some() {
209                write!(f, ", ")?;
210            }
211        }
212        write!(f, "]")?;
213        Ok(())
214    }
215}
216
217#[derive(Clone, Eq, PartialEq, Hash, Debug, Default)]
218pub struct NodeSubset(bit_vec::BitVec<u32>);
219
220impl NodeSubset {
221    pub fn with_size(capacity: NodeCount) -> Self {
222        NodeSubset(bit_vec::BitVec::from_elem(capacity.0, false))
223    }
224
225    pub fn insert(&mut self, i: NodeIndex) {
226        self.0.set(i.0, true);
227    }
228
229    pub fn size(&self) -> usize {
230        self.0.len()
231    }
232
233    pub fn elements(&self) -> impl Iterator<Item = NodeIndex> + '_ {
234        self.0
235            .iter()
236            .enumerate()
237            .filter_map(|(i, b)| if b { Some(i.into()) } else { None })
238    }
239
240    pub fn len(&self) -> usize {
241        self.elements().count()
242    }
243
244    pub fn is_empty(&self) -> bool {
245        self.len() == 0
246    }
247}
248
249impl Encode for NodeSubset {
250    fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
251        (self.0.len() as u32).encode_to(dest);
252        self.0.to_bytes().encode_to(dest);
253    }
254}
255
256impl Decode for NodeSubset {
257    fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
258        let capacity = u32::decode(input)? as usize;
259        let bytes = Vec::decode(input)?;
260        let mut bv = bit_vec::BitVec::from_bytes(&bytes);
261        // Length should be capacity rounded up to the closest multiple of 8
262        if bv.len() != 8 * ((capacity + 7) / 8) {
263            return Err(Error::from(
264                "Length of bitvector inconsistent with encoded capacity.",
265            ));
266        }
267        while bv.len() > capacity {
268            if bv.pop() != Some(false) {
269                return Err(Error::from(
270                    "Non-canonical encoding. Trailing bits should be all 0.",
271                ));
272            }
273        }
274        bv.truncate(capacity);
275        Ok(NodeSubset(bv))
276    }
277}
278
279impl StdIndex<NodeIndex> for NodeSubset {
280    type Output = bool;
281
282    fn index(&self, vidx: NodeIndex) -> &bool {
283        &self.0[vidx.0]
284    }
285}
286
287impl fmt::Display for NodeSubset {
288    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289        let mut v: Vec<usize> = self.elements().map(|n| n.into()).collect();
290        v.sort();
291        write!(f, "{:?}", v)
292    }
293}
294
295#[cfg(test)]
296mod tests {
297
298    use crate::node::{NodeIndex, NodeSubset};
299    use codec::{Decode, Encode};
300    #[test]
301    fn decoding_node_index_works() {
302        for i in 0..1000 {
303            let node_index = NodeIndex(i);
304            let mut encoded: &[u8] = &node_index.encode();
305            let decoded = NodeIndex::decode(&mut encoded);
306            assert_eq!(node_index, decoded.unwrap());
307        }
308    }
309
310    #[test]
311    fn bool_node_map_decoding_works() {
312        for len in 0..12 {
313            for mask in 0..(1 << len) {
314                let mut bnm = NodeSubset::with_size(len.into());
315                for i in 0..len {
316                    if (1 << i) & mask != 0 {
317                        bnm.insert(i.into());
318                    }
319                }
320                let encoded: Vec<_> = bnm.encode();
321                let decoded =
322                    NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
323                assert!(decoded == bnm);
324            }
325        }
326    }
327
328    #[test]
329    fn bool_node_map_decoding_deals_with_trailing_zeros() {
330        let mut encoded = vec![1, 0, 0, 0];
331        encoded.extend(vec![128u8].encode());
332        //128 encodes bit-vec 10000000
333        let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
334        assert_eq!(decoded, NodeSubset([true].iter().cloned().collect()));
335
336        let mut encoded = vec![1, 0, 0, 0];
337        encoded.extend(vec![129u8].encode());
338        //129 encodes bit-vec 10000001
339        assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
340    }
341
342    #[test]
343    fn bool_node_map_decoding_deals_with_too_long_bitvec() {
344        let mut encoded = vec![1, 0, 0, 0];
345        encoded.extend(vec![128u8, 0].encode());
346        //[128, 0] encodes bit-vec 1000000000000000
347        assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
348    }
349
350    #[test]
351    fn decoding_bool_node_map_works() {
352        let bool_node_map = NodeSubset([true, false, true, true, true].iter().cloned().collect());
353        let encoded: Vec<_> = bool_node_map.encode();
354        let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
355        assert_eq!(decoded, bool_node_map);
356    }
357
358    #[test]
359    fn test_bool_node_map_has_efficient_encoding() {
360        let mut bnm = NodeSubset::with_size(100.into());
361        for i in 0..50 {
362            bnm.insert(i.into())
363        }
364        assert!(bnm.encode().len() < 20);
365    }
366}