1use codec::{Decode, Encode, Error, Input, Output};
2use derive_more::{Add, AddAssign, From, Into, Sub, SubAssign, Sum};
3use std::{
4 collections::HashMap,
5 fmt,
6 hash::Hash,
7 ops::{Div, Index as StdIndex, Mul},
8 vec,
9};
10
11#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, From, Into)]
13pub struct NodeIndex(pub usize);
14
15impl Encode for NodeIndex {
16 fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
17 let val = self.0 as u64;
18 let bytes = val.to_le_bytes();
19 dest.write(&bytes);
20 }
21}
22
23impl Decode for NodeIndex {
24 fn decode<I: Input>(value: &mut I) -> Result<Self, Error> {
25 let mut arr = [0u8; 8];
26 value.read(&mut arr)?;
27 let val: u64 = u64::from_le_bytes(arr);
28 Ok(NodeIndex(val as usize))
29 }
30}
31
32pub trait Index {
34 fn index(&self) -> NodeIndex;
35}
36
37#[derive(
40 Copy,
41 Clone,
42 Eq,
43 PartialEq,
44 Ord,
45 PartialOrd,
46 Hash,
47 Debug,
48 Default,
49 Add,
50 AddAssign,
51 From,
52 Into,
53 Sub,
54 SubAssign,
55 Sum,
56)]
57pub struct NodeCount(pub usize);
58
59impl Mul<usize> for NodeCount {
61 type Output = Self;
62 fn mul(self, rhs: usize) -> Self::Output {
63 NodeCount(self.0 * rhs)
64 }
65}
66
67impl Div<usize> for NodeCount {
68 type Output = Self;
69 fn div(self, rhs: usize) -> Self::Output {
70 NodeCount(self.0 / rhs)
71 }
72}
73
74impl NodeCount {
75 pub fn into_range(self) -> core::ops::Range<NodeIndex> {
76 core::ops::Range {
77 start: 0.into(),
78 end: self.0.into(),
79 }
80 }
81
82 pub fn into_iterator(self) -> impl Iterator<Item = NodeIndex> {
83 (0..self.0).map(NodeIndex)
84 }
85
86 pub fn consensus_threshold(&self) -> NodeCount {
88 (*self * 2) / 3 + NodeCount(1)
89 }
90}
91
92#[derive(Clone, Eq, PartialEq, Hash, Debug, Default, Decode, Encode, From)]
94pub struct NodeMap<T>(Vec<Option<T>>);
95
96impl<T> NodeMap<T> {
97 pub fn with_size(len: NodeCount) -> Self
99 where
100 T: Clone,
101 {
102 let v = vec![None; len.into()];
103 NodeMap(v)
104 }
105
106 pub fn from_hashmap(len: NodeCount, hashmap: HashMap<NodeIndex, T>) -> Self
107 where
108 T: Clone,
109 {
110 let v = vec![None; len.into()];
111 let mut nm = NodeMap(v);
112 for (id, item) in hashmap.into_iter() {
113 nm.insert(id, item);
114 }
115 nm
116 }
117
118 pub fn size(&self) -> NodeCount {
119 self.0.len().into()
120 }
121
122 pub fn iter(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
123 self.0
124 .iter()
125 .enumerate()
126 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_ref()?)))
127 }
128
129 pub fn iter_mut(&mut self) -> impl Iterator<Item = (NodeIndex, &mut T)> {
130 self.0
131 .iter_mut()
132 .enumerate()
133 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_mut()?)))
134 }
135
136 fn into_iter(self) -> impl Iterator<Item = (NodeIndex, T)>
137 where
138 T: 'static,
139 {
140 self.0
141 .into_iter()
142 .enumerate()
143 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value?)))
144 }
145
146 pub fn values(&self) -> impl Iterator<Item = &T> {
147 self.iter().map(|(_, value)| value)
148 }
149
150 pub fn into_values(self) -> impl Iterator<Item = T>
151 where
152 T: 'static,
153 {
154 self.into_iter().map(|(_, value)| value)
155 }
156
157 pub fn get(&self, node_id: NodeIndex) -> Option<&T> {
158 self.0[node_id.0].as_ref()
159 }
160
161 pub fn get_mut(&mut self, node_id: NodeIndex) -> Option<&mut T> {
162 self.0[node_id.0].as_mut()
163 }
164
165 pub fn insert(&mut self, node_id: NodeIndex, value: T) {
166 self.0[node_id.0] = Some(value)
167 }
168
169 pub fn to_subset(&self) -> NodeSubset {
170 NodeSubset(self.0.iter().map(Option::is_some).collect())
171 }
172
173 pub fn item_count(&self) -> usize {
174 self.iter().count()
175 }
176}
177
178impl<T: 'static> IntoIterator for NodeMap<T> {
179 type Item = (NodeIndex, T);
180 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, T)>>;
181 fn into_iter(self) -> Self::IntoIter {
182 Box::new(self.into_iter())
183 }
184}
185
186impl<'a, T> IntoIterator for &'a NodeMap<T> {
187 type Item = (NodeIndex, &'a T);
188 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a T)> + 'a>;
189 fn into_iter(self) -> Self::IntoIter {
190 Box::new(self.iter())
191 }
192}
193
194impl<'a, T> IntoIterator for &'a mut NodeMap<T> {
195 type Item = (NodeIndex, &'a mut T);
196 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a mut T)> + 'a>;
197 fn into_iter(self) -> Self::IntoIter {
198 Box::new(self.iter_mut())
199 }
200}
201
202impl<T: fmt::Display> fmt::Display for NodeMap<T> {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 write!(f, "[")?;
205 let mut it = self.iter().peekable();
206 while let Some((id, item)) = it.next() {
207 write!(f, "({}, {})", id.0, item)?;
208 if it.peek().is_some() {
209 write!(f, ", ")?;
210 }
211 }
212 write!(f, "]")?;
213 Ok(())
214 }
215}
216
217#[derive(Clone, Eq, PartialEq, Hash, Debug, Default)]
218pub struct NodeSubset(bit_vec::BitVec<u32>);
219
220impl NodeSubset {
221 pub fn with_size(capacity: NodeCount) -> Self {
222 NodeSubset(bit_vec::BitVec::from_elem(capacity.0, false))
223 }
224
225 pub fn insert(&mut self, i: NodeIndex) {
226 self.0.set(i.0, true);
227 }
228
229 pub fn size(&self) -> usize {
230 self.0.len()
231 }
232
233 pub fn elements(&self) -> impl Iterator<Item = NodeIndex> + '_ {
234 self.0
235 .iter()
236 .enumerate()
237 .filter_map(|(i, b)| if b { Some(i.into()) } else { None })
238 }
239
240 pub fn len(&self) -> usize {
241 self.elements().count()
242 }
243
244 pub fn is_empty(&self) -> bool {
245 self.len() == 0
246 }
247}
248
249impl Encode for NodeSubset {
250 fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
251 (self.0.len() as u32).encode_to(dest);
252 self.0.to_bytes().encode_to(dest);
253 }
254}
255
256impl Decode for NodeSubset {
257 fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
258 let capacity = u32::decode(input)? as usize;
259 let bytes = Vec::decode(input)?;
260 let mut bv = bit_vec::BitVec::from_bytes(&bytes);
261 if bv.len() != 8 * ((capacity + 7) / 8) {
263 return Err(Error::from(
264 "Length of bitvector inconsistent with encoded capacity.",
265 ));
266 }
267 while bv.len() > capacity {
268 if bv.pop() != Some(false) {
269 return Err(Error::from(
270 "Non-canonical encoding. Trailing bits should be all 0.",
271 ));
272 }
273 }
274 bv.truncate(capacity);
275 Ok(NodeSubset(bv))
276 }
277}
278
279impl StdIndex<NodeIndex> for NodeSubset {
280 type Output = bool;
281
282 fn index(&self, vidx: NodeIndex) -> &bool {
283 &self.0[vidx.0]
284 }
285}
286
287impl fmt::Display for NodeSubset {
288 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
289 let mut v: Vec<usize> = self.elements().map(|n| n.into()).collect();
290 v.sort();
291 write!(f, "{:?}", v)
292 }
293}
294
295#[cfg(test)]
296mod tests {
297
298 use crate::node::{NodeIndex, NodeSubset};
299 use codec::{Decode, Encode};
300 #[test]
301 fn decoding_node_index_works() {
302 for i in 0..1000 {
303 let node_index = NodeIndex(i);
304 let mut encoded: &[u8] = &node_index.encode();
305 let decoded = NodeIndex::decode(&mut encoded);
306 assert_eq!(node_index, decoded.unwrap());
307 }
308 }
309
310 #[test]
311 fn bool_node_map_decoding_works() {
312 for len in 0..12 {
313 for mask in 0..(1 << len) {
314 let mut bnm = NodeSubset::with_size(len.into());
315 for i in 0..len {
316 if (1 << i) & mask != 0 {
317 bnm.insert(i.into());
318 }
319 }
320 let encoded: Vec<_> = bnm.encode();
321 let decoded =
322 NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
323 assert!(decoded == bnm);
324 }
325 }
326 }
327
328 #[test]
329 fn bool_node_map_decoding_deals_with_trailing_zeros() {
330 let mut encoded = vec![1, 0, 0, 0];
331 encoded.extend(vec![128u8].encode());
332 let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
334 assert_eq!(decoded, NodeSubset([true].iter().cloned().collect()));
335
336 let mut encoded = vec![1, 0, 0, 0];
337 encoded.extend(vec![129u8].encode());
338 assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
340 }
341
342 #[test]
343 fn bool_node_map_decoding_deals_with_too_long_bitvec() {
344 let mut encoded = vec![1, 0, 0, 0];
345 encoded.extend(vec![128u8, 0].encode());
346 assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
348 }
349
350 #[test]
351 fn decoding_bool_node_map_works() {
352 let bool_node_map = NodeSubset([true, false, true, true, true].iter().cloned().collect());
353 let encoded: Vec<_> = bool_node_map.encode();
354 let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
355 assert_eq!(decoded, bool_node_map);
356 }
357
358 #[test]
359 fn test_bool_node_map_has_efficient_encoding() {
360 let mut bnm = NodeSubset::with_size(100.into());
361 for i in 0..50 {
362 bnm.insert(i.into())
363 }
364 assert!(bnm.encode().len() < 20);
365 }
366}