1use codec::{Decode, Encode, Error, Input, Output};
2use derive_more::{Add, AddAssign, From, Into, Sub, SubAssign, Sum};
3use std::{
4 collections::HashMap,
5 fmt,
6 hash::Hash,
7 ops::{Div, Index as StdIndex, Mul},
8 vec,
9};
10
11#[derive(Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, Default, From, Into)]
13pub struct NodeIndex(pub usize);
14
15impl Encode for NodeIndex {
16 fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
17 let val = self.0 as u64;
18 let bytes = val.to_le_bytes();
19 dest.write(&bytes);
20 }
21}
22
23impl Decode for NodeIndex {
24 fn decode<I: Input>(value: &mut I) -> Result<Self, Error> {
25 let mut arr = [0u8; 8];
26 value.read(&mut arr)?;
27 let val: u64 = u64::from_le_bytes(arr);
28 Ok(NodeIndex(val as usize))
29 }
30}
31
32pub trait Index {
34 fn index(&self) -> NodeIndex;
35}
36
37#[derive(
40 Copy,
41 Clone,
42 Eq,
43 PartialEq,
44 Ord,
45 PartialOrd,
46 Hash,
47 Debug,
48 Default,
49 Add,
50 AddAssign,
51 From,
52 Into,
53 Sub,
54 SubAssign,
55 Sum,
56)]
57pub struct NodeCount(pub usize);
58
59impl Mul<usize> for NodeCount {
61 type Output = Self;
62 fn mul(self, rhs: usize) -> Self::Output {
63 NodeCount(self.0 * rhs)
64 }
65}
66
67impl Div<usize> for NodeCount {
68 type Output = Self;
69 fn div(self, rhs: usize) -> Self::Output {
70 NodeCount(self.0 / rhs)
71 }
72}
73
74impl NodeCount {
75 pub fn into_range(self) -> core::ops::Range<NodeIndex> {
76 core::ops::Range {
77 start: 0.into(),
78 end: self.0.into(),
79 }
80 }
81
82 pub fn into_iterator(self) -> impl Iterator<Item = NodeIndex> {
83 (0..self.0).map(NodeIndex)
84 }
85
86 pub fn consensus_threshold(&self) -> NodeCount {
88 (*self * 2) / 3 + NodeCount(1)
89 }
90}
91
92#[derive(Clone, Eq, PartialEq, Hash, Debug, Default, Decode, Encode, From)]
94pub struct NodeMap<T>(Vec<Option<T>>);
95
96impl<T> NodeMap<T> {
97 pub fn with_size(len: NodeCount) -> Self
99 where
100 T: Clone,
101 {
102 let v = vec![None; len.into()];
103 NodeMap(v)
104 }
105
106 pub fn from_hashmap(len: NodeCount, hashmap: HashMap<NodeIndex, T>) -> Self
107 where
108 T: Clone,
109 {
110 let v = vec![None; len.into()];
111 let mut nm = NodeMap(v);
112 for (id, item) in hashmap.into_iter() {
113 nm.insert(id, item);
114 }
115 nm
116 }
117
118 pub fn size(&self) -> NodeCount {
119 self.0.len().into()
120 }
121
122 pub fn iter(&self) -> impl Iterator<Item = (NodeIndex, &T)> {
123 self.0
124 .iter()
125 .enumerate()
126 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_ref()?)))
127 }
128
129 pub fn iter_mut(&mut self) -> impl Iterator<Item = (NodeIndex, &mut T)> {
130 self.0
131 .iter_mut()
132 .enumerate()
133 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value.as_mut()?)))
134 }
135
136 fn into_iter(self) -> impl Iterator<Item = (NodeIndex, T)>
137 where
138 T: 'static,
139 {
140 self.0
141 .into_iter()
142 .enumerate()
143 .filter_map(|(idx, maybe_value)| Some((NodeIndex(idx), maybe_value?)))
144 }
145
146 pub fn values(&self) -> impl Iterator<Item = &T> {
147 self.iter().map(|(_, value)| value)
148 }
149
150 pub fn into_values(self) -> impl Iterator<Item = T>
151 where
152 T: 'static,
153 {
154 self.into_iter().map(|(_, value)| value)
155 }
156
157 pub fn get(&self, node_id: NodeIndex) -> Option<&T> {
158 self.0[node_id.0].as_ref()
159 }
160
161 pub fn get_mut(&mut self, node_id: NodeIndex) -> Option<&mut T> {
162 self.0[node_id.0].as_mut()
163 }
164
165 pub fn insert(&mut self, node_id: NodeIndex, value: T) {
166 self.0[node_id.0] = Some(value)
167 }
168
169 pub fn delete(&mut self, node_id: NodeIndex) {
170 self.0[node_id.0] = None
171 }
172
173 pub fn to_subset(&self) -> NodeSubset {
174 NodeSubset(self.0.iter().map(Option::is_some).collect())
175 }
176
177 pub fn item_count(&self) -> usize {
178 self.iter().count()
179 }
180}
181
182impl<T: 'static> IntoIterator for NodeMap<T> {
183 type Item = (NodeIndex, T);
184 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, T)>>;
185 fn into_iter(self) -> Self::IntoIter {
186 Box::new(self.into_iter())
187 }
188}
189
190impl<'a, T> IntoIterator for &'a NodeMap<T> {
191 type Item = (NodeIndex, &'a T);
192 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a T)> + 'a>;
193 fn into_iter(self) -> Self::IntoIter {
194 Box::new(self.iter())
195 }
196}
197
198impl<'a, T> IntoIterator for &'a mut NodeMap<T> {
199 type Item = (NodeIndex, &'a mut T);
200 type IntoIter = Box<dyn Iterator<Item = (NodeIndex, &'a mut T)> + 'a>;
201 fn into_iter(self) -> Self::IntoIter {
202 Box::new(self.iter_mut())
203 }
204}
205
206impl<T: fmt::Display> fmt::Display for NodeMap<T> {
207 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
208 write!(f, "[")?;
209 let mut it = self.iter().peekable();
210 while let Some((id, item)) = it.next() {
211 write!(f, "({}, {})", id.0, item)?;
212 if it.peek().is_some() {
213 write!(f, ", ")?;
214 }
215 }
216 write!(f, "]")?;
217 Ok(())
218 }
219}
220
221#[derive(Clone, Eq, PartialEq, Hash, Debug, Default)]
222pub struct NodeSubset(bit_vec::BitVec<u32>);
223
224impl NodeSubset {
225 pub fn with_size(capacity: NodeCount) -> Self {
226 NodeSubset(bit_vec::BitVec::from_elem(capacity.0, false))
227 }
228
229 pub fn insert(&mut self, i: NodeIndex) {
230 self.0.set(i.0, true);
231 }
232
233 pub fn size(&self) -> usize {
234 self.0.len()
235 }
236
237 pub fn elements(&self) -> impl Iterator<Item = NodeIndex> + '_ {
238 self.0
239 .iter()
240 .enumerate()
241 .filter_map(|(i, b)| if b { Some(i.into()) } else { None })
242 }
243
244 pub fn complement(&self) -> NodeSubset {
245 let mut result = self.0.clone();
246 result.negate();
247 NodeSubset(result)
248 }
249
250 pub fn len(&self) -> usize {
251 self.elements().count()
252 }
253
254 pub fn is_empty(&self) -> bool {
255 self.len() == 0
256 }
257}
258
259impl Encode for NodeSubset {
260 fn encode_to<T: Output + ?Sized>(&self, dest: &mut T) {
261 (self.0.len() as u32).encode_to(dest);
262 self.0.to_bytes().encode_to(dest);
263 }
264}
265
266impl Decode for NodeSubset {
267 fn decode<I: Input>(input: &mut I) -> Result<Self, Error> {
268 let capacity = u32::decode(input)? as usize;
269 let bytes = Vec::decode(input)?;
270 let mut bv = bit_vec::BitVec::from_bytes(&bytes);
271 if bv.len() != 8 * ((capacity + 7) / 8) {
273 return Err(Error::from(
274 "Length of bitvector inconsistent with encoded capacity.",
275 ));
276 }
277 while bv.len() > capacity {
278 if bv.pop() != Some(false) {
279 return Err(Error::from(
280 "Non-canonical encoding. Trailing bits should be all 0.",
281 ));
282 }
283 }
284 bv.truncate(capacity);
285 Ok(NodeSubset(bv))
286 }
287}
288
289impl StdIndex<NodeIndex> for NodeSubset {
290 type Output = bool;
291
292 fn index(&self, vidx: NodeIndex) -> &bool {
293 &self.0[vidx.0]
294 }
295}
296
297impl fmt::Display for NodeSubset {
298 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
299 let mut v: Vec<usize> = self.elements().map(|n| n.into()).collect();
300 v.sort();
301 write!(f, "{:?}", v)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307
308 use crate::node::{NodeIndex, NodeSubset};
309 use codec::{Decode, Encode};
310 #[test]
311 fn decoding_node_index_works() {
312 for i in 0..1000 {
313 let node_index = NodeIndex(i);
314 let mut encoded: &[u8] = &node_index.encode();
315 let decoded = NodeIndex::decode(&mut encoded);
316 assert_eq!(node_index, decoded.unwrap());
317 }
318 }
319
320 #[test]
321 fn bool_node_map_decoding_works() {
322 for len in 0..12 {
323 for mask in 0..(1 << len) {
324 let mut bnm = NodeSubset::with_size(len.into());
325 for i in 0..len {
326 if (1 << i) & mask != 0 {
327 bnm.insert(i.into());
328 }
329 }
330 let encoded: Vec<_> = bnm.encode();
331 let decoded =
332 NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
333 assert!(decoded == bnm);
334 }
335 }
336 }
337
338 #[test]
339 fn bool_node_map_decoding_deals_with_trailing_zeros() {
340 let mut encoded = vec![1, 0, 0, 0];
341 encoded.extend(vec![128u8].encode());
342 let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
344 assert_eq!(decoded, NodeSubset([true].iter().cloned().collect()));
345
346 let mut encoded = vec![1, 0, 0, 0];
347 encoded.extend(vec![129u8].encode());
348 assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
350 }
351
352 #[test]
353 fn bool_node_map_decoding_deals_with_too_long_bitvec() {
354 let mut encoded = vec![1, 0, 0, 0];
355 encoded.extend(vec![128u8, 0].encode());
356 assert!(NodeSubset::decode(&mut encoded.as_slice()).is_err());
358 }
359
360 #[test]
361 fn decoding_bool_node_map_works() {
362 let bool_node_map = NodeSubset([true, false, true, true, true].iter().cloned().collect());
363 let encoded: Vec<_> = bool_node_map.encode();
364 let decoded = NodeSubset::decode(&mut encoded.as_slice()).expect("decode should work");
365 assert_eq!(decoded, bool_node_map);
366 }
367
368 #[test]
369 fn test_bool_node_map_has_efficient_encoding() {
370 let mut bnm = NodeSubset::with_size(100.into());
371 for i in 0..50 {
372 bnm.insert(i.into())
373 }
374 assert!(bnm.encode().len() < 20);
375 }
376}