pub mod errors;
use cord_message::{errors::Error as MessageError, Codec, Message, Pattern};
use errors::{Error, ErrorKind, Result};
use futures::{
future::{self, try_select},
stream::SplitSink,
Sink, SinkExt, Stream, StreamExt, TryStreamExt,
};
use futures_locks_pre::Mutex;
use retain_mut::RetainMut;
use tokio::{net::TcpStream, sync::mpsc, sync::oneshot};
use tokio_util::codec::Framed;
use std::{
collections::HashMap,
convert::Into,
net::SocketAddr,
ops::Drop,
pin::Pin,
result,
sync::Arc,
task::{Context, Poll},
};
pub type Client = ClientConn<SplitSink<Framed<TcpStream, Codec>, Message>>;
pub struct ClientConn<S> {
sink: S,
inner: Arc<Inner>,
}
pub struct Subscriber {
receiver: mpsc::Receiver<Message>,
_inner: Arc<Inner>,
}
struct Inner {
receivers: Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>,
detonator: Option<oneshot::Sender<()>>,
}
impl<S> ClientConn<S>
where
S: Sink<Message, Error = MessageError> + Unpin,
{
pub async fn connect(addr: SocketAddr) -> Result<Client> {
let (det_tx, det_rx) = oneshot::channel();
let sock = TcpStream::connect(&addr).await?;
let framed = Framed::new(sock, Codec::default());
let (sink, stream) = framed.split();
let receivers = Mutex::new(HashMap::new());
let receivers_c = receivers.clone();
let router = Box::pin(
stream
.map_err(|e| Error::from_kind(ErrorKind::Message(e)))
.try_fold(receivers_c, |recv, message| async move {
route(&recv, message).await;
Ok(recv)
}),
);
tokio::spawn(try_select(router, det_rx));
Ok(ClientConn {
sink,
inner: Arc::new(Inner {
receivers,
detonator: Some(det_tx),
}),
})
}
pub async fn provide(&mut self, namespace: Pattern) -> Result<()> {
self.sink
.send(Message::Provide(namespace))
.await
.map_err(|e| ErrorKind::Message(e).into())
}
pub async fn revoke(&mut self, namespace: Pattern) -> Result<()> {
self.sink
.send(Message::Revoke(namespace))
.await
.map_err(|e| ErrorKind::Message(e).into())
}
pub async fn subscribe(&mut self, namespace: Pattern) -> Result<Subscriber> {
let namespace_c = namespace.clone();
self.sink.send(Message::Subscribe(namespace)).await?;
let (tx, rx) = mpsc::channel(10);
self.inner
.receivers
.with(move |mut guard| {
(*guard)
.entry(namespace_c)
.or_insert_with(Vec::new)
.push(tx);
future::ready(())
})
.await;
Ok(Subscriber {
receiver: rx,
_inner: self.inner.clone(),
})
}
pub async fn unsubscribe(&mut self, namespace: Pattern) -> Result<()> {
let namespace_c = namespace.clone();
self.sink.send(Message::Unsubscribe(namespace)).await?;
self.inner
.receivers
.with(move |mut guard| {
(*guard).remove(&namespace_c);
future::ready(())
})
.await;
Ok(())
}
pub async fn event<Str: Into<String>>(&mut self, namespace: Pattern, data: Str) -> Result<()> {
self.sink
.send(Message::Event(namespace, data.into()))
.await
.map_err(|e| ErrorKind::Message(e).into())
}
}
impl<E, S, T> Sink<T> for ClientConn<S>
where
S: Sink<T, Error = E>,
E: Into<Error>,
{
type Error = Error;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
.poll_ready(cx)
.map_err(|e| e.into())
}
fn start_send(self: Pin<&mut Self>, item: T) -> result::Result<(), Self::Error> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
.start_send(item)
.map_err(|e| e.into())
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
.poll_flush(cx)
.map_err(|e| e.into())
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<result::Result<(), Self::Error>> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.sink) }
.poll_close(cx)
.map_err(|e| e.into())
}
}
impl Stream for Subscriber {
type Item = (Pattern, String);
fn poll_next(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
unsafe { Pin::map_unchecked_mut(self, |x| &mut x.receiver) }
.poll_next(cx)
.map(|opt_msg| match opt_msg {
Some(Message::Event(pattern, data)) => Some((pattern, data)),
None => None,
_ => unreachable!(),
})
}
}
impl Drop for Inner {
fn drop(&mut self) {
let _ = self
.detonator
.take()
.expect("Inner has already been terminated")
.send(());
}
}
async fn route(receivers: &Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>, message: Message) {
receivers
.with(move |mut guard| {
(*guard).retain(|namespace, senders| {
if namespace.contains(message.namespace()) {
senders.retain_mut(|tx| tx.try_send(message.clone()).is_ok());
}
!senders.is_empty()
});
future::ready(())
})
.await
}
#[cfg(test)]
mod tests {
use super::*;
use cord_message::errors::ErrorKind as MessageErrorKind;
use futures::channel::mpsc::{unbounded, UnboundedReceiver};
struct ForwardStream(Vec<Message>);
impl Stream for ForwardStream {
type Item = Result<Message>;
fn poll_next(mut self: Pin<&mut Self>, _: &mut Context) -> Poll<Option<Self::Item>> {
Poll::Ready(self.0.pop().map(Ok))
}
}
#[allow(clippy::type_complexity)]
fn setup_client() -> (
ClientConn<impl Sink<Message, Error = MessageError>>,
UnboundedReceiver<Message>,
Mutex<HashMap<Pattern, Vec<mpsc::Sender<Message>>>>,
) {
let (tx, rx) = unbounded();
let (det_tx, _) = oneshot::channel();
let receivers = Mutex::new(HashMap::new());
(
ClientConn {
sink: tx.sink_map_err(|e| MessageErrorKind::Msg(format!("{}", e)).into()),
inner: Arc::new(Inner {
receivers: receivers.clone(),
detonator: Some(det_tx),
}),
},
rx,
receivers,
)
}
#[tokio::test]
async fn test_forward() {
let (client, rx, _) = setup_client();
let data_stream = ForwardStream(vec![
Message::Event("/a".into(), "b".into()),
Message::Provide("/a".into()),
]);
data_stream.forward(client).await.unwrap();
let (item, rx) = rx.into_future().await;
assert_eq!(item, Some(Message::Provide("/a".into())));
let (item, _) = rx.into_future().await;
assert_eq!(item, Some(Message::Event("/a".into(), "b".into())));
}
#[tokio::test]
async fn test_provide() {
let (mut client, rx, _) = setup_client();
client.provide("/a/b".into()).await.unwrap();
assert_eq!(
rx.into_future().await.0.unwrap(),
Message::Provide("/a/b".into())
);
}
#[tokio::test]
async fn test_revoke() {
let (mut client, rx, _) = setup_client();
client.revoke("/a/b".into()).await.unwrap();
assert_eq!(
rx.into_future().await.0.unwrap(),
Message::Revoke("/a/b".into())
);
}
#[tokio::test]
async fn test_subscribe() {
let (mut client, rx, receivers) = setup_client();
client.subscribe("/a/b".into()).await.unwrap();
assert_eq!(
rx.into_future().await.0.unwrap(),
Message::Subscribe("/a/b".into())
);
let guard = receivers.lock().await;
assert!((*guard).contains_key(&"/a/b".into()));
}
#[tokio::test]
async fn test_unsubscribe() {
let (mut client, rx, receivers) = setup_client();
receivers
.with(|mut guard| {
(*guard).insert("/a/b".into(), Vec::new());
future::ready(())
})
.await;
client.unsubscribe("/a/b".into()).await.unwrap();
assert_eq!(
rx.into_future().await.0.unwrap(),
Message::Unsubscribe("/a/b".into())
);
let guard = receivers.lock().await;
assert!((*guard).is_empty());
}
#[tokio::test]
async fn test_event() {
let (mut client, rx, _) = setup_client();
client.event("/a/b".into(), "moo").await.unwrap();
assert_eq!(
rx.into_future().await.0.unwrap(),
Message::Event("/a/b".into(), "moo".into())
);
}
#[tokio::test]
async fn test_route() {
let (tx, rx) = mpsc::channel(10);
let receivers = Mutex::new(HashMap::new());
receivers
.with(|mut guard| {
(*guard).insert("/a/b".into(), vec![tx]);
future::ready(())
})
.await;
let event_msg = Message::Event("/a/b".into(), "Moo!".into());
let event_msg_c = event_msg.clone();
route(&receivers, event_msg).await;
assert_eq!(rx.into_future().await.0.unwrap(), event_msg_c);
let guard = receivers.lock().await;
assert!((*guard).contains_key(&"/a/b".into()));
}
#[tokio::test]
async fn test_route_norecv() {
let (tx, _) = mpsc::channel(10);
let receivers = Mutex::new(HashMap::new());
receivers
.with(|mut guard| {
(*guard).insert("/a/b".into(), vec![tx]);
future::ready(())
})
.await;
route(&receivers, Message::Event("/a/b".into(), "Moo!".into())).await;
let guard = receivers.lock().await;
assert!((*guard).is_empty());
}
}