use std::{
fmt::{Debug, Formatter, Pointer},
marker::PhantomData,
mem::{forget, size_of},
sync::atomic::{AtomicUsize, Ordering},
};
use atomic::Atomic;
use static_assertions::const_assert;
use crate::ebr_impl::{Guard, Tagged};
use crate::utils::{Raw, RcInner};
use crate::{CompareExchangeError, Rc, RcObject, Snapshot};
pub struct AtomicWeak<T> {
pub(crate) link: Atomic<Raw<T>>,
}
unsafe impl<T: Send + Sync> Send for AtomicWeak<T> {}
unsafe impl<T: Send + Sync> Sync for AtomicWeak<T> {}
const_assert!(Atomic::<Raw<u8>>::is_lock_free());
const_assert!(size_of::<Raw<u8>>() == size_of::<usize>());
const_assert!(size_of::<Atomic<Raw<u8>>>() == size_of::<AtomicUsize>());
impl<T> AtomicWeak<T> {
#[inline(always)]
pub fn null() -> Self {
Self {
link: Atomic::new(Tagged::null()),
}
}
#[inline]
pub fn load<'g>(&self, order: Ordering, guard: &'g Guard) -> WeakSnapshot<'g, T> {
WeakSnapshot::from_raw(self.link.load(order), guard)
}
#[inline]
pub fn store(&self, ptr: Weak<T>, order: Ordering, guard: &Guard) {
let new_ptr = ptr.ptr;
forget(ptr);
let old_ptr = self.link.swap(new_ptr, order);
unsafe {
if let Some(cnt) = old_ptr.as_raw().as_mut() {
RcInner::decrement_weak(cnt, Some(guard));
}
}
}
#[inline(always)]
pub fn swap(&self, new: Weak<T>, order: Ordering) -> Weak<T> {
let new_ptr = new.into_raw();
let old_ptr = self.link.swap(new_ptr, order);
Weak::from_raw(old_ptr)
}
#[inline(always)]
pub fn compare_exchange<'g>(
&self,
expected: WeakSnapshot<'g, T>,
desired: Weak<T>,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<Weak<T>, CompareExchangeError<Weak<T>, WeakSnapshot<'g, T>>> {
match self
.link
.compare_exchange(expected.ptr, desired.ptr, success, failure)
{
Ok(_) => {
forget(desired);
let weak = Weak::from_raw(expected.ptr);
Ok(weak)
}
Err(current) => {
let current = WeakSnapshot::from_raw(current, guard);
Err(CompareExchangeError { desired, current })
}
}
}
#[inline(always)]
pub fn compare_exchange_weak<'g>(
&self,
expected: WeakSnapshot<'g, T>,
desired: Weak<T>,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<Weak<T>, CompareExchangeError<Weak<T>, WeakSnapshot<'g, T>>> {
match self
.link
.compare_exchange_weak(expected.ptr, desired.ptr, success, failure)
{
Ok(_) => {
forget(desired);
let weak = Weak::from_raw(expected.ptr);
Ok(weak)
}
Err(current) => {
let current = WeakSnapshot::from_raw(current, guard);
Err(CompareExchangeError { desired, current })
}
}
}
#[inline]
pub fn compare_exchange_tag<'g>(
&self,
expected: WeakSnapshot<'g, T>,
desired_tag: usize,
success: Ordering,
failure: Ordering,
guard: &'g Guard,
) -> Result<WeakSnapshot<'g, T>, CompareExchangeError<WeakSnapshot<'g, T>, WeakSnapshot<'g, T>>>
{
let desired_raw = expected.ptr.with_tag(desired_tag);
match self
.link
.compare_exchange(expected.ptr, desired_raw, success, failure)
{
Ok(current) => Ok(WeakSnapshot::from_raw(current, guard)),
Err(current) => Err(CompareExchangeError {
desired: WeakSnapshot::from_raw(desired_raw, guard),
current: WeakSnapshot::from_raw(current, guard),
}),
}
}
pub fn get_mut(&mut self) -> &mut Weak<T> {
unsafe { core::mem::transmute(self.link.get_mut()) }
}
}
impl<T> From<Weak<T>> for AtomicWeak<T> {
#[inline]
fn from(value: Weak<T>) -> Self {
let init_ptr = value.into_raw();
Self {
link: Atomic::new(init_ptr),
}
}
}
impl<T> From<&Weak<T>> for AtomicWeak<T> {
#[inline]
fn from(value: &Weak<T>) -> Self {
Self::from(value.clone())
}
}
impl<T: RcObject> From<&Rc<T>> for AtomicWeak<T> {
#[inline]
fn from(value: &Rc<T>) -> Self {
Self::from(value.downgrade())
}
}
impl<T> Debug for AtomicWeak<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.link.load(Ordering::Relaxed), f)
}
}
impl<T> Pointer for AtomicWeak<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.link.load(Ordering::Relaxed), f)
}
}
impl<T> Drop for AtomicWeak<T> {
#[inline(always)]
fn drop(&mut self) {
let ptr = (*self.link.get_mut()).as_raw();
unsafe {
if let Some(cnt) = ptr.as_mut() {
RcInner::decrement_weak(cnt, None);
}
}
}
}
impl<T> Default for AtomicWeak<T> {
#[inline(always)]
fn default() -> Self {
Self::null()
}
}
pub struct Weak<T> {
ptr: Raw<T>,
}
unsafe impl<T: Send + Sync> Send for Weak<T> {}
unsafe impl<T: Send + Sync> Sync for Weak<T> {}
impl<T> Clone for Weak<T> {
fn clone(&self) -> Self {
let weak = Self { ptr: self.ptr };
unsafe {
if let Some(cnt) = weak.ptr.as_raw().as_ref() {
cnt.increment_weak(1);
}
}
weak
}
}
impl<T> Weak<T> {
#[inline(always)]
pub fn null() -> Self {
Self::from_raw(Raw::null())
}
#[inline(always)]
pub fn is_null(&self) -> bool {
self.ptr.is_null()
}
#[inline(always)]
pub(crate) fn from_raw(ptr: Raw<T>) -> Self {
Self { ptr }
}
#[inline(always)]
pub fn tag(&self) -> usize {
self.ptr.tag()
}
#[inline(always)]
pub fn with_tag(mut self, tag: usize) -> Self {
self.ptr = self.ptr.with_tag(tag);
self
}
#[inline]
pub fn snapshot<'g>(&self, guard: &'g Guard) -> WeakSnapshot<'g, T> {
WeakSnapshot::from_raw(self.ptr, guard)
}
#[inline]
pub(crate) fn into_raw(self) -> Raw<T> {
let new_ptr = self.ptr;
forget(self);
new_ptr
}
#[inline]
pub(crate) fn increment_weak(&self) {
if let Some(ptr) = unsafe { self.ptr.as_raw().as_ref() } {
ptr.increment_weak(1);
}
}
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}
impl<T: RcObject> Weak<T> {
#[inline]
pub fn upgrade(&self) -> Option<Rc<T>> {
let Some(obj) = (unsafe { self.ptr.as_raw().as_ref() }) else {
return Some(Rc::from_raw(self.ptr));
};
if obj.increment_strong() {
return Some(Rc::from_raw(self.ptr));
}
None
}
}
impl<T> Drop for Weak<T> {
#[inline(always)]
fn drop(&mut self) {
unsafe {
if let Some(cnt) = self.ptr.as_raw().as_mut() {
RcInner::decrement_weak(cnt, None);
}
}
}
}
impl<'g, T> From<WeakSnapshot<'g, T>> for Weak<T> {
fn from(value: WeakSnapshot<'g, T>) -> Self {
value.counted()
}
}
impl<'g, T: RcObject> From<Snapshot<'g, T>> for Weak<T> {
fn from(value: Snapshot<'g, T>) -> Self {
value.downgrade().counted()
}
}
impl<T> Debug for Weak<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.ptr, f)
}
}
impl<T> Pointer for Weak<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}
pub struct WeakSnapshot<'g, T> {
pub(crate) ptr: Raw<T>,
pub(crate) _marker: PhantomData<&'g T>,
}
impl<'g, T> Clone for WeakSnapshot<'g, T> {
fn clone(&self) -> Self {
*self
}
}
impl<'g, T> Copy for WeakSnapshot<'g, T> {}
impl<'g, T> WeakSnapshot<'g, T> {
#[inline(always)]
pub fn is_null(&self) -> bool {
self.ptr.is_null()
}
#[inline]
pub fn counted(self) -> Weak<T> {
let weak = Weak::from_raw(self.ptr);
weak.increment_weak();
weak
}
pub fn upgrade(self) -> Option<Snapshot<'g, T>> {
let ptr = self.ptr;
if !ptr.is_null() && !unsafe { ptr.deref() }.is_not_destructed() {
return None;
}
Some(Snapshot {
ptr,
_marker: PhantomData,
})
}
#[inline(always)]
pub fn tag(self) -> usize {
self.ptr.tag()
}
#[inline]
pub fn with_tag(self, tag: usize) -> Self {
let mut result = self;
result.ptr = result.ptr.with_tag(tag);
result
}
}
impl<'g, T> WeakSnapshot<'g, T> {
#[inline(always)]
pub fn null() -> Self {
Self {
ptr: Tagged::null(),
_marker: PhantomData,
}
}
#[inline]
pub(crate) fn from_raw(acquired: Raw<T>, _: &'g Guard) -> Self {
Self {
ptr: acquired,
_marker: PhantomData,
}
}
#[inline]
pub fn ptr_eq(self, other: Self) -> bool {
self.ptr.ptr_eq(other.ptr)
}
}
impl<'g, T> Default for WeakSnapshot<'g, T> {
#[inline]
fn default() -> Self {
Self::null()
}
}
impl<'g, T: RcObject> From<Snapshot<'g, T>> for WeakSnapshot<'g, T> {
fn from(value: Snapshot<'g, T>) -> Self {
value.downgrade()
}
}
impl<'g, T> Debug for WeakSnapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Debug::fmt(&self.ptr, f)
}
}
impl<'g, T> Pointer for WeakSnapshot<'g, T> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
Pointer::fmt(&self.ptr, f)
}
}