use std::{
fmt,
error::Error,
sync::Arc,
ops::{Deref, DerefMut},
cell::{self, UnsafeCell, RefCell},
mem::{MaybeUninit, ManuallyDrop}
};
use parking_lot::{Once, ReentrantMutex, ReentrantMutexGuard};
#[derive(Debug)]
pub struct BorrowFail;
impl fmt::Display for BorrowFail {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "failed to borrow global value twice in same thread")
}
}
impl Error for BorrowFail {}
type InnerPointer<T> = Arc<ReentrantMutex<RefCell<T>>>;
pub struct Global<T>(Immutable<InnerPointer<T>>);
unsafe impl<T: Send> Sync for Global<T> {}
unsafe impl<T: Send> Send for Global<T> {}
impl<T> Global<T> {
pub const fn new() -> Self {
Self(Immutable::new())
}
}
impl<T: Default + 'static> Global<T> {
pub fn with<F, R>(&self, f: F) -> R
where
F: FnOnce(&T) -> R,
{
f(&*self.lock().expect("Couldn't immutably access global variable"))
}
pub fn with_mut<F, R>(&self, f: F) -> R
where
F: FnOnce(&mut T) -> R,
{
f(&mut *self.lock_mut().expect("Couldn't mutably access global variable"))
}
pub fn lock(&self) -> Result<GlobalGuard<T>, BorrowFail> {
let mutex: Arc<_> = Arc::clone(&*self.0);
let mutex_ptr = &*mutex as *const ReentrantMutex<RefCell<T>>;
let mutex_guard = unsafe { (*mutex_ptr).lock() };
let mutex_guard_ptr = &*mutex_guard as *const RefCell<T>;
let ref_cell_guard = unsafe {
(*mutex_guard_ptr)
.try_borrow()
.map_err(|_| BorrowFail)?
};
Ok(GlobalGuard {
mutex: ManuallyDrop::new(mutex),
mutex_guard: ManuallyDrop::new(mutex_guard),
ref_cell_guard: ManuallyDrop::new(ref_cell_guard),
})
}
pub fn lock_mut(&self) -> Result<GlobalGuardMut<T>, BorrowFail> {
let mutex: Arc<_> = Arc::clone(&*self.0);
let mutex_ptr = &*mutex as *const ReentrantMutex<RefCell<T>>;
let mutex_guard = unsafe { (*mutex_ptr).lock() };
let mutex_guard_ptr = &*mutex_guard as *const RefCell<T>;
let ref_cell_guard = unsafe {
(*mutex_guard_ptr)
.try_borrow_mut()
.map_err(|_| BorrowFail)?
};
Ok(GlobalGuardMut {
mutex: ManuallyDrop::new(mutex),
mutex_guard: ManuallyDrop::new(mutex_guard),
ref_cell_guard: ManuallyDrop::new(ref_cell_guard),
})
}
pub fn force_init(&self) {
self.0.ensure_exists();
}
}
pub struct GlobalGuardMut<T: 'static> {
mutex: ManuallyDrop<Arc<ReentrantMutex<RefCell<T>>>>,
mutex_guard: ManuallyDrop<ReentrantMutexGuard<'static, RefCell<T>>>,
ref_cell_guard: ManuallyDrop<cell::RefMut<'static, T>>,
}
impl<T: 'static> Drop for GlobalGuardMut<T> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.ref_cell_guard);
ManuallyDrop::drop(&mut self.mutex_guard);
ManuallyDrop::drop(&mut self.mutex);
}
}
}
impl<T: 'static> Deref for GlobalGuardMut<T> {
type Target = T;
fn deref(&self) -> &T {
&*self.ref_cell_guard
}
}
impl<T: 'static> DerefMut for GlobalGuardMut<T> {
fn deref_mut(&mut self) -> &mut T {
&mut *self.ref_cell_guard
}
}
pub struct GlobalGuard<T: 'static> {
mutex: ManuallyDrop<Arc<ReentrantMutex<RefCell<T>>>>,
mutex_guard: ManuallyDrop<ReentrantMutexGuard<'static, RefCell<T>>>,
ref_cell_guard: ManuallyDrop<cell::Ref<'static, T>>,
}
impl<T: 'static> Drop for GlobalGuard<T> {
fn drop(&mut self) {
unsafe {
ManuallyDrop::drop(&mut self.ref_cell_guard);
ManuallyDrop::drop(&mut self.mutex_guard);
ManuallyDrop::drop(&mut self.mutex);
}
}
}
impl<T: 'static> Deref for GlobalGuard<T> {
type Target = T;
fn deref(&self) -> &T {
&*self.ref_cell_guard
}
}
pub struct Immutable<T> {
once: Once,
inner: UnsafeCell<MaybeUninit<T>>,
}
impl<T> Drop for Immutable<T> {
fn drop(&mut self) {
if let parking_lot::OnceState::Done = self.once.state() {
drop(unsafe {
std::ptr::drop_in_place((*self.inner.get()).as_mut_ptr());
});
}
}
}
unsafe impl<T: Send> Send for Immutable<T> {}
unsafe impl<T: Sync> Sync for Immutable<T> {}
impl<T: Default> Immutable<T> {
fn ensure_exists(&self) {
self.once.call_once(|| {
unsafe {
*self.inner.get() = MaybeUninit::new(T::default());
}
});
}
pub fn force_init(&self) {
self.ensure_exists();
}
}
impl<T> Immutable<T> {
pub const fn new() -> Self {
Self {
once: Once::new(),
inner: UnsafeCell::new(MaybeUninit::uninit()),
}
}
}
impl<T: Default> Deref for Immutable<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.ensure_exists();
unsafe {
&*(*self.inner.get()).as_ptr()
}
}
}
#[cfg(test)]
mod test {
use std::{
thread,
sync::mpsc,
time::Duration,
};
use super::{Global, Immutable};
#[test]
fn no_race_condition() {
static NUM: Global<i32> = Global::new();
let mut v = Vec::new();
for _ in 0..1000 {
v.push(thread::spawn(|| {
for _ in 0..100 {
*NUM.lock_mut().unwrap() += 1;
}
}));
}
for thread in v {
thread.join().unwrap();
}
assert_eq!(*NUM.lock().unwrap(), 100_000);
}
#[test]
fn no_race_extended_lock() {
static NUM: Global<i32> = Global::new();
let (tx, rx) = mpsc::channel();
let t1 = thread::spawn(move || {
let mut lock = NUM.lock_mut().unwrap();
tx.send(()).unwrap();
thread::sleep(Duration::new(0, 1_000_000));
*lock += 1;
});
let t2 = thread::spawn(move || {
let () = rx.recv().unwrap();
*NUM.lock_mut().unwrap() += 1;
});
t1.join().unwrap();
t2.join().unwrap();
assert_eq!(*NUM.lock().unwrap(), 2);
}
#[test]
#[should_panic]
fn borrow_immutably_while_mutably_borrowed() {
static NUM: Global<i32> = Global::new();
let _x = NUM.lock_mut().unwrap();
let _y = NUM.lock().unwrap();
}
#[test]
#[should_panic]
fn borrow_mutably_while_mutably_borrowed() {
static NUM: Global<i32> = Global::new();
let _x = NUM.lock_mut().unwrap();
let _y = NUM.lock_mut().unwrap();
}
#[test]
#[should_panic]
fn borrow_mutably_while_immutably_borrowed() {
static NUM: Global<i32> = Global::new();
let _x = NUM.lock().unwrap();
let _y = NUM.lock_mut().unwrap();
}
#[test]
fn borrow_immutably_while_immutably_borrowed() {
static NUM: Global<i32> = Global::new();
let _x = NUM.lock().unwrap();
let _y = NUM.lock().unwrap();
}
#[test]
fn complex_thread_interactions() {
static NUM: Global<i32> = Global::new();
let lock1 = NUM.lock().unwrap();
let lock2 = NUM.lock().unwrap();
let lock3 = NUM.lock().unwrap();
let t = thread::spawn(|| {
*NUM.lock_mut().unwrap() += 1;
assert!(NUM.lock().is_ok());
});
thread::sleep(Duration::from_millis(100));
assert!(NUM.lock_mut().is_err());
drop(lock1);
drop(lock2);
drop(lock3);
*NUM.lock_mut().unwrap() += 1;
t.join().unwrap();
assert_eq!(2, *NUM.lock().unwrap());
}
#[test]
fn ensure_drop() {
static mut COUNTER: u32 = 0;
#[derive(Default)]
struct Increase;
impl Drop for Increase {
fn drop(&mut self) {
unsafe { COUNTER += 1; }
}
}
let immutable = Immutable::<Increase>::new();
drop(immutable);
assert_eq!(unsafe { COUNTER }, 0);
let global = Global::<Increase>::new();
drop(global);
assert_eq!(unsafe { COUNTER }, 0);
let immutable = Immutable::<Increase>::new();
immutable.force_init();
drop(immutable);
assert_eq!(unsafe { COUNTER }, 1);
let global = Global::<Increase>::new();
global.force_init();
drop(global);
assert_eq!(unsafe { COUNTER }, 2);
}
}