#![warn(
clippy::pedantic,
rust_2018_idioms,
missing_docs,
unused_qualifications
)]
#![cfg_attr(not(test), no_std)]
use ::core::{
cell::UnsafeCell,
fmt::{self, Debug, Display, Formatter},
marker::PhantomData,
mem::ManuallyDrop,
ops::{Deref, DerefMut},
ptr::NonNull,
sync::atomic::{self, AtomicUsize},
};
#[derive(Default)]
pub struct TryRwLock<T> {
readers: AtomicUsize,
data: UnsafeCell<T>,
}
impl<T> TryRwLock<T> {
#[must_use]
pub const fn new(data: T) -> Self {
Self {
readers: AtomicUsize::new(0),
data: UnsafeCell::new(data),
}
}
pub fn try_read(&self) -> Option<ReadGuard<'_, T>> {
self.readers
.fetch_update(
atomic::Ordering::Acquire,
atomic::Ordering::Relaxed,
|readers| readers.checked_add(1),
)
.ok()
.map(|_| unsafe { ReadGuard::new(self) })
}
pub fn try_write(&self) -> Option<WriteGuard<'_, T>> {
self.readers
.compare_exchange(
0,
usize::MAX,
atomic::Ordering::Acquire,
atomic::Ordering::Relaxed,
)
.ok()
.map(|_| unsafe { WriteGuard::new(self) })
}
#[must_use]
pub fn into_inner(self) -> T {
self.data.into_inner()
}
#[must_use]
pub fn get_mut(&mut self) -> &mut T {
self.data.get_mut()
}
#[must_use]
pub fn is_locked(&self) -> bool {
self.readers.load(atomic::Ordering::Acquire) != 0
}
#[must_use]
pub fn is_write_locked(&self) -> bool {
self.readers.load(atomic::Ordering::Acquire) == usize::MAX
}
}
impl<T: Debug> Debug for TryRwLock<T> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
#[allow(clippy::option_if_let_else)]
if let Some(guard) = self.try_read() {
f.debug_struct("TryRwLock").field("data", &*guard).finish()
} else {
struct LockedPlaceholder;
impl Debug for LockedPlaceholder {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("TryRwLock")
.field("data", &LockedPlaceholder)
.finish()
}
}
}
impl<T> From<T> for TryRwLock<T> {
fn from(data: T) -> Self {
Self::new(data)
}
}
unsafe impl<T: Send> Send for TryRwLock<T> {}
unsafe impl<T: Send + Sync> Sync for TryRwLock<T> {}
#[must_use = "if unused the TryRwLock will immediately unlock"]
pub struct ReadGuard<'lock, T, U = T> {
data: NonNull<U>,
lock: &'lock TryRwLock<T>,
}
unsafe impl<T: Sync, U: Sync> Send for ReadGuard<'_, T, U> {}
unsafe impl<T: Sync, U: Sync> Sync for ReadGuard<'_, T, U> {}
impl<'lock, T> ReadGuard<'lock, T> {
unsafe fn new(lock: &'lock TryRwLock<T>) -> Self {
Self {
data: NonNull::new(lock.data.get()).expect("`UnsafeCell::get` never returns null"),
lock,
}
}
}
impl<'lock, T, U> ReadGuard<'lock, T, U> {
#[must_use]
pub fn rwlock(guard: &Self) -> &'lock TryRwLock<T> {
guard.lock
}
pub fn try_upgrade(guard: Self) -> Result<WriteGuard<'lock, T>, Self> {
match guard.lock.readers.compare_exchange(
1,
usize::MAX,
atomic::Ordering::Acquire,
atomic::Ordering::Relaxed,
) {
Ok(_) => {
let guard = ManuallyDrop::new(guard);
Ok(unsafe { WriteGuard::new(guard.lock) })
}
Err(_) => Err(guard),
}
}
pub fn map<V>(guard: Self, f: impl FnOnce(&U) -> &V) -> ReadGuard<'lock, T, V> {
let guard = ManuallyDrop::new(guard);
ReadGuard {
data: NonNull::from(f(&**guard)),
lock: guard.lock,
}
}
pub fn unmap(guard: Self) -> ReadGuard<'lock, T> {
let guard = ManuallyDrop::new(guard);
unsafe { ReadGuard::new(guard.lock) }
}
}
impl<T, U> Deref for ReadGuard<'_, T, U> {
type Target = U;
fn deref(&self) -> &Self::Target {
unsafe { self.data.as_ref() }
}
}
impl<T, U> Drop for ReadGuard<'_, T, U> {
fn drop(&mut self) {
self.lock.readers.fetch_sub(1, atomic::Ordering::Release);
}
}
impl<T, U: Debug> Debug for ReadGuard<'_, T, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TryRwLockReadGuard")
.field("data", &**self)
.finish()
}
}
impl<T, U: Display> Display for ReadGuard<'_, T, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&**self, f)
}
}
#[must_use = "if unused the TryRwLock will immediately unlock"]
pub struct WriteGuard<'lock, T, U = T> {
data: NonNull<U>,
lock: &'lock TryRwLock<T>,
_invariant_over_u: PhantomData<&'lock mut U>,
}
unsafe impl<T, U: Send> Send for WriteGuard<'_, T, U> {}
unsafe impl<T, U: Sync> Sync for WriteGuard<'_, T, U> {}
impl<'lock, T> WriteGuard<'lock, T> {
unsafe fn new(lock: &'lock TryRwLock<T>) -> Self {
Self {
data: NonNull::new(lock.data.get()).expect("`UnsafeCell::get` never returns null"),
lock,
_invariant_over_u: PhantomData,
}
}
}
impl<'lock, T, U> WriteGuard<'lock, T, U> {
#[must_use]
pub fn rwlock(guard: &Self) -> &'lock TryRwLock<T> {
guard.lock
}
pub fn downgrade(guard: Self) -> ReadGuard<'lock, T> {
let guard = ManuallyDrop::new(guard);
guard.lock.readers.store(1, atomic::Ordering::Release);
unsafe { ReadGuard::new(guard.lock) }
}
pub fn map<V>(guard: Self, f: impl FnOnce(&mut U) -> &mut V) -> WriteGuard<'lock, T, V> {
let mut guard = ManuallyDrop::new(guard);
WriteGuard {
data: NonNull::from(f(&mut **guard)),
lock: guard.lock,
_invariant_over_u: PhantomData,
}
}
pub fn unmap(guard: Self) -> WriteGuard<'lock, T> {
let guard = ManuallyDrop::new(guard);
unsafe { WriteGuard::new(guard.lock) }
}
}
impl<T, U> Deref for WriteGuard<'_, T, U> {
type Target = U;
fn deref(&self) -> &Self::Target {
unsafe { self.data.as_ref() }
}
}
impl<T, U> DerefMut for WriteGuard<'_, T, U> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { self.data.as_mut() }
}
}
impl<T, U> Drop for WriteGuard<'_, T, U> {
fn drop(&mut self) {
self.lock.readers.store(0, atomic::Ordering::Release);
}
}
impl<T, U: Debug> Debug for WriteGuard<'_, T, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
f.debug_struct("TryRwLockWriteGuard")
.field("data", &**self)
.finish()
}
}
impl<T, U: Display> Display for WriteGuard<'_, T, U> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
Display::fmt(&**self, f)
}
}
#[test]
fn test_read() {
let lock = TryRwLock::new("Hello World!".to_owned());
assert!(!lock.is_locked());
assert!(!lock.is_write_locked());
let guard_1 = lock.try_read().unwrap();
let guard_2 = lock.try_read().unwrap();
assert_eq!(&*guard_1, "Hello World!");
assert_eq!(&*guard_2, "Hello World!");
assert!(lock.try_write().is_none());
assert!(lock.is_locked());
assert!(!lock.is_write_locked());
let guard_1 = ReadGuard::try_upgrade(guard_1).unwrap_err();
let guard_2 = ReadGuard::try_upgrade(guard_2).unwrap_err();
drop(guard_1);
assert!(lock.try_write().is_none());
assert!(lock.try_read().is_some());
let guard_2 = ReadGuard::try_upgrade(guard_2).unwrap();
assert!(lock.try_read().is_none());
let guard_2 = WriteGuard::downgrade(guard_2);
assert!(lock.try_read().is_some());
drop(guard_2);
assert!(!lock.is_locked());
assert!(!lock.is_write_locked());
}
#[test]
fn test_read_map() {
let lock = TryRwLock::new(vec![1u8, 2, 3]);
let guard_1 = ReadGuard::map(lock.try_read().unwrap(), |v| &v[0]);
let guard_2 = ReadGuard::map(lock.try_read().unwrap(), |v| &v[1]);
let guard_3 = ReadGuard::map(lock.try_read().unwrap(), |v| &v[2]);
assert!(lock.is_locked());
assert!(!lock.is_write_locked());
assert_eq!(lock.readers.load(atomic::Ordering::Relaxed), 3);
assert_eq!(*guard_1, 1);
assert_eq!(*guard_2, 2);
assert_eq!(*guard_3, 3);
let guard_1 = ReadGuard::unmap(guard_1);
assert_eq!(*guard_1, [1, 2, 3]);
drop(guard_1);
drop(guard_2);
drop(guard_3);
assert!(!lock.is_locked());
assert!(!lock.is_write_locked());
assert_eq!(lock.readers.load(atomic::Ordering::Relaxed), 0);
}
#[test]
fn test_write() {
let lock = TryRwLock::new("Hello World!".to_owned());
let mut guard = lock.try_write().unwrap();
assert_eq!(&*guard, "Hello World!");
*guard = "Foo".to_owned();
assert_eq!(&*guard, "Foo");
assert!(lock.is_locked());
assert!(lock.is_write_locked());
assert!(lock.try_read().is_none());
assert!(lock.try_write().is_none());
drop(guard);
assert!(!lock.is_locked());
assert!(!lock.is_write_locked());
assert_eq!(&*lock.try_read().unwrap(), "Foo");
}
#[test]
fn test_write_map() {
let lock = TryRwLock::new(vec![1u8, 2, 3]);
let guard = WriteGuard::map(lock.try_write().unwrap(), |v| &mut v[0]);
assert!(lock.is_locked());
assert!(lock.is_write_locked());
assert_eq!(lock.readers.load(atomic::Ordering::Relaxed), usize::MAX);
assert_eq!(*guard, 1);
let guard = WriteGuard::unmap(guard);
assert_eq!(*guard, [1, 2, 3]);
drop(guard);
assert!(!lock.is_locked());
assert!(!lock.is_write_locked());
assert_eq!(lock.readers.load(atomic::Ordering::Relaxed), 0);
}