#![no_std]
#[cfg(test)]
extern crate alloc;
pub use bytemuck::{Pod, Zeroable};
use core::cell::UnsafeCell;
use core::mem::{align_of, size_of};
use core::sync::atomic::*;
#[repr(transparent)]
pub struct TearCell<T>(UnsafeCell<T>);
unsafe impl<T: Pod + Send> Sync for TearCell<T> {}
unsafe impl<T: Pod + Send> Send for TearCell<T> {}
impl<T> TearCell<T> {
#[inline]
pub const fn new(v: T) -> Self {
Self(UnsafeCell::new(v))
}
#[inline]
pub fn into_inner(self) -> T {
self.0.into_inner()
}
}
#[inline]
fn tearcell_can_use_atom<T, U: Atom>() -> bool {
(align_of::<T>() >= align_of::<U>())
&& (size_of::<T>() >= size_of::<U>())
&& (size_of::<T>() % size_of::<U>()) == 0
}
impl<T: Pod> TearCell<T> {
#[inline]
pub fn store(&self, value: T) {
self.store_ref(&value)
}
#[inline]
pub fn load(&self) -> T {
if size_of::<T>() == 0 {
T::zeroed()
} else if tearcell_can_use_atom::<T, AtomicU64>() {
self.do_load::<AtomicU64>()
} else if tearcell_can_use_atom::<T, AtomicU32>() {
self.do_load::<AtomicU32>()
} else if tearcell_can_use_atom::<T, AtomicU16>() {
self.do_load::<AtomicU16>()
} else if tearcell_can_use_atom::<T, AtomicU8>() {
self.do_load::<AtomicU8>()
} else {
unreachable!();
}
}
#[inline]
pub fn store_ref(&self, value: &T) {
if size_of::<T>() == 0 {
return;
} else if tearcell_can_use_atom::<T, AtomicU64>() {
self.do_store::<AtomicU64>(value)
} else if tearcell_can_use_atom::<T, AtomicU32>() {
self.do_store::<AtomicU32>(value)
} else if tearcell_can_use_atom::<T, AtomicU16>() {
self.do_store::<AtomicU16>(value)
} else if tearcell_can_use_atom::<T, AtomicU8>() {
self.do_store::<AtomicU8>(value)
} else {
unreachable!();
}
}
#[inline]
fn atom_slice<A: Atom>(&self) -> &[A] {
let size = size_of::<T>() / size_of::<A>();
assert!(size != 0);
assert!(size * size_of::<A>() == size_of::<T>());
unsafe { core::slice::from_raw_parts(self.0.get() as *const A, size) }
}
#[inline]
fn do_load<A: Atom>(&self) -> T {
let mut result = T::zeroed();
let src: &[A] = self.atom_slice();
let dst: &mut [A::Prim] =
bytemuck::try_cast_slice_mut(core::slice::from_mut(&mut result)).unwrap();
assert_eq!(src.len(), dst.len());
for (db, sb) in dst.iter_mut().zip(src.iter()) {
*db = sb.get();
}
result
}
#[inline]
fn do_store<A: Atom>(&self, v: &T) {
let src: &[A::Prim] = bytemuck::try_cast_slice(core::slice::from_ref(v)).unwrap();
let dst: &[A] = self.atom_slice();
assert_eq!(src.len(), dst.len());
for (d, s) in dst.iter().zip(src.iter()) {
d.set(*s);
}
}
}
unsafe trait Atom: Sync + Send + 'static + Sized {
type Prim: Pod;
fn get(&self) -> Self::Prim;
fn set(&self, p: Self::Prim);
}
macro_rules! def_atom {
($Atom:ty, $prim:ty) => {
unsafe impl Atom for $Atom {
type Prim = $prim;
#[inline]
fn get(&self) -> Self::Prim {
self.load(Ordering::Relaxed)
}
#[inline]
fn set(&self, v: Self::Prim) {
self.store(v, Ordering::Relaxed);
}
}
};
}
def_atom!(AtomicU64, u64);
def_atom!(AtomicU32, u32);
def_atom!(AtomicU16, u16);
def_atom!(AtomicU8, u8);
#[cfg(test)]
mod test {
use super::*;
use core::convert::*;
#[test]
fn test_basic() {
let v: TearCell<[u8; 0]> = TearCell::new([]);
assert_eq!(v.load(), []);
v.store([]);
v.store_ref(&[]);
let v: TearCell<[usize; 0]> = TearCell::new([]);
assert_eq!(v.load(), []);
v.store([]);
v.store_ref(&[]);
let v: TearCell<[u8; 0]> = TearCell::new([]);
assert_eq!(v.load(), []);
v.store([]);
v.store_ref(&[]);
let v: TearCell<[u8; 1]> = TearCell::new([0u8; 1]);
assert_eq!(v.load(), [0]);
v.store([1]);
assert_eq!(v.load(), [1]);
v.store_ref(&[2]);
assert_eq!(v.load(), [2]);
}
macro_rules! test_arr {
($($n:expr),+) => {$(
do_test_arr::<u8, [u8; $n]>();
do_test_arr::<u16, [u16; $n]>();
do_test_arr::<u32, [u32; $n]>();
do_test_arr::<u64, [u64; $n]>();
do_test_arr::<usize, [usize; $n]>();
do_test_arr::<u128, [u128; $n]>();
)+};
}
#[test]
fn test_many() {
test_arr![2, 3, 4, 5, 6, 7, 8, 9, 10];
}
#[test]
fn test_overaligned() {
#[derive(Copy, Clone, PartialEq, Default, Debug)]
#[repr(C, align(16))]
struct Overaligned([u8; 16]);
unsafe impl bytemuck::Zeroable for Overaligned {}
unsafe impl bytemuck::Pod for Overaligned {}
impl AsRef<[u8]> for Overaligned {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
do_test::<u8, Overaligned>(16, move |s: &[u8]| Overaligned(s.try_into().unwrap()));
}
fn do_test_arr<T, Arr>()
where
T: Copy + From<u8> + Default + core::ops::Not<Output = T> + PartialEq + core::fmt::Debug,
Arr: Pod + AsRef<[T]> + PartialEq + core::fmt::Debug,
for<'a> Arr: TryFrom<&'a [T]>,
{
do_test(
core::mem::size_of::<Arr>() / core::mem::size_of::<T>(),
move |v: &[T]| -> Arr { v.try_into().ok().unwrap() },
)
}
fn do_test<T, Arr>(size: usize, make: fn(&[T]) -> Arr)
where
T: Copy + From<u8> + Default + core::ops::Not<Output = T> + PartialEq + core::fmt::Debug,
Arr: Pod + AsRef<[T]> + PartialEq + core::fmt::Debug,
{
let mut v0 = alloc::vec![Default::default(); size];
let a0: Arr = make(&v0);
let tc0: TearCell<Arr> = TearCell::new(a0);
assert_eq!(tc0.load(), a0);
for i in 0..size {
v0[i] = ((i + 1) as u8).into();
}
let a0: Arr = make(&v0);
let tc1 = TearCell::new(a0);
assert_eq!(&v0[..], tc1.load().as_ref());
tc0.store(a0);
assert_eq!(&v0[..], tc0.load().as_ref());
v0.reverse();
let ar0: Arr = make(&v0);
tc0.store_ref(&ar0);
assert_eq!(&v0[..], tc0.load().as_ref());
for i in 0..size {
v0[i] = !v0[i]
}
let a0: Arr = make(&v0);
tc0.store(a0);
assert_eq!(&v0[..], tc0.load().as_ref());
tc1.store_ref(&a0);
assert_eq!(&v0[..], tc0.load().as_ref());
}
}