use crate::{
rng::{CryptoRng, Rng},
sync::{CachePadded, SeqLock},
};
use core::{
cell::UnsafeCell,
fmt::{Debug, Formatter},
mem::{self, ManuallyDrop},
panic::{RefUnwindSafe, UnwindSafe},
ptr,
};
const LEN: usize = 67;
pub struct AtomicCell<T> {
value: UnsafeCell<T>,
}
impl<T> AtomicCell<T> {
#[inline]
pub const fn new(value: T) -> AtomicCell<T> {
AtomicCell { value: UnsafeCell::new(value) }
}
#[inline]
pub fn into_inner(self) -> T {
let this = ManuallyDrop::new(self);
unsafe { this.as_ptr().read() }
}
#[inline]
pub fn store(&self, value: T) {
if const { mem::needs_drop::<T>() } {
drop(self.swap(value));
} else {
let dst = self.as_ptr();
let _guard = lock(dst.addr()).write();
unsafe {
ptr::write(dst, value);
}
}
}
#[inline]
pub fn swap(&self, value: T) -> T {
let dst = self.value.get();
let _guard = lock(dst.addr()).write();
unsafe { ptr::replace(dst, value) }
}
const fn as_ptr(&self) -> *mut T {
self.value.get()
}
}
impl<T> AtomicCell<T>
where
T: Copy,
{
#[inline]
pub fn load(&self) -> T {
let src = self.as_ptr();
let lock = lock(src.addr());
if let Some(stamp) = lock.optimistic_read() {
let value = unsafe { ptr::read_volatile(src) };
if lock.validate_read(stamp) {
return value;
}
}
let guard = lock.write();
let value = unsafe { ptr::read(src) };
guard.abort();
value
}
}
impl<T> AtomicCell<T>
where
T: Copy + Eq,
{
#[inline]
pub fn compare_exchange(&self, curr: T, new: T) -> Result<T, T> {
let dest = self.as_ptr();
let guard = lock(dest.addr()).write();
if T::eq(unsafe { &*dest }, &curr) {
Ok(unsafe { ptr::replace(dest, new) })
} else {
let elem = unsafe { ptr::read(dest) };
guard.abort();
Err(elem)
}
}
#[inline]
pub fn update(&self, mut cb: impl FnMut(T) -> T) -> T {
let mut prev = self.load();
loop {
match self.compare_exchange(prev, cb(prev)) {
Ok(elem) => return elem,
Err(next_prev) => prev = next_prev,
}
}
}
#[inline]
pub fn try_update(&self, mut cb: impl FnMut(T) -> Option<T>) -> Result<T, T> {
let mut prev = self.load();
while let Some(next) = cb(prev) {
match self.compare_exchange(prev, next) {
Ok(elem) => return Ok(elem),
Err(next_prev) => prev = next_prev,
}
}
Err(prev)
}
}
impl<T> CryptoRng for AtomicCell<T> where T: Copy + Eq + CryptoRng {}
impl<T> CryptoRng for &AtomicCell<T> where T: Copy + Eq + CryptoRng {}
impl<T> Rng for AtomicCell<T>
where
T: Copy + Eq + Rng,
{
#[inline]
fn u8_4(&mut self) -> [u8; 4] {
(&*self).u8_4()
}
#[inline]
fn u8_8(&mut self) -> [u8; 8] {
(&*self).u8_8()
}
#[inline]
fn u8_16(&mut self) -> [u8; 16] {
(&*self).u8_16()
}
#[inline]
fn u8_32(&mut self) -> [u8; 32] {
(&*self).u8_32()
}
}
impl<T> Rng for &AtomicCell<T>
where
T: Copy + Eq + Rng,
{
#[inline]
fn u8_4(&mut self) -> [u8; 4] {
let mut ret = [0; 4];
let _rslt = self.update(|mut el| {
ret = el.u8_4();
el
});
ret
}
#[inline]
fn u8_8(&mut self) -> [u8; 8] {
let mut ret = [0; 8];
let _rslt = self.update(|mut el| {
ret = el.u8_8();
el
});
ret
}
#[inline]
fn u8_16(&mut self) -> [u8; 16] {
let mut ret = [0; 16];
let _rslt = self.update(|mut el| {
ret = el.u8_16();
el
});
ret
}
#[inline]
fn u8_32(&mut self) -> [u8; 32] {
let mut ret = [0; 32];
let _rslt = self.update(|mut el| {
ret = el.u8_32();
el
});
ret
}
}
impl<T> Debug for AtomicCell<T>
where
T: Copy + Debug,
{
#[inline]
fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
f.debug_struct("AtomicCell").field("value", &self.load()).finish()
}
}
impl<T> Default for AtomicCell<T>
where
T: Default,
{
#[inline]
fn default() -> AtomicCell<T> {
AtomicCell::new(T::default())
}
}
impl<T> Drop for AtomicCell<T> {
#[inline]
fn drop(&mut self) {
if mem::needs_drop::<T>() {
unsafe {
self.as_ptr().drop_in_place();
}
}
}
}
impl<T> From<T> for AtomicCell<T> {
#[inline]
fn from(value: T) -> AtomicCell<T> {
AtomicCell::new(value)
}
}
impl<T> RefUnwindSafe for AtomicCell<T> {}
unsafe impl<T> Send for AtomicCell<T> where T: Send {}
unsafe impl<T: Send> Sync for AtomicCell<T> {}
impl<T> UnwindSafe for AtomicCell<T> where T: Send {}
#[expect(clippy::indexing_slicing, reason = "modulo result will always be in-bounds")]
fn lock(addr: usize) -> &'static SeqLock {
static LOCKS: [CachePadded<SeqLock>; LEN] = [const { CachePadded(SeqLock::new()) }; LEN];
&LOCKS[addr % LEN].0
}
#[cfg(feature = "rand_core")]
mod rand_core {
use crate::{rng::Rng, sync::AtomicCell};
use core::convert::Infallible;
impl<T> rand_core::TryCryptoRng for AtomicCell<T> where T: Copy + Eq + Rng {}
impl<T> rand_core::TryCryptoRng for &AtomicCell<T> where T: Copy + Eq + Rng {}
impl<T> rand_core::TryRng for AtomicCell<T>
where
T: Copy + Eq + Rng,
{
type Error = Infallible;
#[inline(always)]
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(u32::from_le_bytes(self.u8_4()))
}
#[inline(always)]
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(u64::from_le_bytes(self.u8_8()))
}
#[inline(always)]
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
self.fill_slice(dst);
Ok(())
}
}
impl<T> rand_core::TryRng for &AtomicCell<T>
where
T: Copy + Eq + Rng,
{
type Error = Infallible;
#[inline(always)]
fn try_next_u32(&mut self) -> Result<u32, Self::Error> {
Ok(u32::from_le_bytes(self.u8_4()))
}
#[inline(always)]
fn try_next_u64(&mut self) -> Result<u64, Self::Error> {
Ok(u64::from_le_bytes(self.u8_8()))
}
#[inline(always)]
fn try_fill_bytes(&mut self, dst: &mut [u8]) -> Result<(), Self::Error> {
self.fill_slice(dst);
Ok(())
}
}
}
#[cfg(feature = "serde")]
mod serde {
use crate::sync::AtomicCell;
use serde::{Serialize, Serializer};
impl<T> Serialize for AtomicCell<T>
where
T: Copy + Serialize,
{
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
T::serialize(&self.load(), serializer)
}
}
}