1use std::borrow::Borrow;
2use std::num::NonZeroU32;
3use std::rc::Rc;
4
5use parking_lot::{RwLock, RwLockReadGuard};
6use rand::seq::SliceRandom;
7
8use super::node_id::NodeIdShort;
9use crate::util::{fast_thread_rng, FastDashSet, FastHashMap};
10
11pub struct PeersSet {
13 state: RwLock<PeersSetState>,
14}
15
16impl PeersSet {
17 pub fn with_capacity(capacity: u32) -> Self {
19 Self {
20 state: RwLock::new(PeersSetState::with_capacity(make_capacity(capacity))),
21 }
22 }
23
24 pub fn with_peers_and_capacity(peers: &[NodeIdShort], capacity: u32) -> Self {
28 Self {
29 state: RwLock::new(PeersSetState::with_peers_and_capacity(
30 peers,
31 make_capacity(capacity),
32 )),
33 }
34 }
35
36 pub fn version(&self) -> u64 {
37 self.state.read().version
38 }
39
40 pub fn contains(&self, peer: &NodeIdShort) -> bool {
41 self.state.read().cache.contains_key(Wrapper::wrap(peer))
42 }
43
44 pub fn get(&self, index: usize) -> Option<NodeIdShort> {
45 let state = self.state.read();
46
47 let item = state.index.get(index)?;
48 Some(*item.0.borrow())
49 }
50
51 pub fn len(&self) -> usize {
52 self.state.read().index.len()
53 }
54
55 pub fn is_empty(&self) -> bool {
56 self.state.read().index.is_empty()
57 }
58
59 pub fn is_full(&self) -> bool {
60 self.state.read().is_full()
61 }
62
63 pub fn iter(&self) -> Iter {
64 Iter::new(self.state.read())
65 }
66
67 pub fn get_random_peers(&self, amount: u32, except: Option<&NodeIdShort>) -> Vec<NodeIdShort> {
68 let state = self.state.read();
69
70 let items = state.index.choose_multiple(
71 &mut fast_thread_rng(),
72 if except.is_some() { amount + 1 } else { amount } as usize,
73 );
74
75 match except {
76 Some(except) => items
77 .filter(|item| &*item.0 != except)
78 .take(amount as usize)
79 .map(RefId::copy_inner)
80 .collect(),
81 None => items.map(RefId::copy_inner).collect(),
82 }
83 }
84
85 pub fn randomly_fill_from(
86 &self,
87 other: &PeersSet,
88 amount: u32,
89 except: Option<&FastDashSet<NodeIdShort>>,
90 ) {
91 if std::ptr::eq(self, other) {
93 return;
94 }
95
96 let selected_amount = match except {
97 Some(peers) => amount as usize + peers.len(),
98 None => amount as usize,
99 };
100
101 let other_state = other.state.read();
102 let new_peers = other_state
103 .index
104 .choose_multiple(&mut rand::thread_rng(), selected_amount);
105
106 let mut state = self.state.write();
107
108 let insert = |peer_id: &RefId| {
109 state.insert(peer_id.copy_inner());
110 };
111
112 match except {
113 Some(except) => {
114 new_peers
115 .filter(|peer_id| !except.contains(&*peer_id.0))
116 .take(amount as usize)
117 .for_each(insert);
118 }
119 None => new_peers.for_each(insert),
120 }
121 }
122
123 pub fn insert(&self, peer_id: NodeIdShort) -> bool {
127 self.state.write().insert(peer_id)
128 }
129
130 pub fn extend<I>(&self, peers: I)
131 where
132 I: IntoIterator<Item = NodeIdShort>,
133 {
134 let mut state = self.state.write();
135 for peer_id in peers.into_iter() {
136 state.insert(peer_id);
137 }
138 }
139
140 pub fn clone_inner(&self) -> Vec<NodeIdShort> {
142 let state = self.state.read();
143 state.index.iter().map(Ref::copy_inner).collect()
144 }
145}
146
147impl IntoIterator for PeersSet {
148 type Item = NodeIdShort;
149 type IntoIter = IntoIter;
150
151 fn into_iter(self) -> Self::IntoIter {
152 IntoIter {
153 inner: self.state.into_inner().index.into_iter(),
154 }
155 }
156}
157
158pub struct IntoIter {
159 inner: std::vec::IntoIter<Ref<NodeIdShort>>,
160}
161
162impl Iterator for IntoIter {
163 type Item = NodeIdShort;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 loop {
167 let next = self.inner.next()?;
168 if let Ok(id) = Rc::try_unwrap(next.0) {
169 break Some(id);
170 }
171 }
172 }
173
174 fn size_hint(&self) -> (usize, Option<usize>) {
175 self.inner.size_hint()
176 }
177}
178
179pub struct Iter<'a> {
180 _state: RwLockReadGuard<'a, PeersSetState>,
181 iter: std::slice::Iter<'a, Ref<NodeIdShort>>,
182}
183
184impl<'a> Iter<'a> {
185 fn new(state: RwLockReadGuard<'a, PeersSetState>) -> Self {
186 let iter = unsafe {
188 std::slice::from_raw_parts::<'a>(state.index.as_ptr(), state.index.len()).iter()
189 };
190 Self {
191 _state: state,
192 iter,
193 }
194 }
195}
196
197impl<'a> Iterator for Iter<'a> {
198 type Item = &'a NodeIdShort;
199
200 fn next(&mut self) -> Option<Self::Item> {
201 let item = self.iter.next()?;
202 Some(item.0.as_ref())
203 }
204
205 fn size_hint(&self) -> (usize, Option<usize>) {
206 self.iter.size_hint()
207 }
208}
209
210impl<'a> IntoIterator for &'a PeersSet {
211 type Item = &'a NodeIdShort;
212 type IntoIter = Iter<'a>;
213
214 fn into_iter(self) -> Self::IntoIter {
215 self.iter()
216 }
217}
218
219struct PeersSetState {
220 version: u64,
221 cache: FastHashMap<RefId, u32>,
222 index: Vec<RefId>,
223 capacity: NonZeroU32,
224 upper: u32,
225}
226
227impl PeersSetState {
228 fn with_capacity(capacity: NonZeroU32) -> Self {
229 Self {
230 version: 0,
231 cache: FastHashMap::with_capacity_and_hasher(
232 capacity.get() as usize,
233 Default::default(),
234 ),
235 index: Vec::with_capacity(capacity.get() as usize),
236 capacity,
237 upper: 0,
238 }
239 }
240
241 fn with_peers_and_capacity(peers: &[NodeIdShort], capacity: NonZeroU32) -> Self {
242 use std::collections::hash_map::Entry;
243
244 let mut res = Self::with_capacity(capacity);
245 let capacity = res.capacity.get();
246
247 for peer in peers {
248 if res.upper >= capacity {
249 break;
250 }
251
252 let peer = Ref(Rc::new(*peer));
253
254 match res.cache.entry(peer.clone()) {
255 Entry::Vacant(entry) => {
256 entry.insert(res.upper);
257 res.index.push(peer);
258 res.upper += 1;
259 }
260 Entry::Occupied(_) => continue,
261 }
262 }
263
264 res.upper %= capacity;
265 res
266 }
267
268 fn is_full(&self) -> bool {
269 self.index.len() >= self.capacity.get() as usize
270 }
271
272 fn insert(&mut self, peer_id: NodeIdShort) -> bool {
273 use std::collections::hash_map::Entry;
274
275 let peer_id = Ref(Rc::new(peer_id));
276
277 match self.cache.entry(peer_id.clone()) {
279 Entry::Vacant(entry) => {
280 self.version += 1;
281 entry.insert(self.upper);
282 }
283 Entry::Occupied(_) => return false,
284 };
285
286 let upper = (self.upper + 1) % self.capacity;
287 let index = std::mem::replace(&mut self.upper, upper) as usize;
288
289 match self.index.get_mut(index) {
290 Some(slot) => {
291 let old_peer = std::mem::replace(slot, peer_id);
292
293 if let Entry::Occupied(entry) = self.cache.entry(old_peer) {
295 if entry.get() == &(index as u32) {
296 entry.remove();
297 }
298 }
299 }
300 None => self.index.push(peer_id),
301 }
302
303 true
304 }
305}
306
307unsafe impl Send for PeersSetState {}
310unsafe impl Sync for PeersSetState {}
311
312type RefId = Ref<NodeIdShort>;
313
314#[derive(Hash, Eq, PartialEq)]
315struct Ref<T>(Rc<T>);
316
317impl<T: Copy> Ref<T> {
318 #[inline]
319 fn copy_inner(&self) -> T {
320 *self.0
321 }
322}
323
324impl<T> Clone for Ref<T> {
325 fn clone(&self) -> Self {
326 Self(self.0.clone())
327 }
328}
329
330#[derive(Hash, Eq, PartialEq)]
331#[repr(transparent)]
332struct Wrapper<T: ?Sized>(T);
333
334impl<T: ?Sized> Wrapper<T> {
335 #[inline(always)]
336 fn wrap(value: &T) -> &Self {
337 unsafe { &*(value as *const T as *const Self) }
339 }
340}
341
342impl<K, Q> Borrow<Wrapper<Q>> for Ref<K>
343where
344 K: Borrow<Q>,
345 Q: ?Sized,
346{
347 fn borrow(&self) -> &Wrapper<Q> {
348 let k: &K = self.0.borrow();
349 let q: &Q = k.borrow();
350 Wrapper::wrap(q)
351 }
352}
353
354fn make_capacity(capacity: u32) -> NonZeroU32 {
355 let capacity = std::cmp::max(1, capacity);
356 unsafe { NonZeroU32::new_unchecked(capacity) }
358}
359
360#[cfg(test)]
361mod tests {
362 use std::collections::HashSet;
363
364 use super::*;
365
366 #[test]
367 fn test_insertion() {
368 let cache = PeersSet::with_capacity(10);
369
370 let peer_id = NodeIdShort::random();
371 assert!(cache.insert(peer_id));
372 assert!(!cache.insert(peer_id));
373 assert!(!cache.is_full());
374 }
375
376 #[test]
377 fn test_entries_replacing() {
378 let cache = PeersSet::with_capacity(3);
379
380 let peers = std::iter::repeat_with(NodeIdShort::random)
381 .take(4)
382 .collect::<Vec<_>>();
383
384 for peer_id in peers.iter().take(3) {
385 assert!(!cache.is_full());
386 assert!(cache.insert(*peer_id));
387 }
388
389 assert!(cache.is_full());
390 assert!(cache.contains(&peers[0]));
391
392 cache.insert(peers[3]);
393
394 assert!(cache.contains(&peers[3]));
395 assert!(!cache.contains(&peers[0]));
396 }
397
398 #[test]
399 fn test_full_entries_replacing() {
400 let cache = PeersSet::with_capacity(3);
401
402 let peers = std::iter::repeat_with(NodeIdShort::random)
403 .take(3)
404 .collect::<Vec<_>>();
405
406 for peer_id in peers.iter() {
407 assert!(!cache.is_full());
408 assert!(cache.insert(*peer_id));
409 }
410
411 for peer_id in peers.iter() {
412 assert!(cache.contains(peer_id));
413 }
414
415 std::iter::repeat_with(NodeIdShort::random)
416 .take(6)
417 .for_each(|peer_id| {
418 assert!(cache.is_full());
419 cache.insert(peer_id);
420 });
421
422 for peer_id in peers.iter() {
423 assert!(!cache.contains(peer_id));
424 }
425 }
426
427 #[test]
428 fn test_iterator() {
429 let cache = PeersSet::with_capacity(10);
430
431 let peers = std::iter::repeat_with(NodeIdShort::random)
432 .take(3)
433 .collect::<Vec<_>>();
434
435 for peer_id in peers.iter() {
436 assert!(cache.insert(*peer_id));
437 }
438
439 assert_eq!(peers.len(), cache.iter().count());
440 for (cache_peer_id, peer_id) in cache.iter().zip(peers.iter()) {
441 assert_eq!(cache_peer_id, peer_id);
442 }
443 }
444
445 #[test]
446 fn test_overlapping_insertion() {
447 let cache = PeersSet::with_capacity(10);
448
449 for i in 1..1000 {
450 assert!(cache.insert(NodeIdShort::random()));
451 assert_eq!(cache.len(), std::cmp::min(i, 10));
452 }
453 }
454
455 #[test]
456 fn test_random_peers() {
457 let cache = PeersSet::with_capacity(10);
458 std::iter::repeat_with(NodeIdShort::random)
459 .take(10)
460 .for_each(|peer_id| {
461 cache.insert(peer_id);
462 });
463
464 let peers = cache.get_random_peers(5, None);
465 assert_eq!(peers.len(), 5);
466 assert_eq!(peers.into_iter().collect::<HashSet<_>>().len(), 5);
467
468 for i in 0..cache.len() {
469 let except = cache.get(i).unwrap();
470
471 let peers = cache.get_random_peers(5, Some(&except));
472 assert_eq!(peers.len(), 5);
473
474 let unique_peers = peers.into_iter().collect::<HashSet<_>>();
475 assert!(!unique_peers.contains(&except));
476 assert_eq!(unique_peers.len(), 5);
477 }
478 }
479
480 #[test]
481 fn with_peers_same_size_as_capacity() {
482 let peers = std::iter::repeat_with(NodeIdShort::random)
483 .take(10)
484 .collect::<Vec<_>>();
485 let cache = PeersSet::with_peers_and_capacity(&peers, peers.len() as u32);
486
487 {
488 let state = cache.state.write();
489 assert_eq!(state.version, 0);
490 assert_eq!(state.cache.len(), peers.len());
491 assert_eq!(state.index.len(), peers.len());
492 assert_eq!(state.upper, 0);
493 assert!(state.is_full());
494 }
495 }
496
497 #[test]
498 fn with_peers_less_than_capacity() {
499 let peers = std::iter::repeat_with(NodeIdShort::random)
500 .take(5)
501 .collect::<Vec<_>>();
502 let cache = PeersSet::with_peers_and_capacity(&peers, 10);
503
504 {
505 let state = cache.state.write();
506 assert_eq!(state.cache.len(), peers.len());
507 assert_eq!(state.index.len(), peers.len());
508 assert_eq!(state.upper, peers.len() as u32);
509 assert!(!state.is_full());
510 }
511 }
512
513 #[test]
514 fn with_peers_greater_than_capacity() {
515 let peers = std::iter::repeat_with(NodeIdShort::random)
516 .take(16)
517 .collect::<Vec<_>>();
518 let cache = PeersSet::with_peers_and_capacity(&peers, 10);
519
520 {
521 let state = cache.state.write();
522 assert_eq!(state.cache.len(), 10);
523 assert_eq!(state.index.len(), 10);
524 assert_eq!(state.upper, 0);
525 assert!(state.is_full());
526 }
527 }
528}