aleph_bft_crypto/
node.rs

1use codec::{Decode, Encode, Error, Input, Output};
2use derive_more::{Add, AddAssign, From, Into, Sub, SubAssign, Sum};
3use std::{
4    collections::HashMap,
5    fmt,
6    hash::Hash,
7    ops::{Div, Index as StdIndex, Mul},
8    vec,
9};
10
11/// The index of a node
12#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, From, Into)]
13pub struct NodeIndex(pub usize);
14
15impl Encode for NodeIndex {
16    fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
17        let val = self.0 as u64;
18        let bytes = val.to_le_bytes();
19        dest.write(&bytes);
20    }
21}
22
23impl Decode for NodeIndex {
24    fn decode<I: Input>(value: &mut I) -> Result<Self, Error> {
25        let mut arr = [0u8; 8];
26        value.read(&mut arr)?;
27        let val: u64 = u64::from_le_bytes(arr);
28        Ok(NodeIndex(val as usize))
29    }
30}
31
32/// Indicates that an implementor has been assigned some index.
33pub trait Index {
34    fn index(&self) -> NodeIndex;
35}
36
37/// Node count. Right now it doubles as node weight in many places in the code, in the future we
38/// might need a new type for that.
39#[derive(
40    Copy,
41    Clone,
42    Eq,
43    PartialEq,
44    Ord,
45    PartialOrd,
46    Hash,
47    Debug,
48    Default,
49    Add,
50    AddAssign,
51    From,
52    Into,
53    Sub,
54    SubAssign,
55    Sum,
56)]
57pub struct NodeCount(pub usize);
58
59// deriving Mul and Div is somehow cumbersome
60impl Mul<usize> for NodeCount {
61    type Output = Self;
62    fn mul(self, rhs: usize) -> Self::Output {
63        NodeCount(self.0 * rhs)
64    }
65}
66
67impl Div<usize> for NodeCount {
68    type Output = Self;
69    fn div(self, rhs: usize) -> Self::Output {
70        NodeCount(self.0 / rhs)
71    }
72}
73
74impl NodeCount {
75    pub fn into_range(self) -> core::ops::Range<NodeIndex> {
76        core::ops::Range {
77            start: 0.into(),
78            end: self.0.into(),
79        }
80    }
81
82    pub fn into_iterator(self) -> impl Iterator<Item = NodeIndex> {
83        (0..self.0).map(NodeIndex)
84    }
85
86    /// If this is the total node count, what number of nodes is required for secure consensus.
87    pub fn consensus_threshold(&self) -> NodeCount {
88        (*self * 2) / 3 + NodeCount(1)
89    }
90}
91
92/// A container keeping items indexed by NodeIndex.
93#[derive(Clone, Eq, PartialEq, Hash, Debug, Default, Decode, Encode, From)]
94pub struct NodeMap<T>(Vec<Option<T>>);
95
96impl<T> NodeMap<T> {
97    /// Constructs a new node map with a given length.
98    pub fn with_size(len: NodeCount) -> Self
99    where
100        T: Clone,
101    {
102        let v = vec![None; len.into()];
103        NodeMap(v)
104    }
105
106    pub fn from_hashmap(len: NodeCount, hashmap: HashMap<NodeIndex, T>) -> Self
107    where
108        T: Clone,
109    {
110        let v = vec![None; len.into()];
111        let mut nm = NodeMap(v);
112        for (id, item) in hashmap.into_iter() {
113            nm.insert(id, item);
114        }
115        nm
116    }
117
118    pub fn size(&self) -> NodeCount {
119        self.0.len().into()
120    }
121
122    pub fn iter(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
123        self.0
124            .iter()
125            .enumerate()
126            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_ref()?)))
127    }
128
129    pub fn iter_mut(&mut self) -> impl Iterator<Item = (NodeIndex, &mut T)> {
130        self.0
131            .iter_mut()
132            .enumerate()
133            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_mut()?)))
134    }
135
136    fn into_iter(self) -> impl Iterator<Item = (NodeIndex, T)>
137    where
138        T: 'static,
139    {
140        self.0
141            .into_iter()
142            .enumerate()
143            .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value?)))
144    }
145
146    pub fn values(&self) -> impl Iterator<Item = &T> {
147        self.iter().map(|(_, value)| value)
148    }
149
150    pub fn into_values(self) -> impl Iterator<Item = T>
151    where
152        T: 'static,
153    {
154        self.into_iter().map(|(_, value)| value)
155    }
156
157    pub fn get(&self, node_id: NodeIndex) -> Option<&T> {
158        self.0[node_id.0].as_ref()
159    }
160
161    pub fn get_mut(&mut self, node_id: NodeIndex) -> Option<&mut T> {
162        self.0[node_id.0].as_mut()
163    }
164
165    pub fn insert(&mut self, node_id: NodeIndex, value: T) {
166        self.0[node_id.0] = Some(value)
167    }
168
169    pub fn delete(&mut self, node_id: NodeIndex) {
170        self.0[node_id.0] = None
171    }
172
173    pub fn to_subset(&self) -> NodeSubset {
174        NodeSubset(self.0.iter().map(Option::is_some).collect())
175    }
176
177    pub fn item_count(&self) -> usize {
178        self.iter().count()
179    }
180}
181
182impl<T: 'static> IntoIterator for NodeMap<T> {
183    type Item = (NodeIndex, T);
184    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, T)>>;
185    fn into_iter(self) -> Self::IntoIter {
186        Box::new(self.into_iter())
187    }
188}
189
190impl<'a, T> IntoIterator for &'a NodeMap<T> {
191    type Item = (NodeIndex, &'a T);
192    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a T)> + 'a>;
193    fn into_iter(self) -> Self::IntoIter {
194        Box::new(self.iter())
195    }
196}
197
198impl<'a, T> IntoIterator for &'a mut NodeMap<T> {
199    type Item = (NodeIndex, &'a mut T);
200    type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a mut T)> + 'a>;
201    fn into_iter(self) -> Self::IntoIter {
202        Box::new(self.iter_mut())
203    }
204}
205
206impl<T: fmt::Display> fmt::Display for NodeMap<T> {
207    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
208        write!(f, "[")?;
209        let mut it = self.iter().peekable();
210        while let Some((id, item)) = it.next() {
211            write!(f, "({}, {})", id.0, item)?;
212            if it.peek().is_some() {
213                write!(f, ", ")?;
214            }
215        }
216        write!(f, "]")?;
217        Ok(())
218    }
219}
220
221#[derive(Clone, Eq, PartialEq, Hash, Debug, Default)]
222pub struct NodeSubset(bit_vec::BitVec<u32>);
223
224impl NodeSubset {
225    pub fn with_size(capacity: NodeCount) -> Self {
226        NodeSubset(bit_vec::BitVec::from_elem(capacity.0, false))
227    }
228
229    pub fn insert(&mut self, i: NodeIndex) {
230        self.0.set(i.0, true);
231    }
232
233    pub fn size(&self) -> usize {
234        self.0.len()
235    }
236
237    pub fn elements(&self) -> impl Iterator<Item = NodeIndex> + '_ {
238        self.0
239            .iter()
240            .enumerate()
241            .filter_map(|(i, b)| if b { Some(i.into()) } else { None })
242    }
243
244    pub fn complement(&self) -> NodeSubset {
245        let mut result = self.0.clone();
246        result.negate();
247        NodeSubset(result)
248    }
249
250    pub fn len(&self) -> usize {
251        self.elements().count()
252    }
253
254    pub fn is_empty(&self) -> bool {
255        self.len() == 0
256    }
257}
258
259impl Encode for NodeSubset {
260    fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
261        (self.0.len() as u32).encode_to(dest);
262        self.0.to_bytes().encode_to(dest);
263    }
264}
265
266impl Decode for NodeSubset {
267    fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
268        let capacity = u32::decode(input)? as usize;
269        let bytes = Vec::decode(input)?;
270        let mut bv = bit_vec::BitVec::from_bytes(&bytes);
271        // Length should be capacity rounded up to the closest multiple of 8
272        if bv.len() != 8 * ((capacity + 7) / 8) {
273            return Err(Error::from(
274                "Length of bitvector inconsistent with encoded capacity.",
275            ));
276        }
277        while bv.len() > capacity {
278            if bv.pop() != Some(false) {
279                return Err(Error::from(
280                    "Non-canonical encoding. Trailing bits should be all 0.",
281                ));
282            }
283        }
284        bv.truncate(capacity);
285        Ok(NodeSubset(bv))
286    }
287}
288
289impl StdIndex<NodeIndex> for NodeSubset {
290    type Output = bool;
291
292    fn index(&self, vidx: NodeIndex) -> &bool {
293        &self.0[vidx.0]
294    }
295}
296
297impl fmt::Display for NodeSubset {
298    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299        let mut v: Vec<usize> = self.elements().map(|n| n.into()).collect();
300        v.sort();
301        write!(f, "{:?}", v)
302    }
303}
304
305#[cfg(test)]
306mod tests {
307
308    use crate::node::{NodeIndex, NodeSubset};
309    use codec::{Decode, Encode};
310    #[test]
311    fn decoding_node_index_works() {
312        for i in 0..1000 {
313            let node_index = NodeIndex(i);
314            let mut encoded: &[u8] = &node_index.encode();
315            let decoded = NodeIndex::decode(&mut encoded);
316            assert_eq!(node_index, decoded.unwrap());
317        }
318    }
319
320    #[test]
321    fn bool_node_map_decoding_works() {
322        for len in 0..12 {
323            for mask in 0..(1 << len) {
324                let mut bnm = NodeSubset::with_size(len.into());
325                for i in 0..len {
326                    if (1 << i) & mask != 0 {
327                        bnm.insert(i.into());
328                    }
329                }
330                let encoded: Vec<_> = bnm.encode();
331                let decoded =
332                    NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
333                assert!(decoded == bnm);
334            }
335        }
336    }
337
338    #[test]
339    fn bool_node_map_decoding_deals_with_trailing_zeros() {
340        let mut encoded = vec![1, 0, 0, 0];
341        encoded.extend(vec![128u8].encode());
342        //128 encodes bit-vec 10000000
343        let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
344        assert_eq!(decoded, NodeSubset([true].iter().cloned().collect()));
345
346        let mut encoded = vec![1, 0, 0, 0];
347        encoded.extend(vec![129u8].encode());
348        //129 encodes bit-vec 10000001
349        assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
350    }
351
352    #[test]
353    fn bool_node_map_decoding_deals_with_too_long_bitvec() {
354        let mut encoded = vec![1, 0, 0, 0];
355        encoded.extend(vec![128u8, 0].encode());
356        //[128, 0] encodes bit-vec 1000000000000000
357        assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
358    }
359
360    #[test]
361    fn decoding_bool_node_map_works() {
362        let bool_node_map = NodeSubset([true, false, true, true, true].iter().cloned().collect());
363        let encoded: Vec<_> = bool_node_map.encode();
364        let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
365        assert_eq!(decoded, bool_node_map);
366    }
367
368    #[test]
369    fn test_bool_node_map_has_efficient_encoding() {
370        let mut bnm = NodeSubset::with_size(100.into());
371        for i in 0..50 {
372            bnm.insert(i.into())
373        }
374        assert!(bnm.encode().len() < 20);
375    }
376}