use futures_channel::oneshot;
use futures_task::{Context, Poll};
use std::{
cell::UnsafeCell,
clone::Clone,
collections::VecDeque,
future::Future,
ops::{Deref, DerefMut},
pin::Pin,
sync,
};
use super::{FutState, TryLockError};
#[cfg(feature = "tokio")] use tokio::task;
#[derive(Debug)]
pub struct RwLockReadGuard<T: ?Sized> {
rwlock: RwLock<T>
}
impl<T: ?Sized> Deref for RwLockReadGuard<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe {&*self.rwlock.inner.data.get()}
}
}
impl<T: ?Sized> Drop for RwLockReadGuard<T> {
fn drop(&mut self) {
self.rwlock.unlock_reader();
}
}
#[derive(Debug)]
pub struct RwLockWriteGuard<T: ?Sized> {
rwlock: RwLock<T>
}
impl<T: ?Sized> Deref for RwLockWriteGuard<T> {
type Target = T;
fn deref(&self) -> &T {
unsafe {&*self.rwlock.inner.data.get()}
}
}
impl<T: ?Sized> DerefMut for RwLockWriteGuard<T> {
fn deref_mut(&mut self) -> &mut T {
unsafe {&mut *self.rwlock.inner.data.get()}
}
}
impl<T: ?Sized> Drop for RwLockWriteGuard<T> {
fn drop(&mut self) {
self.rwlock.unlock_writer();
}
}
pub struct RwLockReadFut<T: ?Sized> {
state: FutState,
rwlock: RwLock<T>,
}
impl<T: ?Sized> RwLockReadFut<T> {
fn new(state: FutState, rwlock: RwLock<T>) -> Self {
RwLockReadFut{state, rwlock}
}
}
impl<T: ?Sized> Drop for RwLockReadFut<T> {
fn drop(&mut self) {
match self.state {
FutState::New => {
},
FutState::Pending(ref mut rx) => {
rx.close();
match rx.try_recv() {
Ok(Some(())) => {
self.rwlock.unlock_reader()
},
Ok(None) => {
},
Err(oneshot::Canceled) => {
}
}
},
FutState::Acquired => {
}
}
}
}
impl<T: ?Sized> Future for RwLockReadFut<T> {
type Output = RwLockReadGuard<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (result, new_state) = match self.state {
FutState::New => {
let mut lock_data = self.rwlock.inner.mutex.lock()
.expect("sync::Mutex::lock");
if lock_data.exclusive {
let (tx, mut rx) = oneshot::channel::<()>();
lock_data.read_waiters.push_back(tx);
assert!(Pin::new(&mut rx).poll(cx).is_pending());
(Poll::Pending, FutState::Pending(rx))
} else {
lock_data.num_readers += 1;
let guard = RwLockReadGuard{rwlock: self.rwlock.clone()};
(Poll::Ready(guard), FutState::Acquired)
}
},
FutState::Pending(ref mut rx) => {
match Pin::new(rx).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => {
let state = FutState::Acquired;
let result = Poll::Ready(
RwLockReadGuard{rwlock: self.rwlock.clone()}
);
(result, state)
} }
},
FutState::Acquired => panic!("Double-poll of ready Future")
};
self.state = new_state;
result
}
}
pub struct RwLockWriteFut<T: ?Sized> {
state: FutState,
rwlock: RwLock<T>,
}
impl<T: ?Sized> RwLockWriteFut<T> {
fn new(state: FutState, rwlock: RwLock<T>) -> Self {
RwLockWriteFut{state, rwlock}
}
}
impl<T: ?Sized> Drop for RwLockWriteFut<T> {
fn drop(&mut self) {
match self.state {
FutState::New => {
},
FutState::Pending(ref mut rx) => {
rx.close();
match rx.try_recv() {
Ok(Some(())) => {
self.rwlock.unlock_writer()
},
Ok(None) => {
},
Err(oneshot::Canceled) => {
}
}
},
FutState::Acquired => {
}
}
}
}
impl<T: ?Sized> Future for RwLockWriteFut<T> {
type Output = RwLockWriteGuard<T>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let (result, new_state) = match self.state {
FutState::New => {
let mut lock_data = self.rwlock.inner.mutex.lock()
.expect("sync::Mutex::lock");
if lock_data.exclusive || lock_data.num_readers > 0 {
let (tx, mut rx) = oneshot::channel::<()>();
lock_data.write_waiters.push_back(tx);
assert!(Pin::new(&mut rx).poll(cx).is_pending());
(Poll::Pending, FutState::Pending(rx))
} else {
lock_data.exclusive = true;
let guard = RwLockWriteGuard{rwlock: self.rwlock.clone()};
(Poll::Ready(guard), FutState::Acquired)
}
},
FutState::Pending(ref mut rx) => {
match Pin::new(rx).poll(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(_) => {
let state = FutState::Acquired;
let result = Poll::Ready(
RwLockWriteGuard{rwlock: self.rwlock.clone()}
);
(result, state)
} }
},
FutState::Acquired => panic!("Double-poll of ready Future")
};
self.state = new_state;
result
}
}
#[derive(Debug, Default)]
struct RwLockData {
exclusive: bool,
num_readers: u32,
read_waiters: VecDeque<oneshot::Sender<()>>,
write_waiters: VecDeque<oneshot::Sender<()>>,
}
#[derive(Debug, Default)]
struct Inner<T: ?Sized> {
mutex: sync::Mutex<RwLockData>,
data: UnsafeCell<T>,
}
#[derive(Debug, Default)]
pub struct RwLock<T: ?Sized> {
inner: sync::Arc<Inner<T>>,
}
impl<T: ?Sized> Clone for RwLock<T> {
fn clone(&self) -> RwLock<T> {
RwLock { inner: self.inner.clone()}
}
}
impl<T> RwLock<T> {
pub fn new(t: T) -> RwLock<T> {
let lock_data = RwLockData {
exclusive: false,
num_readers: 0,
read_waiters: VecDeque::new(),
write_waiters: VecDeque::new(),
}; let inner = Inner {
mutex: sync::Mutex::new(lock_data),
data: UnsafeCell::new(t)
}; RwLock { inner: sync::Arc::new(inner)}
}
pub fn try_unwrap(self) -> Result<T, RwLock<T>> {
match sync::Arc::try_unwrap(self.inner) {
Ok(inner) => Ok({
#[allow(unused_unsafe)]
unsafe { inner.data.into_inner() }
}),
Err(arc) => Err(RwLock {inner: arc})
}
}
}
impl<T: ?Sized> RwLock<T> {
pub fn get_mut(&mut self) -> Option<&mut T> {
if let Some(inner) = sync::Arc::get_mut(&mut self.inner) {
let lock_data = inner.mutex.get_mut().unwrap();
let data = unsafe { inner.data.get().as_mut() }.unwrap();
debug_assert!(!lock_data.exclusive);
debug_assert_eq!(lock_data.num_readers, 0);
Some(data)
} else {
None
}
}
pub fn read(&self) -> RwLockReadFut<T> {
RwLockReadFut::new(FutState::New, self.clone())
}
pub fn write(&self) -> RwLockWriteFut<T> {
RwLockWriteFut::new(FutState::New, self.clone())
}
pub fn try_read(&self) -> Result<RwLockReadGuard<T>, TryLockError> {
let mut lock_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
if lock_data.exclusive {
Err(TryLockError)
} else {
lock_data.num_readers += 1;
Ok(RwLockReadGuard{rwlock: self.clone()})
}
}
pub fn try_write(&self) -> Result<RwLockWriteGuard<T>, TryLockError> {
let mut lock_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
if lock_data.exclusive || lock_data.num_readers > 0 {
Err(TryLockError)
} else {
lock_data.exclusive = true;
Ok(RwLockWriteGuard{rwlock: self.clone()})
}
}
fn unlock_reader(&self) {
let mut lock_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
assert!(lock_data.num_readers > 0);
assert!(!lock_data.exclusive);
assert_eq!(lock_data.read_waiters.len(), 0);
lock_data.num_readers -= 1;
if lock_data.num_readers == 0 {
while let Some(tx) = lock_data.write_waiters.pop_front() {
if tx.send(()).is_ok() {
lock_data.exclusive = true;
return
}
}
}
}
fn unlock_writer(&self) {
let mut lock_data = self.inner.mutex.lock().expect("sync::Mutex::lock");
assert!(lock_data.num_readers == 0);
assert!(lock_data.exclusive);
while let Some(tx) = lock_data.write_waiters.pop_front() {
if tx.send(()).is_ok() {
return;
}
}
lock_data.exclusive = false;
lock_data.num_readers += lock_data.read_waiters.len() as u32;
for tx in lock_data.read_waiters.drain(..) {
let _ = tx.send(());
}
}
}
impl<T: 'static + ?Sized> RwLock<T> {
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with_read<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(RwLockReadGuard<T>) -> B + Send + 'static,
B: Future<Output = R> + Send + 'static,
R: Send + 'static,
T: Send
{
let jh = tokio::spawn({
let fut = self.read();
async move { f(fut.await).await }
});
async move { jh.await.unwrap() }
}
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with_read_local<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(RwLockReadGuard<T>) -> B + 'static + Unpin,
B: Future<Output = R> + 'static,
R: 'static
{
let local = task::LocalSet::new();
let jh = local.spawn_local({
let fut = self.read();
async move { f(fut.await).await }
});
async move {
local.await;
jh.await.unwrap()
}
}
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with_write<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(RwLockWriteGuard<T>) -> B + Send + 'static,
B: Future<Output = R> + Send + 'static,
R: Send + 'static,
T: Send
{
let jh = tokio::spawn({
let fut = self.write();
async move { f(fut.await).await }
});
async move { jh.await.unwrap() }
}
#[cfg(any(feature = "tokio", all(docsrs, rustdoc)))]
#[cfg_attr(docsrs, doc(cfg(feature = "tokio")))]
pub fn with_write_local<B, F, R>(&self, f: F)
-> impl Future<Output = R>
where F: FnOnce(RwLockWriteGuard<T>) -> B + 'static + Unpin,
B: Future<Output = R> + 'static,
R: 'static
{
let local = task::LocalSet::new();
let jh = local.spawn_local({
let fut = self.write();
async move { f(fut.await).await }
});
async move {
local.await;
jh.await.unwrap()
}
}
}
#[allow(clippy::non_send_fields_in_send_ty)]
unsafe impl<T: ?Sized + Send> Send for RwLock<T> {}
unsafe impl<T: ?Sized + Send + Sync> Sync for RwLock<T> {}
#[cfg(test)]
mod t {
use super::*;
#[test]
fn debug() {
let m = RwLock::<u32>::new(0);
format!("{:?}", &m);
}
#[test]
fn test_default() {
let lock = RwLock::default();
let value: u32 = lock.try_unwrap().unwrap();
let expected = u32::default();
assert_eq!(expected, value);
}
}