1use std::{
3 cell::{Cell, RefCell},
4 collections::{HashMap, HashSet},
5 rc::Rc,
6};
7
8use crate::ids::{bits, Id};
9
10pub struct Bag {
13 counts: Rc<RefCell<HashMap<Id, u32>>>,
14 size: Cell<u32>,
15
16 mode: Cell<Id>,
17 mode_freq: Cell<u32>,
18
19 threshold: Cell<u32>,
20 met_threshold: Rc<RefCell<HashSet<Id>>>,
21}
22
23impl Default for Bag {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl Clone for Bag {
30 fn clone(&self) -> Self {
31 self.deep_copy()
32 }
33}
34
35impl Bag {
36 pub fn new() -> Self {
37 Self {
38 counts: Rc::new(RefCell::new(HashMap::new())),
39 size: Cell::new(0),
40
41 mode: Cell::new(Id::empty()),
42 mode_freq: Cell::new(0_u32),
43
44 threshold: Cell::new(0_u32),
45 met_threshold: Rc::new(RefCell::new(HashSet::new())),
46 }
47 }
48
49 pub fn deep_copy(&self) -> Self {
50 Self {
51 counts: Rc::new(RefCell::new(self.counts())),
52 size: Cell::new(self.len()),
53
54 mode: Cell::new(self.mode()),
55 mode_freq: Cell::new(self.mode_frequency()),
56
57 threshold: Cell::new(self.threshold()),
58 met_threshold: Rc::new(RefCell::new(self.met_threshold())),
59 }
60 }
61
62 pub fn is_empty(&self) -> bool {
63 self.size.get() == 0
64 }
65
66 pub fn len(&self) -> u32 {
67 self.size.get()
68 }
69
70 pub fn mode(&self) -> Id {
71 self.mode.get()
72 }
73
74 pub fn mode_frequency(&self) -> u32 {
75 self.mode_freq.get()
76 }
77
78 pub fn threshold(&self) -> u32 {
79 self.threshold.get()
80 }
81
82 pub fn met_threshold(&self) -> HashSet<Id> {
84 self.met_threshold.borrow().clone()
85 }
86
87 pub fn list(&self) -> Vec<Id> {
88 let mut ids = Vec::with_capacity(self.counts.borrow().len());
89 ids.extend(self.counts.borrow().keys().copied());
90 ids
91 }
92
93 pub fn counts(&self) -> HashMap<Id, u32> {
94 self.counts.borrow().clone()
95 }
96
97 pub fn set_threshold(&self, threshold: u32) {
98 if self.threshold.get().eq(&threshold) {
99 return;
100 }
101
102 self.threshold.set(threshold);
103 self.met_threshold.borrow_mut().clear();
104
105 for (vote, count) in self.counts.borrow().iter() {
106 if *count >= threshold {
107 self.met_threshold.borrow_mut().insert(*vote);
108 }
109 }
110 }
111
112 pub fn add_count(&self, id: &Id, count: u32) {
113 if count == 0 {
114 return;
115 }
116
117 let mut borrowed_mut_counts = self.counts.borrow_mut();
118 let current_count = borrowed_mut_counts.get(id).unwrap_or(&0);
119 let total_count = *current_count + count;
120
121 borrowed_mut_counts.insert(*id, total_count);
122
123 self.size.set(self.size.get() + count);
124
125 if total_count > self.mode_freq.get() {
126 self.mode.set(*id);
127 self.mode_freq.set(total_count);
128 }
129 if total_count >= self.threshold.get() {
130 self.met_threshold.borrow_mut().insert(*id);
131 }
132 }
133
134 pub fn count(&self, id: &Id) -> u32 {
135 let borrowed_counts = self.counts.borrow();
136 let current_count = borrowed_counts.get(id).unwrap_or(&0);
137 *current_count
138 }
139
140 pub fn equals(&self, other: &Self) -> bool {
141 if self.len() != other.len() {
142 return false;
143 }
144
145 {
146 for (vote, count) in self.counts.borrow().iter() {
147 let cnt = *count;
148
149 let borrowed_other_counts = other.counts.borrow();
150 let found = borrowed_other_counts.get(vote);
151 if found.is_none() {
152 return false;
153 }
154 let other_count = found.unwrap_or(&0);
155 let other_cnt = *other_count;
156 if cnt != other_cnt {
157 return false;
158 }
159 }
160 true
161 }
162 }
163
164 pub fn filter(&self, start: usize, end: usize, id: &Id) -> Self {
167 let new_bag = Self::new();
168 for (vote, count) in self.counts.borrow().iter() {
169 let count = *count;
170
171 if bits::equal_subset(start, end, id, vote) {
172 new_bag.add_count(vote, count);
173 }
174 }
175 new_bag
176 }
177
178 pub fn split(&self, index: usize) -> [Self; 2] {
182 let split_votes = [Self::new(), Self::new()];
183
184 for (vote, count) in self.counts.borrow().iter() {
185 let count = *count;
186
187 let bit = vote.bit(index);
188 split_votes[bit.as_usize()].add_count(vote, count);
189 }
190
191 split_votes
192 }
193}
194
195#[test]
198fn test_bag_add() {
199 let id0 = Id::empty();
200 let id1 = Id::from_slice(&[1_u8]);
201
202 let bag = Bag::new();
203
204 assert_eq!(bag.count(&id0), 0);
205 assert_eq!(bag.count(&id1), 0);
206 assert_eq!(bag.len(), 0);
207 assert_eq!(bag.list().len(), 0);
208 assert_eq!(bag.mode(), Id::empty());
209 assert_eq!(bag.mode_frequency(), 0);
210 assert_eq!(bag.threshold(), 0);
211 assert_eq!(bag.met_threshold().len(), 0);
212
213 bag.add_count(&id0, 1);
214 assert_eq!(bag.count(&id0), 1);
215 assert_eq!(bag.count(&id1), 0);
216 assert_eq!(bag.len(), 1);
217 assert_eq!(bag.list().len(), 1);
218 assert_eq!(bag.mode(), id0);
219 assert_eq!(bag.mode_frequency(), 1);
220 assert_eq!(bag.threshold(), 0);
221 assert_eq!(bag.met_threshold().len(), 1);
222
223 bag.add_count(&id0, 1);
224 assert_eq!(bag.count(&id0), 2);
225 assert_eq!(bag.count(&id1), 0);
226 assert_eq!(bag.len(), 2);
227 assert_eq!(bag.list().len(), 1);
228 assert_eq!(bag.mode(), id0);
229 assert_eq!(bag.mode_frequency(), 2);
230 assert_eq!(bag.threshold(), 0);
231 assert_eq!(bag.met_threshold().len(), 1);
232
233 bag.add_count(&id1, 3);
234 assert_eq!(bag.count(&id0), 2);
235 assert_eq!(bag.count(&id1), 3);
236 assert_eq!(bag.len(), 5);
237 assert_eq!(bag.list().len(), 2);
238 assert_eq!(bag.mode(), id1);
239 assert_eq!(bag.mode_frequency(), 3);
240 assert_eq!(bag.threshold(), 0);
241 assert_eq!(bag.met_threshold().len(), 2);
242}
243
244#[test]
247fn test_bag_set_threshold() {
248 let id0 = Id::empty();
249 let id1 = Id::from_slice(&[1_u8]);
250
251 let bag = Bag::new();
252 bag.add_count(&id0, 2);
253 bag.add_count(&id1, 3);
254
255 bag.set_threshold(0);
256 assert_eq!(bag.count(&id0), 2);
257 assert_eq!(bag.count(&id1), 3);
258 assert_eq!(bag.len(), 5);
259 assert_eq!(bag.list().len(), 2);
260 assert_eq!(bag.mode(), id1);
261 assert_eq!(bag.mode_frequency(), 3);
262 assert_eq!(bag.threshold(), 0);
263 assert_eq!(bag.met_threshold().len(), 2);
264
265 bag.set_threshold(3);
266 assert_eq!(bag.count(&id0), 2);
267 assert_eq!(bag.count(&id1), 3);
268 assert_eq!(bag.len(), 5);
269 assert_eq!(bag.list().len(), 2);
270 assert_eq!(bag.mode(), id1);
271 assert_eq!(bag.mode_frequency(), 3);
272 assert_eq!(bag.threshold(), 3);
273 assert_eq!(bag.met_threshold().len(), 1);
274 assert!(bag.met_threshold().contains(&id1));
275}
276
277#[test]
280fn test_bag_filter() {
281 let id0 = Id::empty();
282 let id1 = Id::from_slice(&[1_u8]);
283 let id2 = Id::from_slice(&[2_u8]);
284
285 let bag = Bag::new();
286
287 bag.add_count(&id0, 1);
288 bag.add_count(&id1, 3);
289 bag.add_count(&id2, 5);
290
291 let even = bag.filter(0, 1, &id0);
292 assert_eq!(even.count(&id0), 1);
293 assert_eq!(even.count(&id1), 0);
294 assert_eq!(even.count(&id2), 5);
295}
296
297#[test]
300fn test_bag_split() {
301 let id0 = Id::empty();
302 let id1 = Id::from_slice(&[1_u8]);
303 let id2 = Id::from_slice(&[2_u8]);
304
305 let bag = Bag::new();
306
307 bag.add_count(&id0, 1);
308 bag.add_count(&id1, 3);
309 bag.add_count(&id2, 5);
310
311 let bags = bag.split(0);
312 let evens = &bags[0];
313 let odds = &bags[1];
314
315 assert_eq!(evens.count(&id0), 1);
316 assert_eq!(evens.count(&id1), 0);
317 assert_eq!(evens.count(&id2), 5);
318 assert_eq!(odds.count(&id0), 0);
319 assert_eq!(odds.count(&id1), 3);
320 assert_eq!(odds.count(&id2), 0);
321}
322
323const MIN_UNIQUE_BAG_SIZE: usize = 16;
324
325pub struct Unique(Rc<RefCell<HashMap<Id, Rc<RefCell<bits::Set64>>>>>);
328
329impl Unique {
330 pub fn new() -> Self {
331 let b: HashMap<Id, Rc<RefCell<bits::Set64>>> = HashMap::with_capacity(MIN_UNIQUE_BAG_SIZE);
332 Self(Rc::new(RefCell::new(b)))
333 }
334
335 pub fn union_set(&self, id: Id, set: bits::Set64) {
336 if let Some(v) = self.0.borrow().get(&id) {
337 v.borrow_mut().union(set);
338 return;
339 }
340
341 self.0.borrow_mut().insert(id, Rc::new(RefCell::new(set)));
342 }
343
344 pub fn difference_set(&self, id: Id, set: bits::Set64) {
345 if let Some(v) = self.0.borrow().get(&id) {
346 v.borrow_mut().difference(set)
347 }
348 }
349
350 pub fn add(&self, set_id: u64, ids: Vec<Id>) {
351 let mut bs = bits::Set64::new();
352 bs.add(set_id);
353
354 for id in ids.iter() {
355 self.union_set(*id, bs);
356 }
357 }
358
359 pub fn difference(&self, diff: &Unique) {
360 for (id, v) in self.0.borrow().iter() {
361 if let Some(vv) = diff.0.borrow().get(id) {
362 v.borrow_mut().difference(*vv.borrow());
363 }
364 }
365 }
366
367 pub fn get_set(&self, id: &Id) -> bits::Set64 {
368 if let Some(v) = self.0.borrow().get(id) {
369 *v.borrow()
370 } else {
371 bits::Set64::new()
372 }
373 }
374
375 pub fn remove_set(&self, id: &Id) {
376 self.0.borrow_mut().remove(id);
377 }
378
379 pub fn list(&self) -> Vec<Id> {
380 let mut ids: Vec<Id> = Vec::new();
381 for (id, _) in self.0.borrow().iter() {
382 ids.push(*id)
383 }
384 ids
385 }
386
387 pub fn bag(&self, alpha: u32) -> Bag {
388 let bag = Bag::new();
389 bag.set_threshold(alpha);
390
391 for (id, bs) in self.0.borrow().iter() {
392 bag.add_count(id, bs.borrow().len());
393 }
394 bag
395 }
396
397 pub fn clear(&self) {
398 self.0.borrow_mut().clear()
399 }
400}
401
402impl Default for Unique {
403 fn default() -> Self {
404 Self::new()
405 }
406}
407
408impl std::fmt::Display for Unique {
412 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413 write!(f, "UniqueBag: (Size = {})", self.0.borrow().len())?;
414 for (id, set) in self.0.borrow().iter() {
415 write!(f, "\n ID[{}]: Members = {}", id, set.borrow())?;
416 }
417 Ok(())
418 }
419}
420
421#[test]
424fn test_unique_bag() {
425 let ub1 = Unique::new();
426 assert_eq!(ub1.list().len(), 0);
427
428 let id1 = Id::empty().prefix(&[1_u64]).unwrap();
429 let id2 = Id::empty().prefix(&[2_u64]).unwrap();
430
431 let ub2 = Unique::new();
432 ub2.add(1, vec![id1, id2]);
433
434 assert!(ub2.get_set(&id1).contains(1));
435 assert!(ub2.get_set(&id2).contains(1));
436
437 let mut bs1 = bits::Set64::new();
438 bs1.add(2);
439 bs1.add(4);
440
441 let ub3 = Unique::new();
442 ub3.union_set(id1, bs1);
443
444 bs1.clear();
445 let mut bs1 = ub3.get_set(&id1);
446
447 assert_eq!(bs1.len(), 2);
448 assert!(bs1.contains(2));
449 assert!(bs1.contains(4));
450
451 bs1.clear();
452
453 let ub4 = Unique::new();
454 ub4.add(1, vec![id1]);
455 ub4.add(2, vec![id1]);
456 ub4.add(5, vec![id2]);
457 ub4.add(8, vec![id2]);
458
459 let ub5 = Unique::new();
460 ub5.add(5, vec![id2]);
461 ub5.add(5, vec![id1]);
462
463 ub4.difference(&ub5);
464 assert_eq!(ub5.list().len(), 2);
465
466 let ub4_id1 = ub4.get_set(&id1);
467 assert_eq!(ub4_id1.len(), 2);
468 assert!(ub4_id1.contains(1));
469 assert!(ub4_id1.contains(2));
470
471 let ub4_id2 = ub4.get_set(&id2);
472 assert_eq!(ub4_id2.len(), 1);
473 assert!(ub4_id2.contains(8));
474
475 let ub6 = Unique::new();
476 ub6.add(1, vec![id1]);
477 ub6.add(2, vec![id1]);
478 ub6.add(7, vec![id1]);
479
480 let mut diff_bs = bits::Set64::new();
481 diff_bs.add(1);
482 diff_bs.add(7);
483
484 ub6.difference_set(id1, diff_bs);
485
486 let ub6_id1 = ub6.get_set(&id1);
487 assert_eq!(ub6_id1.len(), 1);
488 assert!(ub6_id1.contains(2));
489}
490
491#[test]
494fn test_unique_bag_clear() {
495 let id1 = Id::empty().prefix(&[1_u64]).unwrap();
496 let id2 = Id::empty().prefix(&[2_u64]).unwrap();
497
498 let b = Unique::new();
499 b.add(0, vec![id1]);
500 b.add(1, vec![id1, id2]);
501
502 b.clear();
503 assert_eq!(b.list().len(), 0);
504
505 let bs = b.get_set(&id1);
506 assert_eq!(bs.len(), 0);
507
508 let bs = b.get_set(&id2);
509 assert_eq!(bs.len(), 0);
510}