concurrent_arena/
bucket.rs

1use super::{bitmap::BitMap, Arc, OptionExt, SliceExt};
2
3use core::{array, cell::UnsafeCell, hint::spin_loop, ops::Deref};
4use std::sync::atomic::{fence, AtomicU8, Ordering};
5
6const REMOVED_MASK: u8 = 1 << (u8::BITS - 1);
7const REFCNT_MASK: u8 = !REMOVED_MASK;
8pub const MAX_REFCNT: u8 = REFCNT_MASK;
9
10#[derive(Debug)]
11struct Entry<T> {
12    counter: AtomicU8,
13    val: UnsafeCell<Option<T>>,
14}
15
16impl<T> Entry<T> {
17    const fn new() -> Self {
18        Self {
19            counter: AtomicU8::new(0),
20            val: UnsafeCell::new(None),
21        }
22    }
23}
24
25impl<T> Drop for Entry<T> {
26    fn drop(&mut self) {
27        // Use `Acquire` here to make sure option is set to None before
28        // the entry is dropped.
29        let cnt = self.counter.load(Ordering::Acquire);
30
31        // It must be either deleted, or is still alive
32        // but no `ArenaArc` reference exist.
33        debug_assert!(cnt <= 1);
34
35        let val = self.val.get_mut().take();
36
37        if cnt == 0 {
38            debug_assert!(val.is_none());
39        } else {
40            debug_assert!(val.is_some());
41        }
42    }
43}
44
45#[derive(Debug)]
46pub(crate) struct Bucket<T, const BITARRAY_LEN: usize, const LEN: usize> {
47    bitset: BitMap<BITARRAY_LEN>,
48    entries: [Entry<T>; LEN],
49}
50
51unsafe impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Sync
52    for Bucket<T, BITARRAY_LEN, LEN>
53{
54}
55
56unsafe impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Send
57    for Bucket<T, BITARRAY_LEN, LEN>
58{
59}
60
61impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Default
62    for Bucket<T, BITARRAY_LEN, LEN>
63{
64    fn default() -> Self {
65        Self::new()
66    }
67}
68
69impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Bucket<T, BITARRAY_LEN, LEN> {
70    pub(crate) fn new() -> Self {
71        Self {
72            bitset: BitMap::new(),
73            entries: array::from_fn(|_| Entry::new()),
74        }
75    }
76
77    pub(crate) fn try_insert(
78        this: &Arc<Self>,
79        bucket_index: u32,
80        value: T,
81    ) -> Result<ArenaArc<T, BITARRAY_LEN, LEN>, T> {
82        let index = match this.bitset.allocate() {
83            Some(index) => index,
84            None => return Err(value),
85        };
86
87        // Safety: index <= LEN
88        let entry = unsafe { this.entries.get_unchecked_on_release(index) };
89
90        // Use `Acquire` here to make sure option is set to None before
91        // the entry is reused again.
92        let prev_refcnt = entry.counter.load(Ordering::Acquire);
93        debug_assert_eq!(prev_refcnt, 0);
94
95        let ptr = entry.val.get();
96        // Safety: ptr can only accessed by this thread
97        let res = unsafe { ptr.replace(Some(value)) };
98        debug_assert!(res.is_none());
99
100        // 1 for the ArenaArc, another is for the Bucket itself.
101        //
102        // Set counter after option is set to `Some(...)` to avoid
103        // race condition with `remove`.
104        if cfg!(debug_assertions) {
105            let prev_refcnt = entry.counter.swap(2, Ordering::Relaxed);
106            assert_eq!(prev_refcnt, 0);
107        } else {
108            entry.counter.store(2, Ordering::Relaxed);
109        }
110
111        let index = index as u32;
112
113        Ok(ArenaArc {
114            slot: bucket_index * (LEN as u32) + index,
115            index,
116            bucket: Arc::clone(this),
117        })
118    }
119
120    /// # Safety
121    ///
122    /// `index` <= `LEN`
123    unsafe fn access_impl(
124        this: Arc<Self>,
125        bucket_index: u32,
126        index: u32,
127        update_refcnt: fn(u8) -> u8,
128    ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
129        if this.bitset.load(index) {
130            let counter = &this
131                .entries
132                .get_unchecked_on_release(index as usize)
133                .counter;
134            let mut refcnt = counter.load(Ordering::Relaxed);
135
136            loop {
137                if (refcnt & REMOVED_MASK) != 0 {
138                    return None;
139                }
140
141                if refcnt == 0 {
142                    // The variable is not yet fully initialized.
143                    // Reload the refcnt and check again.
144                    spin_loop();
145                    refcnt = counter.load(Ordering::Relaxed);
146                    continue;
147                }
148
149                match counter.compare_exchange_weak(
150                    refcnt,
151                    update_refcnt(refcnt),
152                    Ordering::Relaxed,
153                    Ordering::Relaxed,
154                ) {
155                    Ok(_) => break,
156                    Err(new_refcnt) => refcnt = new_refcnt,
157                }
158            }
159
160            Some(ArenaArc {
161                slot: bucket_index * (LEN as u32) + index,
162                index,
163                bucket: this,
164            })
165        } else {
166            None
167        }
168    }
169
170    /// # Safety
171    ///
172    /// `index` <= `LEN`
173    pub(crate) unsafe fn get(
174        this: Arc<Self>,
175        bucket_index: u32,
176        index: u32,
177    ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
178        Self::access_impl(this, bucket_index, index, |refcnt| refcnt + 1)
179    }
180
181    /// # Safety
182    ///
183    /// `index` <= `LEN`
184    pub(crate) unsafe fn remove(
185        this: Arc<Self>,
186        bucket_index: u32,
187        index: u32,
188    ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
189        Self::access_impl(this, bucket_index, index, |refcnt| refcnt | REMOVED_MASK)
190    }
191}
192
193/// Can have at most `MAX_REFCNT` refcount.
194#[derive(Debug)]
195pub struct ArenaArc<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> {
196    slot: u32,
197    index: u32,
198    bucket: Arc<Bucket<T, BITARRAY_LEN, LEN>>,
199}
200
201impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Unpin
202    for ArenaArc<T, BITARRAY_LEN, LEN>
203{
204}
205
206impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> ArenaArc<T, BITARRAY_LEN, LEN> {
207    pub fn slot(this: &Self) -> u32 {
208        this.slot
209    }
210
211    fn get_index(this: &Self) -> usize {
212        this.index as usize
213    }
214
215    fn get_entry(this: &Self) -> &Entry<T> {
216        // Safety: `Self::get_index(this)` <= `LEN`
217        let entry = unsafe {
218            this.bucket
219                .entries
220                .get_unchecked_on_release(Self::get_index(this))
221        };
222        debug_assert!((entry.counter.load(Ordering::Relaxed) & REFCNT_MASK) > 0);
223        entry
224    }
225
226    pub fn strong_count(this: &Self) -> u8 {
227        let entry = Self::get_entry(this);
228        let cnt = entry.counter.load(Ordering::Relaxed) & REFCNT_MASK;
229        debug_assert!(cnt > 0);
230        cnt
231    }
232
233    pub fn is_removed(this: &Self) -> bool {
234        let counter = &Self::get_entry(this).counter;
235        let refcnt = counter.load(Ordering::Relaxed);
236
237        (refcnt & REMOVED_MASK) != 0
238    }
239
240    /// Remove this element.
241    ///
242    /// Return true if succeeds, false if it is already removed.
243    pub fn remove(this: &Self) -> bool {
244        let counter = &Self::get_entry(this).counter;
245        let mut refcnt = counter.load(Ordering::Relaxed);
246
247        loop {
248            debug_assert_ne!(refcnt & REFCNT_MASK, 0);
249
250            if (refcnt & REMOVED_MASK) != 0 {
251                // already removed
252                return false;
253            }
254
255            // Since the element is not removed, there is at least two ref to it:
256            //  - From the bucket itself
257            //  - From `self`
258            debug_assert_ne!(refcnt, 1);
259
260            match counter.compare_exchange_weak(
261                refcnt,
262                // Reduce refcnt by one since it is removed from bucket.
263                (refcnt - 1) | REMOVED_MASK,
264                Ordering::Relaxed,
265                Ordering::Relaxed,
266            ) {
267                Ok(_) => return true,
268                Err(new_refcnt) => refcnt = new_refcnt,
269            }
270        }
271    }
272}
273
274impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Deref
275    for ArenaArc<T, BITARRAY_LEN, LEN>
276{
277    type Target = T;
278
279    fn deref(&self) -> &Self::Target {
280        let ptr = Self::get_entry(self).val.get();
281
282        // Safety: `Self::get_index(this)` <= `LEN`
283        unsafe { (*ptr).as_ref().unwrap_unchecked_on_release() }
284    }
285}
286
287impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Clone
288    for ArenaArc<T, BITARRAY_LEN, LEN>
289{
290    fn clone(&self) -> Self {
291        let entry = Self::get_entry(self);
292
293        // According to [Boost documentation][1], increasing the refcount
294        // can be done using Relaxed operation since there are at least one
295        // reference alive.
296        //
297        // [1]: https://www.boost.org/doc/libs/1_77_0/doc/html/atomic/usage_examples.html
298        if (entry.counter.fetch_add(1, Ordering::Relaxed) & REFCNT_MASK) == MAX_REFCNT {
299            panic!("ArenaArc can have at most u8::MAX refcount");
300        }
301
302        Self {
303            slot: self.slot,
304            index: self.index,
305            bucket: Arc::clone(&self.bucket),
306        }
307    }
308}
309
310impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Drop
311    for ArenaArc<T, BITARRAY_LEN, LEN>
312{
313    fn drop(&mut self) {
314        let entry = Self::get_entry(self);
315
316        // According to [Boost documentation][1], decreasing refcount must be done
317        // using Release to ensure the write to the value happens before the
318        // reference is dropped.
319        //
320        // [1]: https://www.boost.org/doc/libs/1_77_0/doc/html/atomic/usage_examples.html
321        let prev_counter = entry.counter.fetch_sub(1, Ordering::Release);
322        let prev_refcnt = prev_counter & MAX_REFCNT;
323
324        debug_assert_ne!(prev_refcnt, 0);
325
326        if prev_refcnt == 1 {
327            debug_assert_eq!(prev_counter, REMOVED_MASK | 1);
328
329            // This is the last reference, drop the value.
330
331            // According to [Boost documentation][1], an Acquire fence must be used
332            // before dropping value to ensure that all write to the value happens
333            // before it is dropped.
334            fence(Ordering::Acquire);
335
336            // Now entry.counter == 0
337
338            // Safety: `entry.val` can only be accessed by this thread now.
339            let option = unsafe { &mut *entry.val.get() };
340            *option = None;
341
342            // Make sure drop is written to memory before
343            // the entry is reused again.
344            entry.counter.store(0, Ordering::Release);
345
346            // Safety:
347            //
348            // `Self::get_index(self)` <= `LEN` == `BITARRAY_LEN / usize::BITS`
349            unsafe { self.bucket.bitset.deallocate(Self::get_index(self)) };
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::Arc;
357    use super::ArenaArc;
358
359    use parking_lot::Mutex;
360    use parking_lot::MutexGuard;
361
362    use std::thread::sleep;
363    use std::thread::spawn;
364    use std::time::Duration;
365
366    use rayon::prelude::*;
367
368    const LEN: u32 = usize::BITS;
369    type Bucket<T> = super::Bucket<T, 1, { LEN as usize }>;
370
371    #[test]
372    fn test_basic() {
373        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
374
375        let arcs: Vec<_> = (0..LEN)
376            .into_par_iter()
377            .map(|i| {
378                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
379
380                assert_eq!(ArenaArc::strong_count(&arc), 2);
381                assert_eq!(*arc, i);
382
383                arc
384            })
385            .collect();
386
387        assert!(Bucket::try_insert(&bucket, 0, 0).is_err());
388
389        for (i, each) in arcs.iter().enumerate() {
390            assert_eq!((**each) as usize, i);
391        }
392
393        let arcs_get: Vec<_> = (&arcs)
394            .into_par_iter()
395            .enumerate()
396            .map(|(i, orig_arc)| {
397                let arc = unsafe { Bucket::get(Arc::clone(&bucket), 0, orig_arc.index) }.unwrap();
398
399                assert_eq!(ArenaArc::strong_count(&arc), 3);
400                assert_eq!(*arc as usize, i);
401
402                arc
403            })
404            .collect();
405
406        for (i, each) in arcs_get.iter().enumerate() {
407            assert_eq!((**each) as usize, i);
408        }
409    }
410
411    #[test]
412    fn test_clone() {
413        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
414
415        let arcs: Vec<_> = (0..LEN)
416            .into_par_iter()
417            .map(|i| {
418                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
419
420                assert_eq!(ArenaArc::strong_count(&arc), 2);
421                assert_eq!(*arc, i);
422
423                arc
424            })
425            .collect();
426
427        let arcs_cloned: Vec<_> = arcs
428            .iter()
429            .map(|arc| {
430                let new_arc = arc.clone();
431                assert_eq!(ArenaArc::strong_count(&new_arc), 3);
432                assert_eq!(ArenaArc::strong_count(arc), 3);
433
434                new_arc
435            })
436            .collect();
437
438        drop(arcs);
439        drop(bucket);
440
441        // bucket are dropped, however as long as the arcs
442        // are alive, these values are still kept alive.
443        for (i, each) in arcs_cloned.iter().enumerate() {
444            assert_eq!((**each) as usize, i);
445        }
446    }
447
448    #[test]
449    fn test_reuse() {
450        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
451
452        let mut arcs: Vec<_> = (0..LEN)
453            .into_par_iter()
454            .map(|i| {
455                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
456
457                assert_eq!(ArenaArc::strong_count(&arc), 2);
458                assert_eq!(*arc, i);
459
460                arc
461            })
462            .collect();
463
464        for arc in arcs.drain(arcs.len() / 2..) {
465            assert_eq!(ArenaArc::strong_count(&arc), 2);
466            let new_arc = unsafe { Bucket::remove(bucket.clone(), 0, arc.index) }.unwrap();
467            assert_eq!(ArenaArc::strong_count(&arc), 2);
468
469            assert!(ArenaArc::is_removed(&new_arc));
470
471            drop(new_arc);
472            assert_eq!(ArenaArc::strong_count(&arc), 1);
473        }
474
475        let new_arcs: Vec<_> = (LEN..LEN + LEN / 2)
476            .into_par_iter()
477            .map(|i| {
478                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
479
480                assert_eq!(ArenaArc::strong_count(&arc), 2);
481                assert_eq!(*arc, i);
482
483                arc
484            })
485            .collect();
486
487        let handle1 = spawn(move || {
488            arcs.into_par_iter().enumerate().for_each(|(i, each)| {
489                assert_eq!((*each) as usize, i);
490            });
491        });
492
493        let handle2 = spawn(move || {
494            new_arcs
495                .into_par_iter()
496                .zip(LEN..LEN + LEN / 2)
497                .for_each(|(each, i)| {
498                    assert_eq!(*each, i);
499                });
500        });
501
502        handle1.join().unwrap();
503        handle2.join().unwrap();
504    }
505
506    #[test]
507    fn test_reuse2() {
508        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
509
510        let mut arcs: Vec<_> = (0..LEN)
511            .into_par_iter()
512            .map(|i| {
513                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
514
515                assert_eq!(ArenaArc::strong_count(&arc), 2);
516                assert_eq!(*arc, i);
517
518                arc
519            })
520            .collect();
521
522        for arc in arcs.drain(arcs.len() / 2..) {
523            assert_eq!(ArenaArc::strong_count(&arc), 2);
524            ArenaArc::remove(&arc);
525            assert!(ArenaArc::is_removed(&arc));
526            assert_eq!(ArenaArc::strong_count(&arc), 1);
527        }
528
529        let new_arcs: Vec<_> = (LEN..LEN + LEN / 2)
530            .into_par_iter()
531            .map(|i| {
532                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
533
534                assert_eq!(ArenaArc::strong_count(&arc), 2);
535                assert_eq!(*arc, i);
536
537                arc
538            })
539            .collect();
540
541        let handle1 = spawn(move || {
542            arcs.into_par_iter().enumerate().for_each(|(i, each)| {
543                assert_eq!((*each) as usize, i);
544            });
545        });
546
547        let handle2 = spawn(move || {
548            new_arcs
549                .into_par_iter()
550                .zip(LEN..LEN + LEN / 2)
551                .for_each(|(each, i)| {
552                    assert_eq!(*each, i);
553                });
554        });
555
556        handle1.join().unwrap();
557        handle2.join().unwrap();
558    }
559
560    #[test]
561    fn test_concurrent_remove() {
562        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
563
564        let arcs: Vec<_> = (0..LEN)
565            .into_par_iter()
566            .map(|i| {
567                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
568
569                assert_eq!(ArenaArc::strong_count(&arc), 2);
570                assert_eq!(*arc, i);
571
572                arc
573            })
574            .collect();
575
576        arcs.into_par_iter().for_each(|arc| {
577            assert_eq!(ArenaArc::strong_count(&arc), 2);
578            let new_arc = unsafe { Bucket::remove(bucket.clone(), 0, arc.index) }.unwrap();
579            assert!(ArenaArc::is_removed(&new_arc));
580            assert_eq!(ArenaArc::strong_count(&arc), 2);
581
582            drop(new_arc);
583            assert_eq!(ArenaArc::strong_count(&arc), 1);
584        });
585    }
586
587    #[test]
588    fn test_concurrent_remove2() {
589        let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
590
591        let arcs: Vec<_> = (0..LEN)
592            .into_par_iter()
593            .map(|i| {
594                let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
595
596                assert_eq!(ArenaArc::strong_count(&arc), 2);
597                assert_eq!(*arc, i);
598
599                arc
600            })
601            .collect();
602
603        arcs.into_par_iter().for_each(|arc| {
604            assert_eq!(ArenaArc::strong_count(&arc), 2);
605            ArenaArc::remove(&arc);
606            assert!(ArenaArc::is_removed(&arc));
607            assert_eq!(ArenaArc::strong_count(&arc), 1);
608        });
609    }
610
611    #[test]
612    fn realworld_test() {
613        let bucket: Arc<Bucket<Mutex<u32>>> = Arc::new(Bucket::new());
614
615        (0..LEN).into_par_iter().for_each(|i| {
616            let arc = Bucket::try_insert(&bucket, 0, Mutex::new(i)).unwrap();
617
618            assert_eq!(ArenaArc::strong_count(&arc), 2);
619            assert_eq!(*arc.lock(), i);
620
621            let arc_cloned = arc.clone();
622
623            let f = move |mut guard: MutexGuard<'_, u32>| {
624                if *guard == i {
625                    *guard = i + 1;
626                } else if *guard == i + 1 {
627                    *guard = i + 2;
628                } else {
629                    panic!("");
630                }
631            };
632
633            let handle = spawn(move || {
634                sleep(Duration::from_micros(1));
635
636                f(arc_cloned.lock());
637            });
638
639            spawn(move || {
640                sleep(Duration::from_micros(1));
641                f(arc.lock());
642
643                handle.join().unwrap();
644
645                assert_eq!(*arc.lock(), i + 2);
646            });
647        });
648    }
649}