use futures::{task::*, Sink, Stream};
use pin_project::pin_project;
use std::{error::Error, pin::Pin};
use tokio::sync::mpsc;
#[derive(thiserror::Error, Debug)]
pub enum ChannelError {
#[error("an error occurred sending over the channel")]
Send(#[source] Box<dyn Error + Send + Sync + 'static>),
}
pub fn unbounded<SinkItem, Item>() -> (
UnboundedChannel<SinkItem, Item>,
UnboundedChannel<Item, SinkItem>,
) {
let (tx1, rx2) = mpsc::unbounded_channel();
let (tx2, rx1) = mpsc::unbounded_channel();
(
UnboundedChannel { tx: tx1, rx: rx1 },
UnboundedChannel { tx: tx2, rx: rx2 },
)
}
#[derive(Debug)]
pub struct UnboundedChannel<Item, SinkItem> {
rx: mpsc::UnboundedReceiver<Item>,
tx: mpsc::UnboundedSender<SinkItem>,
}
impl<Item, SinkItem> Stream for UnboundedChannel<Item, SinkItem> {
type Item = Result<Item, ChannelError>;
fn poll_next(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.rx.poll_recv(cx).map(|option| option.map(Ok))
}
}
const CLOSED_MESSAGE: &str = "the channel is closed and cannot accept new items for sending";
impl<Item, SinkItem> Sink<SinkItem> for UnboundedChannel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(if self.tx.is_closed() {
Err(ChannelError::Send(CLOSED_MESSAGE.into()))
} else {
Ok(())
})
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.tx
.send(item)
.map_err(|_| ChannelError::Send(CLOSED_MESSAGE.into()))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
}
}
pub fn bounded<SinkItem, Item>(
capacity: usize,
) -> (Channel<SinkItem, Item>, Channel<Item, SinkItem>) {
let (tx1, rx2) = futures::channel::mpsc::channel(capacity);
let (tx2, rx1) = futures::channel::mpsc::channel(capacity);
(Channel { tx: tx1, rx: rx1 }, Channel { tx: tx2, rx: rx2 })
}
#[pin_project]
#[derive(Debug)]
pub struct Channel<Item, SinkItem> {
#[pin]
rx: futures::channel::mpsc::Receiver<Item>,
#[pin]
tx: futures::channel::mpsc::Sender<SinkItem>,
}
impl<Item, SinkItem> Stream for Channel<Item, SinkItem> {
type Item = Result<Item, ChannelError>;
fn poll_next(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, ChannelError>>> {
self.project().rx.poll_next(cx).map(|option| option.map(Ok))
}
}
impl<Item, SinkItem> Sink<SinkItem> for Channel<Item, SinkItem> {
type Error = ChannelError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_ready(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn start_send(self: Pin<&mut Self>, item: SinkItem) -> Result<(), Self::Error> {
self.project()
.tx
.start_send(item)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_flush(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project()
.tx
.poll_close(cx)
.map_err(|e| ChannelError::Send(Box::new(e)))
}
}
#[cfg(test)]
#[cfg(feature = "tokio1")]
mod tests {
use crate::{
client, context,
server::{incoming::Incoming, BaseChannel},
transport::{
self,
channel::{Channel, UnboundedChannel},
},
};
use assert_matches::assert_matches;
use futures::{prelude::*, stream};
use std::io;
use tracing::trace;
#[test]
fn ensure_is_transport() {
fn is_transport<SinkItem, Item, T: crate::Transport<SinkItem, Item>>() {}
is_transport::<(), (), UnboundedChannel<(), ()>>();
is_transport::<(), (), Channel<(), ()>>();
}
#[tokio::test]
async fn integration() -> anyhow::Result<()> {
let _ = tracing_subscriber::fmt::try_init();
let (client_channel, server_channel) = transport::channel::unbounded();
tokio::spawn(
stream::once(future::ready(server_channel))
.map(BaseChannel::with_defaults)
.execute(|_ctx, request: String| {
future::ready(request.parse::<u64>().map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidInput,
format!("{request:?} is not an int"),
)
}))
}),
);
let client = client::new(client::Config::default(), client_channel).spawn();
let response1 = client.call(context::current(), "", "123".into()).await?;
let response2 = client.call(context::current(), "", "abc".into()).await?;
trace!("response1: {:?}, response2: {:?}", response1, response2);
assert_matches!(response1, Ok(123));
assert_matches!(response2, Err(ref e) if e.kind() == io::ErrorKind::InvalidInput);
Ok(())
}
}