use crate::sync::{futex_wait_fast, NotSend};
use core::cell::UnsafeCell;
use core::fmt;
use core::ops::{Deref, DerefMut};
use core::ptr::NonNull;
use core::sync::atomic::AtomicU32;
use core::sync::atomic::Ordering::{Acquire, Relaxed, Release};
use rusl::futex::futex_wake;
pub struct RwLock<T: ?Sized> {
inner: InnerLock,
data: UnsafeCell<T>,
}
unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
#[must_use = "if unused the RwLock will immediately unlock"]
#[clippy::has_significant_drop]
pub struct RwLockReadGuard<'a, T: ?Sized + 'a> {
data: NonNull<T>,
inner_lock: &'a InnerLock,
_not_send: NotSend,
}
unsafe impl<T: ?Sized + Sync> Sync for RwLockReadGuard<'_, T> {}
impl<'rwlock, T: ?Sized> RwLockReadGuard<'rwlock, T> {
unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockReadGuard<'rwlock, T> {
RwLockReadGuard {
data: NonNull::new_unchecked(lock.data.get()),
inner_lock: &lock.inner,
_not_send: NotSend::new(),
}
}
}
#[must_use = "if unused the RwLock will immediately unlock"]
#[clippy::has_significant_drop]
pub struct RwLockWriteGuard<'a, T: ?Sized + 'a> {
lock: &'a RwLock<T>,
_not_send: NotSend,
}
unsafe impl<T: ?Sized + Sync> Sync for RwLockWriteGuard<'_, T> {}
impl<'rwlock, T: ?Sized> RwLockWriteGuard<'rwlock, T> {
unsafe fn new(lock: &'rwlock RwLock<T>) -> RwLockWriteGuard<'rwlock, T> {
RwLockWriteGuard {
lock,
_not_send: NotSend::new(),
}
}
}
impl<T> RwLock<T> {
#[inline]
pub const fn new(t: T) -> RwLock<T> {
RwLock {
inner: InnerLock::new(),
data: UnsafeCell::new(t),
}
}
}
impl<T: ?Sized> RwLock<T> {
#[inline]
pub fn read(&self) -> RwLockReadGuard<'_, T> {
unsafe {
self.inner.read();
RwLockReadGuard::new(self)
}
}
#[inline]
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, T>> {
unsafe { self.inner.try_read().then(|| RwLockReadGuard::new(self)) }
}
#[inline]
pub fn write(&self) -> RwLockWriteGuard<'_, T> {
unsafe {
self.inner.write();
RwLockWriteGuard::new(self)
}
}
#[inline]
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, T>> {
unsafe { self.inner.try_write().then(|| RwLockWriteGuard::new(self)) }
}
pub fn into_inner(self) -> T
where
T: Sized,
{
self.data.into_inner()
}
pub fn get_mut(&mut self) -> &mut T {
self.data.get_mut()
}
}
struct InnerLock {
state: AtomicU32,
writer_notify: AtomicU32,
}
const READ_LOCKED: u32 = 1;
const MASK: u32 = (1 << 30) - 1;
const WRITE_LOCKED: u32 = MASK;
const MAX_READERS: u32 = MASK - 1;
const READERS_WAITING: u32 = 1 << 30;
const WRITERS_WAITING: u32 = 1 << 31;
#[inline]
fn is_unlocked(state: u32) -> bool {
state & MASK == 0
}
#[inline]
fn is_write_locked(state: u32) -> bool {
state & MASK == WRITE_LOCKED
}
#[inline]
fn has_readers_waiting(state: u32) -> bool {
state & READERS_WAITING != 0
}
#[inline]
fn has_writers_waiting(state: u32) -> bool {
state & WRITERS_WAITING != 0
}
#[inline]
fn is_read_lockable(state: u32) -> bool {
state & MASK < MAX_READERS && !has_readers_waiting(state) && !has_writers_waiting(state)
}
#[inline]
fn has_reached_max_readers(state: u32) -> bool {
state & MASK == MAX_READERS
}
impl InnerLock {
#[inline]
pub const fn new() -> Self {
Self {
state: AtomicU32::new(0),
writer_notify: AtomicU32::new(0),
}
}
#[inline]
pub fn try_read(&self) -> bool {
self.state
.fetch_update(Acquire, Relaxed, |s| {
is_read_lockable(s).then_some(s + READ_LOCKED)
})
.is_ok()
}
#[inline]
pub fn read(&self) {
let state = self.state.load(Relaxed);
if !is_read_lockable(state)
|| self
.state
.compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
.is_err()
{
self.read_contended();
}
}
#[inline]
pub unsafe fn read_unlock(&self) {
let state = self.state.fetch_sub(READ_LOCKED, Release) - READ_LOCKED;
debug_assert!(!has_readers_waiting(state) || has_writers_waiting(state));
if is_unlocked(state) && has_writers_waiting(state) {
self.wake_writer_or_readers(state);
}
}
#[cold]
fn read_contended(&self) {
let mut state = self.spin_read();
loop {
if is_read_lockable(state) {
match self
.state
.compare_exchange_weak(state, state + READ_LOCKED, Acquire, Relaxed)
{
Ok(_) => return, Err(s) => {
state = s;
continue;
}
}
}
assert!(
!has_reached_max_readers(state),
"too many active read locks on RwLock"
);
if !has_readers_waiting(state) {
if let Err(s) =
self.state
.compare_exchange(state, state | READERS_WAITING, Relaxed, Relaxed)
{
state = s;
continue;
}
}
futex_wait_fast(&self.state, state | READERS_WAITING);
state = self.spin_read();
}
}
#[inline]
pub fn try_write(&self) -> bool {
self.state
.fetch_update(Acquire, Relaxed, |s| {
is_unlocked(s).then_some(s + WRITE_LOCKED)
})
.is_ok()
}
#[inline]
pub fn write(&self) {
if self
.state
.compare_exchange_weak(0, WRITE_LOCKED, Acquire, Relaxed)
.is_err()
{
self.write_contended();
}
}
#[inline]
pub unsafe fn write_unlock(&self) {
let state = self.state.fetch_sub(WRITE_LOCKED, Release) - WRITE_LOCKED;
debug_assert!(is_unlocked(state));
if has_writers_waiting(state) || has_readers_waiting(state) {
self.wake_writer_or_readers(state);
}
}
#[cold]
fn write_contended(&self) {
let mut state = self.spin_write();
let mut other_writers_waiting = 0;
loop {
if is_unlocked(state) {
match self.state.compare_exchange_weak(
state,
state | WRITE_LOCKED | other_writers_waiting,
Acquire,
Relaxed,
) {
Ok(_) => return, Err(s) => {
state = s;
continue;
}
}
}
if !has_writers_waiting(state) {
if let Err(s) =
self.state
.compare_exchange(state, state | WRITERS_WAITING, Relaxed, Relaxed)
{
state = s;
continue;
}
}
other_writers_waiting = WRITERS_WAITING;
let seq = self.writer_notify.load(Acquire);
state = self.state.load(Relaxed);
if is_unlocked(state) || !has_writers_waiting(state) {
continue;
}
futex_wait_fast(&self.writer_notify, seq);
state = self.spin_write();
}
}
#[cold]
fn wake_writer_or_readers(&self, mut state: u32) {
assert!(is_unlocked(state));
if state == WRITERS_WAITING {
match self.state.compare_exchange(state, 0, Relaxed, Relaxed) {
Ok(_) => {
self.wake_writer();
return;
}
Err(s) => {
state = s;
}
}
}
if state == READERS_WAITING + WRITERS_WAITING {
if self
.state
.compare_exchange(state, READERS_WAITING, Relaxed, Relaxed)
.is_err()
{
return;
}
if self.wake_writer() {
return;
}
state = READERS_WAITING;
}
if state == READERS_WAITING
&& self
.state
.compare_exchange(state, 0, Relaxed, Relaxed)
.is_ok()
{
let _ = futex_wake(&self.state, i32::MAX);
}
}
fn wake_writer(&self) -> bool {
self.writer_notify.fetch_add(1, Release);
futex_wake(&self.writer_notify, 1).unwrap() != 0
}
#[inline]
fn spin_until(&self, f: impl Fn(u32) -> bool) -> u32 {
let mut spin = 100; loop {
let state = self.state.load(Relaxed);
if f(state) || spin == 0 {
return state;
}
core::hint::spin_loop();
spin -= 1;
}
}
#[inline]
fn spin_write(&self) -> u32 {
self.spin_until(|state| is_unlocked(state) || has_writers_waiting(state))
}
#[inline]
fn spin_read(&self) -> u32 {
self.spin_until(|state| {
!is_write_locked(state) || has_readers_waiting(state) || has_writers_waiting(state)
})
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for RwLockReadGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: fmt::Debug> fmt::Debug for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized + fmt::Display> fmt::Display for RwLockWriteGuard<'_, T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(**self).fmt(f)
}
}
impl<T: ?Sized> Deref for RwLockReadGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { self.data.as_ref() }
}
}
impl<T: ?Sized> Deref for RwLockWriteGuard<'_, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.lock.data.get() }
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<'_, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.lock.data.get() }
}
}
impl<T: ?Sized> Drop for RwLockReadGuard<'_, T> {
fn drop(&mut self) {
unsafe {
self.inner_lock.read_unlock();
}
}
}
impl<T: ?Sized> Drop for RwLockWriteGuard<'_, T> {
fn drop(&mut self) {
unsafe {
self.lock.inner.write_unlock();
}
}
}
#[cfg(test)]
mod tests {
use crate::sync::RwLock;
use core::time::Duration;
#[test]
fn can_lock() {
let rw = std::sync::Arc::new(super::RwLock::new(0));
let rw_c = rw.clone();
let mut guard = rw.write();
let res = std::thread::spawn(move || *rw_c.read());
*guard = 15;
drop(guard);
let thread_res = res.join().unwrap();
assert_eq!(15, thread_res);
}
#[test]
fn can_mutex_contended() {
const NUM_THREADS: usize = 32;
let count = std::sync::Arc::new(RwLock::new(0));
let mut handles = std::vec::Vec::new();
for _i in 0..NUM_THREADS {
let count_c = count.clone();
let handle = std::thread::spawn(move || {
let mut w_guard = count_c.write();
let orig = *w_guard;
std::thread::sleep(Duration::from_millis(1));
*w_guard += 1;
drop(w_guard);
std::thread::sleep(Duration::from_millis(1));
let r_guard = count_c.read();
std::thread::sleep(Duration::from_millis(1));
assert!(*r_guard > orig);
});
handles.push(handle);
}
for h in handles {
h.join().unwrap();
}
assert_eq!(NUM_THREADS, *count.read());
}
#[test]
fn can_try_rw_single_thread_contended() {
let rw = std::sync::Arc::new(super::RwLock::new(0));
let rw_c = rw.clone();
assert_eq!(0, *rw_c.try_read().unwrap());
let r_guard = rw.read();
assert_eq!(0, *rw_c.try_read().unwrap());
assert!(rw_c.try_write().is_none());
drop(r_guard);
assert_eq!(0, *rw_c.try_write().unwrap());
let _w_guard = rw.write();
assert!(rw_c.try_read().is_none());
}
}