use crate::sync::Notify;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::{Relaxed, SeqCst};
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
use std::ops;
#[derive(Debug)]
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
version: usize,
}
#[derive(Debug)]
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
#[derive(Debug)]
pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
}
#[derive(Debug)]
struct Shared<T> {
value: RwLock<T>,
version: AtomicUsize,
ref_count_rx: AtomicUsize,
notify_rx: Notify,
notify_tx: Notify,
}
pub mod error {
use std::fmt;
#[derive(Debug)]
pub struct SendError<T> {
pub(crate) inner: T,
}
impl<T: fmt::Debug> fmt::Display for SendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T: fmt::Debug> std::error::Error for SendError<T> {}
#[derive(Debug)]
pub struct RecvError(pub(super) ());
impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl std::error::Error for RecvError {}
}
const CLOSED: usize = 1;
pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
version: AtomicUsize::new(0),
ref_count_rx: AtomicUsize::new(1),
notify_rx: Notify::new(),
notify_tx: Notify::new(),
});
let tx = Sender {
shared: shared.clone(),
};
let rx = Receiver { shared, version: 0 };
(tx, rx)
}
impl<T> Receiver<T> {
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
Ref { inner }
}
pub async fn changed(&mut self) -> Result<(), error::RecvError> {
loop {
let notified = self.shared.notify_rx.notified();
if let Some(ret) = maybe_changed(&self.shared, &mut self.version) {
return ret;
}
notified.await;
}
}
}
fn maybe_changed<T>(
shared: &Shared<T>,
version: &mut usize,
) -> Option<Result<(), error::RecvError>> {
let state = shared.version.load(SeqCst);
let new_version = state & !CLOSED;
if *version != new_version {
*version = new_version;
return Some(Ok(()));
}
if CLOSED == state & CLOSED {
return Some(Err(error::RecvError(())));
}
None
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let version = self.version;
let shared = self.shared.clone();
shared.ref_count_rx.fetch_add(1, Relaxed);
Receiver { version, shared }
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
self.shared.notify_tx.notify_waiters();
}
}
}
impl<T> Sender<T> {
pub fn send(&self, value: T) -> Result<(), error::SendError<T>> {
if 0 == self.shared.ref_count_rx.load(Relaxed) {
return Err(error::SendError { inner: value });
}
*self.shared.value.write().unwrap() = value;
self.shared.version.fetch_add(2, SeqCst);
self.shared.notify_rx.notify_waiters();
Ok(())
}
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
Ref { inner }
}
pub fn is_closed(&self) -> bool {
self.shared.ref_count_rx.load(Relaxed) == 0
}
pub async fn closed(&self) {
let notified = self.shared.notify_tx.notified();
if self.shared.ref_count_rx.load(Relaxed) == 0 {
return;
}
notified.await;
debug_assert_eq!(0, self.shared.ref_count_rx.load(Relaxed));
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.shared.version.fetch_or(CLOSED, SeqCst);
self.shared.notify_rx.notify_waiters();
}
}
impl<T> ops::Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
#[cfg(all(test, loom))]
mod tests {
use futures::future::FutureExt;
use loom::thread;
#[test]
fn watch_spurious_wakeup() {
loom::model(|| {
let (send, mut recv) = crate::sync::watch::channel(0i32);
send.send(1).unwrap();
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
send
});
recv.changed().now_or_never();
let send = send_thread.join().unwrap();
let recv_thread = thread::spawn(move || {
recv.changed().now_or_never();
recv.changed().now_or_never();
recv
});
send.send(3).unwrap();
let mut recv = recv_thread.join().unwrap();
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
});
recv.changed().now_or_never();
send_thread.join().unwrap();
});
}
#[test]
fn watch_borrow() {
loom::model(|| {
let (send, mut recv) = crate::sync::watch::channel(0i32);
assert!(send.borrow().eq(&0));
assert!(recv.borrow().eq(&0));
send.send(1).unwrap();
assert!(send.borrow().eq(&1));
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
send
});
recv.changed().now_or_never();
let send = send_thread.join().unwrap();
let recv_thread = thread::spawn(move || {
recv.changed().now_or_never();
recv.changed().now_or_never();
recv
});
send.send(3).unwrap();
let recv = recv_thread.join().unwrap();
assert!(recv.borrow().eq(&3));
assert!(send.borrow().eq(&3));
send.send(2).unwrap();
thread::spawn(move || {
assert!(recv.borrow().eq(&2));
});
assert!(send.borrow().eq(&2));
});
}
}