use std::cell::UnsafeCell;
use std::fmt;
use std::ops::{Deref, DerefMut};
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::sync::atomic::{AtomicU64, Ordering};
use crate::misc::{PhantomRwLock, PhantomRwLockReadGuard, PhantomRwLockWriteGuard};
use crate::result::{LockResult, PoisonError, TryLockError, TryLockResult};
pub struct RwLock<T: ?Sized> {
lock: AtomicU64,
_phantom: PhantomRwLock<T>,
data: UnsafeCell<T>,
}
impl<T> RwLock<T> {
pub const fn new(t: T) -> Self {
Self {
lock: AtomicU64::new(INIT),
data: UnsafeCell::new(t),
_phantom: PhantomRwLock {},
}
}
pub fn into_inner(self) -> LockResult<T> {
let is_err = self.is_poisoned();
let data = self.data.into_inner();
if is_err {
Err(PoisonError::new(data))
} else {
Ok(data)
}
}
}
impl<T: ?Sized> RwLock<T> {
pub const MAX_READ_LOCK_COUNT: u64 = SHARED_LOCK_MASK;
pub fn read(&self) -> LockResult<RwLockReadGuard<'_, T>> {
loop {
match self.try_lock(acquire_shared_lock, is_locked_exclusively) {
s if is_locked_exclusively(s) => std::thread::yield_now(),
s if is_poisoned(s) => return Err(PoisonError::new(RwLockReadGuard::new(self))),
_ => return Ok(RwLockReadGuard::new(self)),
}
}
}
pub fn try_read(&self) -> TryLockResult<RwLockReadGuard<T>> {
match self.try_lock(acquire_shared_lock, is_locked_exclusively) {
s if is_locked_exclusively(s) => Err(TryLockError::WouldBlock),
s if is_poisoned(s) => Err(TryLockError::Poisoned(PoisonError::new(
RwLockReadGuard::new(self),
))),
_ => Ok(RwLockReadGuard::new(self)),
}
}
pub fn try_write(&self) -> TryLockResult<RwLockWriteGuard<T>> {
match self.try_lock(acquire_exclusive_lock, is_locked) {
s if is_locked(s) => Err(TryLockError::WouldBlock),
s if is_poisoned(s) => Err(TryLockError::Poisoned(PoisonError::new(
RwLockWriteGuard::new(self),
))),
_ => Ok(RwLockWriteGuard::new(self)),
}
}
pub fn write(&self) -> LockResult<RwLockWriteGuard<'_, T>> {
loop {
match self.try_lock(acquire_exclusive_lock, is_locked) {
s if is_locked(s) => std::thread::yield_now(),
s if is_poisoned(s) => return Err(PoisonError::new(RwLockWriteGuard::new(self))),
_ => return Ok(RwLockWriteGuard::new(self)),
}
}
}
fn try_lock<AcqFn, LockCheckFn>(&self, acq_fn: AcqFn, lock_check_fn: LockCheckFn) -> LockStatus
where
AcqFn: Fn(LockStatus) -> LockStatus,
LockCheckFn: Fn(LockStatus) -> bool,
{
let mut expected = INIT;
loop {
let desired = acq_fn(expected);
let current = self
.lock
.compare_and_swap(expected, desired, Ordering::Acquire);
if current == expected {
return current;
}
if lock_check_fn(current) {
return current;
}
expected = current;
}
}
pub fn is_poisoned(&self) -> bool {
let status = self.lock.load(Ordering::Relaxed);
is_poisoned(status)
}
pub fn get_mut(&mut self) -> LockResult<&mut T> {
let data = unsafe { &mut *self.data.get() };
if self.is_poisoned() {
Err(PoisonError::new(data))
} else {
Ok(data)
}
}
}
impl<T> From<T> for RwLock<T> {
fn from(t: T) -> Self {
RwLock::new(t)
}
}
impl<T: Default> Default for RwLock<T> {
fn default() -> Self {
RwLock::new(T::default())
}
}
impl<T: ?Sized + fmt::Debug> fmt::Debug for RwLock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.try_read() {
Ok(guard) => f.debug_struct("RwLock").field("data", &&*guard).finish(),
Err(TryLockError::Poisoned(err)) => f
.debug_struct("RwLock")
.field("data", &&**err.get_ref())
.finish(),
Err(TryLockError::WouldBlock) => {
struct LockedPlaceholder;
impl fmt::Debug for LockedPlaceholder {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<locked>")
}
}
f.debug_struct("RwLock")
.field("data", &LockedPlaceholder)
.finish()
}
}
}
}
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
rwlock: &'a RwLock<T>,
_phantom: PhantomRwLockReadGuard<'a, T>, }
impl<'a, T: ?Sized> RwLockReadGuard<'a, T> {
fn new(rwlock: &'a RwLock<T>) -> Self {
Self {
rwlock,
_phantom: Default::default(),
}
}
}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.rwlock.data.get() }
}
}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
let mut expected = acquire_shared_lock(INIT);
loop {
let desired = release_shared_lock(expected);
let current = self
.rwlock
.lock
.compare_and_swap(expected, desired, Ordering::Release);
if current == expected {
return;
}
expected = current;
}
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RwLockReadGuard")
.field("lock", &self.rwlock)
.finish()
}
}
#[must_use = "if unused the RwLock will immediately unlock"]
pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
rwlock: &'a RwLock<T>,
_phantom: PhantomRwLockWriteGuard<'a, T>, }
impl<'a, T: ?Sized> RwLockWriteGuard<'a, T> {
fn new(rwlock: &'a RwLock<T>) -> Self {
Self {
rwlock,
_phantom: Default::default(),
}
}
}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
unsafe { &*self.rwlock.data.get() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.rwlock.data.get() }
}
}
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
let old_status = self.rwlock.lock.load(Ordering::Relaxed);
let mut new_status = release_exclusive_lock(old_status);
if std::thread::panicking() {
new_status = set_poison_flag(new_status);
}
self.rwlock.lock.store(new_status, Ordering::Release);
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RwLockWriteGuard")
.field("lock", &self.rwlock)
.finish()
}
}
impl<T: ?Sized> UnwindSafe for RwLock<T> {}
impl<T: ?Sized> RefUnwindSafe for RwLock<T> {}
unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
type LockStatus = u64;
const INIT: LockStatus = 0;
const SHARED_LOCK_MASK: LockStatus = 0x3fffffffffffffff;
const EXCLUSIVE_LOCK_FLAG: LockStatus = 0x4000000000000000;
const POISON_FLAG: LockStatus = 0x8000000000000000;
#[must_use]
#[inline]
fn is_poisoned(s: LockStatus) -> bool {
(s & POISON_FLAG) != 0
}
#[must_use]
#[inline]
fn set_poison_flag(s: LockStatus) -> LockStatus {
s | POISON_FLAG
}
#[must_use]
#[inline]
fn is_locked(s: LockStatus) -> bool {
s & (!POISON_FLAG) != 0
}
#[must_use]
#[inline]
fn is_locked_exclusively(s: LockStatus) -> bool {
let ret = (s & EXCLUSIVE_LOCK_FLAG) != 0;
if ret {
debug_assert_eq!(0, s & SHARED_LOCK_MASK);
}
ret
}
#[must_use]
#[inline]
fn acquire_exclusive_lock(s: LockStatus) -> LockStatus {
debug_assert_eq!(false, is_locked(s));
s | EXCLUSIVE_LOCK_FLAG
}
#[must_use]
#[inline]
fn release_exclusive_lock(s: LockStatus) -> LockStatus {
debug_assert_eq!(true, is_locked_exclusively(s));
s & (!EXCLUSIVE_LOCK_FLAG)
}
#[must_use]
#[inline]
fn count_shared_locks(s: LockStatus) -> u64 {
let ret = s & SHARED_LOCK_MASK;
if 0 < ret {
debug_assert_eq!(0, s & EXCLUSIVE_LOCK_FLAG);
}
ret
}
#[must_use]
#[inline]
fn acquire_shared_lock(s: LockStatus) -> LockStatus {
debug_assert_eq!(false, is_locked_exclusively(s));
if count_shared_locks(s) == SHARED_LOCK_MASK {
panic!("rwlock maximum reader count exceeded");
}
s + 1
}
#[must_use]
#[inline]
fn release_shared_lock(s: LockStatus) -> LockStatus {
debug_assert!(0 < count_shared_locks(s));
s - 1
}
#[cfg(test)]
mod rwlock_tests {
use super::*;
#[test]
fn try_many_times() {
let lock = RwLock::new(0);
{
let mut guard0 = lock.try_write().unwrap();
assert_eq!(0, *guard0);
*guard0 += 1;
assert_eq!(1, *guard0);
let result1 = lock.try_read();
assert!(result1.is_err());
let result2 = lock.try_write();
assert!(result2.is_err());
let result3 = lock.try_read();
assert!(result3.is_err());
let result4 = lock.try_write();
assert!(result4.is_err());
}
{
let guard0 = lock.try_read().unwrap();
assert_eq!(1, *guard0);
let result1 = lock.try_write();
assert!(result1.is_err());
let guard2 = lock.try_read().unwrap();
assert_eq!(1, *guard2);
let result3 = lock.try_write();
assert!(result3.is_err());
let guard4 = lock.try_read().unwrap();
assert_eq!(1, *guard4);
let result5 = lock.try_write();
assert!(result5.is_err());
}
}
}
#[cfg(test)]
mod lock_state_tests {
use super::*;
#[test]
fn flag_duplication() {
assert_eq!(0, INIT & SHARED_LOCK_MASK);
assert_eq!(0, INIT & EXCLUSIVE_LOCK_FLAG);
assert_eq!(0, INIT & POISON_FLAG);
assert_eq!(0, SHARED_LOCK_MASK & EXCLUSIVE_LOCK_FLAG);
assert_eq!(0, SHARED_LOCK_MASK & POISON_FLAG);
assert_eq!(0, EXCLUSIVE_LOCK_FLAG & POISON_FLAG);
}
#[test]
fn flag_uses_all_bits() {
assert_eq!(
std::u64::MAX,
INIT | SHARED_LOCK_MASK | EXCLUSIVE_LOCK_FLAG | POISON_FLAG
);
}
}