use crate::sync::{RawMutex, WatchGuardMut, WatchGuardRef};
use crossbeam_utils::CachePadded;
use std::cell::UnsafeCell;
use std::mem::{self, MaybeUninit};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::{self, AtomicUsize, Ordering};
struct Inner<T> {
val: UnsafeCell<MaybeUninit<T>>,
state: RawMutex,
ref_count: CachePadded<AtomicUsize>,
}
impl<T> Inner<T> {
fn new(val: T) -> Self {
Self {
val: UnsafeCell::new(MaybeUninit::new(val)),
state: RawMutex::new(),
ref_count: CachePadded::new(AtomicUsize::new(1)),
}
}
}
unsafe impl<T: Send> Send for AtomicCell<T> {}
unsafe impl<T: Sync> Sync for AtomicCell<T> {}
impl<T> UnwindSafe for AtomicCell<T> {}
impl<T> RefUnwindSafe for AtomicCell<T> {}
#[repr(transparent)]
pub struct AtomicCell<T> {
ptr: *const Inner<T>,
}
impl<T> AtomicCell<T> {
pub fn new(val: T) -> Self {
let inner = Box::new(Inner::new(val));
let ptr = Box::into_raw(inner);
Self { ptr }
}
#[inline(always)]
fn inner(&self) -> &Inner<T> {
unsafe { &*self.ptr }
}
pub fn get(&self) -> WatchGuardRef<'_, T> {
let lock = &self.inner().state;
lock.lock_shared();
let val = unsafe { (&*self.inner().val.get()).assume_init_ref() };
WatchGuardRef::new(val, lock)
}
pub fn with<R>(&self, f: impl FnOnce(&T) -> R) -> R {
let guard = self.get();
f(&guard)
}
pub fn get_mut(&self) -> WatchGuardMut<'_, T> {
let lock = &self.inner().state;
lock.lock_exclusive();
let val = unsafe { (&mut *self.inner().val.get()).assume_init_mut() };
WatchGuardMut::new(val, lock)
}
pub fn with_mut<R>(&self, f: impl FnOnce(&mut T) -> R) -> R {
let mut guard = self.get_mut();
f(&mut guard)
}
#[inline]
pub fn as_ptr(&self) -> *mut T {
self.inner().val.get().cast::<T>()
}
pub fn into_inner(self) -> T {
let ref_count = self.inner().ref_count.load(Ordering::Acquire);
assert_eq!(
ref_count, 1,
"cannot call into_inner with multiple references"
);
let ptr = self.ptr as *mut Inner<T>;
mem::forget(self);
let boxed = unsafe { Box::from_raw(ptr) };
let value = unsafe { boxed.val.get().read().assume_init() };
value
}
pub fn update<F>(&self, mut f: F)
where
F: FnMut(&mut T),
{
let mut guard = self.get_mut();
f(&mut *guard);
}
pub fn swap(&self, val: T) -> T {
let mut guard = self.get_mut();
mem::replace(&mut *guard, val)
}
pub fn store(&self, val: T) {
let mut guard = self.get_mut();
*guard = val;
}
}
impl<T: Copy + Eq> AtomicCell<T> {
pub fn compare_exchange(&self, current: T, new: T) -> Result<T, WatchGuardMut<'_, T>> {
let mut guard = self.get_mut();
if *guard == current {
let old = mem::replace(&mut *guard, new);
Ok(old)
} else {
Err(guard)
}
}
}
impl<T> Clone for AtomicCell<T> {
fn clone(&self) -> Self {
self.inner().ref_count.fetch_add(1, Ordering::Relaxed);
Self { ptr: self.ptr }
}
}
impl<T> Drop for AtomicCell<T> {
fn drop(&mut self) {
let inner = unsafe { &*self.ptr };
if inner.ref_count.fetch_sub(1, Ordering::Release) == 1 {
atomic::fence(Ordering::Acquire);
unsafe {
let val_ptr = (*self.ptr).val.get();
std::ptr::drop_in_place((*val_ptr).assume_init_mut());
drop(Box::from_raw(self.ptr as *mut Inner<T>));
}
}
}
}