everscale_network/adnl/
peers_set.rs

1use std::borrow::Borrow;
2use std::num::NonZeroU32;
3use std::rc::Rc;
4
5use parking_lot::{RwLock, RwLockReadGuard};
6use rand::seq::SliceRandom;
7
8use super::node_id::NodeIdShort;
9use crate::util::{fast_thread_rng, FastDashSet, FastHashMap};
10
11/// A set of unique short node ids
12pub struct PeersSet {
13    state: RwLock<PeersSetState>,
14}
15
16impl PeersSet {
17    /// Constructs new peers set with the specified fixed capacity
18    pub fn with_capacity(capacity: u32) -> Self {
19        Self {
20            state: RwLock::new(PeersSetState::with_capacity(make_capacity(capacity))),
21        }
22    }
23
24    /// Constructs new peers set with some initial peers
25    ///
26    /// NOTE: Only first `capacity` peers will be added
27    pub fn with_peers_and_capacity(peers: &[NodeIdShort], capacity: u32) -> Self {
28        Self {
29            state: RwLock::new(PeersSetState::with_peers_and_capacity(
30                peers,
31                make_capacity(capacity),
32            )),
33        }
34    }
35
36    pub fn version(&self) -> u64 {
37        self.state.read().version
38    }
39
40    pub fn contains(&self, peer: &NodeIdShort) -> bool {
41        self.state.read().cache.contains_key(Wrapper::wrap(peer))
42    }
43
44    pub fn get(&self, index: usize) -> Option<NodeIdShort> {
45        let state = self.state.read();
46
47        let item = state.index.get(index)?;
48        Some(*item.0.borrow())
49    }
50
51    pub fn len(&self) -> usize {
52        self.state.read().index.len()
53    }
54
55    pub fn is_empty(&self) -> bool {
56        self.state.read().index.is_empty()
57    }
58
59    pub fn is_full(&self) -> bool {
60        self.state.read().is_full()
61    }
62
63    pub fn iter(&self) -> Iter {
64        Iter::new(self.state.read())
65    }
66
67    pub fn get_random_peers(&self, amount: u32, except: Option<&NodeIdShort>) -> Vec<NodeIdShort> {
68        let state = self.state.read();
69
70        let items = state.index.choose_multiple(
71            &mut fast_thread_rng(),
72            if except.is_some() { amount + 1 } else { amount } as usize,
73        );
74
75        match except {
76            Some(except) => items
77                .filter(|item| &*item.0 != except)
78                .take(amount as usize)
79                .map(RefId::copy_inner)
80                .collect(),
81            None => items.map(RefId::copy_inner).collect(),
82        }
83    }
84
85    pub fn randomly_fill_from(
86        &self,
87        other: &PeersSet,
88        amount: u32,
89        except: Option<&FastDashSet<NodeIdShort>>,
90    ) {
91        // NOTE: early return, otherwise it will deadlock if `other` is the same as self
92        if std::ptr::eq(self, other) {
93            return;
94        }
95
96        let selected_amount = match except {
97            Some(peers) => amount as usize + peers.len(),
98            None => amount as usize,
99        };
100
101        let other_state = other.state.read();
102        let new_peers = other_state
103            .index
104            .choose_multiple(&mut rand::thread_rng(), selected_amount);
105
106        let mut state = self.state.write();
107
108        let insert = |peer_id: &RefId| {
109            state.insert(peer_id.copy_inner());
110        };
111
112        match except {
113            Some(except) => {
114                new_peers
115                    .filter(|peer_id| !except.contains(&*peer_id.0))
116                    .take(amount as usize)
117                    .for_each(insert);
118            }
119            None => new_peers.for_each(insert),
120        }
121    }
122
123    /// Adds a value to the set.
124    ///
125    /// If the set did not have this value present, `true` is returned.
126    pub fn insert(&self, peer_id: NodeIdShort) -> bool {
127        self.state.write().insert(peer_id)
128    }
129
130    pub fn extend<I>(&self, peers: I)
131    where
132        I: IntoIterator<Item = NodeIdShort>,
133    {
134        let mut state = self.state.write();
135        for peer_id in peers.into_iter() {
136            state.insert(peer_id);
137        }
138    }
139
140    /// Clones internal node ids storage
141    pub fn clone_inner(&self) -> Vec<NodeIdShort> {
142        let state = self.state.read();
143        state.index.iter().map(Ref::copy_inner).collect()
144    }
145}
146
147impl IntoIterator for PeersSet {
148    type Item = NodeIdShort;
149    type IntoIter = IntoIter;
150
151    fn into_iter(self) -> Self::IntoIter {
152        IntoIter {
153            inner: self.state.into_inner().index.into_iter(),
154        }
155    }
156}
157
158pub struct IntoIter {
159    inner: std::vec::IntoIter<Ref<NodeIdShort>>,
160}
161
162impl Iterator for IntoIter {
163    type Item = NodeIdShort;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        loop {
167            let next = self.inner.next()?;
168            if let Ok(id) = Rc::try_unwrap(next.0) {
169                break Some(id);
170            }
171        }
172    }
173
174    fn size_hint(&self) -> (usize, Option<usize>) {
175        self.inner.size_hint()
176    }
177}
178
179pub struct Iter<'a> {
180    _state: RwLockReadGuard<'a, PeersSetState>,
181    iter: std::slice::Iter<'a, Ref<NodeIdShort>>,
182}
183
184impl<'a> Iter<'a> {
185    fn new(state: RwLockReadGuard<'a, PeersSetState>) -> Self {
186        // SAFETY: index array lifetime is bounded to the lifetime of the `RwLockReadGuard`
187        let iter = unsafe {
188            std::slice::from_raw_parts::<'a>(state.index.as_ptr(), state.index.len()).iter()
189        };
190        Self {
191            _state: state,
192            iter,
193        }
194    }
195}
196
197impl<'a> Iterator for Iter<'a> {
198    type Item = &'a NodeIdShort;
199
200    fn next(&mut self) -> Option<Self::Item> {
201        let item = self.iter.next()?;
202        Some(item.0.as_ref())
203    }
204
205    fn size_hint(&self) -> (usize, Option<usize>) {
206        self.iter.size_hint()
207    }
208}
209
210impl<'a> IntoIterator for &'a PeersSet {
211    type Item = &'a NodeIdShort;
212    type IntoIter = Iter<'a>;
213
214    fn into_iter(self) -> Self::IntoIter {
215        self.iter()
216    }
217}
218
219struct PeersSetState {
220    version: u64,
221    cache: FastHashMap<RefId, u32>,
222    index: Vec<RefId>,
223    capacity: NonZeroU32,
224    upper: u32,
225}
226
227impl PeersSetState {
228    fn with_capacity(capacity: NonZeroU32) -> Self {
229        Self {
230            version: 0,
231            cache: FastHashMap::with_capacity_and_hasher(
232                capacity.get() as usize,
233                Default::default(),
234            ),
235            index: Vec::with_capacity(capacity.get() as usize),
236            capacity,
237            upper: 0,
238        }
239    }
240
241    fn with_peers_and_capacity(peers: &[NodeIdShort], capacity: NonZeroU32) -> Self {
242        use std::collections::hash_map::Entry;
243
244        let mut res = Self::with_capacity(capacity);
245        let capacity = res.capacity.get();
246
247        for peer in peers {
248            if res.upper >= capacity {
249                break;
250            }
251
252            let peer = Ref(Rc::new(*peer));
253
254            match res.cache.entry(peer.clone()) {
255                Entry::Vacant(entry) => {
256                    entry.insert(res.upper);
257                    res.index.push(peer);
258                    res.upper += 1;
259                }
260                Entry::Occupied(_) => continue,
261            }
262        }
263
264        res.upper %= capacity;
265        res
266    }
267
268    fn is_full(&self) -> bool {
269        self.index.len() >= self.capacity.get() as usize
270    }
271
272    fn insert(&mut self, peer_id: NodeIdShort) -> bool {
273        use std::collections::hash_map::Entry;
274
275        let peer_id = Ref(Rc::new(peer_id));
276
277        // Insert new peer into cache
278        match self.cache.entry(peer_id.clone()) {
279            Entry::Vacant(entry) => {
280                self.version += 1;
281                entry.insert(self.upper);
282            }
283            Entry::Occupied(_) => return false,
284        };
285
286        let upper = (self.upper + 1) % self.capacity;
287        let index = std::mem::replace(&mut self.upper, upper) as usize;
288
289        match self.index.get_mut(index) {
290            Some(slot) => {
291                let old_peer = std::mem::replace(slot, peer_id);
292
293                // Remove old peer
294                if let Entry::Occupied(entry) = self.cache.entry(old_peer) {
295                    if entry.get() == &(index as u32) {
296                        entry.remove();
297                    }
298                }
299            }
300            None => self.index.push(peer_id),
301        }
302
303        true
304    }
305}
306
307// SAFETY: internal Rcs are not exposed by the api and the reference
308// counts only change in methods with `&mut self`
309unsafe impl Send for PeersSetState {}
310unsafe impl Sync for PeersSetState {}
311
312type RefId = Ref<NodeIdShort>;
313
314#[derive(Hash, Eq, PartialEq)]
315struct Ref<T>(Rc<T>);
316
317impl<T: Copy> Ref<T> {
318    #[inline]
319    fn copy_inner(&self) -> T {
320        *self.0
321    }
322}
323
324impl<T> Clone for Ref<T> {
325    fn clone(&self) -> Self {
326        Self(self.0.clone())
327    }
328}
329
330#[derive(Hash, Eq, PartialEq)]
331#[repr(transparent)]
332struct Wrapper<T: ?Sized>(T);
333
334impl<T: ?Sized> Wrapper<T> {
335    #[inline(always)]
336    fn wrap(value: &T) -> &Self {
337        // SAFETY: Wrapper<T> is #[repr(transparent)]
338        unsafe { &*(value as *const T as *const Self) }
339    }
340}
341
342impl<K, Q> Borrow<Wrapper<Q>> for Ref<K>
343where
344    K: Borrow<Q>,
345    Q: ?Sized,
346{
347    fn borrow(&self) -> &Wrapper<Q> {
348        let k: &K = self.0.borrow();
349        let q: &Q = k.borrow();
350        Wrapper::wrap(q)
351    }
352}
353
354fn make_capacity(capacity: u32) -> NonZeroU32 {
355    let capacity = std::cmp::max(1, capacity);
356    // SAFETY: capacity is guaranteed to be at least 1
357    unsafe { NonZeroU32::new_unchecked(capacity) }
358}
359
360#[cfg(test)]
361mod tests {
362    use std::collections::HashSet;
363
364    use super::*;
365
366    #[test]
367    fn test_insertion() {
368        let cache = PeersSet::with_capacity(10);
369
370        let peer_id = NodeIdShort::random();
371        assert!(cache.insert(peer_id));
372        assert!(!cache.insert(peer_id));
373        assert!(!cache.is_full());
374    }
375
376    #[test]
377    fn test_entries_replacing() {
378        let cache = PeersSet::with_capacity(3);
379
380        let peers = std::iter::repeat_with(NodeIdShort::random)
381            .take(4)
382            .collect::<Vec<_>>();
383
384        for peer_id in peers.iter().take(3) {
385            assert!(!cache.is_full());
386            assert!(cache.insert(*peer_id));
387        }
388
389        assert!(cache.is_full());
390        assert!(cache.contains(&peers[0]));
391
392        cache.insert(peers[3]);
393
394        assert!(cache.contains(&peers[3]));
395        assert!(!cache.contains(&peers[0]));
396    }
397
398    #[test]
399    fn test_full_entries_replacing() {
400        let cache = PeersSet::with_capacity(3);
401
402        let peers = std::iter::repeat_with(NodeIdShort::random)
403            .take(3)
404            .collect::<Vec<_>>();
405
406        for peer_id in peers.iter() {
407            assert!(!cache.is_full());
408            assert!(cache.insert(*peer_id));
409        }
410
411        for peer_id in peers.iter() {
412            assert!(cache.contains(peer_id));
413        }
414
415        std::iter::repeat_with(NodeIdShort::random)
416            .take(6)
417            .for_each(|peer_id| {
418                assert!(cache.is_full());
419                cache.insert(peer_id);
420            });
421
422        for peer_id in peers.iter() {
423            assert!(!cache.contains(peer_id));
424        }
425    }
426
427    #[test]
428    fn test_iterator() {
429        let cache = PeersSet::with_capacity(10);
430
431        let peers = std::iter::repeat_with(NodeIdShort::random)
432            .take(3)
433            .collect::<Vec<_>>();
434
435        for peer_id in peers.iter() {
436            assert!(cache.insert(*peer_id));
437        }
438
439        assert_eq!(peers.len(), cache.iter().count());
440        for (cache_peer_id, peer_id) in cache.iter().zip(peers.iter()) {
441            assert_eq!(cache_peer_id, peer_id);
442        }
443    }
444
445    #[test]
446    fn test_overlapping_insertion() {
447        let cache = PeersSet::with_capacity(10);
448
449        for i in 1..1000 {
450            assert!(cache.insert(NodeIdShort::random()));
451            assert_eq!(cache.len(), std::cmp::min(i, 10));
452        }
453    }
454
455    #[test]
456    fn test_random_peers() {
457        let cache = PeersSet::with_capacity(10);
458        std::iter::repeat_with(NodeIdShort::random)
459            .take(10)
460            .for_each(|peer_id| {
461                cache.insert(peer_id);
462            });
463
464        let peers = cache.get_random_peers(5, None);
465        assert_eq!(peers.len(), 5);
466        assert_eq!(peers.into_iter().collect::<HashSet<_>>().len(), 5);
467
468        for i in 0..cache.len() {
469            let except = cache.get(i).unwrap();
470
471            let peers = cache.get_random_peers(5, Some(&except));
472            assert_eq!(peers.len(), 5);
473
474            let unique_peers = peers.into_iter().collect::<HashSet<_>>();
475            assert!(!unique_peers.contains(&except));
476            assert_eq!(unique_peers.len(), 5);
477        }
478    }
479
480    #[test]
481    fn with_peers_same_size_as_capacity() {
482        let peers = std::iter::repeat_with(NodeIdShort::random)
483            .take(10)
484            .collect::<Vec<_>>();
485        let cache = PeersSet::with_peers_and_capacity(&peers, peers.len() as u32);
486
487        {
488            let state = cache.state.write();
489            assert_eq!(state.version, 0);
490            assert_eq!(state.cache.len(), peers.len());
491            assert_eq!(state.index.len(), peers.len());
492            assert_eq!(state.upper, 0);
493            assert!(state.is_full());
494        }
495    }
496
497    #[test]
498    fn with_peers_less_than_capacity() {
499        let peers = std::iter::repeat_with(NodeIdShort::random)
500            .take(5)
501            .collect::<Vec<_>>();
502        let cache = PeersSet::with_peers_and_capacity(&peers, 10);
503
504        {
505            let state = cache.state.write();
506            assert_eq!(state.cache.len(), peers.len());
507            assert_eq!(state.index.len(), peers.len());
508            assert_eq!(state.upper, peers.len() as u32);
509            assert!(!state.is_full());
510        }
511    }
512
513    #[test]
514    fn with_peers_greater_than_capacity() {
515        let peers = std::iter::repeat_with(NodeIdShort::random)
516            .take(16)
517            .collect::<Vec<_>>();
518        let cache = PeersSet::with_peers_and_capacity(&peers, 10);
519
520        {
521            let state = cache.state.write();
522            assert_eq!(state.cache.len(), 10);
523            assert_eq!(state.index.len(), 10);
524            assert_eq!(state.upper, 0);
525            assert!(state.is_full());
526        }
527    }
528}