use crate::task::AtomicWaker;
use core::task::Poll::{Pending, Ready};
use core::task::{Context, Poll};
use fnv::FnvHashMap;
use futures_util::future::poll_fn;
use std::ops;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::SeqCst;
use std::sync::{Arc, Mutex, RwLock, RwLockReadGuard, Weak};
#[cfg(feature = "async-traits")]
use futures_core::ready;
#[cfg(feature = "async-traits")]
use futures_util::pin_mut;
#[cfg(feature = "async-traits")]
use std::pin::Pin;
#[derive(Debug)]
pub struct Receiver<T> {
shared: Arc<Shared<T>>,
inner: Arc<WatchInner>,
id: u64,
ver: usize,
}
#[derive(Debug)]
pub struct Sender<T> {
shared: Weak<Shared<T>>,
}
#[derive(Debug)]
pub struct Ref<'a, T> {
inner: RwLockReadGuard<'a, T>,
}
pub mod error {
use std::fmt;
#[derive(Debug)]
pub struct SendError<T> {
pub(crate) inner: T,
}
impl<T: fmt::Debug> fmt::Display for SendError<T> {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(fmt, "channel closed")
}
}
impl<T: fmt::Debug> ::std::error::Error for SendError<T> {}
}
#[derive(Debug)]
struct Shared<T> {
value: RwLock<T>,
version: AtomicUsize,
watchers: Mutex<Watchers>,
cancel: AtomicWaker,
}
#[derive(Debug)]
struct Watchers {
next_id: u64,
watchers: FnvHashMap<u64, Arc<WatchInner>>,
}
#[derive(Debug)]
struct WatchInner {
waker: AtomicWaker,
}
const CLOSED: usize = 1;
pub fn channel<T>(init: T) -> (Sender<T>, Receiver<T>) {
const INIT_ID: u64 = 0;
let inner = Arc::new(WatchInner::new());
let mut watchers = FnvHashMap::with_capacity_and_hasher(0, Default::default());
watchers.insert(INIT_ID, inner.clone());
let shared = Arc::new(Shared {
value: RwLock::new(init),
version: AtomicUsize::new(2),
watchers: Mutex::new(Watchers {
next_id: INIT_ID + 1,
watchers,
}),
cancel: AtomicWaker::new(),
});
let tx = Sender {
shared: Arc::downgrade(&shared),
};
let rx = Receiver {
shared,
inner,
id: INIT_ID,
ver: 0,
};
(tx, rx)
}
impl<T> Receiver<T> {
pub fn get_ref(&self) -> Ref<'_, T> {
let inner = self.shared.value.read().unwrap();
Ref { inner }
}
pub async fn recv_ref(&mut self) -> Option<Ref<'_, T>> {
let shared = &self.shared;
let inner = &self.inner;
let version = self.ver;
match poll_fn(|cx| poll_lock(cx, shared, inner, version)).await {
Some((lock, version)) => {
self.ver = version;
Some(lock)
}
None => None,
}
}
}
fn poll_lock<'a, T>(
cx: &mut Context<'_>,
shared: &'a Arc<Shared<T>>,
inner: &Arc<WatchInner>,
ver: usize,
) -> Poll<Option<(Ref<'a, T>, usize)>> {
inner.waker.register_by_ref(cx.waker());
let state = shared.version.load(SeqCst);
let version = state & !CLOSED;
if version != ver {
let inner = shared.value.read().unwrap();
return Ready(Some((Ref { inner }, version)));
}
if CLOSED == state & CLOSED {
return Ready(None);
}
Pending
}
impl<T: Clone> Receiver<T> {
#[allow(clippy::map_clone)] pub async fn recv(&mut self) -> Option<T> {
self.recv_ref().await.map(|v_ref| v_ref.clone())
}
}
#[cfg(feature = "async-traits")]
impl<T: Clone> futures_core::Stream for Receiver<T> {
type Item = T;
#[allow(clippy::map_clone)] fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T>> {
use std::future::Future;
let fut = self.get_mut().recv();
pin_mut!(fut);
let item = ready!(fut.poll(cx));
Ready(item.map(|v_ref| v_ref.clone()))
}
}
impl<T> Clone for Receiver<T> {
fn clone(&self) -> Self {
let inner = Arc::new(WatchInner::new());
let shared = self.shared.clone();
let id = {
let mut watchers = shared.watchers.lock().unwrap();
let id = watchers.next_id;
watchers.next_id += 1;
watchers.watchers.insert(id, inner.clone());
id
};
let ver = self.ver;
Receiver {
shared,
inner,
id,
ver,
}
}
}
impl<T> Drop for Receiver<T> {
fn drop(&mut self) {
let mut watchers = self.shared.watchers.lock().unwrap();
watchers.watchers.remove(&self.id);
}
}
impl WatchInner {
fn new() -> Self {
WatchInner {
waker: AtomicWaker::new(),
}
}
}
impl<T> Sender<T> {
pub fn broadcast(&self, value: T) -> Result<(), error::SendError<T>> {
let shared = match self.shared.upgrade() {
Some(shared) => shared,
None => return Err(error::SendError { inner: value }),
};
{
let mut lock = shared.value.write().unwrap();
*lock = value;
}
shared.version.fetch_add(2, SeqCst);
notify_all(&*shared);
Ok(())
}
pub async fn closed(&mut self) {
poll_fn(|cx| self.poll_close(cx)).await
}
fn poll_close(&mut self, cx: &mut Context<'_>) -> Poll<()> {
match self.shared.upgrade() {
Some(shared) => {
shared.cancel.register_by_ref(cx.waker());
Pending
}
None => Ready(()),
}
}
}
#[cfg(feature = "async-traits")]
impl<T> futures_sink::Sink<T> for Sender<T> {
type Error = error::SendError<T>;
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ready(Ok(()))
}
fn start_send(self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> {
self.as_ref().get_ref().broadcast(item)?;
Ok(())
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Ready(Ok(()))
}
}
fn notify_all<T>(shared: &Shared<T>) {
let watchers = shared.watchers.lock().unwrap();
for watcher in watchers.watchers.values() {
watcher.waker.wake();
}
}
impl<T> Drop for Sender<T> {
fn drop(&mut self) {
if let Some(shared) = self.shared.upgrade() {
shared.version.fetch_or(CLOSED, SeqCst);
notify_all(&*shared);
}
}
}
impl<T> ops::Deref for Ref<'_, T> {
type Target = T;
fn deref(&self) -> &T {
self.inner.deref()
}
}
impl<T> Drop for Shared<T> {
fn drop(&mut self) {
self.cancel.wake();
}
}