use std::{
cell::{Cell, UnsafeCell},
fmt,
marker::PhantomData,
ops::{Deref, DerefMut},
ptr::NonNull,
};
use crate::vlib::{BarrierHeldMainRef, MainRef};
pub struct BarrierRwLock<T: ?Sized> {
readers: UnsafeCell<u32>,
writer: UnsafeCell<bool>,
data: UnsafeCell<T>,
}
impl<T> BarrierRwLock<T> {
#[inline]
pub const fn new(t: T) -> Self {
Self {
data: UnsafeCell::new(t),
readers: UnsafeCell::new(0),
writer: UnsafeCell::new(false),
}
}
}
impl<T: ?Sized> BarrierRwLock<T> {
#[inline(always)]
pub fn read(&self, vm: &MainRef) -> BarrierRwLockReadGuard<'_, T> {
let main_thread = vm.thread_index() == 0;
unsafe { BarrierRwLockReadGuard::new(self, main_thread) }
}
#[inline(always)]
pub fn write(&self, vm: &BarrierHeldMainRef) -> BarrierRwLockWriteGuard<'_, T> {
debug_assert_eq!(vm.thread_index(), 0);
unsafe { BarrierRwLockWriteGuard::new(self) }
}
pub fn get_mut(&mut self) -> &mut T {
self.data.get_mut()
}
pub const fn data_ptr(&self) -> *mut T {
self.data.get()
}
}
impl<T> BarrierRwLock<T> {
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
unsafe impl<T: ?Sized + Send> Send for BarrierRwLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for BarrierRwLock<T> {}
impl<T: Default> Default for BarrierRwLock<T> {
fn default() -> BarrierRwLock<T> {
BarrierRwLock::new(Default::default())
}
}
pub struct BarrierRwLockReadGuard<'rwlock, T: ?Sized + 'rwlock> {
data: NonNull<T>,
lock: &'rwlock BarrierRwLock<T>,
main_thread: bool,
}
unsafe impl<T: ?Sized + Sync> Sync for BarrierRwLockReadGuard<'_, T> {}
impl<'rwlock, T: ?Sized> BarrierRwLockReadGuard<'rwlock, T> {
#[inline(always)]
unsafe fn new(
lock: &'rwlock BarrierRwLock<T>,
main_thread: bool,
) -> BarrierRwLockReadGuard<'rwlock, T> {
if main_thread && unsafe { *lock.writer.get() } {
panic!("Write lock already taken by this thread");
}
let data = unsafe { NonNull::new_unchecked(lock.data.get()) };
if main_thread {
unsafe {
*lock.readers.get() += 1;
}
}
Self {
data,
lock,
main_thread,
}
}
}
impl<T: ?Sized> Drop for BarrierRwLockReadGuard<'_, T> {
#[inline(always)]
fn drop(&mut self) {
if self.main_thread {
unsafe {
*self.lock.readers.get() -= 1;
}
}
}
}
impl<T: ?Sized> Deref for BarrierRwLockReadGuard<'_, T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { self.data.as_ref() }
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for BarrierRwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for BarrierRwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
pub struct BarrierRwLockWriteGuard<'rwlock, T: ?Sized + 'rwlock> {
lock: &'rwlock BarrierRwLock<T>,
_phantom: PhantomData<Cell<()>>,
}
impl<'rwlock, T: ?Sized> BarrierRwLockWriteGuard<'rwlock, T> {
#[inline(always)]
unsafe fn new(lock: &'rwlock BarrierRwLock<T>) -> BarrierRwLockWriteGuard<'rwlock, T> {
unsafe {
if *lock.readers.get() != 0 {
panic!("Read lock already taken by this thread");
}
if *lock.writer.get() {
panic!("Write lock already taken by this thread");
}
*lock.writer.get() = true;
}
BarrierRwLockWriteGuard {
lock,
_phantom: PhantomData,
}
}
}
impl<T: ?Sized> Drop for BarrierRwLockWriteGuard<'_, T> {
#[inline(always)]
fn drop(&mut self) {
unsafe {
*self.lock.writer.get() = false;
}
}
}
unsafe impl<T: ?Sized + Sync> Sync for BarrierRwLockWriteGuard<'_, T> {}
impl<T: ?Sized> Deref for BarrierRwLockWriteGuard<'_, T> {
type Target = T;
#[inline(always)]
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: ?Sized> DerefMut for BarrierRwLockWriteGuard<'_, T> {
#[inline(always)]
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for BarrierRwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for BarrierRwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
#[cfg(test)]
mod tests {
use std::thread;
use crate::{
bindings::vlib_main_t,
vlib::{BarrierHeldMainRef, MainRef, main::sync::BarrierRwLock},
};
#[test]
fn concurrent_reads() {
let lock = BarrierRwLock::new("value".to_string());
let ref_lock = &lock;
thread::scope(|s| {
let thread1 = s.spawn(move || {
let mut main = vlib_main_t::default();
let main_ref = unsafe { MainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
for _ in 0..1000 {
assert_eq!(*ref_lock.read(main_ref), "value");
}
});
let thread2 = s.spawn(move || {
let mut main = vlib_main_t {
thread_index: 1,
..vlib_main_t::default()
};
let main_ref = unsafe { MainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
for _ in 0..1000 {
assert_eq!(*ref_lock.read(main_ref), "value");
}
});
thread1.join().unwrap();
thread2.join().unwrap();
});
}
#[test]
fn write_guard() {
let mut main = vlib_main_t::default();
let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
let lock = BarrierRwLock::new("value".to_string());
*lock.write(main_ref) = "new value".to_string();
assert_eq!(*lock.read(main_ref), "new value");
}
#[test]
#[should_panic(expected = "Write lock already taken by this thread")]
fn read_and_write1() {
let mut main = vlib_main_t::default();
let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
let lock = BarrierRwLock::new("value".to_string());
let _guard1 = lock.write(main_ref);
let _guard2 = lock.read(main_ref);
}
#[test]
#[should_panic(expected = "Read lock already taken by this thread")]
fn read_and_write2() {
let mut main = vlib_main_t::default();
let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
let lock = BarrierRwLock::new("value".to_string());
let _guard1 = lock.read(main_ref);
let _guard2 = lock.write(main_ref);
}
#[test]
#[should_panic(expected = "Write lock already taken by this thread")]
fn write_write() {
let mut main = vlib_main_t::default();
let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
let lock = BarrierRwLock::new("value".to_string());
let _guard1 = lock.write(main_ref);
let _guard2 = lock.write(main_ref);
}
#[test]
fn misc() {
let mut main = vlib_main_t::default();
let main_ref = unsafe { BarrierHeldMainRef::from_ptr_mut(std::ptr::addr_of_mut!(main)) };
let mut lock: BarrierRwLock<String> = BarrierRwLock::default();
assert_eq!(*lock.write(main_ref), "");
*lock.get_mut() = "value".to_string();
assert_eq!(lock.write(main_ref).to_string(), "value");
assert_eq!(format!("{:?}", lock.write(main_ref)), "\"value\"");
assert_eq!(lock.read(main_ref).to_string(), "value");
assert_eq!(format!("{:?}", lock.read(main_ref)), "\"value\"");
unsafe {
assert_eq!(&*lock.data_ptr(), "value");
}
assert_eq!(lock.into_inner(), "value");
}
}