use futures::task::AtomicWaker;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
use std::sync::atomic::Ordering;
use std::task::Context;
use std::task::Poll;
use crate::runtime::BlockMessage;
use crate::runtime::PortId;
use crate::runtime::channel::mpsc;
#[derive(Debug)]
struct BlockNotifyState {
pending: AtomicBool,
waker: AtomicWaker,
}
impl Default for BlockNotifyState {
fn default() -> Self {
Self {
pending: AtomicBool::new(false),
waker: AtomicWaker::new(),
}
}
}
#[derive(Clone, Debug)]
pub struct BlockNotifier {
state: Arc<BlockNotifyState>,
}
impl BlockNotifier {
pub fn new() -> Self {
Self {
state: Arc::new(BlockNotifyState::default()),
}
}
pub fn notify(&self) {
if !self.state.pending.swap(true, Ordering::AcqRel) {
self.state.waker.wake();
}
}
pub fn take_pending(&self) -> bool {
self.state.pending.swap(false, Ordering::AcqRel)
}
pub fn notified(&self) -> Notified {
Notified {
state: self.state.clone(),
}
}
}
impl Default for BlockNotifier {
fn default() -> Self {
Self::new()
}
}
pub struct Notified {
state: Arc<BlockNotifyState>,
}
impl Future for Notified {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.state.pending.swap(false, Ordering::AcqRel) {
return Poll::Ready(());
}
self.state.waker.register(cx.waker());
if self.state.pending.swap(false, Ordering::AcqRel) {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
#[derive(Clone, Debug)]
pub struct BlockInbox {
control: mpsc::Sender<BlockMessage>,
notifier: BlockNotifier,
}
impl BlockInbox {
pub(crate) fn new(control: mpsc::Sender<BlockMessage>, notifier: BlockNotifier) -> Self {
Self { control, notifier }
}
pub fn disconnected() -> Self {
let (control, _) = mpsc::channel::<BlockMessage>(0);
Self::new(control, BlockNotifier::new())
}
pub fn notifier(&self) -> BlockNotifier {
self.notifier.clone()
}
pub fn notify(&self) {
self.notifier.notify();
}
pub fn is_closed(&self) -> bool {
self.control.is_closed()
}
pub async fn stream_input_done(&self, input_id: PortId) -> Result<(), crate::runtime::Error> {
self.send(BlockMessage::StreamInputDone { input_id }).await
}
pub async fn stream_output_done(&self, output_id: PortId) -> Result<(), crate::runtime::Error> {
self.send(BlockMessage::StreamOutputDone { output_id })
.await
}
pub(crate) async fn send(&self, msg: BlockMessage) -> Result<(), crate::runtime::Error> {
self.control.send(msg).await?;
self.notifier.notify();
Ok(())
}
}
impl Default for BlockInbox {
fn default() -> Self {
Self::disconnected()
}
}
#[derive(Debug)]
pub(crate) struct BlockInboxReader {
control: mpsc::Receiver<BlockMessage>,
notifier: BlockNotifier,
}
impl BlockInboxReader {
pub fn new(control: mpsc::Receiver<BlockMessage>, notifier: BlockNotifier) -> Self {
Self { control, notifier }
}
pub fn try_recv(&mut self) -> Option<BlockMessage> {
self.control.try_recv().ok()
}
pub async fn recv(&mut self) -> Option<BlockMessage> {
self.control.recv().await
}
pub fn take_pending(&self) -> bool {
self.notifier.take_pending()
}
#[allow(dead_code)]
pub fn notified(&self) -> Notified {
self.notifier.notified()
}
}
pub(crate) fn channel(size: usize) -> (BlockInbox, BlockInboxReader) {
let (control, receiver) = mpsc::channel::<BlockMessage>(size);
let notifier = BlockNotifier::new();
(
BlockInbox::new(control, notifier.clone()),
BlockInboxReader::new(receiver, notifier),
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::runtime::BlockMessage;
use futures::executor::block_on;
#[test]
fn coalesces_multiple_notifies() {
let n = BlockNotifier::new();
n.notify();
n.notify();
n.notify();
assert!(n.take_pending());
assert!(!n.take_pending());
}
#[test]
fn notified_completes_after_notify() {
let n = BlockNotifier::new();
n.notify();
block_on(n.notified());
assert!(!n.take_pending());
}
#[test]
fn send_enqueues_and_wakes_reader() {
let (tx, mut rx) = channel(1);
block_on(tx.send(BlockMessage::Initialize)).unwrap();
assert!(rx.take_pending());
assert!(matches!(rx.try_recv(), Some(BlockMessage::Initialize)));
}
#[test]
fn recv_waits_for_message() {
let (tx, mut rx) = channel(1);
block_on(tx.send(BlockMessage::Initialize)).unwrap();
assert!(matches!(
block_on(rx.recv()),
Some(BlockMessage::Initialize)
));
}
#[test]
fn notify_wakes_without_message() {
let (tx, mut rx) = channel(1);
tx.notify();
assert!(rx.take_pending());
assert!(rx.try_recv().is_none());
}
#[test]
fn multiple_sends_coalesce_but_keep_messages() {
let (tx, mut rx) = channel(4);
block_on(tx.send(BlockMessage::Initialize)).unwrap();
block_on(tx.send(BlockMessage::Terminate)).unwrap();
assert!(rx.take_pending());
assert!(!rx.take_pending());
assert!(matches!(rx.try_recv(), Some(BlockMessage::Initialize)));
assert!(matches!(rx.try_recv(), Some(BlockMessage::Terminate)));
assert!(rx.try_recv().is_none());
}
}