#![cfg_attr(not(feature = "sync"), allow(dead_code, unreachable_pub))]
use crate::sync::notify::Notify;
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering::Relaxed;
use crate::loom::sync::{Arc, RwLock, RwLockReadGuard};
use std::fmt;
use std::mem;
use std::ops;
use std::panic;
#[derive(Debug)]
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
version: Version,
}
#[derive(Debug)]
pub struct Sender<T> {
shared: Arc<Shared<T>>,
}
#[derive(Debug)]
pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
has_changed: bool,
}
impl<'a, T> Ref<'a, T> {
pub fn has_changed(&self) -> bool {
self.has_changed
}
}
struct Shared<T> {
value: RwLock<T>,
state: AtomicState,
ref_count_rx: AtomicUsize,
notify_rx: big_notify::BigNotify,
notify_tx: Notify,
}
impl<T: fmt::Debug> fmt::Debug for Shared<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let state = self.state.load();
f.debug_struct("Shared")
.field("value", &self.value)
.field("version", &state.version())
.field("is_closed", &state.is_closed())
.field("ref_count_rx", &self.ref_count_rx)
.finish()
}
}
pub mod error {
use std::error::Error;
use std::fmt;
#[derive(PartialEq, Eq, Clone, Copy)]
pub struct SendError<T>(pub T);
impl<T> fmt::Debug for SendError<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("SendError").finish_non_exhaustive()
}
}
impl<T> fmt::Display for SendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T> Error for SendError<T> {}
#[derive(Debug, Clone)]
pub struct RecvError(pub(super) ());
impl fmt::Display for RecvError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl Error for RecvError {}
}
mod big_notify {
use super::*;
use crate::sync::notify::Notified;
pub(super) struct BigNotify {
#[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))]
next: AtomicUsize,
inner: [Notify; 8],
}
impl BigNotify {
pub(super) fn new() -> Self {
Self {
#[cfg(not(all(
not(loom),
feature = "sync",
any(feature = "rt", feature = "macros")
)))]
next: AtomicUsize::new(0),
inner: Default::default(),
}
}
pub(super) fn notify_waiters(&self) {
for notify in &self.inner {
notify.notify_waiters();
}
}
#[cfg(not(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros"))))]
pub(super) fn notified(&self) -> Notified<'_> {
let i = self.next.fetch_add(1, Relaxed) % 8;
self.inner[i].notified()
}
#[cfg(all(not(loom), feature = "sync", any(feature = "rt", feature = "macros")))]
pub(super) fn notified(&self) -> Notified<'_> {
let i = crate::runtime::context::thread_rng_n(8) as usize;
self.inner[i].notified()
}
}
}
use self::state::{AtomicState, Version};
mod state {
use crate::loom::sync::atomic::AtomicUsize;
use crate::loom::sync::atomic::Ordering;
const CLOSED_BIT: usize = 1;
const STEP_SIZE: usize = 2;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub(super) struct Version(usize);
#[derive(Copy, Clone, Debug)]
pub(super) struct StateSnapshot(usize);
#[derive(Debug)]
pub(super) struct AtomicState(AtomicUsize);
impl Version {
pub(super) fn decrement(&mut self) {
self.0 = self.0.wrapping_sub(STEP_SIZE);
}
pub(super) const INITIAL: Self = Version(0);
}
impl StateSnapshot {
pub(super) fn version(self) -> Version {
Version(self.0 & !CLOSED_BIT)
}
pub(super) fn is_closed(self) -> bool {
(self.0 & CLOSED_BIT) == CLOSED_BIT
}
}
impl AtomicState {
pub(super) fn new() -> Self {
AtomicState(AtomicUsize::new(Version::INITIAL.0))
}
pub(super) fn load(&self) -> StateSnapshot {
StateSnapshot(self.0.load(Ordering::Acquire))
}
pub(super) fn increment_version_while_locked(&self) {
self.0.fetch_add(STEP_SIZE, Ordering::Release);
}
pub(super) fn set_closed(&self) {
self.0.fetch_or(CLOSED_BIT, Ordering::Release);
}
}
}
pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
let shared = Arc::new(Shared {
value: RwLock::new(init),
state: AtomicState::new(),
ref_count_rx: AtomicUsize::new(1),
notify_rx: big_notify::BigNotify::new(),
notify_tx: Notify::new(),
});
let tx = Sender {
shared: shared.clone(),
};
let rx = Receiver {
shared,
version: Version::INITIAL,
};
(tx, rx)
}
impl<T> Receiver<T> {
fn from_shared(version: Version, shared: Arc<Shared<T>>) -> Self {
shared.ref_count_rx.fetch_add(1, Relaxed);
Self { shared, version }
}
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
Ref { inner, has_changed }
}
pub fn borrow_and_update(&mut self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
self.version = new_version;
Ref { inner, has_changed }
}
pub fn has_changed(&self) -> Result<bool, error::RecvError> {
let state = self.shared.state.load();
if state.is_closed() {
return Err(error::RecvError(()));
}
let new_version = state.version();
Ok(self.version != new_version)
}
pub fn mark_changed(&mut self) {
self.version.decrement();
}
pub async fn changed(&mut self) -> Result<(), error::RecvError> {
changed_impl(&self.shared, &mut self.version).await
}
pub async fn wait_for(
&mut self,
mut f: impl FnMut(&T) -> bool,
) -> Result<Ref<'_, T>, error::RecvError> {
let mut closed = false;
loop {
{
let inner = self.shared.value.read().unwrap();
let new_version = self.shared.state.load().version();
let has_changed = self.version != new_version;
self.version = new_version;
if !closed || has_changed {
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| f(&inner)));
match result {
Ok(true) => {
return Ok(Ref { inner, has_changed });
}
Ok(false) => {
}
Err(panicked) => {
drop(inner);
panic::resume_unwind(panicked);
}
};
}
}
if closed {
return Err(error::RecvError(()));
}
closed = changed_impl(&self.shared, &mut self.version).await.is_err();
}
}
pub fn same_channel(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.shared, &other.shared)
}
cfg_process_driver! {
pub(crate) fn try_has_changed(&mut self) -> Option<Result<(), error::RecvError>> {
maybe_changed(&self.shared, &mut self.version)
}
}
}
fn maybe_changed<T>(
shared: &Shared<T>,
version: &mut Version,
) -> Option<Result<(), error::RecvError>> {
let state = shared.state.load();
let new_version = state.version();
if *version != new_version {
*version = new_version;
return Some(Ok(()));
}
if state.is_closed() {
return Some(Err(error::RecvError(())));
}
None
}
async fn changed_impl<T>(
shared: &Shared<T>,
version: &mut Version,
) -> Result<(), error::RecvError> {
crate::trace::async_trace_leaf().await;
loop {
let notified = shared.notify_rx.notified();
if let Some(ret) = maybe_changed(shared, version) {
return ret;
}
notified.await;
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let version = self.version;
let shared = self.shared.clone();
Self::from_shared(version, shared)
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
if 1 == self.shared.ref_count_rx.fetch_sub(1, Relaxed) {
self.shared.notify_tx.notify_waiters();
}
}
}
impl<T> Sender<T> {
pub fn new(init: T) -> Self {
let (tx, _) = channel(init);
tx
}
pub fn send(&self, value: T) -> Result<(), error::SendError<T>> {
if 0 == self.receiver_count() {
return Err(error::SendError(value));
}
self.send_replace(value);
Ok(())
}
pub fn send_modify<F>(&self, modify: F)
where
F: FnOnce(&mut T),
{
self.send_if_modified(|value| {
modify(value);
true
});
}
pub fn send_if_modified<F>(&self, modify: F) -> bool
where
F: FnOnce(&mut T) -> bool,
{
{
let mut lock = self.shared.value.write().unwrap();
let result = panic::catch_unwind(panic::AssertUnwindSafe(|| modify(&mut lock)));
match result {
Ok(modified) => {
if !modified {
return false;
}
}
Err(panicked) => {
drop(lock);
panic::resume_unwind(panicked);
}
};
self.shared.state.increment_version_while_locked();
drop(lock);
}
self.shared.notify_rx.notify_waiters();
true
}
pub fn send_replace(&self, mut value: T) -> T {
self.send_modify(|old| mem::swap(old, &mut value));
value
}
pub fn borrow(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
let has_changed = false;
Ref { inner, has_changed }
}
pub fn is_closed(&self) -> bool {
self.receiver_count() == 0
}
pub async fn closed(&self) {
crate::trace::async_trace_leaf().await;
while self.receiver_count() > 0 {
let notified = self.shared.notify_tx.notified();
if self.receiver_count() == 0 {
return;
}
notified.await;
}
}
pub fn subscribe(&self) -> Receiver<T> {
let shared = self.shared.clone();
let version = shared.state.load().version();
Receiver::from_shared(version, shared)
}
pub fn receiver_count(&self) -> usize {
self.shared.ref_count_rx.load(Relaxed)
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
self.shared.state.set_closed();
self.shared.notify_rx.notify_waiters();
}
}
impl<T> ops::Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
#[cfg(all(test, loom))]
mod tests {
use futures::future::FutureExt;
use loom::thread;
#[test]
fn watch_spurious_wakeup() {
loom::model(|| {
let (send, mut recv) = crate::sync::watch::channel(0i32);
send.send(1).unwrap();
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
send
});
recv.changed().now_or_never();
let send = send_thread.join().unwrap();
let recv_thread = thread::spawn(move || {
recv.changed().now_or_never();
recv.changed().now_or_never();
recv
});
send.send(3).unwrap();
let mut recv = recv_thread.join().unwrap();
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
});
recv.changed().now_or_never();
send_thread.join().unwrap();
});
}
#[test]
fn watch_borrow() {
loom::model(|| {
let (send, mut recv) = crate::sync::watch::channel(0i32);
assert!(send.borrow().eq(&0));
assert!(recv.borrow().eq(&0));
send.send(1).unwrap();
assert!(send.borrow().eq(&1));
let send_thread = thread::spawn(move || {
send.send(2).unwrap();
send
});
recv.changed().now_or_never();
let send = send_thread.join().unwrap();
let recv_thread = thread::spawn(move || {
recv.changed().now_or_never();
recv.changed().now_or_never();
recv
});
send.send(3).unwrap();
let recv = recv_thread.join().unwrap();
assert!(recv.borrow().eq(&3));
assert!(send.borrow().eq(&3));
send.send(2).unwrap();
thread::spawn(move || {
assert!(recv.borrow().eq(&2));
});
assert!(send.borrow().eq(&2));
});
}
}