use crate::Refcounted;
use alloc::boxed::Box;
use core::cell::Cell;
use core::fmt;
use core::mem::{self, ManuallyDrop};
use core::ptr;
use core::sync::atomic::Ordering;
#[cfg(feature = "atomic")]
use core::sync::atomic::{self, AtomicUsize};
#[repr(C)]
pub struct Inner<T: ?Sized + Refcounted> {
pub refcnt: T::Rc,
pub data: ManuallyDrop<T>,
}
unsafe fn data_offset<T: ?Sized + Refcounted>(ptr: *const T) -> usize {
let align = mem::align_of_val(&*ptr);
mem::size_of::<T::Rc>().wrapping_add(align).wrapping_sub(1) & !align.wrapping_sub(1)
}
unsafe fn set_data_ptr<T: ?Sized, U>(mut ptr: *mut T, data: *mut U) -> *mut T {
ptr::write(&mut ptr as *mut *mut T as *mut *mut u8, data as *mut u8);
ptr
}
impl<T: ?Sized + Refcounted> Inner<T> {
pub(crate) unsafe fn cast(val: *mut T) -> *mut Self {
let offset = data_offset(val);
let new_data = (val as *mut u8).sub(offset);
let ptr = set_data_ptr(val, new_data) as *mut Self;
debug_assert_eq!(&*(*ptr).data as *const T, val as *const T);
ptr
}
}
pub unsafe trait Refcount: Sized {
type Metadata;
unsafe fn new() -> Self;
unsafe fn inc_strong<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>;
unsafe fn dec_strong<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>;
unsafe fn strong_count<T: ?Sized>(ptr: *const Inner<T>) -> usize
where
T: Refcounted<Rc = Self>;
}
pub unsafe trait WeakRefcount: Refcount {
unsafe fn inc_weak<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>;
unsafe fn dec_weak<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>;
unsafe fn upgrade<T: ?Sized>(ptr: *const Inner<T>) -> bool
where
T: Refcounted<Rc = Self>;
unsafe fn weak_count<T: ?Sized>(ptr: *const Inner<T>) -> usize
where
T: Refcounted<Rc = Self>;
}
const MAX_REFCOUNT: usize = isize::max_value() as usize;
#[inline(never)]
#[cold]
fn abort() -> ! {
struct Bomb;
impl Drop for Bomb {
fn drop(&mut self) {
panic!("double-panicing to force abort");
}
}
let _bomb = Bomb;
panic!("aborting due to excessively large reference count");
}
macro_rules! dec_strong_impl {
() => {
unsafe fn dec_strong<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>,
{
let drop_fields = || {
ManuallyDrop::drop(&mut (*(ptr as *mut Inner<T>)).data);
};
if (*ptr).refcnt.inner.dec_strong(drop_fields) {
drop(Box::from_raw(ptr as *mut Inner<T>));
}
}
};
(finalize) => {
unsafe fn dec_strong<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>,
{
let drop_fields = || {
ManuallyDrop::drop(&mut (*(ptr as *mut Inner<T>)).data);
};
let finalize = || {
let finalize_fn = (*ptr).data.refcount_metadata();
finalize_fn((&(*ptr).data) as *const _ as *const u8);
};
if (*ptr)
.refcnt
.inner
.dec_strong_finalize(drop_fields, finalize)
{
drop(Box::from_raw(ptr as *mut Inner<T>));
}
}
};
}
macro_rules! metadata_type {
() => { () };
(finalize) => { unsafe fn(*const u8) };
}
macro_rules! decl_refcnt {
($name:ident, [$strong:ty $(, $weak:ty)?] $(, $finalize:ident)?) => {
pub struct $name {
inner: RefcntImpl<$strong $(, $weak)?>,
}
unsafe impl Refcount for $name {
type Metadata = metadata_type!($($finalize)?);
unsafe fn new() -> Self {
$name {
inner: RefcntImpl::new(),
}
}
unsafe fn inc_strong<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>,
{
(*ptr).refcnt.inner.inc_strong()
}
dec_strong_impl!($($finalize)?);
unsafe fn strong_count<T: ?Sized>(ptr: *const Inner<T>) -> usize
where
T: Refcounted<Rc = Self>,
{
(*ptr).refcnt.inner.strong_count()
}
}
$(
unsafe impl WeakRefcount for $name {
unsafe fn inc_weak<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>,
{
type _Ignore = $weak;
(*ptr).refcnt.inner.inc_weak()
}
unsafe fn dec_weak<T: ?Sized>(ptr: *const Inner<T>)
where
T: Refcounted<Rc = Self>,
{
if (*ptr).refcnt.inner.dec_weak() {
drop(Box::from_raw(ptr as *mut Inner<T>));
}
}
unsafe fn upgrade<T: ?Sized>(ptr: *const Inner<T>) -> bool
where
T: Refcounted<Rc = Self>,
{
(*ptr).refcnt.inner.upgrade()
}
unsafe fn weak_count<T: ?Sized>(ptr: *const Inner<T>) -> usize
where
T: Refcounted<Rc = Self>,
{
(*ptr).refcnt.inner.weak_count()
}
}
)?
impl fmt::Debug for $name {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct(stringify!($name))
.field("strong", &self.inner.strong_count())
$(
.field("weak", {
type _Ignore = $weak;
&self.inner.weak_count()
})
)?
.finish()
}
}
}
}
decl_refcnt!(Local, [Cell<usize>]);
decl_refcnt!(LocalWeak, [Cell<usize>, Cell<usize>]);
#[cfg(feature = "atomic")]
decl_refcnt!(Atomic, [AtomicUsize]);
#[cfg(feature = "atomic")]
decl_refcnt!(AtomicWeak, [AtomicUsize, AtomicUsize]);
decl_refcnt!(LocalFinalize, [Cell<usize>], finalize);
decl_refcnt!(LocalWeakFinalize, [Cell<usize>, Cell<usize>], finalize);
#[cfg(feature = "atomic")]
decl_refcnt!(AtomicFinalize, [AtomicUsize], finalize);
#[cfg(feature = "atomic")]
decl_refcnt!(AtomicWeakFinalize, [AtomicUsize, AtomicUsize], finalize);
trait IncDecCount {
fn new() -> Self;
fn load(&self) -> usize;
fn dec(&self) -> usize;
fn inc(&self) -> usize;
fn fence(order: Ordering);
}
#[cfg(feature = "atomic")]
impl IncDecCount for AtomicUsize {
fn new() -> Self {
AtomicUsize::new(1)
}
fn load(&self) -> usize {
self.load(Ordering::SeqCst)
}
fn dec(&self) -> usize {
self.fetch_sub(1, Ordering::Release)
}
fn inc(&self) -> usize {
let prev = self.fetch_add(1, Ordering::Relaxed);
if prev > MAX_REFCOUNT - 1 {
abort();
}
prev
}
fn fence(order: Ordering) {
atomic::fence(order)
}
}
impl IncDecCount for Cell<usize> {
fn new() -> Self {
Cell::new(1)
}
fn load(&self) -> usize {
self.get()
}
fn dec(&self) -> usize {
let prev = self.get();
self.set(prev - 1);
prev
}
fn inc(&self) -> usize {
let prev = self.get();
let next = prev.checked_add(1).unwrap_or_else(|| abort());
self.set(next);
prev
}
fn fence(_: Ordering) {}
}
impl IncDecCount for () {
fn new() -> Self {
()
}
fn load(&self) -> usize {
unreachable!()
}
fn dec(&self) -> usize {
1
}
fn inc(&self) -> usize {
unreachable!()
}
fn fence(_: Ordering) {
unreachable!()
}
}
struct RefcntImpl<T, W = ()> {
strong: T,
weak: W,
}
impl<T: IncDecCount, W: IncDecCount> RefcntImpl<T, W> {
unsafe fn new() -> Self {
RefcntImpl {
strong: T::new(),
weak: W::new(),
}
}
fn strong_count(&self) -> usize {
self.strong.load()
}
fn weak_count(&self) -> usize {
let weak = self.weak.load();
let strong = self.strong.load();
if strong == 0 {
0
} else {
weak - 1
}
}
unsafe fn inc_strong(&self) {
self.strong.inc();
}
unsafe fn dec_strong(&self, drop_fields: impl FnOnce()) -> bool {
let old_count = self.strong.dec();
if old_count != 1 {
return false;
}
T::fence(Ordering::Acquire);
drop_fields();
self.dec_weak()
}
unsafe fn inc_weak(&self) {
self.weak.inc();
}
unsafe fn dec_weak(&self) -> bool {
self.weak.dec() == 1
}
}
#[cfg(feature = "atomic")]
impl RefcntImpl<AtomicUsize, ()> {
unsafe fn dec_strong_finalize(
&self,
drop_fields: impl FnOnce(),
finalize: impl FnOnce(),
) -> bool {
let old_count = self.strong.fetch_sub(1, Ordering::Release);
if old_count != 1 {
return false;
}
self.strong.store(1, Ordering::Relaxed);
atomic::fence(Ordering::Acquire);
finalize();
self.dec_strong(drop_fields)
}
}
#[cfg(feature = "atomic")]
impl RefcntImpl<AtomicUsize, AtomicUsize> {
unsafe fn upgrade(&self) -> bool {
let mut prev = self.strong.load(Ordering::Relaxed);
while prev != 0 {
if prev > MAX_REFCOUNT - 1 {
abort();
}
match self.strong.compare_exchange_weak(
prev,
prev + 1,
Ordering::Relaxed,
Ordering::Relaxed,
) {
Ok(_) => return true,
Err(actual) => prev = actual,
}
}
false
}
unsafe fn dec_strong_finalize(
&self,
drop_fields: impl FnOnce(),
finalize: impl FnOnce(),
) -> bool {
let mut old_count = self.strong.load(Ordering::Relaxed);
while old_count != 1 {
match self.strong.compare_exchange_weak(
old_count,
old_count - 1,
Ordering::Release,
Ordering::Relaxed,
) {
Ok(_) => return false,
Err(actual) => old_count = actual,
}
}
debug_assert!(old_count == 1);
atomic::fence(Ordering::Acquire);
finalize();
self.dec_strong(drop_fields)
}
}
impl<W> RefcntImpl<Cell<usize>, W>
where
W: IncDecCount,
{
unsafe fn upgrade(&self) -> bool {
let prev_value = self.strong.get();
if prev_value == 0 {
return false;
}
if prev_value > MAX_REFCOUNT - 1 {
abort();
}
self.strong.set(prev_value + 1);
true
}
unsafe fn dec_strong_finalize(
&self,
drop_fields: impl FnOnce(),
finalize: impl FnOnce(),
) -> bool {
let prev_value = self.strong.get();
if prev_value != 1 {
self.strong.set(prev_value - 1);
return false;
}
finalize();
self.dec_strong(drop_fields)
}
}