use crate::{Ref, compare::CasResult};
use std::sync::{
Arc,
atomic::{AtomicPtr, Ordering},
};
#[derive(Clone, Debug)]
pub struct AtomPtr<T> {
inner: Arc<AtomicPtr<Arc<T>>>,
}
impl<T: Default> Default for AtomPtr<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T> PartialEq for AtomPtr<T> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.get_ref().inner, &other.get_ref().inner)
}
}
impl<T> Drop for AtomPtr<T> {
fn drop(&mut self) {
if Arc::strong_count(&self.inner) == 1 {
let ptr = self.inner.load(Ordering::Acquire);
let _b = unsafe { Box::from_raw(ptr) };
}
}
}
impl<T> AtomPtr<T> {
fn make_raw_ptr(t: T) -> *mut Arc<T> {
Box::into_raw(Box::new(Arc::new(t)))
}
pub fn new(t: T) -> Self {
let ptr = Self::make_raw_ptr(t);
let inner = Arc::new(AtomicPtr::from(ptr));
Self { inner }
}
pub fn get_ref(&self) -> Ref<T> {
let ptr = self.inner.load(Ordering::Relaxed);
let b = unsafe { Box::from_raw(ptr) };
let arc = Arc::clone(&*b);
std::mem::forget(b);
Ref {
inner: Box::new(arc),
ptr,
}
}
pub unsafe fn inplace_mut(&self) -> Option<&mut T> {
unsafe { self.inner.load(Ordering::Relaxed).as_mut() }.and_then(|arc| Arc::get_mut(arc))
}
pub fn swap(&self, new: T) -> Ref<T> {
let new = Self::make_raw_ptr(new);
let prev = self.inner.swap(new, Ordering::AcqRel);
let inner = unsafe { Box::from_raw(prev) };
Ref { inner, ptr: prev }
}
pub fn compare_exchange(&self, prev: Ref<T>, new: T) -> CasResult<T> {
let new = Self::make_raw_ptr(new);
let prev: *const Arc<T> = prev.as_ptr();
let prev_mut = prev as *mut Arc<T>;
match self
.inner
.compare_exchange(prev_mut, new, Ordering::SeqCst, Ordering::Acquire)
{
Ok(t) => CasResult::Success(Ref {
inner: unsafe { Box::from_raw(t) },
ptr: t,
}),
Err(t) => CasResult::Failure(Ref {
inner: unsafe { Box::from_raw(t) },
ptr: t,
}),
}
}
pub fn compare_exchange_weak(&self, prev: Ref<T>, new: T) -> CasResult<T> {
let new = Self::make_raw_ptr(new);
let prev: *const Arc<T> = prev.as_ptr();
let prev_mut = prev as *mut Arc<T>;
match self
.inner
.compare_exchange_weak(prev_mut, new, Ordering::SeqCst, Ordering::Acquire)
{
Ok(t) => CasResult::Success(Ref {
inner: unsafe { Box::from_raw(t) },
ptr: t,
}),
Err(t) => CasResult::Failure(Ref {
inner: unsafe { Box::from_raw(t) },
ptr: t,
}),
}
}
}