Skip to main content

avalanche_types/ids/
bag.rs

1//! The bag abstraction used to group votes in voting.
2use std::{
3    cell::{Cell, RefCell},
4    collections::{HashMap, HashSet},
5    rc::Rc,
6};
7
8use crate::ids::{bits, Id};
9
10/// Represents a bag of multiple Ids for binary voting.
11/// ref. <https://pkg.go.dev/github.com/ava-labs/avalanchego/ids#Bag>
12pub struct Bag {
13    counts: Rc<RefCell<HashMap<Id, u32>>>,
14    size: Cell<u32>,
15
16    mode: Cell<Id>,
17    mode_freq: Cell<u32>,
18
19    threshold: Cell<u32>,
20    met_threshold: Rc<RefCell<HashSet<Id>>>,
21}
22
23impl Default for Bag {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl Clone for Bag {
30    fn clone(&self) -> Self {
31        self.deep_copy()
32    }
33}
34
35impl Bag {
36    pub fn new() -> Self {
37        Self {
38            counts: Rc::new(RefCell::new(HashMap::new())),
39            size: Cell::new(0),
40
41            mode: Cell::new(Id::empty()),
42            mode_freq: Cell::new(0_u32),
43
44            threshold: Cell::new(0_u32),
45            met_threshold: Rc::new(RefCell::new(HashSet::new())),
46        }
47    }
48
49    pub fn deep_copy(&self) -> Self {
50        Self {
51            counts: Rc::new(RefCell::new(self.counts())),
52            size: Cell::new(self.len()),
53
54            mode: Cell::new(self.mode()),
55            mode_freq: Cell::new(self.mode_frequency()),
56
57            threshold: Cell::new(self.threshold()),
58            met_threshold: Rc::new(RefCell::new(self.met_threshold())),
59        }
60    }
61
62    pub fn is_empty(&self) -> bool {
63        self.size.get() == 0
64    }
65
66    pub fn len(&self) -> u32 {
67        self.size.get()
68    }
69
70    pub fn mode(&self) -> Id {
71        self.mode.get()
72    }
73
74    pub fn mode_frequency(&self) -> u32 {
75        self.mode_freq.get()
76    }
77
78    pub fn threshold(&self) -> u32 {
79        self.threshold.get()
80    }
81
82    /// Returns the Ids that have been seen at least threshold times.
83    pub fn met_threshold(&self) -> HashSet<Id> {
84        self.met_threshold.borrow().clone()
85    }
86
87    pub fn list(&self) -> Vec<Id> {
88        let mut ids = Vec::with_capacity(self.counts.borrow().len());
89        ids.extend(self.counts.borrow().keys().copied());
90        ids
91    }
92
93    pub fn counts(&self) -> HashMap<Id, u32> {
94        self.counts.borrow().clone()
95    }
96
97    pub fn set_threshold(&self, threshold: u32) {
98        if self.threshold.get().eq(&threshold) {
99            return;
100        }
101
102        self.threshold.set(threshold);
103        self.met_threshold.borrow_mut().clear();
104
105        for (vote, count) in self.counts.borrow().iter() {
106            if *count >= threshold {
107                self.met_threshold.borrow_mut().insert(*vote);
108            }
109        }
110    }
111
112    pub fn add_count(&self, id: &Id, count: u32) {
113        if count == 0 {
114            return;
115        }
116
117        let mut borrowed_mut_counts = self.counts.borrow_mut();
118        let current_count = borrowed_mut_counts.get(id).unwrap_or(&0);
119        let total_count = *current_count + count;
120
121        borrowed_mut_counts.insert(*id, total_count);
122
123        self.size.set(self.size.get() + count);
124
125        if total_count > self.mode_freq.get() {
126            self.mode.set(*id);
127            self.mode_freq.set(total_count);
128        }
129        if total_count >= self.threshold.get() {
130            self.met_threshold.borrow_mut().insert(*id);
131        }
132    }
133
134    pub fn count(&self, id: &Id) -> u32 {
135        let borrowed_counts = self.counts.borrow();
136        let current_count = borrowed_counts.get(id).unwrap_or(&0);
137        *current_count
138    }
139
140    pub fn equals(&self, other: &Self) -> bool {
141        if self.len() != other.len() {
142            return false;
143        }
144
145        {
146            for (vote, count) in self.counts.borrow().iter() {
147                let cnt = *count;
148
149                let borrowed_other_counts = other.counts.borrow();
150                let found = borrowed_other_counts.get(vote);
151                if found.is_none() {
152                    return false;
153                }
154                let other_count = found.unwrap_or(&0);
155                let other_cnt = *other_count;
156                if cnt != other_cnt {
157                    return false;
158                }
159            }
160            true
161        }
162    }
163
164    /// While retaining the same count values, only selects the IDs
165    /// that have the same bits in the range of [start, end).
166    pub fn filter(&self, start: usize, end: usize, id: &Id) -> Self {
167        let new_bag = Self::new();
168        for (vote, count) in self.counts.borrow().iter() {
169            let count = *count;
170
171            if bits::equal_subset(start, end, id, vote) {
172                new_bag.add_count(vote, count);
173            }
174        }
175        new_bag
176    }
177
178    /// Retaining the same count values, only selects the IDs that
179    /// in the 0th index have a 0 at bit \[index\],
180    /// and all ids in the 1st index have a 1 at bit \[index\].
181    pub fn split(&self, index: usize) -> [Self; 2] {
182        let split_votes = [Self::new(), Self::new()];
183
184        for (vote, count) in self.counts.borrow().iter() {
185            let count = *count;
186
187            let bit = vote.bit(index);
188            split_votes[bit.as_usize()].add_count(vote, count);
189        }
190
191        split_votes
192    }
193}
194
195/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_bag_add --exact --show-output
196/// ref. "TestBagAdd"
197#[test]
198fn test_bag_add() {
199    let id0 = Id::empty();
200    let id1 = Id::from_slice(&[1_u8]);
201
202    let bag = Bag::new();
203
204    assert_eq!(bag.count(&id0), 0);
205    assert_eq!(bag.count(&id1), 0);
206    assert_eq!(bag.len(), 0);
207    assert_eq!(bag.list().len(), 0);
208    assert_eq!(bag.mode(), Id::empty());
209    assert_eq!(bag.mode_frequency(), 0);
210    assert_eq!(bag.threshold(), 0);
211    assert_eq!(bag.met_threshold().len(), 0);
212
213    bag.add_count(&id0, 1);
214    assert_eq!(bag.count(&id0), 1);
215    assert_eq!(bag.count(&id1), 0);
216    assert_eq!(bag.len(), 1);
217    assert_eq!(bag.list().len(), 1);
218    assert_eq!(bag.mode(), id0);
219    assert_eq!(bag.mode_frequency(), 1);
220    assert_eq!(bag.threshold(), 0);
221    assert_eq!(bag.met_threshold().len(), 1);
222
223    bag.add_count(&id0, 1);
224    assert_eq!(bag.count(&id0), 2);
225    assert_eq!(bag.count(&id1), 0);
226    assert_eq!(bag.len(), 2);
227    assert_eq!(bag.list().len(), 1);
228    assert_eq!(bag.mode(), id0);
229    assert_eq!(bag.mode_frequency(), 2);
230    assert_eq!(bag.threshold(), 0);
231    assert_eq!(bag.met_threshold().len(), 1);
232
233    bag.add_count(&id1, 3);
234    assert_eq!(bag.count(&id0), 2);
235    assert_eq!(bag.count(&id1), 3);
236    assert_eq!(bag.len(), 5);
237    assert_eq!(bag.list().len(), 2);
238    assert_eq!(bag.mode(), id1);
239    assert_eq!(bag.mode_frequency(), 3);
240    assert_eq!(bag.threshold(), 0);
241    assert_eq!(bag.met_threshold().len(), 2);
242}
243
244/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_bag_set_threshold --exact --show-output
245/// ref. "TestBagSetThreshold"
246#[test]
247fn test_bag_set_threshold() {
248    let id0 = Id::empty();
249    let id1 = Id::from_slice(&[1_u8]);
250
251    let bag = Bag::new();
252    bag.add_count(&id0, 2);
253    bag.add_count(&id1, 3);
254
255    bag.set_threshold(0);
256    assert_eq!(bag.count(&id0), 2);
257    assert_eq!(bag.count(&id1), 3);
258    assert_eq!(bag.len(), 5);
259    assert_eq!(bag.list().len(), 2);
260    assert_eq!(bag.mode(), id1);
261    assert_eq!(bag.mode_frequency(), 3);
262    assert_eq!(bag.threshold(), 0);
263    assert_eq!(bag.met_threshold().len(), 2);
264
265    bag.set_threshold(3);
266    assert_eq!(bag.count(&id0), 2);
267    assert_eq!(bag.count(&id1), 3);
268    assert_eq!(bag.len(), 5);
269    assert_eq!(bag.list().len(), 2);
270    assert_eq!(bag.mode(), id1);
271    assert_eq!(bag.mode_frequency(), 3);
272    assert_eq!(bag.threshold(), 3);
273    assert_eq!(bag.met_threshold().len(), 1);
274    assert!(bag.met_threshold().contains(&id1));
275}
276
277/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_bag_filter --exact --show-output
278/// ref. "TestBagFilter"
279#[test]
280fn test_bag_filter() {
281    let id0 = Id::empty();
282    let id1 = Id::from_slice(&[1_u8]);
283    let id2 = Id::from_slice(&[2_u8]);
284
285    let bag = Bag::new();
286
287    bag.add_count(&id0, 1);
288    bag.add_count(&id1, 3);
289    bag.add_count(&id2, 5);
290
291    let even = bag.filter(0, 1, &id0);
292    assert_eq!(even.count(&id0), 1);
293    assert_eq!(even.count(&id1), 0);
294    assert_eq!(even.count(&id2), 5);
295}
296
297/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_bag_split --exact --show-output
298/// ref. "TestBagSplit"
299#[test]
300fn test_bag_split() {
301    let id0 = Id::empty();
302    let id1 = Id::from_slice(&[1_u8]);
303    let id2 = Id::from_slice(&[2_u8]);
304
305    let bag = Bag::new();
306
307    bag.add_count(&id0, 1);
308    bag.add_count(&id1, 3);
309    bag.add_count(&id2, 5);
310
311    let bags = bag.split(0);
312    let evens = &bags[0];
313    let odds = &bags[1];
314
315    assert_eq!(evens.count(&id0), 1);
316    assert_eq!(evens.count(&id1), 0);
317    assert_eq!(evens.count(&id2), 5);
318    assert_eq!(odds.count(&id0), 0);
319    assert_eq!(odds.count(&id1), 3);
320    assert_eq!(odds.count(&id2), 0);
321}
322
323const MIN_UNIQUE_BAG_SIZE: usize = 16;
324
325/// Maps from an Id to the BitSet.
326/// ref. <https://pkg.go.dev/github.com/ava-labs/avalanchego/ids#UniqueBag>
327pub struct Unique(Rc<RefCell<HashMap<Id, Rc<RefCell<bits::Set64>>>>>);
328
329impl Unique {
330    pub fn new() -> Self {
331        let b: HashMap<Id, Rc<RefCell<bits::Set64>>> = HashMap::with_capacity(MIN_UNIQUE_BAG_SIZE);
332        Self(Rc::new(RefCell::new(b)))
333    }
334
335    pub fn union_set(&self, id: Id, set: bits::Set64) {
336        if let Some(v) = self.0.borrow().get(&id) {
337            v.borrow_mut().union(set);
338            return;
339        }
340
341        self.0.borrow_mut().insert(id, Rc::new(RefCell::new(set)));
342    }
343
344    pub fn difference_set(&self, id: Id, set: bits::Set64) {
345        if let Some(v) = self.0.borrow().get(&id) {
346            v.borrow_mut().difference(set)
347        }
348    }
349
350    pub fn add(&self, set_id: u64, ids: Vec<Id>) {
351        let mut bs = bits::Set64::new();
352        bs.add(set_id);
353
354        for id in ids.iter() {
355            self.union_set(*id, bs);
356        }
357    }
358
359    pub fn difference(&self, diff: &Unique) {
360        for (id, v) in self.0.borrow().iter() {
361            if let Some(vv) = diff.0.borrow().get(id) {
362                v.borrow_mut().difference(*vv.borrow());
363            }
364        }
365    }
366
367    pub fn get_set(&self, id: &Id) -> bits::Set64 {
368        if let Some(v) = self.0.borrow().get(id) {
369            *v.borrow()
370        } else {
371            bits::Set64::new()
372        }
373    }
374
375    pub fn remove_set(&self, id: &Id) {
376        self.0.borrow_mut().remove(id);
377    }
378
379    pub fn list(&self) -> Vec<Id> {
380        let mut ids: Vec<Id> = Vec::new();
381        for (id, _) in self.0.borrow().iter() {
382            ids.push(*id)
383        }
384        ids
385    }
386
387    pub fn bag(&self, alpha: u32) -> Bag {
388        let bag = Bag::new();
389        bag.set_threshold(alpha);
390
391        for (id, bs) in self.0.borrow().iter() {
392            bag.add_count(id, bs.borrow().len());
393        }
394        bag
395    }
396
397    pub fn clear(&self) {
398        self.0.borrow_mut().clear()
399    }
400}
401
402impl Default for Unique {
403    fn default() -> Self {
404        Self::new()
405    }
406}
407
408/// ref. <https://doc.rust-lang.org/std/string/trait.ToString.html>
409/// ref. <https://doc.rust-lang.org/std/fmt/trait.Display.html>
410/// Use "Self.to_string()" to directly invoke this.
411impl std::fmt::Display for Unique {
412    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413        write!(f, "UniqueBag: (Size = {})", self.0.borrow().len())?;
414        for (id, set) in self.0.borrow().iter() {
415            write!(f, "\n    ID[{}]: Members = {}", id, set.borrow())?;
416        }
417        Ok(())
418    }
419}
420
421/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_unique_bag --exact --show-output
422/// ref. "TestUniqueBag"
423#[test]
424fn test_unique_bag() {
425    let ub1 = Unique::new();
426    assert_eq!(ub1.list().len(), 0);
427
428    let id1 = Id::empty().prefix(&[1_u64]).unwrap();
429    let id2 = Id::empty().prefix(&[2_u64]).unwrap();
430
431    let ub2 = Unique::new();
432    ub2.add(1, vec![id1, id2]);
433
434    assert!(ub2.get_set(&id1).contains(1));
435    assert!(ub2.get_set(&id2).contains(1));
436
437    let mut bs1 = bits::Set64::new();
438    bs1.add(2);
439    bs1.add(4);
440
441    let ub3 = Unique::new();
442    ub3.union_set(id1, bs1);
443
444    bs1.clear();
445    let mut bs1 = ub3.get_set(&id1);
446
447    assert_eq!(bs1.len(), 2);
448    assert!(bs1.contains(2));
449    assert!(bs1.contains(4));
450
451    bs1.clear();
452
453    let ub4 = Unique::new();
454    ub4.add(1, vec![id1]);
455    ub4.add(2, vec![id1]);
456    ub4.add(5, vec![id2]);
457    ub4.add(8, vec![id2]);
458
459    let ub5 = Unique::new();
460    ub5.add(5, vec![id2]);
461    ub5.add(5, vec![id1]);
462
463    ub4.difference(&ub5);
464    assert_eq!(ub5.list().len(), 2);
465
466    let ub4_id1 = ub4.get_set(&id1);
467    assert_eq!(ub4_id1.len(), 2);
468    assert!(ub4_id1.contains(1));
469    assert!(ub4_id1.contains(2));
470
471    let ub4_id2 = ub4.get_set(&id2);
472    assert_eq!(ub4_id2.len(), 1);
473    assert!(ub4_id2.contains(8));
474
475    let ub6 = Unique::new();
476    ub6.add(1, vec![id1]);
477    ub6.add(2, vec![id1]);
478    ub6.add(7, vec![id1]);
479
480    let mut diff_bs = bits::Set64::new();
481    diff_bs.add(1);
482    diff_bs.add(7);
483
484    ub6.difference_set(id1, diff_bs);
485
486    let ub6_id1 = ub6.get_set(&id1);
487    assert_eq!(ub6_id1.len(), 1);
488    assert!(ub6_id1.contains(2));
489}
490
491/// RUST_LOG=debug cargo test --package avalanche-types --lib -- ids::bag::test_unique_bag_clear --exact --show-output
492/// ref. "TestUniqueBagClear"
493#[test]
494fn test_unique_bag_clear() {
495    let id1 = Id::empty().prefix(&[1_u64]).unwrap();
496    let id2 = Id::empty().prefix(&[2_u64]).unwrap();
497
498    let b = Unique::new();
499    b.add(0, vec![id1]);
500    b.add(1, vec![id1, id2]);
501
502    b.clear();
503    assert_eq!(b.list().len(), 0);
504
505    let bs = b.get_set(&id1);
506    assert_eq!(bs.len(), 0);
507
508    let bs = b.get_set(&id2);
509    assert_eq!(bs.len(), 0);
510}