1use super::{bitmap::BitMap, Arc, OptionExt, SliceExt};
2
3use core::{array, cell::UnsafeCell, hint::spin_loop, ops::Deref};
4use std::sync::atomic::{fence, AtomicU8, Ordering};
5
6const REMOVED_MASK: u8 = 1 << (u8::BITS - 1);
7const REFCNT_MASK: u8 = !REMOVED_MASK;
8pub const MAX_REFCNT: u8 = REFCNT_MASK;
9
10#[derive(Debug)]
11struct Entry<T> {
12 counter: AtomicU8,
13 val: UnsafeCell<Option<T>>,
14}
15
16impl<T> Entry<T> {
17 const fn new() -> Self {
18 Self {
19 counter: AtomicU8::new(0),
20 val: UnsafeCell::new(None),
21 }
22 }
23}
24
25impl<T> Drop for Entry<T> {
26 fn drop(&mut self) {
27 let cnt = self.counter.load(Ordering::Acquire);
30
31 debug_assert!(cnt <= 1);
34
35 let val = self.val.get_mut().take();
36
37 if cnt == 0 {
38 debug_assert!(val.is_none());
39 } else {
40 debug_assert!(val.is_some());
41 }
42 }
43}
44
45#[derive(Debug)]
46pub(crate) struct Bucket<T, const BITARRAY_LEN: usize, const LEN: usize> {
47 bitset: BitMap<BITARRAY_LEN>,
48 entries: [Entry<T>; LEN],
49}
50
51unsafe impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Sync
52 for Bucket<T, BITARRAY_LEN, LEN>
53{
54}
55
56unsafe impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Send
57 for Bucket<T, BITARRAY_LEN, LEN>
58{
59}
60
61impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Default
62 for Bucket<T, BITARRAY_LEN, LEN>
63{
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Bucket<T, BITARRAY_LEN, LEN> {
70 pub(crate) fn new() -> Self {
71 Self {
72 bitset: BitMap::new(),
73 entries: array::from_fn(|_| Entry::new()),
74 }
75 }
76
77 pub(crate) fn try_insert(
78 this: &Arc<Self>,
79 bucket_index: u32,
80 value: T,
81 ) -> Result<ArenaArc<T, BITARRAY_LEN, LEN>, T> {
82 let index = match this.bitset.allocate() {
83 Some(index) => index,
84 None => return Err(value),
85 };
86
87 let entry = unsafe { this.entries.get_unchecked_on_release(index) };
89
90 let prev_refcnt = entry.counter.load(Ordering::Acquire);
93 debug_assert_eq!(prev_refcnt, 0);
94
95 let ptr = entry.val.get();
96 let res = unsafe { ptr.replace(Some(value)) };
98 debug_assert!(res.is_none());
99
100 if cfg!(debug_assertions) {
105 let prev_refcnt = entry.counter.swap(2, Ordering::Relaxed);
106 assert_eq!(prev_refcnt, 0);
107 } else {
108 entry.counter.store(2, Ordering::Relaxed);
109 }
110
111 let index = index as u32;
112
113 Ok(ArenaArc {
114 slot: bucket_index * (LEN as u32) + index,
115 index,
116 bucket: Arc::clone(this),
117 })
118 }
119
120 unsafe fn access_impl(
124 this: Arc<Self>,
125 bucket_index: u32,
126 index: u32,
127 update_refcnt: fn(u8) -> u8,
128 ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
129 if this.bitset.load(index) {
130 let counter = &this
131 .entries
132 .get_unchecked_on_release(index as usize)
133 .counter;
134 let mut refcnt = counter.load(Ordering::Relaxed);
135
136 loop {
137 if (refcnt & REMOVED_MASK) != 0 {
138 return None;
139 }
140
141 if refcnt == 0 {
142 spin_loop();
145 refcnt = counter.load(Ordering::Relaxed);
146 continue;
147 }
148
149 match counter.compare_exchange_weak(
150 refcnt,
151 update_refcnt(refcnt),
152 Ordering::Relaxed,
153 Ordering::Relaxed,
154 ) {
155 Ok(_) => break,
156 Err(new_refcnt) => refcnt = new_refcnt,
157 }
158 }
159
160 Some(ArenaArc {
161 slot: bucket_index * (LEN as u32) + index,
162 index,
163 bucket: this,
164 })
165 } else {
166 None
167 }
168 }
169
170 pub(crate) unsafe fn get(
174 this: Arc<Self>,
175 bucket_index: u32,
176 index: u32,
177 ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
178 Self::access_impl(this, bucket_index, index, |refcnt| refcnt + 1)
179 }
180
181 pub(crate) unsafe fn remove(
185 this: Arc<Self>,
186 bucket_index: u32,
187 index: u32,
188 ) -> Option<ArenaArc<T, BITARRAY_LEN, LEN>> {
189 Self::access_impl(this, bucket_index, index, |refcnt| refcnt | REMOVED_MASK)
190 }
191}
192
193#[derive(Debug)]
195pub struct ArenaArc<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> {
196 slot: u32,
197 index: u32,
198 bucket: Arc<Bucket<T, BITARRAY_LEN, LEN>>,
199}
200
201impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Unpin
202 for ArenaArc<T, BITARRAY_LEN, LEN>
203{
204}
205
206impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> ArenaArc<T, BITARRAY_LEN, LEN> {
207 pub fn slot(this: &Self) -> u32 {
208 this.slot
209 }
210
211 fn get_index(this: &Self) -> usize {
212 this.index as usize
213 }
214
215 fn get_entry(this: &Self) -> &Entry<T> {
216 let entry = unsafe {
218 this.bucket
219 .entries
220 .get_unchecked_on_release(Self::get_index(this))
221 };
222 debug_assert!((entry.counter.load(Ordering::Relaxed) & REFCNT_MASK) > 0);
223 entry
224 }
225
226 pub fn strong_count(this: &Self) -> u8 {
227 let entry = Self::get_entry(this);
228 let cnt = entry.counter.load(Ordering::Relaxed) & REFCNT_MASK;
229 debug_assert!(cnt > 0);
230 cnt
231 }
232
233 pub fn is_removed(this: &Self) -> bool {
234 let counter = &Self::get_entry(this).counter;
235 let refcnt = counter.load(Ordering::Relaxed);
236
237 (refcnt & REMOVED_MASK) != 0
238 }
239
240 pub fn remove(this: &Self) -> bool {
244 let counter = &Self::get_entry(this).counter;
245 let mut refcnt = counter.load(Ordering::Relaxed);
246
247 loop {
248 debug_assert_ne!(refcnt & REFCNT_MASK, 0);
249
250 if (refcnt & REMOVED_MASK) != 0 {
251 return false;
253 }
254
255 debug_assert_ne!(refcnt, 1);
259
260 match counter.compare_exchange_weak(
261 refcnt,
262 (refcnt - 1) | REMOVED_MASK,
264 Ordering::Relaxed,
265 Ordering::Relaxed,
266 ) {
267 Ok(_) => return true,
268 Err(new_refcnt) => refcnt = new_refcnt,
269 }
270 }
271 }
272}
273
274impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Deref
275 for ArenaArc<T, BITARRAY_LEN, LEN>
276{
277 type Target = T;
278
279 fn deref(&self) -> &Self::Target {
280 let ptr = Self::get_entry(self).val.get();
281
282 unsafe { (*ptr).as_ref().unwrap_unchecked_on_release() }
284 }
285}
286
287impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Clone
288 for ArenaArc<T, BITARRAY_LEN, LEN>
289{
290 fn clone(&self) -> Self {
291 let entry = Self::get_entry(self);
292
293 if (entry.counter.fetch_add(1, Ordering::Relaxed) & REFCNT_MASK) == MAX_REFCNT {
299 panic!("ArenaArc can have at most u8::MAX refcount");
300 }
301
302 Self {
303 slot: self.slot,
304 index: self.index,
305 bucket: Arc::clone(&self.bucket),
306 }
307 }
308}
309
310impl<T: Send + Sync, const BITARRAY_LEN: usize, const LEN: usize> Drop
311 for ArenaArc<T, BITARRAY_LEN, LEN>
312{
313 fn drop(&mut self) {
314 let entry = Self::get_entry(self);
315
316 let prev_counter = entry.counter.fetch_sub(1, Ordering::Release);
322 let prev_refcnt = prev_counter & MAX_REFCNT;
323
324 debug_assert_ne!(prev_refcnt, 0);
325
326 if prev_refcnt == 1 {
327 debug_assert_eq!(prev_counter, REMOVED_MASK | 1);
328
329 fence(Ordering::Acquire);
335
336 let option = unsafe { &mut *entry.val.get() };
340 *option = None;
341
342 entry.counter.store(0, Ordering::Release);
345
346 unsafe { self.bucket.bitset.deallocate(Self::get_index(self)) };
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use super::Arc;
357 use super::ArenaArc;
358
359 use parking_lot::Mutex;
360 use parking_lot::MutexGuard;
361
362 use std::thread::sleep;
363 use std::thread::spawn;
364 use std::time::Duration;
365
366 use rayon::prelude::*;
367
368 const LEN: u32 = usize::BITS;
369 type Bucket<T> = super::Bucket<T, 1, { LEN as usize }>;
370
371 #[test]
372 fn test_basic() {
373 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
374
375 let arcs: Vec<_> = (0..LEN)
376 .into_par_iter()
377 .map(|i| {
378 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
379
380 assert_eq!(ArenaArc::strong_count(&arc), 2);
381 assert_eq!(*arc, i);
382
383 arc
384 })
385 .collect();
386
387 assert!(Bucket::try_insert(&bucket, 0, 0).is_err());
388
389 for (i, each) in arcs.iter().enumerate() {
390 assert_eq!((**each) as usize, i);
391 }
392
393 let arcs_get: Vec<_> = (&arcs)
394 .into_par_iter()
395 .enumerate()
396 .map(|(i, orig_arc)| {
397 let arc = unsafe { Bucket::get(Arc::clone(&bucket), 0, orig_arc.index) }.unwrap();
398
399 assert_eq!(ArenaArc::strong_count(&arc), 3);
400 assert_eq!(*arc as usize, i);
401
402 arc
403 })
404 .collect();
405
406 for (i, each) in arcs_get.iter().enumerate() {
407 assert_eq!((**each) as usize, i);
408 }
409 }
410
411 #[test]
412 fn test_clone() {
413 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
414
415 let arcs: Vec<_> = (0..LEN)
416 .into_par_iter()
417 .map(|i| {
418 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
419
420 assert_eq!(ArenaArc::strong_count(&arc), 2);
421 assert_eq!(*arc, i);
422
423 arc
424 })
425 .collect();
426
427 let arcs_cloned: Vec<_> = arcs
428 .iter()
429 .map(|arc| {
430 let new_arc = arc.clone();
431 assert_eq!(ArenaArc::strong_count(&new_arc), 3);
432 assert_eq!(ArenaArc::strong_count(arc), 3);
433
434 new_arc
435 })
436 .collect();
437
438 drop(arcs);
439 drop(bucket);
440
441 for (i, each) in arcs_cloned.iter().enumerate() {
444 assert_eq!((**each) as usize, i);
445 }
446 }
447
448 #[test]
449 fn test_reuse() {
450 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
451
452 let mut arcs: Vec<_> = (0..LEN)
453 .into_par_iter()
454 .map(|i| {
455 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
456
457 assert_eq!(ArenaArc::strong_count(&arc), 2);
458 assert_eq!(*arc, i);
459
460 arc
461 })
462 .collect();
463
464 for arc in arcs.drain(arcs.len() / 2..) {
465 assert_eq!(ArenaArc::strong_count(&arc), 2);
466 let new_arc = unsafe { Bucket::remove(bucket.clone(), 0, arc.index) }.unwrap();
467 assert_eq!(ArenaArc::strong_count(&arc), 2);
468
469 assert!(ArenaArc::is_removed(&new_arc));
470
471 drop(new_arc);
472 assert_eq!(ArenaArc::strong_count(&arc), 1);
473 }
474
475 let new_arcs: Vec<_> = (LEN..LEN + LEN / 2)
476 .into_par_iter()
477 .map(|i| {
478 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
479
480 assert_eq!(ArenaArc::strong_count(&arc), 2);
481 assert_eq!(*arc, i);
482
483 arc
484 })
485 .collect();
486
487 let handle1 = spawn(move || {
488 arcs.into_par_iter().enumerate().for_each(|(i, each)| {
489 assert_eq!((*each) as usize, i);
490 });
491 });
492
493 let handle2 = spawn(move || {
494 new_arcs
495 .into_par_iter()
496 .zip(LEN..LEN + LEN / 2)
497 .for_each(|(each, i)| {
498 assert_eq!(*each, i);
499 });
500 });
501
502 handle1.join().unwrap();
503 handle2.join().unwrap();
504 }
505
506 #[test]
507 fn test_reuse2() {
508 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
509
510 let mut arcs: Vec<_> = (0..LEN)
511 .into_par_iter()
512 .map(|i| {
513 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
514
515 assert_eq!(ArenaArc::strong_count(&arc), 2);
516 assert_eq!(*arc, i);
517
518 arc
519 })
520 .collect();
521
522 for arc in arcs.drain(arcs.len() / 2..) {
523 assert_eq!(ArenaArc::strong_count(&arc), 2);
524 ArenaArc::remove(&arc);
525 assert!(ArenaArc::is_removed(&arc));
526 assert_eq!(ArenaArc::strong_count(&arc), 1);
527 }
528
529 let new_arcs: Vec<_> = (LEN..LEN + LEN / 2)
530 .into_par_iter()
531 .map(|i| {
532 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
533
534 assert_eq!(ArenaArc::strong_count(&arc), 2);
535 assert_eq!(*arc, i);
536
537 arc
538 })
539 .collect();
540
541 let handle1 = spawn(move || {
542 arcs.into_par_iter().enumerate().for_each(|(i, each)| {
543 assert_eq!((*each) as usize, i);
544 });
545 });
546
547 let handle2 = spawn(move || {
548 new_arcs
549 .into_par_iter()
550 .zip(LEN..LEN + LEN / 2)
551 .for_each(|(each, i)| {
552 assert_eq!(*each, i);
553 });
554 });
555
556 handle1.join().unwrap();
557 handle2.join().unwrap();
558 }
559
560 #[test]
561 fn test_concurrent_remove() {
562 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
563
564 let arcs: Vec<_> = (0..LEN)
565 .into_par_iter()
566 .map(|i| {
567 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
568
569 assert_eq!(ArenaArc::strong_count(&arc), 2);
570 assert_eq!(*arc, i);
571
572 arc
573 })
574 .collect();
575
576 arcs.into_par_iter().for_each(|arc| {
577 assert_eq!(ArenaArc::strong_count(&arc), 2);
578 let new_arc = unsafe { Bucket::remove(bucket.clone(), 0, arc.index) }.unwrap();
579 assert!(ArenaArc::is_removed(&new_arc));
580 assert_eq!(ArenaArc::strong_count(&arc), 2);
581
582 drop(new_arc);
583 assert_eq!(ArenaArc::strong_count(&arc), 1);
584 });
585 }
586
587 #[test]
588 fn test_concurrent_remove2() {
589 let bucket: Arc<Bucket<u32>> = Arc::new(Bucket::new());
590
591 let arcs: Vec<_> = (0..LEN)
592 .into_par_iter()
593 .map(|i| {
594 let arc = Bucket::try_insert(&bucket, 0, i).unwrap();
595
596 assert_eq!(ArenaArc::strong_count(&arc), 2);
597 assert_eq!(*arc, i);
598
599 arc
600 })
601 .collect();
602
603 arcs.into_par_iter().for_each(|arc| {
604 assert_eq!(ArenaArc::strong_count(&arc), 2);
605 ArenaArc::remove(&arc);
606 assert!(ArenaArc::is_removed(&arc));
607 assert_eq!(ArenaArc::strong_count(&arc), 1);
608 });
609 }
610
611 #[test]
612 fn realworld_test() {
613 let bucket: Arc<Bucket<Mutex<u32>>> = Arc::new(Bucket::new());
614
615 (0..LEN).into_par_iter().for_each(|i| {
616 let arc = Bucket::try_insert(&bucket, 0, Mutex::new(i)).unwrap();
617
618 assert_eq!(ArenaArc::strong_count(&arc), 2);
619 assert_eq!(*arc.lock(), i);
620
621 let arc_cloned = arc.clone();
622
623 let f = move |mut guard: MutexGuard<'_, u32>| {
624 if *guard == i {
625 *guard = i + 1;
626 } else if *guard == i + 1 {
627 *guard = i + 2;
628 } else {
629 panic!("");
630 }
631 };
632
633 let handle = spawn(move || {
634 sleep(Duration::from_micros(1));
635
636 f(arc_cloned.lock());
637 });
638
639 spawn(move || {
640 sleep(Duration::from_micros(1));
641 f(arc.lock());
642
643 handle.join().unwrap();
644
645 assert_eq!(*arc.lock(), i + 2);
646 });
647 });
648 }
649}