use std::any::TypeId;
use std::cell::UnsafeCell;
use std::collections::HashSet;
use std::marker::PhantomData;
use super::Invariant;
std::thread_local! {
static SINGLETON_CHECK: std::cell::RefCell<HashSet<TypeId>> = std::cell::RefCell::new(HashSet::new());
}
#[allow(dead_code)]
struct NotSendOrSync(*const ());
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub struct TLCellOwner<Q: 'static> {
not_send_or_sync: PhantomData<NotSendOrSync>,
typ: PhantomData<Invariant<Q>>,
}
impl<Q: 'static> Drop for TLCellOwner<Q> {
fn drop(&mut self) {
SINGLETON_CHECK.with(|set| set.borrow_mut().remove(&TypeId::of::<Q>()));
}
}
impl<Q: 'static> Default for TLCellOwner<Q> {
fn default() -> Self {
TLCellOwner::new()
}
}
impl<Q: 'static> TLCellOwner<Q> {
pub fn new() -> Self {
if let Some(owner) = Self::try_new() {
owner
} else {
panic!("Illegal to create two TLCellOwner instances within the same thread with the same marker type parameter");
}
}
pub fn try_new() -> Option<Self> {
if SINGLETON_CHECK.with(|set| set.borrow_mut().insert(TypeId::of::<Q>())) {
Some(Self {
not_send_or_sync: PhantomData,
typ: PhantomData,
})
} else {
None
}
}
pub fn cell<T>(&self, value: T) -> TLCell<Q, T> {
TLCell::<Q, T>::new(value)
}
#[inline]
pub fn ro<'a, T: ?Sized>(&'a self, tc: &'a TLCell<Q, T>) -> &'a T {
unsafe { &*tc.value.get() }
}
#[inline]
pub fn rw<'a, T: ?Sized>(&'a mut self, tc: &'a TLCell<Q, T>) -> &'a mut T {
unsafe { &mut *tc.value.get() }
}
#[inline]
pub fn rw2<'a, T: ?Sized, U: ?Sized>(
&'a mut self,
tc1: &'a TLCell<Q, T>,
tc2: &'a TLCell<Q, U>,
) -> (&'a mut T, &'a mut U) {
assert!(
!core::ptr::eq(tc1 as *const _ as *const (), tc2 as *const _ as *const ()),
"Illegal to borrow same TLCell twice with rw2()"
);
unsafe { (&mut *tc1.value.get(), &mut *tc2.value.get()) }
}
#[inline]
pub fn rw3<'a, T: ?Sized, U: ?Sized, V: ?Sized>(
&'a mut self,
tc1: &'a TLCell<Q, T>,
tc2: &'a TLCell<Q, U>,
tc3: &'a TLCell<Q, V>,
) -> (&'a mut T, &'a mut U, &'a mut V) {
assert!(
!core::ptr::eq(tc1 as *const _ as *const (), tc2 as *const _ as *const ())
&& !core::ptr::eq(tc2 as *const _ as *const (), tc3 as *const _ as *const ())
&& !core::ptr::eq(tc3 as *const _ as *const (), tc1 as *const _ as *const ()),
"Illegal to borrow same TLCell twice with rw3()"
);
unsafe {
(
&mut *tc1.value.get(),
&mut *tc2.value.get(),
&mut *tc3.value.get(),
)
}
}
}
#[repr(transparent)]
#[cfg_attr(docsrs, doc(cfg(feature = "std")))]
pub struct TLCell<Q, T: ?Sized> {
owner: PhantomData<Invariant<Q>>,
value: UnsafeCell<T>,
}
impl<Q, T> TLCell<Q, T> {
#[inline]
pub const fn new(value: T) -> TLCell<Q, T> {
TLCell {
owner: PhantomData,
value: UnsafeCell::new(value),
}
}
#[inline]
pub fn into_inner(self) -> T {
self.value.into_inner()
}
}
impl<Q, T: ?Sized> TLCell<Q, T> {
#[inline]
pub fn ro<'a>(&'a self, owner: &'a TLCellOwner<Q>) -> &'a T {
owner.ro(self)
}
#[inline]
pub fn rw<'a>(&'a self, owner: &'a mut TLCellOwner<Q>) -> &'a mut T {
owner.rw(self)
}
#[inline]
pub fn get_mut(&mut self) -> &mut T {
self.value.get_mut()
}
}
impl<Q: 'static, T: Default> Default for TLCell<Q, T> {
fn default() -> Self {
TLCell::new(T::default())
}
}
#[cfg(test)]
mod tests {
use super::{TLCell, TLCellOwner};
#[test]
#[should_panic]
fn tlcell_singleton_1() {
struct Marker;
let _owner1 = TLCellOwner::<Marker>::new();
let _owner2 = TLCellOwner::<Marker>::new(); }
#[test]
fn tlcell_singleton_2() {
struct Marker;
let owner1 = TLCellOwner::<Marker>::new();
drop(owner1);
let _owner2 = TLCellOwner::<Marker>::new();
}
#[test]
fn tlcell_singleton_3() {
struct Marker1;
struct Marker2;
let _owner1 = TLCellOwner::<Marker1>::new();
let _owner2 = TLCellOwner::<Marker2>::new();
}
#[test]
fn tlcell_singleton_try_new() {
struct Marker;
let owner1 = TLCellOwner::<Marker>::try_new();
assert!(owner1.is_some());
let owner2 = TLCellOwner::<Marker>::try_new();
assert!(owner2.is_none());
}
#[test]
fn tlcell() {
struct Marker;
type ACellOwner = TLCellOwner<Marker>;
type ACell<T> = TLCell<Marker, T>;
let mut owner = ACellOwner::new();
let c1 = ACell::new(100u32);
let c2 = owner.cell(200u32);
(*owner.rw(&c1)) += 1;
(*owner.rw(&c2)) += 2;
let c1ref = owner.ro(&c1);
let c2ref = owner.ro(&c2);
let total = *c1ref + *c2ref;
assert_eq!(total, 303);
}
#[test]
fn tlcell_threads() {
struct Marker;
type ACellOwner = TLCellOwner<Marker>;
let mut _owner1 = ACellOwner::new();
std::thread::spawn(|| {
let mut _owner2 = ACellOwner::new();
})
.join()
.unwrap();
}
#[test]
fn tlcell_get_mut() {
struct Marker;
type ACellOwner = TLCellOwner<Marker>;
type ACell<T> = TLCell<Marker, T>;
let owner = ACellOwner::new();
let mut cell = ACell::new(100u32);
let mut_ref = cell.get_mut();
*mut_ref = 50;
let cell_ref = owner.ro(&cell);
assert_eq!(*cell_ref, 50);
}
#[test]
fn tlcell_into_inner() {
struct Marker;
type ACell<T> = TLCell<Marker, T>;
let cell = ACell::new(100u32);
assert_eq!(cell.into_inner(), 100);
}
#[test]
fn tlcell_unsized() {
struct Marker;
type ACellOwner = TLCellOwner<Marker>;
type ACell<T> = TLCell<Marker, T>;
let mut owner = ACellOwner::new();
struct Squares(u32);
struct Integers(u64);
trait Series {
fn step(&mut self);
fn value(&self) -> u64;
}
impl Series for Squares {
fn step(&mut self) {
self.0 += 1;
}
fn value(&self) -> u64 {
(self.0 as u64) * (self.0 as u64)
}
}
impl Series for Integers {
fn step(&mut self) {
self.0 += 1;
}
fn value(&self) -> u64 {
self.0
}
}
fn series(init: u32, is_squares: bool) -> Box<ACell<dyn Series>> {
if is_squares {
Box::new(ACell::new(Squares(init)))
} else {
Box::new(ACell::new(Integers(init as u64)))
}
}
let own = &mut owner;
let cell1 = series(4, false);
let cell2 = series(7, true);
let cell3 = series(3, true);
assert_eq!(cell1.ro(own).value(), 4);
cell1.rw(own).step();
assert_eq!(cell1.ro(own).value(), 5);
assert_eq!(own.ro(&cell2).value(), 49);
own.rw(&cell2).step();
assert_eq!(own.ro(&cell2).value(), 64);
let (r1, r2, r3) = own.rw3(&cell1, &cell2, &cell3);
r1.step();
r2.step();
r3.step();
assert_eq!(cell1.ro(own).value(), 6);
assert_eq!(cell2.ro(own).value(), 81);
assert_eq!(cell3.ro(own).value(), 16);
let (r1, r2) = own.rw2(&cell1, &cell2);
r1.step();
r2.step();
assert_eq!(cell1.ro(own).value(), 7);
assert_eq!(cell2.ro(own).value(), 100);
}
}