use std::{
cell::UnsafeCell,
pin::Pin,
sync::{
Condvar, Mutex,
atomic::{AtomicBool, AtomicUsize, Ordering as AtomOrd},
},
task::{Context, Poll, Waker},
};
pub struct FusedRw<T> {
data: UnsafeCell<T>,
state: AtomicUsize,
poisoned: AtomicBool,
wait_mutex: Mutex<()>,
wait_cvar: Condvar,
wait_wakers: Mutex<Vec<Waker>>,
}
unsafe impl<T: Send> Send for FusedRw<T> {}
unsafe impl<T: Sync> Sync for FusedRw<T> {}
impl<T> FusedRw<T> {
pub fn new(value: T) -> Self {
Self {
data: UnsafeCell::new(value),
state: AtomicUsize::new(1),
poisoned: AtomicBool::new(false),
wait_mutex: Mutex::new(()),
wait_cvar: Condvar::new(),
wait_wakers: Mutex::new(Vec::new()),
}
}
pub fn read<'a>(&'a self) -> FusedReadGuard<'a, T> {
if self.poisoned.load(AtomOrd::Acquire) {
panic!("FusedRw is poisoned");
}
if self.state.load(AtomOrd::Acquire) > 0 {
self.state.fetch_add(1, AtomOrd::AcqRel);
return FusedReadGuard { rw: self };
}
let mut guard = self.wait_mutex.lock().expect("Wait cannot poison");
while self.state.load(AtomOrd::Acquire) == 0 {
guard = self.wait_cvar.wait(guard).expect("Wait cannot poison");
}
self.state.fetch_add(1, AtomOrd::AcqRel);
FusedReadGuard { rw: self }
}
pub fn write<'a>(&'a self) -> FusedWriteGuard<'a, T> {
if self.poisoned.load(AtomOrd::Acquire) {
panic!("FusedRw is poisoned");
}
if self
.state
.compare_exchange(1, 0, AtomOrd::AcqRel, AtomOrd::Acquire)
.is_ok()
{
return FusedWriteGuard { rw: self };
}
let mut guard = self.wait_mutex.lock().expect("Wait cannot poison");
while self
.state
.compare_exchange(1, 0, AtomOrd::AcqRel, AtomOrd::Acquire)
.is_err()
{
guard = self.wait_cvar.wait(guard).expect("Wait cannot poison");
}
FusedWriteGuard { rw: self }
}
pub fn read_async<'a>(&'a self) -> ReadFuture<'a, T> {
ReadFuture { rw: self }
}
pub fn write_async<'a>(&'a self) -> WriteFuture<'a, T> {
WriteFuture { rw: self }
}
fn notify(&self) {
let mut wakers = self.wait_wakers.lock().unwrap();
for waker in wakers.drain(..) {
waker.wake();
}
self.wait_cvar.notify_all();
}
}
pub struct FusedReadGuard<'a, T> {
rw: &'a FusedRw<T>,
}
impl<'a, T> std::ops::Deref for FusedReadGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.rw.data.get() }
}
}
impl<'a, T> Drop for FusedReadGuard<'a, T> {
fn drop(&mut self) {
if std::thread::panicking() {
self.rw.poisoned.store(true, AtomOrd::Release);
}
if self.rw.state.fetch_sub(1, AtomOrd::AcqRel) == 2 {
self.rw.notify();
}
}
}
pub struct FusedWriteGuard<'a, T> {
rw: &'a FusedRw<T>,
}
impl<'a, T> std::ops::Deref for FusedWriteGuard<'a, T> {
type Target = T;
fn deref(&self) -> &T {
unsafe { &*self.rw.data.get() }
}
}
impl<'a, T> std::ops::DerefMut for FusedWriteGuard<'a, T> {
fn deref_mut(&mut self) -> &mut T {
unsafe { &mut *self.rw.data.get() }
}
}
impl<'a, T> Drop for FusedWriteGuard<'a, T> {
fn drop(&mut self) {
if std::thread::panicking() {
self.rw.poisoned.store(true, AtomOrd::Release);
}
self.rw.state.store(1, AtomOrd::Release);
self.rw.notify();
}
}
pub struct ReadFuture<'a, T> {
rw: &'a FusedRw<T>,
}
impl<'a, T> Future for ReadFuture<'a, T> {
type Output = FusedReadGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.rw.poisoned.load(AtomOrd::Acquire) {
panic!("FusedRw is poisoned");
}
if self.rw.state.load(AtomOrd::Acquire) > 0 {
self.rw.state.fetch_add(1, AtomOrd::AcqRel);
return Poll::Ready(FusedReadGuard { rw: self.rw });
}
let mut guard = self.rw.wait_wakers.lock().expect("Wakers cannot poison");
if self.rw.state.load(AtomOrd::Acquire) > 0 {
self.rw.state.fetch_add(1, AtomOrd::AcqRel);
return Poll::Ready(FusedReadGuard { rw: self.rw });
}
let waker = cx.waker().clone();
if !guard.iter().any(|w| w.will_wake(&waker)) {
guard.push(waker);
}
Poll::Pending
}
}
pub struct WriteFuture<'a, T> {
rw: &'a FusedRw<T>,
}
impl<'a, T> Future for WriteFuture<'a, T> {
type Output = FusedWriteGuard<'a, T>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.rw.poisoned.load(AtomOrd::Acquire) {
panic!("FusedRw is poisoned");
}
if self
.rw
.state
.compare_exchange(1, 0, AtomOrd::AcqRel, AtomOrd::Acquire)
.is_ok()
{
return Poll::Ready(FusedWriteGuard { rw: self.rw });
}
let mut guard = self.rw.wait_wakers.lock().expect("Wakers cannot poison");
if self
.rw
.state
.compare_exchange(1, 0, AtomOrd::AcqRel, AtomOrd::Acquire)
.is_ok()
{
return Poll::Ready(FusedWriteGuard { rw: self.rw });
}
let waker = cx.waker().clone();
if !guard.iter().any(|w| w.will_wake(&waker)) {
guard.push(waker);
}
Poll::Pending
}
}