mod borrow_state;
pub use borrow_state::{ExclusiveBorrow, SharedBorrow};
use crate::error;
use borrow_state::BorrowState;
use core::cell::UnsafeCell;
use core::marker::PhantomData;
use core::ops::{Deref, DerefMut};
#[cfg(feature = "thread_local")]
use std::thread::ThreadId;
#[doc(hidden)]
pub struct AtomicRefCell<T: ?Sized> {
borrow_state: BorrowState,
#[cfg(feature = "thread_local")]
send: Option<ThreadId>,
#[cfg(feature = "thread_local")]
is_sync: bool,
_non_send_sync: PhantomData<*const ()>,
inner: UnsafeCell<T>,
}
#[cfg(not(feature = "thread_local"))]
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: ?Sized> Send for AtomicRefCell<T> {}
unsafe impl<T: ?Sized> Sync for AtomicRefCell<T> {}
impl<T: Send + Sync> AtomicRefCell<T> {
#[inline]
pub(crate) fn new(value: T) -> Self {
AtomicRefCell {
borrow_state: BorrowState::new(),
#[cfg(feature = "thread_local")]
send: None,
#[cfg(feature = "thread_local")]
is_sync: true,
_non_send_sync: PhantomData,
inner: UnsafeCell::new(value),
}
}
}
#[cfg(feature = "thread_local")]
impl<T: Sync> AtomicRefCell<T> {
#[inline]
pub(crate) fn new_non_send(value: T, world_thread_id: ThreadId) -> Self {
AtomicRefCell {
borrow_state: BorrowState::new(),
send: Some(world_thread_id),
#[cfg(feature = "thread_local")]
is_sync: true,
_non_send_sync: PhantomData,
inner: UnsafeCell::new(value),
}
}
}
#[cfg(feature = "thread_local")]
impl<T: Send> AtomicRefCell<T> {
#[inline]
pub(crate) fn new_non_sync(value: T) -> Self {
AtomicRefCell {
borrow_state: BorrowState::new(),
#[cfg(feature = "thread_local")]
send: None,
is_sync: false,
_non_send_sync: PhantomData,
inner: UnsafeCell::new(value),
}
}
}
#[cfg(feature = "thread_local")]
impl<T> AtomicRefCell<T> {
#[inline]
pub(crate) fn new_non_send_sync(value: T, world_thread_id: ThreadId) -> Self {
AtomicRefCell {
borrow_state: BorrowState::new(),
send: Some(world_thread_id),
is_sync: false,
_non_send_sync: PhantomData,
inner: UnsafeCell::new(value),
}
}
}
impl<T> AtomicRefCell<T> {
#[inline]
pub(crate) fn into_inner(self) -> T {
self.inner.into_inner()
}
}
impl<T: ?Sized> AtomicRefCell<T> {
#[inline]
pub(crate) fn borrow(&self) -> Result<Ref<'_, &'_ T>, error::Borrow> {
#[cfg(not(feature = "thread_local"))]
{
match self.borrow_state.read() {
Ok(borrow) => Ok(Ref {
inner: unsafe { &*self.inner.get() },
borrow,
}),
Err(err) => Err(err),
}
}
#[cfg(feature = "thread_local")]
{
match (self.send, self.is_sync) {
(_, true) => {
match self.borrow_state.read() {
Ok(borrow) => Ok(Ref {
inner: unsafe { &*self.inner.get() },
borrow,
}),
Err(err) => Err(err),
}
}
(None, false) => {
match self.borrow_state.exclusive_read() {
Ok(borrow) => {
Ok(Ref {
inner: unsafe { &*self.inner.get() },
borrow,
})
}
Err(err) => Err(err),
}
}
(Some(thread_id), false) => {
if thread_id != std::thread::current().id() {
return Err(error::Borrow::WrongThread);
}
match self.borrow_state.read() {
Ok(borrow) => {
Ok(Ref {
inner: unsafe { &*self.inner.get() },
borrow,
})
}
Err(err) => Err(err),
}
}
}
}
}
#[inline]
pub(crate) fn borrow_mut(&self) -> Result<RefMut<'_, &'_ mut T>, error::Borrow> {
#[cfg(feature = "thread_local")]
{
if let Some(thread_id) = self.send {
if thread_id != std::thread::current().id() {
return Err(error::Borrow::WrongThread);
}
}
}
match self.borrow_state.write() {
Ok(borrow) => {
Ok(RefMut {
inner: unsafe { &mut *self.inner.get() },
borrow,
})
}
Err(err) => Err(err),
}
}
#[inline]
pub(crate) fn get_mut(&mut self) -> &'_ mut T {
self.inner.get_mut()
}
}
pub struct Ref<'a, T> {
inner: T,
borrow: SharedBorrow<'a>,
}
impl<'a, T> Ref<'a, T> {
#[inline]
pub unsafe fn destructure(this: Self) -> (T, SharedBorrow<'a>) {
(this.inner, this.borrow)
}
}
impl<'a, T: ?Sized> Ref<'a, &'a T> {
#[inline]
pub(crate) fn map<U, F: FnOnce(&T) -> &U>(this: Self, f: F) -> Ref<'a, &'a U> {
Ref {
inner: f(this.inner),
borrow: this.borrow,
}
}
}
impl<'a, T: Deref> Deref for Ref<'a, T> {
type Target = T::Target;
#[inline]
fn deref(&self) -> &T::Target {
self.inner.deref()
}
}
impl<T: Clone> Clone for Ref<'_, T> {
#[inline]
fn clone(&self) -> Self {
Ref {
inner: self.inner.clone(),
borrow: self.borrow.clone(),
}
}
}
pub struct RefMut<'a, T> {
inner: T,
borrow: ExclusiveBorrow<'a>,
}
impl<'a, T> RefMut<'a, T> {
#[inline]
pub unsafe fn destructure(this: Self) -> (T, ExclusiveBorrow<'a>) {
(this.inner, this.borrow)
}
}
impl<'a, T: ?Sized> RefMut<'a, &'a mut T> {
#[inline]
pub(crate) fn map<U, F: FnOnce(&mut T) -> &mut U>(this: Self, f: F) -> RefMut<'a, &'a mut U> {
RefMut {
inner: f(this.inner),
borrow: this.borrow,
}
}
}
impl<'a, T: Deref> Deref for RefMut<'a, T> {
type Target = T::Target;
#[inline]
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<'a, T: DerefMut> DerefMut for RefMut<'a, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.inner
}
}
#[test]
fn shared() {
let refcell = AtomicRefCell::new(0);
let first_borrow = refcell.borrow().unwrap();
assert!(refcell.borrow().is_ok());
assert_eq!(refcell.borrow_mut().err(), Some(error::Borrow::Shared));
drop(first_borrow);
assert!(refcell.borrow_mut().is_ok());
}
#[test]
fn exclusive() {
let refcell = AtomicRefCell::new(0);
let first_borrow = refcell.borrow_mut().unwrap();
assert_eq!(refcell.borrow().err(), Some(error::Borrow::Unique));
assert_eq!(refcell.borrow_mut().err(), Some(error::Borrow::Unique));
drop(first_borrow);
assert!(refcell.borrow_mut().is_ok());
}
#[cfg(all(feature = "std", not(feature = "thread_local")))]
#[test]
fn shared_thread() {
use alloc::sync::Arc;
let refcell = Arc::new(AtomicRefCell::new(0));
let refcell_clone = refcell.clone();
let first_borrow = refcell.borrow().unwrap();
std::thread::spawn(move || {
refcell_clone.borrow().unwrap();
assert_eq!(
refcell_clone.borrow_mut().err(),
Some(error::Borrow::Shared)
);
})
.join()
.unwrap();
drop(first_borrow);
assert!(refcell.borrow_mut().is_ok());
}
#[cfg(all(feature = "std", not(feature = "thread_local")))]
#[test]
fn exclusive_thread() {
use std::sync::Arc;
let refcell = Arc::new(AtomicRefCell::new(0));
let refcell_clone = refcell.clone();
std::thread::spawn(move || {
let _first_borrow = refcell_clone.borrow_mut();
assert_eq!(
refcell_clone.borrow_mut().err(),
Some(error::Borrow::Unique)
);
})
.join()
.unwrap();
refcell.borrow_mut().unwrap();
}
#[cfg(feature = "thread_local")]
#[test]
fn non_send() {
let refcell = AtomicRefCell::new_non_send(0u32, std::thread::current().id());
let refcell_ptr: *const _ = &refcell;
let refcell_ptr = refcell_ptr as usize;
std::thread::spawn(move || unsafe {
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow()
.unwrap();
assert_eq!(
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow_mut()
.err(),
Some(error::Borrow::WrongThread)
);
})
.join()
.unwrap();
refcell.borrow().unwrap();
refcell.borrow_mut().unwrap();
}
#[cfg(feature = "thread_local")]
#[test]
fn non_sync() {
let refcell = AtomicRefCell::new_non_sync(0);
let refcell_ptr: *const _ = &refcell;
let refcell_ptr = refcell_ptr as usize;
std::thread::spawn(move || unsafe {
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow()
.unwrap();
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow_mut()
.unwrap();
})
.join()
.unwrap();
refcell.borrow().unwrap();
refcell.borrow_mut().unwrap();
}
#[cfg(feature = "thread_local")]
#[test]
fn non_send_sync() {
let refcell = AtomicRefCell::new_non_send_sync(0u32, std::thread::current().id());
let refcell_ptr: *const _ = &refcell;
let refcell_ptr = refcell_ptr as usize;
std::thread::spawn(move || unsafe {
assert_eq!(
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow()
.err(),
Some(error::Borrow::WrongThread)
);
assert_eq!(
(&*(refcell_ptr as *const AtomicRefCell<u32>))
.borrow_mut()
.err(),
Some(error::Borrow::WrongThread)
);
})
.join()
.unwrap();
refcell.borrow().unwrap();
refcell.borrow_mut().unwrap();
}