use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use futures::channel::mpsc::{unbounded, UnboundedReceiver, UnboundedSender};
use futures::channel::oneshot;
use futures::future::{BoxFuture, Shared};
use futures::lock::Mutex as AsyncMutex;
use futures::{FutureExt, StreamExt};
use parking_lot::Mutex;
use crate::error::ChannelError;
use crate::message::Message;
pub trait CommandChannel: Send + Sync {
fn id(&self) -> &str;
fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>>;
fn close(&self) -> BoxFuture<'_, ()>;
fn send(&self, msg: Message) -> Result<(), ChannelError>;
fn recv(&self) -> BoxFuture<'_, Option<Message>>;
}
pub struct InMemoryChannel {
id: String,
outbound: UnboundedSender<Message>,
inbound: AsyncMutex<Option<UnboundedReceiver<Message>>>,
close_tx: Mutex<Option<oneshot::Sender<()>>>,
close_rx: Shared<oneshot::Receiver<()>>,
closed: AtomicBool,
}
impl InMemoryChannel {
pub fn pair(id_a: impl Into<String>, id_b: impl Into<String>) -> (Arc<Self>, Arc<Self>) {
let (tx_a_to_b, rx_b) = unbounded();
let (tx_b_to_a, rx_a) = unbounded();
let (close_a_tx, close_a_rx) = oneshot::channel::<()>();
let (close_b_tx, close_b_rx) = oneshot::channel::<()>();
let a = Arc::new(Self {
id: id_a.into(),
outbound: tx_a_to_b,
inbound: AsyncMutex::new(Some(rx_a)),
close_tx: Mutex::new(Some(close_a_tx)),
close_rx: close_a_rx.shared(),
closed: AtomicBool::new(false),
});
let b = Arc::new(Self {
id: id_b.into(),
outbound: tx_b_to_a,
inbound: AsyncMutex::new(Some(rx_b)),
close_tx: Mutex::new(Some(close_b_tx)),
close_rx: close_b_rx.shared(),
closed: AtomicBool::new(false),
});
(a, b)
}
}
impl CommandChannel for InMemoryChannel {
fn id(&self) -> &str {
&self.id
}
fn start(&self) -> BoxFuture<'_, Result<(), ChannelError>> {
Box::pin(async { Ok(()) })
}
fn close(&self) -> BoxFuture<'_, ()> {
Box::pin(async move {
self.closed.store(true, Ordering::SeqCst);
self.outbound.close_channel();
if let Some(tx) = self.close_tx.lock().take() {
let _ = tx.send(());
}
})
}
fn send(&self, msg: Message) -> Result<(), ChannelError> {
if self.closed.load(Ordering::SeqCst) {
return Err(ChannelError::Closed);
}
self.outbound
.unbounded_send(msg)
.map_err(|e| ChannelError::Send(e.to_string()))
}
fn recv(&self) -> BoxFuture<'_, Option<Message>> {
Box::pin(async move {
if self.closed.load(Ordering::SeqCst) {
return None;
}
let mut guard = self.inbound.lock().await;
let rx = guard.as_mut()?;
let close_fut = self.close_rx.clone();
futures::select_biased! {
msg = rx.next().fuse() => msg,
_ = close_fut.fuse() => None,
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::message::MessageId;
use futures::executor::block_on;
use futures::future::join;
fn ping(id: MessageId) -> Message {
Message::ListCommandsRequest { id, meta: None }
}
#[test]
fn pair_sends_in_both_directions() {
let (a, b) = InMemoryChannel::pair("alice", "bob");
block_on(async {
assert_eq!(a.id(), "alice");
assert_eq!(b.id(), "bob");
let m1 = ping(MessageId::new_v4());
let m2 = ping(MessageId::new_v4());
a.send(m1.clone()).unwrap();
b.send(m2.clone()).unwrap();
assert_eq!(b.recv().await, Some(m1));
assert_eq!(a.recv().await, Some(m2));
});
}
#[test]
fn recv_awaits_future_send() {
let (a, b) = InMemoryChannel::pair("alice", "bob");
block_on(async {
let msg = ping(MessageId::new_v4());
let expected = msg.clone();
let (_, recvd) = join(
async {
a.send(msg).unwrap();
},
b.recv(),
)
.await;
assert_eq!(recvd, Some(expected));
});
}
#[test]
fn close_stops_recv_on_both_sides() {
let (a, b) = InMemoryChannel::pair("alice", "bob");
block_on(async {
a.close().await;
assert!(a.recv().await.is_none());
assert!(b.recv().await.is_none());
});
}
#[test]
fn send_after_close_is_error() {
let (a, _b) = InMemoryChannel::pair("alice", "bob");
block_on(async {
a.close().await;
});
let err = a.send(ping(MessageId::new_v4())).unwrap_err();
assert!(matches!(err, ChannelError::Closed));
}
#[test]
fn queued_messages_drain_after_peer_close() {
let (a, b) = InMemoryChannel::pair("alice", "bob");
block_on(async {
let m = ping(MessageId::new_v4());
let expected = m.clone();
a.send(m).unwrap();
a.close().await;
assert_eq!(b.recv().await, Some(expected));
assert!(b.recv().await.is_none());
});
}
}