use crate::{
loom::cell::{ConstPtr, MutPtr, UnsafeCell},
spin::RwSpinlock,
util::fmt,
};
use core::{
marker::PhantomData,
ops::{Deref, DerefMut},
};
pub struct RwLock<T: ?Sized, Lock = RwSpinlock> {
lock: Lock,
data: UnsafeCell<T>,
}
#[must_use = "if unused, the `RwLock` will immediately unlock"]
pub struct RwLockReadGuard<'lock, T: ?Sized, Lock: RawRwLock = RwSpinlock> {
ptr: ConstPtr<T>,
lock: &'lock Lock,
_marker: PhantomData<Lock::GuardMarker>,
}
#[must_use = "if unused, the `RwLock` will immediately unlock"]
pub struct RwLockWriteGuard<'lock, T: ?Sized, Lock: RawRwLock = RwSpinlock> {
ptr: MutPtr<T>,
lock: &'lock Lock,
_marker: PhantomData<Lock::GuardMarker>,
}
pub unsafe trait RawRwLock {
type GuardMarker;
fn lock_shared(&self);
fn try_lock_shared(&self) -> bool;
unsafe fn unlock_shared(&self);
fn lock_exclusive(&self);
fn try_lock_exclusive(&self) -> bool;
unsafe fn unlock_exclusive(&self);
fn is_locked(&self) -> bool;
fn is_locked_exclusive(&self) -> bool;
}
impl<T> RwLock<T> {
loom_const_fn! {
#[must_use]
pub fn new(data: T) -> Self {
Self {
lock: RwSpinlock::new(),
data: UnsafeCell::new(data),
}
}
}
#[inline]
#[must_use]
pub fn reader_count(&self) -> usize {
self.lock.reader_count()
}
}
impl<T: ?Sized, Lock: RawRwLock> RwLock<T, Lock> {
fn read_guard(&self) -> RwLockReadGuard<'_, T, Lock> {
RwLockReadGuard {
ptr: self.data.get(),
lock: &self.lock,
_marker: PhantomData,
}
}
fn write_guard(&self) -> RwLockWriteGuard<'_, T, Lock> {
RwLockWriteGuard {
ptr: self.data.get_mut(),
lock: &self.lock,
_marker: PhantomData,
}
}
#[cfg_attr(test, track_caller)]
pub fn read(&self) -> RwLockReadGuard<'_, T, Lock> {
self.lock.lock_shared();
self.read_guard()
}
#[cfg_attr(test, track_caller)]
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T, Lock>> {
if self.lock.try_lock_shared() {
Some(self.read_guard())
} else {
None
}
}
#[cfg_attr(test, track_caller)]
pub fn write(&self) -> RwLockWriteGuard<'_, T, Lock> {
self.lock.lock_exclusive();
self.write_guard()
}
#[inline]
#[must_use]
pub fn has_writer(&self) -> bool {
self.lock.is_locked_exclusive()
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T, Lock>> {
if self.lock.try_lock_exclusive() {
Some(self.write_guard())
} else {
None
}
}
pub fn get_mut(&mut self) -> &mut T {
unsafe {
self.data.with_mut(|data| &mut *data)
}
}
}
impl<T, Lock: RawRwLock> RwLock<T, Lock> {
#[inline]
#[must_use]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: Default, Lock: Default> Default for RwLock<T, Lock> {
fn default() -> RwLock<T, Lock> {
RwLock {
data: UnsafeCell::new(Default::default()),
lock: Default::default(),
}
}
}
impl<T> From<T> for RwLock<T> {
fn from(t: T) -> Self {
RwLock::new(t)
}
}
impl<T, Lock> fmt::Debug for RwLock<T, Lock>
where
T: fmt::Debug,
Lock: fmt::Debug + RawRwLock,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RwLock")
.field(
"data",
&fmt::opt(&self.try_read()).or_else("<write locked>"),
)
.field("lock", &self.lock)
.finish()
}
}
unsafe impl<T: ?Sized + Send, Lock: Send> Send for RwLock<T, Lock> {}
unsafe impl<T: ?Sized + Send + Sync, Lock: Sync> Sync for RwLock<T, Lock> {}
impl<T: ?Sized, Lock: RawRwLock> Deref for RwLockReadGuard<'_, T, Lock> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
self.ptr.deref()
}
}
}
impl<T: ?Sized, R: ?Sized, Lock> AsRef<R> for RwLockReadGuard<'_, T, Lock>
where
T: AsRef<R>,
Lock: RawRwLock,
{
#[inline]
fn as_ref(&self) -> &R {
self.deref().as_ref()
}
}
impl<T: ?Sized, Lock: RawRwLock> Drop for RwLockReadGuard<'_, T, Lock> {
#[inline]
#[cfg_attr(test, track_caller)]
fn drop(&mut self) {
unsafe { self.lock.unlock_shared() }
}
}
impl<T, Lock> fmt::Debug for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + fmt::Debug,
Lock: RawRwLock,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
impl<T, Lock> fmt::Display for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + fmt::Display,
Lock: RawRwLock,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, Lock> Sync for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + Sync,
Lock: RawRwLock + Sync,
{
}
unsafe impl<T, Lock> Send for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + Sync,
Lock: RawRwLock + Sync,
Lock::GuardMarker: Send,
{
}
impl<T: ?Sized, Lock: RawRwLock> Deref for RwLockWriteGuard<'_, T, Lock> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
&*self.ptr.deref()
}
}
}
impl<T: ?Sized, Lock: RawRwLock> DerefMut for RwLockWriteGuard<'_, T, Lock> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
self.ptr.deref()
}
}
}
impl<T: ?Sized, R: ?Sized, Lock> AsRef<R> for RwLockWriteGuard<'_, T, Lock>
where
T: AsRef<R>,
Lock: RawRwLock,
{
#[inline]
fn as_ref(&self) -> &R {
self.deref().as_ref()
}
}
impl<T: ?Sized, Lock: RawRwLock> Drop for RwLockWriteGuard<'_, T, Lock> {
#[inline]
#[cfg_attr(test, track_caller)]
fn drop(&mut self) {
unsafe { self.lock.unlock_exclusive() }
}
}
impl<T, Lock> fmt::Debug for RwLockWriteGuard<'_, T, Lock>
where
T: ?Sized + fmt::Debug,
Lock: RawRwLock,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
impl<T, Lock> fmt::Display for RwLockWriteGuard<'_, T, Lock>
where
T: ?Sized + fmt::Display,
Lock: RawRwLock,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, Lock> Send for RwLockWriteGuard<'_, T, Lock>
where
T: ?Sized + Send + Sync,
Lock: RawRwLock,
Lock::GuardMarker: Send,
{
}
unsafe impl<T, Lock> Sync for RwLockWriteGuard<'_, T, Lock>
where
T: ?Sized + Send + Sync,
Lock: RawRwLock,
{
}
#[cfg(test)]
mod tests {
use super::*;
use crate::loom::{self, sync::Arc, thread};
#[test]
fn write() {
const WRITERS: usize = 2;
loom::model(|| {
let lock = Arc::new(RwLock::<usize>::new(0));
let threads = (0..WRITERS)
.map(|_| {
let lock = lock.clone();
thread::spawn(writer(lock))
})
.collect::<Vec<_>>();
for thread in threads {
thread.join().expect("writer thread mustn't panic");
}
let guard = lock.read();
assert_eq!(*guard, WRITERS, "final state must equal number of writers");
});
}
#[test]
fn read_write() {
const WRITERS: usize = if cfg!(loom) { 1 } else { 2 };
loom::model(|| {
let lock = Arc::new(RwLock::<usize>::new(0));
let w_threads = (0..WRITERS)
.map(|_| {
let lock = lock.clone();
thread::spawn(writer(lock))
})
.collect::<Vec<_>>();
{
let guard = lock.read();
assert!(*guard == 0 || *guard == 1 || *guard == 2);
}
for thread in w_threads {
thread.join().expect("writer thread mustn't panic")
}
let guard = lock.read();
assert_eq!(*guard, WRITERS, "final state must equal number of writers");
});
}
fn writer(lock: Arc<RwLock<usize>>) -> impl FnOnce() {
move || {
test_debug!("trying to acquire write lock...");
let mut guard = lock.write();
test_debug!("got write lock!");
*guard += 1;
}
}
}