use core::{borrow, cmp, fmt, hash, mem, ops, ptr};
use core::alloc::Layout;
use core::cell::Cell;
use crate::uninit::{Uninit, UninitView};
pub struct Rc<'a, T> {
inner: UninitView<'a, RcBox<T>>,
}
pub struct Weak<'a, T> {
inner: UninitView<'a, RcBox<T>>,
}
#[repr(C)]
struct RcBox<T> {
val: T,
strong: Cell<usize>,
weak: Cell<usize>,
}
impl<'a, T> Rc<'a, T> {
pub fn new(val: T, memory: Uninit<'a, ()>) -> Self {
assert!(memory.fits(Self::layout()), "Provided memory must fit the inner layout");
let mut memory = memory.cast::<RcBox<T>>().unwrap();
memory.borrow_mut().init(RcBox {
val,
strong: Cell::new(1),
weak: Cell::new(1),
});
Rc {
inner: memory.into(),
}
}
pub unsafe fn from_raw(init: Uninit<'a, T>) -> Self {
debug_assert!(init.fits(Self::layout()), "Provided memory must fit the inner layout");
let inner = init.cast().unwrap();
Rc {
inner: inner.into(),
}
}
pub fn into_raw(rc: Self) -> Result<Uninit<'a, T>, Self> {
if !Rc::is_unique(&rc) {
return Err(rc);
}
let ptr = rc.inner.as_non_null();
let len = rc.inner.size();
mem::forget(rc);
unsafe {
Ok(Uninit::from_memory(ptr.cast(), len).cast().unwrap())
}
}
pub fn try_unwrap(rc: Self) -> Result<(T, Weak<'a, T>), Self> {
if Rc::strong_count(&rc) != 1 {
return Err(rc);
}
rc.dec_strong();
let val = unsafe { ptr::read(rc.as_ptr()) };
let weak = Weak { inner: rc.inner };
mem::forget(rc);
Ok((val, weak))
}
pub fn downgrade(rc: &Self) -> Weak<'a, T> {
rc.inc_weak();
Weak { inner: rc.inner }
}
}
impl<T> Rc<'_, T> {
pub fn layout() -> Layout {
Layout::new::<RcBox<T>>()
}
pub fn weak_count(rc: &Self) -> usize {
rc.inner().weak.get()
}
pub fn strong_count(rc: &Self) -> usize {
rc.inner().strong.get()
}
pub fn get_mut(rc: &mut Self) -> Option<&mut T> {
if rc.is_unique() {
Some(unsafe { &mut *rc.as_mut_ptr() })
} else {
None
}
}
pub fn ptr_eq(this: &Self, other: &Self) -> bool {
this.inner.as_ptr() == other.inner.as_ptr()
}
fn inner(&self) -> &RcBox<T> {
unsafe {
self.inner.as_ref()
}
}
fn is_unique(&self) -> bool {
Rc::strong_count(self) == 1 && Rc::weak_count(self) == 1
}
fn as_mut_ptr(&mut self) -> *mut T {
self.inner.as_ptr() as *mut T
}
fn as_ptr(&self) -> *const T {
self.inner.as_ptr() as *const T
}
fn inc_strong(&self) {
let val = Self::strong_count(self) + 1;
self.inner().strong.set(val);
}
fn dec_strong(&self) {
let val = Self::strong_count(self) - 1;
self.inner().strong.set(val);
}
fn inc_weak(&self) {
let val = Self::weak_count(self) + 1;
self.inner().weak.set(val);
}
fn dec_weak(&self) {
let val = Self::weak_count(self) - 1;
self.inner().weak.set(val);
}
}
impl<'a, T> Weak<'a, T> {
pub fn try_unwrap(self) -> Result<Uninit<'a, ()>, Self> {
if !self.is_unique_to_rc_memory() {
return Err(self);
}
let ptr = self.inner.as_non_null();
let len = self.inner.size();
unsafe {
Ok(Uninit::from_memory(ptr.cast(), len))
}
}
pub fn upgrade(&self) -> Option<Rc<'a, T>> {
if self.strong_count() == 0 {
None
} else {
let rc = Rc { inner: self.inner };
rc.inc_strong();
Some(rc)
}
}
}
impl<T> Weak<'_, T> {
pub fn strong_count(&self) -> usize {
self.strong().get()
}
pub fn weak_count(&self) -> usize {
self.weak().get()
}
fn is_unique_to_rc_memory(&self) -> bool {
self.strong_count() == 0 && self.weak_count() == 1
}
fn weak(&self) -> &Cell<usize> {
unsafe { &(*self.inner.as_ptr()).weak }
}
fn strong(&self) -> &Cell<usize> {
unsafe { &(*self.inner.as_ptr()).strong }
}
fn inc_weak(&self) {
let val = Weak::weak_count(self);
self.weak().set(val + 1);
}
fn dec_weak(&self) {
let val = Weak::weak_count(self);
self.weak().set(val - 1);
}
}
impl<T> Drop for Rc<'_, T> {
fn drop(&mut self) {
self.dec_strong();
if Rc::strong_count(self) == 0 {
self.dec_weak();
unsafe {
ptr::drop_in_place(self.as_mut_ptr())
}
}
}
}
impl<T> ops::Deref for Rc<'_, T> {
type Target = T;
fn deref(&self) -> &T {
&self.inner().val
}
}
impl<T> Clone for Rc<'_, T> {
fn clone(&self) -> Self {
self.inc_strong();
Rc {
inner: self.inner,
}
}
}
impl<T> Drop for Weak<'_, T> {
fn drop(&mut self) {
self.dec_weak();
}
}
impl<T> Clone for Weak<'_, T> {
fn clone(&self) -> Self {
self.inc_weak();
Weak {
inner: self.inner,
}
}
}
impl<'a, 'b, T: PartialEq> PartialEq<Rc<'b, T>> for Rc<'a, T> {
#[inline]
fn eq(&self, other: &Rc<T>) -> bool {
PartialEq::eq(&**self, &**other)
}
#[inline]
fn ne(&self, other: &Rc<T>) -> bool {
PartialEq::ne(&**self, &**other)
}
}
impl<T: Eq> Eq for Rc<'_, T> { }
impl<'a, 'b, T: PartialOrd> PartialOrd<Rc<'b, T>> for Rc<'a, T> {
#[inline]
fn partial_cmp(&self, other: &Rc<T>) -> Option<cmp::Ordering> {
PartialOrd::partial_cmp(&**self, &**other)
}
#[inline]
fn lt(&self, other: &Rc<T>) -> bool {
PartialOrd::lt(&**self, &**other)
}
#[inline]
fn le(&self, other: &Rc<T>) -> bool {
PartialOrd::le(&**self, &**other)
}
#[inline]
fn ge(&self, other: &Rc<T>) -> bool {
PartialOrd::ge(&**self, &**other)
}
#[inline]
fn gt(&self, other: &Rc<T>) -> bool {
PartialOrd::gt(&**self, &**other)
}
}
impl<T: Ord> Ord for Rc<'_, T> {
#[inline]
fn cmp(&self, other: &Rc<T>) -> cmp::Ordering {
Ord::cmp(&**self, &**other)
}
}
impl<T: hash::Hash> hash::Hash for Rc<'_, T> {
fn hash<H: hash::Hasher>(&self, state: &mut H) {
(**self).hash(state)
}
}
impl<T: fmt::Display> fmt::Display for Rc<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Display::fmt(&**self, f)
}
}
impl<T: fmt::Debug> fmt::Debug for Rc<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&**self, f)
}
}
impl<T> fmt::Pointer for Rc<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Pointer::fmt(&self.as_ptr(), f)
}
}
impl<T> borrow::Borrow<T> for Rc<'_, T> {
fn borrow(&self) -> &T {
&**self
}
}
impl<T> AsRef<T> for Rc<'_, T> {
fn as_ref(&self) -> &T {
&**self
}
}
#[cfg(test)]
mod tests {
use core::alloc::Layout;
use core::cell::Cell;
use super::{RcBox, Rc, Weak};
use static_alloc::Bump;
use crate::alloc::LocalAllocLeakExt;
#[test]
fn layout_box_compatible() {
let mut boxed = RcBox {
val: 0usize,
strong: Cell::new(1),
weak: Cell::new(1),
};
let box_ptr = &mut boxed as *mut RcBox<usize>;
let val_ptr = box_ptr as *const usize;
assert_eq!(unsafe { *val_ptr }, 0);
unsafe { (*box_ptr).val = 0xdeadbeef };
assert_eq!(unsafe { *val_ptr }, 0xdeadbeef);
}
#[test]
fn control_through_counters() {
struct Duck;
struct NeverDrop;
impl Drop for NeverDrop {
fn drop(&mut self) {
panic!("dropped!");
}
}
let slab: Bump<[u8; 1024]> = Bump::uninit();
let rc = slab.rc(NeverDrop).unwrap();
rc.inc_strong();
drop(rc);
let mut rc = slab.rc(Duck).unwrap();
assert_eq!(rc.as_mut_ptr() as *const u8, rc.inner.as_ptr() as *const u8);
assert_eq!(rc.as_ptr() as *const u8, rc.inner.as_ptr() as *const u8);
let rc = slab.rc(Duck).unwrap();
let inner = rc.inner;
drop(rc);
unsafe {
assert_eq!((*inner.as_ptr()).strong.get(), 0);
assert_eq!((*inner.as_ptr()).weak.get(), 0);
}
let rc = slab.rc(Duck).unwrap();
let (_, weak) = Rc::try_unwrap(rc).ok().unwrap();
assert_eq!(Weak::strong_count(&weak), 0);
assert_eq!(Weak::weak_count(&weak), 1);
let inner = weak.inner;
drop(weak);
unsafe {
assert_eq!((*inner.as_ptr()).strong.get(), 0);
assert_eq!((*inner.as_ptr()).weak.get(), 0);
}
}
#[test]
#[should_panic = "inner layout"]
fn wrong_layout_panics() {
use core::convert::TryInto;
struct Foo(u32);
let slab: Bump<[u8; 1024]> = Bump::uninit();
let layout = Layout::new::<Foo>().try_into().unwrap();
let wrong_alloc = slab.alloc_layout(layout).unwrap();
let _ = Rc::new(Foo(0), wrong_alloc.uninit);
}
}