use std::{
array,
fmt::{Debug, Formatter, Pointer},
hash::{Hash, Hasher},
marker::PhantomData,
mem::{forget, size_of},
sync::atomic::{AtomicUsize, Ordering},
};
use atomic::Atomic;
use static_assertions::const_assert;
use crate::ebr_impl::{global_epoch, Guard, Tagged};
use crate::utils::{Raw, RcInner};
use crate::{Weak, WeakSnapshot};
pub unsafe trait RcObject: Sized {
fn pop_edges(&mut self, out: &mut Vec<Rc<Self>>);
}
impl<T> Tagged<RcInner<T>> {
fn with_timestamp(self) -> Self {
if self.is_null() {
self
} else {
self.with_high_tag(global_epoch())
}
}
}
pub struct CompareExchangeError<P, S> {
pub desired: P,
pub current: S,
}
pub struct AtomicRc<T: RcObject> {
link: Atomic<Raw<T>>,
_marker: PhantomData<T>,
}
unsafe impl<T: RcObject + Send + Sync> Send for AtomicRc<T> {}
unsafe impl<T: RcObject + Send + Sync> Sync for AtomicRc<T> {}
const_assert!(Atomic::<Raw<u8>>::is_lock_free());
const_assert!(size_of::<Raw<u8>>() == size_of::<usize>());
const_assert!(size_of::<Atomic<Raw<u8>>>() == size_of::<AtomicUsize>());
impl<T: RcObject> AtomicRc<T> {
#[inline(always)]
pub fn new(obj: T) -> Self {
Self {
link: Atomic::new(Rc::<T>::new(obj).into_raw()),
_marker: PhantomData,
}
}
#[inline(always)]
pub fn null() -> Self {
Self {
link: Atomic::new(Tagged::null()),
_marker: PhantomData,
}
}
#[inline]
pub fn load<'g>(&self, order: Ordering, guard: &'g Guard) -> Snapshot<'g, T> {
Snapshot::from_raw(self.link.load(order), guard)
}
#[inline]
pub fn store(&self, ptr: Rc<T>, order: Ordering, guard: &Guard) {
let new_ptr = ptr.ptr;
let old_ptr = self.link.swap(new_ptr.with_timestamp(), order);
forget(ptr);
unsafe {
if let Some(cnt) = old_ptr.as_raw().as_mut() {
RcInner::decrement_strong(cnt, 1, Some(guard));
}
}
}
#[inline(always)]
pub fn swap(&self, new: Rc<T>, order: Ordering) -> Rc<T> {
let new_ptr = new.into_raw();
let old_ptr = self.link.swap(new_ptr.with_timestamp(), order);
Rc::from_raw(old_ptr)
}
#[inline(always)]
pub fn compare_exchange<'g>(
&self,
expected: Snapshot<'g, T>,
desired: Rc<T>,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<Rc<T>, CompareExchangeError<Rc<T>, Snapshot<'g, T>>> {
let mut expected_raw = expected.ptr;
let desired_raw = desired.ptr.with_timestamp();
loop {
match self
.link
.compare_exchange(expected_raw, desired_raw, success, failure)
{
Ok(_) => {
forget(desired);
let rc = Rc::from_raw(expected_raw);
return Ok(rc);
}
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
let current = Snapshot::from_raw(current_raw, guard);
return Err(CompareExchangeError { desired, current });
}
}
}
}
}
#[inline(always)]
pub fn compare_exchange_weak<'g>(
&self,
expected: Snapshot<'g, T>,
desired: Rc<T>,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<Rc<T>, CompareExchangeError<Rc<T>, Snapshot<'g, T>>> {
let mut expected_raw = expected.ptr;
let desired_raw = desired.ptr.with_timestamp();
loop {
match self
.link
.compare_exchange_weak(expected_raw, desired_raw, success, failure)
{
Ok(_) => {
forget(desired);
let rc = Rc::from_raw(expected_raw);
return Ok(rc);
}
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
let current = Snapshot::from_raw(current_raw, guard);
return Err(CompareExchangeError { desired, current });
}
}
}
}
}
#[inline]
pub fn compare_exchange_tag<'g>(
&self,
expected: Snapshot<'g, T>,
desired_tag: usize,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<Snapshot<'g, T>, CompareExchangeError<Snapshot<'g, T>, Snapshot<'g, T>>> {
let mut expected_raw = expected.ptr;
let desired_raw = expected_raw.with_tag(desired_tag).with_timestamp();
loop {
match self
.link
.compare_exchange(expected_raw, desired_raw, success, failure)
{
Ok(current_raw) => return Ok(Snapshot::from_raw(current_raw, guard)),
Err(current_raw) => {
if current_raw.ptr_eq(expected_raw) {
expected_raw = current_raw;
} else {
return Err(CompareExchangeError {
desired: Snapshot::from_raw(desired_raw, guard),
current: Snapshot::from_raw(current_raw, guard),
});
}
}
}
}
}
#[inline]
pub fn take(&mut self) -> Rc<T> {
Rc::from_raw(core::mem::take(self.link.get_mut()))
}
}
impl<T: RcObject> Drop for AtomicRc<T> {
#[inline(always)]
fn drop(&mut self) {
let ptr = (*self.link.get_mut()).as_raw();
unsafe {
if let Some(cnt) = ptr.as_mut() {
RcInner::decrement_strong(cnt, 1, None);
}
}
}
}
impl<T: RcObject> Default for AtomicRc<T> {
#[inline(always)]
fn default() -> Self {
Self::null()
}
}
impl<T: RcObject> From<Rc<T>> for AtomicRc<T> {
#[inline]
fn from(value: Rc<T>) -> Self {
let ptr = value.into_raw();
Self {
link: Atomic::new(ptr),
_marker: PhantomData,
}
}
}
impl<T: RcObject> From<&Rc<T>> for AtomicRc<T> {
#[inline]
fn from(value: &Rc<T>) -> Self {
Self::from(value.clone())
}
}
impl<T: RcObject> Debug for AtomicRc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.link.load(Ordering::Relaxed), f)
}
}
impl<T: RcObject> Pointer for AtomicRc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.link.load(Ordering::Relaxed), f)
}
}
pub struct Rc<T: RcObject> {
ptr: Raw<T>,
_marker: PhantomData<T>,
}
unsafe impl<T: RcObject + Send + Sync> Send for Rc<T> {}
unsafe impl<T: RcObject + Send + Sync> Sync for Rc<T> {}
impl<T: RcObject> Clone for Rc<T> {
fn clone(&self) -> Self {
let rc = Self {
ptr: self.ptr,
_marker: PhantomData,
};
unsafe {
if let Some(cnt) = rc.ptr.as_raw().as_ref() {
cnt.increment_strong();
}
}
rc
}
}
impl<T: RcObject> Rc<T> {
#[inline(always)]
pub fn null() -> Self {
Self::from_raw(Raw::null())
}
#[inline(always)]
pub fn is_null(&self) -> bool {
self.ptr.is_null()
}
#[inline(always)]
pub(crate) fn from_raw(ptr: Raw<T>) -> Self {
Self {
ptr,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn new(obj: T) -> Self {
let ptr = RcInner::alloc(obj, 1);
Self {
ptr: Raw::from(ptr),
_marker: PhantomData,
}
}
#[inline(always)]
pub fn new_many<const N: usize>(obj: T) -> [Self; N] {
let ptr = RcInner::alloc(obj, N as _);
[(); N].map(|_| Self {
ptr: Raw::from(ptr),
_marker: PhantomData,
})
}
#[inline(always)]
pub fn new_many_iter(obj: T, count: usize) -> NewRcIter<T> {
let ptr = RcInner::alloc(obj, count as _);
NewRcIter {
remain: count,
ptr: Raw::from(ptr),
}
}
#[inline]
pub fn weak_many<const N: usize>(&self) -> [Weak<T>; N] {
if let Some(cnt) = unsafe { self.ptr.as_raw().as_ref() } {
cnt.increment_weak(N as u32);
}
array::from_fn(|_| Weak::null())
}
#[inline(always)]
pub fn tag(&self) -> usize {
self.ptr.tag()
}
#[inline(always)]
pub fn with_tag(mut self, tag: usize) -> Self {
self.ptr = self.ptr.with_tag(tag);
self
}
#[inline]
pub(crate) fn into_raw(self) -> Raw<T> {
let new_ptr = self.ptr;
forget(self);
new_ptr
}
#[inline]
pub fn finalize(self, guard: &Guard) {
unsafe {
if let Some(cnt) = self.ptr.as_raw().as_mut() {
RcInner::decrement_strong(cnt, 1, Some(guard));
}
}
forget(self);
}
#[inline]
pub fn downgrade(&self) -> Weak<T> {
unsafe {
if let Some(cnt) = self.ptr.as_raw().as_ref() {
cnt.increment_weak(1);
return Weak::from_raw(self.ptr);
}
}
Weak::from_raw(self.ptr)
}
#[inline]
pub fn snapshot<'g>(&self, guard: &'g Guard) -> Snapshot<'g, T> {
Snapshot::from_raw(self.ptr, guard)
}
#[inline]
pub unsafe fn deref(&self) -> &T {
self.ptr.deref().data()
}
#[inline]
pub unsafe fn deref_mut(&mut self) -> &mut T {
self.ptr.deref_mut().data_mut()
}
#[inline]
pub fn as_ref(&self) -> Option<&T> {
if self.ptr.is_null() {
None
} else {
Some(unsafe { self.deref() })
}
}
#[inline]
pub unsafe fn as_mut(&mut self) -> Option<&mut T> {
if self.ptr.is_null() {
None
} else {
Some(unsafe { self.deref_mut() })
}
}
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}
impl<'g, T: RcObject> From<Snapshot<'g, T>> for Rc<T> {
fn from(value: Snapshot<'g, T>) -> Self {
value.counted()
}
}
impl<T: RcObject + Debug> Debug for Rc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(cnt) = self.as_ref() {
f.debug_tuple("RcObject").field(cnt).finish()
} else {
f.write_str("Null")
}
}
}
impl<T: RcObject> Pointer for Rc<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}
impl<T: RcObject> Default for Rc<T> {
#[inline]
fn default() -> Self {
Self::null()
}
}
impl<T: RcObject> Drop for Rc<T> {
#[inline(always)]
fn drop(&mut self) {
unsafe {
if let Some(cnt) = self.ptr.as_raw().as_mut() {
RcInner::decrement_strong(cnt, 1, None);
}
}
}
}
impl<T: RcObject + PartialEq> PartialEq for Rc<T> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl<T: RcObject + Eq> Eq for Rc<T> {}
impl<T: RcObject + Hash> Hash for Rc<T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}
impl<T: RcObject + PartialOrd> PartialOrd for Rc<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(&other.as_ref())
}
}
impl<T: RcObject + Ord> Ord for Rc<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(&other.as_ref())
}
}
pub struct NewRcIter<T: RcObject> {
remain: usize,
ptr: Raw<T>,
}
impl<T: RcObject> Iterator for NewRcIter<T> {
type Item = Rc<T>;
#[inline]
fn next(&mut self) -> Option<Self::Item> {
if self.remain == 0 {
None
} else {
self.remain -= 1;
Some(Rc {
ptr: self.ptr,
_marker: PhantomData,
})
}
}
}
impl<T: RcObject> NewRcIter<T> {
#[inline]
pub fn abort(self, guard: &Guard) {
if self.remain > 0 {
unsafe {
RcInner::decrement_strong(self.ptr.as_raw(), self.remain as _, Some(guard));
};
}
forget(self);
}
}
impl<T: RcObject> Drop for NewRcIter<T> {
#[inline]
fn drop(&mut self) {
if self.remain > 0 {
unsafe {
RcInner::decrement_strong(self.ptr.as_raw(), self.remain as _, None);
};
}
}
}
pub struct Snapshot<'g, T> {
pub(crate) ptr: Raw<T>,
pub(crate) _marker: PhantomData<&'g T>,
}
impl<'g, T> Clone for Snapshot<'g, T> {
fn clone(&self) -> Self {
*self
}
}
impl<'g, T> Copy for Snapshot<'g, T> {}
impl<'g, T: RcObject> Snapshot<'g, T> {
#[inline(always)]
pub fn is_null(&self) -> bool {
self.ptr.is_null()
}
#[inline]
pub fn counted(self) -> Rc<T> {
let rc = Rc::from_raw(self.ptr);
unsafe {
if let Some(cnt) = rc.ptr.as_raw().as_ref() {
cnt.increment_strong();
}
}
rc
}
#[inline]
pub fn downgrade(self) -> WeakSnapshot<'g, T> {
WeakSnapshot {
ptr: self.ptr,
_marker: PhantomData,
}
}
#[inline(always)]
pub fn tag(self) -> usize {
self.ptr.tag()
}
#[inline]
pub fn with_tag(self, tag: usize) -> Self {
let mut result = self;
result.ptr = result.ptr.with_tag(tag);
result
}
#[inline]
pub unsafe fn deref(self) -> &'g T {
self.ptr.deref().data()
}
#[inline]
pub unsafe fn deref_mut(mut self) -> &'g mut T {
self.ptr.deref_mut().data_mut()
}
#[inline]
pub fn as_ref(self) -> Option<&'g T> {
if self.ptr.is_null() {
None
} else {
Some(unsafe { self.deref() })
}
}
#[inline]
pub unsafe fn as_mut(self) -> Option<&'g mut T> {
if self.ptr.is_null() {
None
} else {
Some(unsafe { self.deref_mut() })
}
}
#[inline]
pub fn ptr_eq(self, other: Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}
impl<'g, T> Snapshot<'g, T> {
#[inline(always)]
pub fn null() -> Self {
Self {
ptr: Tagged::null(),
_marker: PhantomData,
}
}
#[inline]
pub(crate) fn from_raw(acquired: Raw<T>, _: &'g Guard) -> Self {
Self {
ptr: acquired,
_marker: PhantomData,
}
}
}
impl<'g, T: RcObject> Default for Snapshot<'g, T> {
#[inline]
fn default() -> Self {
Self::null()
}
}
impl<'g, T: RcObject + PartialEq> PartialEq for Snapshot<'g, T> {
#[inline(always)]
fn eq(&self, other: &Self) -> bool {
self.as_ref() == other.as_ref()
}
}
impl<'g, T: RcObject + Eq> Eq for Snapshot<'g, T> {}
impl<'g, T: RcObject + Hash> Hash for Snapshot<'g, T> {
fn hash<H: Hasher>(&self, state: &mut H) {
self.as_ref().hash(state);
}
}
impl<'g, T: RcObject + PartialOrd> PartialOrd for Snapshot<'g, T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.as_ref().partial_cmp(&other.as_ref())
}
}
impl<'g, T: RcObject + Ord> Ord for Snapshot<'g, T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.as_ref().cmp(&other.as_ref())
}
}
impl<'g, T: RcObject + Debug> Debug for Snapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if let Some(cnt) = self.as_ref() {
f.debug_tuple("RcObject").field(cnt).finish()
} else {
f.write_str("Null")
}
}
}
impl<'g, T: RcObject> Pointer for Snapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}