use super::semaphore::{self, Semaphore};
use crate::{
blocking::RawMutex,
loom::cell::{self, UnsafeCell},
spin::Spinlock,
util::fmt,
};
use core::ops::{Deref, DerefMut};
#[cfg(test)]
mod tests;
pub struct RwLock<T: ?Sized, Lock: RawMutex = Spinlock> {
sem: Semaphore<Lock>,
data: UnsafeCell<T>,
}
#[must_use = "if unused, the `RwLock` will immediately unlock"]
pub struct RwLockReadGuard<'lock, T: ?Sized, Lock: RawMutex = Spinlock> {
data: cell::ConstPtr<T>,
_permit: semaphore::Permit<'lock, Lock>,
}
#[must_use = "if unused, the `RwLock` will immediately unlock"]
pub struct RwLockWriteGuard<'lock, T: ?Sized, Lock: RawMutex = Spinlock> {
data: cell::MutPtr<T>,
_permit: semaphore::Permit<'lock, Lock>,
}
feature! {
#![feature = "alloc"]
mod owned;
pub use self::owned::{OwnedRwLockReadGuard, OwnedRwLockWriteGuard};
}
impl<T> RwLock<T> {
loom_const_fn! {
#[must_use]
pub fn new(data: T) -> Self {
Self::new_with_raw_mutex(data, Spinlock::new())
}
}
}
impl<T, Lock: RawMutex> RwLock<T, Lock> {
loom_const_fn! {
#[must_use]
pub fn new_with_raw_mutex(data: T, lock: Lock) -> Self {
Self {
sem: Semaphore::new_with_raw_mutex(Self::MAX_READERS, lock),
data: UnsafeCell::new(data),
}
}
}
#[inline]
#[must_use]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
}
impl<T: ?Sized, Lock: RawMutex> RwLock<T, Lock> {
const MAX_READERS: usize = semaphore::MAX_PERMITS;
pub async fn read(&self) -> RwLockReadGuard<'_, T, Lock> {
let _permit = self
.sem
.acquire(1)
.await
.expect("RwLock semaphore should never be closed");
RwLockReadGuard {
data: self.data.get(),
_permit,
}
}
pub async fn write(&self) -> RwLockWriteGuard<'_, T, Lock> {
let _permit = self
.sem
.acquire(Self::MAX_READERS)
.await
.expect("RwLock semaphore should never be closed");
RwLockWriteGuard {
data: self.data.get_mut(),
_permit,
}
}
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T, Lock>> {
match self.sem.try_acquire(1) {
Ok(_permit) => Some(RwLockReadGuard {
data: self.data.get(),
_permit,
}),
Err(semaphore::TryAcquireError::InsufficientPermits) => None,
Err(semaphore::TryAcquireError::Closed) => {
unreachable!("RwLock semaphore should never be closed")
}
}
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T, Lock>> {
match self.sem.try_acquire(Self::MAX_READERS) {
Ok(_permit) => Some(RwLockWriteGuard {
data: self.data.get_mut(),
_permit,
}),
Err(semaphore::TryAcquireError::InsufficientPermits) => None,
Err(semaphore::TryAcquireError::Closed) => {
unreachable!("RwLock semaphore should never be closed")
}
}
}
pub fn get_mut(&mut self) -> &mut T {
unsafe {
self.data.with_mut(|data| &mut *data)
}
}
}
impl<T: Default> Default for RwLock<T> {
fn default() -> Self {
Self::new(T::default())
}
}
impl<T, Lock> fmt::Debug for RwLock<T, Lock>
where
T: ?Sized + fmt::Debug,
Lock: RawMutex + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let Self { sem, data: _ } = self;
f.debug_struct("RwLock")
.field("sem", sem)
.field("data", &fmt::opt(&self.try_read()).or_else("<locked>"))
.finish()
}
}
unsafe impl<T, Lock> Send for RwLock<T, Lock>
where
T: ?Sized + Send,
Lock: RawMutex + Send,
{
}
unsafe impl<T, Lock> Sync for RwLock<T, Lock>
where
T: ?Sized + Send + Sync,
Lock: RawMutex + Sync,
{
}
impl<T: ?Sized, Lock: RawMutex> Deref for RwLockReadGuard<'_, T, Lock> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
self.data.deref()
}
}
}
impl<T: ?Sized + fmt::Debug, Lock: RawMutex> fmt::Debug for RwLockReadGuard<'_, T, Lock> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T, Lock> Send for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + Sync,
Lock: RawMutex + Sync,
{
}
unsafe impl<T, Lock> Sync for RwLockReadGuard<'_, T, Lock>
where
T: ?Sized + Send + Sync,
Lock: RawMutex + Sync,
{
}
impl<T: ?Sized, Lock: RawMutex> Deref for RwLockWriteGuard<'_, T, Lock> {
type Target = T;
#[inline]
fn deref(&self) -> &Self::Target {
unsafe {
self.data.deref()
}
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
#[inline]
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
self.data.deref()
}
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.deref().fmt(f)
}
}
unsafe impl<T> Send for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}
unsafe impl<T> Sync for RwLockWriteGuard<'_, T> where T: ?Sized + Send + Sync {}