use std::{
any::Any,
io::{Error, ErrorKind},
pin::Pin,
task::{Context, Poll},
};
use futures_channel::mpsc;
use futures_core::Stream as _;
use futures_sink::Sink as _;
use crate::{Deserialize, Serialize, Sink, Stream};
type Item = Box<dyn Any + Send + Sync + 'static>;
#[derive(Debug, Clone)]
pub struct MemorySink(mpsc::Sender<Item>);
impl Sink for MemorySink {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn start_send<Item: Serialize>(
mut self: Pin<&mut Self>,
item: Item,
) -> Result<(), Self::Error> {
Pin::new(&mut self.0)
.start_send(Box::new(item))
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
}
#[derive(Debug)]
pub struct MemoryStream(mpsc::Receiver<Item>);
impl Stream for MemoryStream {
type Error = Error;
fn poll_next<Item: Deserialize>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, Self::Error>>> {
Pin::new(&mut self.0).poll_next(cx).map(|item| {
item.map(|item| {
item.downcast().map(|item| *item).map_err(|_| {
Error::new(ErrorKind::InvalidData, "sender sent an unexpected type")
})
})
})
}
}
pub fn channel(buffer: usize) -> (MemorySink, MemoryStream) {
let (sender, receiver) = mpsc::channel(buffer);
(MemorySink(sender), MemoryStream(receiver))
}
#[derive(Debug, Clone)]
pub struct UnboundedMemorySink(mpsc::UnboundedSender<Item>);
impl Sink for UnboundedMemorySink {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_ready(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn start_send<Item: Serialize>(
mut self: Pin<&mut Self>,
item: Item,
) -> Result<(), Self::Error> {
Pin::new(&mut self.0)
.start_send(Box::new(item))
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_flush(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.0)
.poll_close(cx)
.map_err(|e| Error::new(ErrorKind::ConnectionAborted, e))
}
}
#[derive(Debug)]
pub struct UnboundedMemoryStream(mpsc::UnboundedReceiver<Item>);
impl Stream for UnboundedMemoryStream {
type Error = Error;
fn poll_next<Item: Deserialize>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, Self::Error>>> {
Pin::new(&mut self.0).poll_next(cx).map(|item| {
item.map(|item| {
item.downcast().map(|item| *item).map_err(|_| {
Error::new(ErrorKind::InvalidData, "sender sent an unexpected type")
})
})
})
}
}
pub fn unbounded() -> (UnboundedMemorySink, UnboundedMemoryStream) {
let (sender, receiver) = mpsc::unbounded();
(UnboundedMemorySink(sender), UnboundedMemoryStream(receiver))
}
#[derive(Debug)]
pub struct MemoryDuplex {
sink: MemorySink,
stream: MemoryStream,
}
impl MemoryDuplex {
pub fn into_inner(self) -> (MemorySink, MemoryStream) {
(self.sink, self.stream)
}
pub fn sink_mut(&mut self) -> &mut MemorySink {
&mut self.sink
}
pub fn stream_mut(&mut self) -> &mut MemoryStream {
&mut self.stream
}
}
impl Sink for MemoryDuplex {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_ready(cx)
}
fn start_send<Item: Serialize>(
mut self: Pin<&mut Self>,
item: Item,
) -> Result<(), Self::Error> {
Pin::new(&mut self.sink).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_close(cx)
}
}
impl Stream for MemoryDuplex {
type Error = Error;
fn poll_next<Item: Deserialize>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, Self::Error>>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
pub fn duplex(buffer: usize) -> (MemoryDuplex, MemoryDuplex) {
let (a, b) = channel(buffer);
let (c, d) = channel(buffer);
(
MemoryDuplex { sink: a, stream: d },
MemoryDuplex { sink: c, stream: b },
)
}
#[derive(Debug)]
pub struct UnboundedMemoryDuplex {
sink: UnboundedMemorySink,
stream: UnboundedMemoryStream,
}
impl UnboundedMemoryDuplex {
pub fn into_inner(self) -> (UnboundedMemorySink, UnboundedMemoryStream) {
(self.sink, self.stream)
}
pub fn sink_mut(&mut self) -> &mut UnboundedMemorySink {
&mut self.sink
}
pub fn stream_mut(&mut self) -> &mut UnboundedMemoryStream {
&mut self.stream
}
}
impl Sink for UnboundedMemoryDuplex {
type Error = Error;
fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_ready(cx)
}
fn start_send<Item: Serialize>(
mut self: Pin<&mut Self>,
item: Item,
) -> Result<(), Self::Error> {
Pin::new(&mut self.sink).start_send(item)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_flush(cx)
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Pin::new(&mut self.sink).poll_close(cx)
}
}
impl Stream for UnboundedMemoryDuplex {
type Error = Error;
fn poll_next<Item: Deserialize>(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Item, Self::Error>>> {
Pin::new(&mut self.stream).poll_next(cx)
}
}
pub fn unbounded_duplex() -> (UnboundedMemoryDuplex, UnboundedMemoryDuplex) {
let (a, b) = unbounded();
let (c, d) = unbounded();
(
UnboundedMemoryDuplex { sink: a, stream: d },
UnboundedMemoryDuplex { sink: c, stream: b },
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{SinkExt, StreamExt};
#[test]
fn test_channel() {
let (mut sink, mut stream) = channel(1);
futures::executor::block_on(async {
sink.send(42u8).await.unwrap();
assert_eq!(stream.next::<u8>().await.unwrap().unwrap(), 42);
})
}
#[test]
#[should_panic]
fn test_channel_type_mismatch() {
let (mut sink, mut stream) = channel(1);
futures::executor::block_on(async {
sink.send(42u16).await.unwrap();
stream.next::<u8>().await.unwrap().unwrap();
})
}
#[test]
fn test_duplex() {
let (mut a, mut b) = duplex(1);
futures::executor::block_on(async {
a.send(42u8).await.unwrap();
assert_eq!(b.next::<u8>().await.unwrap().unwrap(), 42);
})
}
#[test]
#[should_panic]
fn test_duplex_type_mismatch() {
let (mut a, mut b) = duplex(1);
futures::executor::block_on(async {
a.send(42u16).await.unwrap();
b.next::<u8>().await.unwrap().unwrap();
})
}
}