use super::SendSyncMessage;
use std::{
fmt,
ops::{Deref, DerefMut},
sync::atomic::{AtomicUsize, Ordering},
};
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use static_assertions::{assert_impl_all, assert_not_impl_all};
use crate::{
sink::{PollSend, Sink},
stream::{PollRecv, Stream},
sync::{shared, ReceiverShared, SenderShared},
};
pub fn channel<T: Clone + Default>() -> (Sender<T>, Receiver<T>) {
channel_with(T::default())
}
pub fn channel_with<T: Clone>(value: T) -> (Sender<T>, Receiver<T>) {
#[cfg(feature = "debug")]
log::error!("Creating watch channel");
let (tx_shared, rx_shared) = shared(StateExtension::new(value));
let sender = Sender { shared: tx_shared };
let receiver = Receiver {
shared: rx_shared,
generation: AtomicUsize::new(0),
};
(sender, receiver)
}
pub fn channel_with_option<T: Clone>() -> (Sender<Option<T>>, Receiver<Option<T>>) {
channel::<Option<T>>()
}
pub struct Sender<T> {
pub(in crate::channels::watch) shared: SenderShared<StateExtension<T>>,
}
assert_impl_all!(Sender<SendSyncMessage>: Send, Sync, fmt::Debug);
assert_not_impl_all!(Sender<SendSyncMessage>: Clone);
impl<T> Sink for Sender<T> {
type Item = T;
fn poll_send(
self: std::pin::Pin<&mut Self>,
_cx: &mut crate::Context<'_>,
value: Self::Item,
) -> PollSend<Self::Item> {
if self.shared.is_closed() {
return PollSend::Rejected(value);
}
self.shared.extension().push(value);
self.shared.notify_receivers();
PollSend::Ready
}
}
#[allow(clippy::needless_lifetimes)]
impl<T> Sender<T> {
pub fn borrow_mut<'s>(&'s mut self) -> RefMut<'s, T> {
let extension = self.shared.extension();
let lock = extension.value.write();
RefMut {
lock,
shared: self.shared.clone(),
}
}
pub fn subscribe(&mut self) -> Receiver<T> {
Receiver {
shared: self.shared.clone_receiver(),
generation: AtomicUsize::new(0),
}
}
pub fn borrow<'s>(&'s mut self) -> Ref<'s, T> {
let extension = self.shared.extension();
let lock = extension.value.read();
Ref { lock }
}
}
impl<T> fmt::Debug for Sender<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Sender").finish()
}
}
#[cfg(feature = "futures-traits")]
mod impl_futures {
use std::task::Poll;
use crate::sink::SendError;
impl<T> futures::sink::Sink<T> for super::Sender<T> {
type Error = SendError<T>;
fn poll_ready(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn start_send(self: std::pin::Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
if self.shared.is_closed() {
return Err(SendError(item));
}
self.shared.extension().push(item);
self.shared.notify_receivers();
Ok(())
}
fn poll_flush(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(
self: std::pin::Pin<&mut Self>,
_cx: &mut std::task::Context<'_>,
) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
}
pub struct Receiver<T> {
pub(in crate::channels::watch) shared: ReceiverShared<StateExtension<T>>,
pub(in crate::channels::watch) generation: AtomicUsize,
}
assert_impl_all!(Receiver<SendSyncMessage>: Clone, Send, Sync, fmt::Debug);
impl<T> Stream for Receiver<T>
where
T: Clone,
{
type Item = T;
fn poll_recv(
self: std::pin::Pin<&mut Self>,
cx: &mut crate::Context<'_>,
) -> PollRecv<Self::Item> {
loop {
let guard = self.shared.send_guard();
match self.try_recv_internal() {
TryRecv::Pending => {
if self.shared.is_closed() {
return PollRecv::Closed;
}
self.shared.subscribe_send(cx);
if guard.is_expired() {
continue;
}
return PollRecv::Pending;
}
TryRecv::Ready(v) => return PollRecv::Ready(v),
}
}
}
}
impl<T> Receiver<T>
where
T: Clone,
{
fn try_recv_internal(&self) -> TryRecv<T> {
let state = self.shared.extension();
if self.generation.load(std::sync::atomic::Ordering::SeqCst)
> state.generation(Ordering::SeqCst)
{
return TryRecv::Pending;
}
let borrow = self.shared.extension().value.read();
let stored_generation = self.shared.extension().generation(Ordering::SeqCst);
self.generation
.store(stored_generation + 1, Ordering::Release);
TryRecv::Ready(borrow.clone())
}
}
enum TryRecv<T> {
Pending,
Ready(T),
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
Self {
shared: self.shared.clone(),
generation: AtomicUsize::new(0),
}
}
}
impl<T> fmt::Debug for Receiver<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Receiver").finish()
}
}
pub struct RefMut<'t, T> {
lock: RwLockWriteGuard<'t, T>,
shared: SenderShared<StateExtension<T>>,
}
impl<'t, T> DerefMut for RefMut<'t, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut *self.lock
}
}
impl<'t, T> Deref for RefMut<'t, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.lock
}
}
impl<'t, T> Drop for RefMut<'t, T> {
fn drop(&mut self) {
self.shared.extension().increment();
self.shared.notify_receivers();
}
}
pub struct Ref<'t, T> {
lock: RwLockReadGuard<'t, T>,
}
impl<'t, T> Deref for Ref<'t, T> {
type Target = T;
fn deref(&self) -> &Self::Target {
&*self.lock
}
}
impl<T> Receiver<T> {
pub fn borrow(&self) -> Ref<'_, T> {
let lock = self.shared.extension().value.read();
Ref { lock }
}
}
struct StateExtension<T> {
generation: AtomicUsize,
value: RwLock<T>,
}
impl<T> StateExtension<T> {
pub fn new(value: T) -> Self {
Self {
generation: AtomicUsize::new(0),
value: RwLock::new(value),
}
}
pub fn push(&self, value: T) {
let mut lock = self.value.write();
*lock = value;
self.generation.fetch_add(1, Ordering::SeqCst);
drop(lock);
}
pub fn increment(&self) {
self.generation.fetch_add(1, Ordering::SeqCst);
}
pub fn generation(&self, ordering: Ordering) -> usize {
self.generation.load(ordering)
}
}
#[cfg(test)]
mod tests {
use std::{pin::Pin, task::Context};
use super::channel;
use crate::{
sink::{PollSend, Sink},
stream::{PollRecv, Stream},
test::{noop_context, panic_context},
};
use futures_test::task::new_count_waker;
#[derive(Clone, Debug, PartialEq, Eq)]
struct State(usize);
impl Default for State {
fn default() -> Self {
State(0)
}
}
#[test]
fn send_accepted() {
let mut cx = noop_context();
let (mut tx, _rx) = channel();
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(2))
);
}
#[test]
fn send_recv() {
let mut cx = noop_context();
let (mut tx, mut rx) = channel();
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
assert_eq!(
PollRecv::Ready(State(1)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
assert_eq!(PollRecv::Pending, Pin::new(&mut rx).poll_recv(&mut cx));
}
#[test]
fn recv_default() {
let mut cx = panic_context();
let (_tx, mut rx) = channel();
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx).poll_recv(&mut noop_context())
);
}
#[test]
fn borrow_default() {
let (_tx, rx) = channel();
assert_eq!(&State(0), &*rx.borrow());
}
#[test]
fn borrow_sent() {
let mut cx = panic_context();
let (mut tx, rx) = channel();
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
assert_eq!(&State(1), &*rx.borrow());
}
#[test]
fn borrow_mut_notifies() {
let mut cx = noop_context();
let (mut tx, mut rx) = channel();
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
let (w1, w1_count) = new_count_waker();
let w1_context = Context::from_waker(&w1);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx).poll_recv(&mut w1_context.into())
);
*tx.borrow_mut() = State(1);
assert_eq!(1, w1_count.get());
assert_eq!(&State(1), &*rx.borrow());
assert_eq!(
PollRecv::Ready(State(1)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
}
#[test]
fn sender_disconnect() {
let mut cx = noop_context();
let (mut tx, mut rx) = channel();
let mut rx2 = rx.clone();
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
drop(tx);
assert_eq!(
PollRecv::Ready(State(1)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
assert_eq!(PollRecv::Closed, Pin::new(&mut rx).poll_recv(&mut cx));
assert_eq!(
PollRecv::Ready(State(1)),
Pin::new(&mut rx2).poll_recv(&mut cx)
);
assert_eq!(PollRecv::Closed, Pin::new(&mut rx2).poll_recv(&mut cx));
}
#[test]
fn receiver_disconnect() {
let mut cx = noop_context();
let (mut tx, rx) = channel();
drop(rx);
assert_eq!(
PollSend::Rejected(State(1)),
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
}
#[test]
fn send_then_receiver_disconnect() {
let mut cx = noop_context();
let (mut tx, rx) = channel();
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
drop(rx);
assert_eq!(
PollSend::Rejected(State(2)),
Pin::new(&mut tx).poll_send(&mut cx, State(2))
);
}
#[test]
fn wake_receiver() {
let mut cx = panic_context();
let (mut tx, mut rx) = channel();
let (w1, w1_count) = new_count_waker();
let w1_context = Context::from_waker(&w1);
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx).poll_recv(&mut w1_context.into())
);
assert_eq!(0, w1_count.get());
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(1))
);
assert_eq!(1, w1_count.get());
assert_eq!(
PollSend::Ready,
Pin::new(&mut tx).poll_send(&mut cx, State(2))
);
assert_eq!(1, w1_count.get());
}
#[test]
fn wake_receiver_on_disconnect() {
let (tx, mut rx) = channel::<State>();
let (w1, w1_count) = new_count_waker();
let w1_context = Context::from_waker(&w1);
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx).poll_recv(&mut panic_context())
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx).poll_recv(&mut w1_context.into())
);
assert_eq!(0, w1_count.get());
drop(tx);
assert_eq!(1, w1_count.get());
}
#[async_std::test]
async fn subscribe_default() {
let mut cx = panic_context();
let (mut tx, _rx) = channel();
let mut rx2 = tx.subscribe();
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx2).poll_recv(&mut cx)
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx2).poll_recv(&mut noop_context())
);
}
#[async_std::test]
async fn subscribe_both_receive_value() {
let mut cx = panic_context();
let (mut tx, mut rx) = channel();
let mut rx2 = tx.subscribe();
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx).poll_recv(&mut cx)
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx).poll_recv(&mut noop_context())
);
assert_eq!(
PollRecv::Ready(State(0)),
Pin::new(&mut rx2).poll_recv(&mut cx)
);
assert_eq!(
PollRecv::Pending,
Pin::new(&mut rx2).poll_recv(&mut noop_context())
);
}
}
#[cfg(test)]
mod tokio_tests {
use tokio::{spawn, time::timeout};
use crate::{
sink::Sink,
stream::Stream,
test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
};
#[tokio::test]
async fn simple() {
let (mut tx, mut rx) = super::channel();
tokio::task::spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
tx.send(message).await.expect("send failed");
}
});
timeout(TEST_TIMEOUT, async move {
let mut channel = Channel::new(0).allow_skips();
while let Some(message) = rx.recv().await {
channel.assert_message(&message);
}
})
.await
.expect("test timeout");
}
#[tokio::test]
async fn send_borrow_mut() {
let (mut tx, mut rx) = super::channel();
tokio::task::spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
*tx.borrow_mut() = message;
}
});
timeout(TEST_TIMEOUT, async move {
let mut channel = Channel::new(0).allow_skips();
while let Some(message) = rx.recv().await {
channel.assert_message(&message);
}
})
.await
.expect("test timeout");
}
#[tokio::test]
async fn multi_receiver() {
let (mut tx, rx) = super::channel();
tokio::task::spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
tx.send(message).await.expect("send failed");
}
});
let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
let mut rx2 = rx.clone();
let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
spawn(async move {
while let Some(message) = rx2.recv().await {
channels.assert_message(&message);
}
})
});
timeout(TEST_TIMEOUT, async move {
for handle in handles {
handle.await.expect("join failed");
}
})
.await
.expect("test timeout");
}
}
#[cfg(test)]
mod async_std_tests {
use async_std::{future::timeout, task::spawn};
use crate::{
sink::Sink,
stream::Stream,
test::{Channel, Channels, Message, CHANNEL_TEST_RECEIVERS, TEST_TIMEOUT},
};
#[async_std::test]
async fn simple() {
let (mut tx, mut rx) = super::channel();
spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
tx.send(message).await.expect("send failed");
}
});
timeout(TEST_TIMEOUT, async move {
let mut channel = Channel::new(0).allow_skips();
while let Some(message) = rx.recv().await {
channel.assert_message(&message);
}
})
.await
.expect("test timeout");
}
#[async_std::test]
async fn send_borrow_mut() {
let (mut tx, mut rx) = super::channel();
spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
*tx.borrow_mut() = message;
}
});
timeout(TEST_TIMEOUT, async move {
let mut channel = Channel::new(0).allow_skips();
while let Some(message) = rx.recv().await {
channel.assert_message(&message);
}
})
.await
.expect("test timeout");
}
#[tokio::test]
async fn multi_receiver() {
let (mut tx, rx) = super::channel();
tokio::task::spawn(async move {
let mut iter = Message::new_iter(0);
iter.next();
for message in iter {
tx.send(message).await.expect("send failed");
}
});
let handles = (0..CHANNEL_TEST_RECEIVERS).map(move |_| {
let mut rx2 = rx.clone();
let mut channels = Channels::new(CHANNEL_TEST_RECEIVERS).allow_skips();
spawn(async move {
while let Some(message) = rx2.recv().await {
channels.assert_message(&message);
}
})
});
timeout(TEST_TIMEOUT, async move {
for handle in handles {
handle.await;
}
})
.await
.expect("test timeout");
}
}