use std::io;
use std::cell::RefCell;
use std::rc::Rc;
use std::i32::MAX;
use std::io::ErrorKind::UnexpectedEof;
use std::fmt::Debug;
use futures::{Stream, Poll, Future, Sink, AsyncSink, StartSend, Async, AndThen};
use futures::sink::Send;
use futures::stream::Fuse;
use tokio_io::{AsyncRead, AsyncWrite};
use multi_producer_sink::MPS;
use multi_consumer_stream::{DefaultMCS, KeyMCS};
use atm_async_utils::sink_futures::Close;
use codec::*;
pub use codec::DecodedPacket;
#[derive(Clone, Copy)]
pub enum PsPacketType {
Binary,
String,
Json,
}
impl PsPacketType {
fn flags(&self) -> u8 {
match *self {
PsPacketType::Binary => TYPE_BINARY,
PsPacketType::String => TYPE_STRING,
PsPacketType::Json => TYPE_JSON,
}
}
}
pub fn packet_stream<R: AsyncRead, W: AsyncWrite, B: AsRef<[u8]>>
(r: R,
w: W)
-> (PsIn<R, W, B>, PsOut<R, W, B>) {
let ps = Rc::new(RefCell::new(PS::new(r, w)));
(PsIn(Rc::clone(&ps)), PsOut(ps))
}
struct PS<R, W, B>
where R: AsyncRead
{
sink: MPS<PSCodecSink<W, B>>,
stream: DefaultMCS<Fuse<PSCodecStream<R>>, PacketId, fn(&DecodedPacket) -> PacketId>,
id_counter: PacketId,
accepted_id: PacketId,
}
impl<R, W, B> PS<R, W, B>
where R: AsyncRead,
W: AsyncWrite,
B: AsRef<[u8]>
{
fn new(r: R, w: W) -> PS<R, W, B> {
PS {
sink: MPS::new(PSCodecSink::new(w)),
stream: DefaultMCS::new(PSCodecStream::new(r).fuse(), DecodedPacket::id),
id_counter: 1,
accepted_id: 1,
}
}
fn next_id(&mut self) -> PacketId {
let ret = self.id_counter;
self.increment_id();
return ret;
}
fn increment_id(&mut self) {
if self.id_counter == MAX {
self.id_counter = 1;
} else {
self.id_counter += 1
}
}
fn increment_accepted(&mut self) {
if self.accepted_id == MAX {
self.accepted_id = 1;
} else {
self.accepted_id += 1
}
}
fn poll(&mut self) -> Poll<Option<IncomingPacket<R, W, B>>, io::Error> {
match try_ready!(self.stream.poll()) {
Some(p) => {
if p.id() == self.accepted_id {
self.increment_accepted();
if p.is_stream_packet() {
let sink_id = p.id() * -1;
let stream_id = p.id();
Ok(Async::Ready(Some(IncomingPacket::Duplex(
p,
PsSink::new(self.sink.clone(), sink_id),
PsStream::new(self.stream.key_handle(stream_id))
))))
} else {
Ok(Async::Ready(Some(IncomingPacket::Request(InRequest::new(p, self.sink.clone())))))
}
} else {
self.poll()
}
}
None => Ok(Async::Ready(None)),
}
}
}
pub struct PsOut<R: AsyncRead, W, B>(Rc<RefCell<PS<R, W, B>>>);
impl<R, W, B> PsOut<R, W, B>
where R: AsyncRead,
W: AsyncWrite,
B: AsRef<[u8]>
{
pub fn request(&self, data: B, t: PsPacketType) -> (SendRequest<W, B>, InResponse<R>) {
let mut ps = self.0.borrow_mut();
let id = ps.next_id();
(SendRequest::new(ps.sink.clone(), data, t, id),
InResponse::new(ps.stream.key_handle(id * -1)))
}
pub fn duplex(&self) -> (PsSink<W, B>, PsStream<R>) {
let mut ps = self.0.borrow_mut();
let id = ps.next_id();
(PsSink::new(ps.sink.clone(), id), PsStream::new(ps.stream.key_handle(id * -1)))
}
pub fn close(&mut self) -> Poll<(), io::Error> {
self.0.borrow_mut().sink.close()
}
}
pub struct SendRequest<W: AsyncWrite, B: AsRef<[u8]>>(AndThen<Send<MPS<PSCodecSink<W, B>>>,
Close<MPS<PSCodecSink<W, B>>>,
fn(MPS<PSCodecSink<W, B>>)
-> Close<MPS<PSCodecSink<W,
B>>>>);
impl<W: AsyncWrite, B: AsRef<[u8]>> SendRequest<W, B> {
fn new(sink_handle: MPS<PSCodecSink<W, B>>,
data: B,
t: PsPacketType,
id: PacketId)
-> SendRequest<W, B> {
SendRequest(sink_handle
.send((data,
MetaData {
flags: t.flags(),
id,
}))
.and_then(|s| Close::new(s)))
}
}
impl<W: AsyncWrite, B: AsRef<[u8]>> Future for SendRequest<W, B> {
type Item = ();
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
try_ready!(self.0.poll());
Ok(Async::Ready(()))
}
}
pub struct InResponse<R: AsyncRead>(KeyMCS<Fuse<PSCodecStream<R>>,
PacketId,
fn(&DecodedPacket) -> PacketId>);
impl<R: AsyncRead> InResponse<R> {
fn new(stream_handle: KeyMCS<Fuse<PSCodecStream<R>>,
PacketId,
fn(&DecodedPacket) -> PacketId>)
-> InResponse<R> {
InResponse(stream_handle)
}
}
impl<R: AsyncRead> Future for InResponse<R> {
type Item = DecodedPacket;
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
match self.0.poll() {
Ok(Async::Ready(Some(decoded_packet))) => Ok(Async::Ready(decoded_packet)),
Ok(Async::Ready(None)) => {
Err(io::Error::new(io::ErrorKind::UnexpectedEof,
"packet-stream closed before response was received"))
}
Ok(Async::NotReady) => Ok(Async::NotReady),
Err(e) => Err(e),
}
}
}
pub struct PsSink<W: AsyncWrite, B: AsRef<[u8]>> {
sink: MPS<PSCodecSink<W, B>>,
id: PacketId,
}
impl<W, B> PsSink<W, B>
where W: AsyncWrite,
B: AsRef<[u8]>
{
fn new(sink: MPS<PSCodecSink<W, B>>, id: PacketId) -> PsSink<W, B> {
PsSink { sink, id }
}
}
impl<W, B> Sink for PsSink<W, B>
where W: AsyncWrite,
B: AsRef<[u8]> + Debug
{
type SinkItem = (B, PsPacketType, bool);
type SinkError = io::Error;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let mut flags = item.1.flags() | STREAM;
if item.2 {
flags |= END;
}
match self.sink
.start_send((item.0, MetaData { flags, id: self.id })) {
Ok(AsyncSink::NotReady((bytes, _))) => Ok(AsyncSink::NotReady((bytes, item.1, item.2))),
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Err(e) => Err(e),
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
self.sink.poll_complete()
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
self.sink.close()
}
}
pub struct PsStream<R: AsyncRead> {
stream: KeyMCS<Fuse<PSCodecStream<R>>, PacketId, fn(&DecodedPacket) -> PacketId>,
}
impl<R: AsyncRead> PsStream<R> {
fn new(stream: KeyMCS<Fuse<PSCodecStream<R>>, PacketId, fn(&DecodedPacket) -> PacketId>)
-> PsStream<R> {
PsStream { stream }
}
}
impl<R: AsyncRead> Stream for PsStream<R> {
type Item = DecodedPacket;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
match try_ready!(self.stream.poll()) {
Some(packet) => Ok(Async::Ready(Some(packet))),
None => {
Err(io::Error::new(UnexpectedEof,
"packet stream closed while substream was waiting for items"))
}
}
}
}
pub struct PsIn<R: AsyncRead, W, B>(Rc<RefCell<PS<R, W, B>>>);
impl<R: AsyncRead, W: AsyncWrite, B: AsRef<[u8]>> Stream for PsIn<R, W, B> {
type Item = IncomingPacket<R, W, B>;
type Error = io::Error;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
self.0.borrow_mut().poll()
}
}
pub enum IncomingPacket<R: AsyncRead, W: AsyncWrite, B: AsRef<[u8]>> {
Request(InRequest<W, B>),
Duplex(DecodedPacket, PsSink<W, B>, PsStream<R>),
}
pub struct InRequest<W, B> {
packet: DecodedPacket,
sink: MPS<PSCodecSink<W, B>>,
}
impl<W, B> InRequest<W, B> {
fn new(packet: DecodedPacket, sink: MPS<PSCodecSink<W, B>>) -> InRequest<W, B> {
InRequest { packet, sink }
}
pub fn packet(&self) -> &DecodedPacket {
&self.packet
}
}
impl<W: AsyncWrite, B: AsRef<[u8]>> InRequest<W, B> {
pub fn respond(self, bytes: B, t: PsPacketType) -> SendResponse<W, B> {
SendResponse::new(self.sink, self.packet.id() * -1, bytes, t)
}
pub fn respond_error(self, bytes: B, t: PsPacketType) -> SendResponseError<W, B> {
SendResponseError::new(self.sink, self.packet.id() * -1, bytes, t)
}
}
pub struct SendResponse<W: AsyncWrite, B: AsRef<[u8]>>(AndThen<Send<MPS<PSCodecSink<W, B>>>,
Close<MPS<PSCodecSink<W, B>>>,
fn(MPS<PSCodecSink<W, B>>)
-> Close<MPS<PSCodecSink<W,
B>>>>);
impl<W: AsyncWrite, B: AsRef<[u8]>> SendResponse<W, B> {
fn new(sink_handle: MPS<PSCodecSink<W, B>>,
id: PacketId,
data: B,
t: PsPacketType)
-> SendResponse<W, B> {
debug_assert!(id < 0);
SendResponse(sink_handle
.send((data,
MetaData {
flags: t.flags(),
id: id,
}))
.and_then(|s| Close::new(s)))
}
}
impl<W: AsyncWrite, B: AsRef<[u8]>> Future for SendResponse<W, B> {
type Item = ();
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
try_ready!(self.0.poll());
Ok(Async::Ready(()))
}
}
pub struct SendResponseError<W: AsyncWrite, B: AsRef<[u8]>>(AndThen<Send<MPS<PSCodecSink<W, B>>>,
Close<MPS<PSCodecSink<W, B>>>,
fn(MPS<PSCodecSink<W, B>>)
-> Close<MPS<PSCodecSink<W,
B>>>>);
impl<W: AsyncWrite, B: AsRef<[u8]>> SendResponseError<W, B> {
fn new(sink_handle: MPS<PSCodecSink<W, B>>,
id: PacketId,
data: B,
t: PsPacketType)
-> SendResponseError<W, B> {
debug_assert!(id < 0);
SendResponseError(sink_handle
.send((data,
MetaData {
flags: t.flags() | END,
id: id,
}))
.and_then(|s| Close::new(s)))
}
}
impl<W: AsyncWrite, B: AsRef<[u8]>> Future for SendResponseError<W, B> {
type Item = ();
type Error = io::Error;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
try_ready!(self.0.poll());
Ok(Async::Ready(()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use partial_io::{PartialAsyncRead, PartialAsyncWrite, PartialWithErrors};
use partial_io::quickcheck_types::GenInterruptedWouldBlock;
use quickcheck::{QuickCheck, StdGen};
use async_ringbuffer::*;
use rand;
use futures::stream::iter_ok;
use futures::future::{ok, poll_fn};
#[test]
fn requests() {
let rng = StdGen::new(rand::thread_rng(), 20);
let mut quickcheck = QuickCheck::new().gen(rng).tests(100);
quickcheck.quickcheck(test_requests as
fn(usize,
usize,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>)
-> bool);
}
fn test_requests(buf_size_a: usize,
buf_size_b: usize,
write_ops_a: PartialWithErrors<GenInterruptedWouldBlock>,
read_ops_a: PartialWithErrors<GenInterruptedWouldBlock>,
write_ops_b: PartialWithErrors<GenInterruptedWouldBlock>,
read_ops_b: PartialWithErrors<GenInterruptedWouldBlock>)
-> bool {
let (writer_a, reader_a) = ring_buffer(buf_size_a + 1);
let writer_a = PartialAsyncWrite::new(writer_a, write_ops_a);
let reader_a = PartialAsyncRead::new(reader_a, read_ops_a);
let (writer_b, reader_b) = ring_buffer(buf_size_b + 1);
let writer_b = PartialAsyncWrite::new(writer_b, write_ops_b);
let reader_b = PartialAsyncRead::new(reader_b, read_ops_b);
let (a_in, mut a_out) = packet_stream(reader_a, writer_b);
let (b_in, mut b_out) = packet_stream(reader_b, writer_a);
let echo = b_in.for_each(|incoming_packet| match incoming_packet {
IncomingPacket::Request(in_request) => {
let data = in_request.packet().data().clone();
in_request.respond(data, PsPacketType::Binary)
}
IncomingPacket::Duplex(_, _, _) => unreachable!(),
})
.and_then(|_| poll_fn(|| b_out.close()));
let consume_a = a_in.for_each(|_| ok(()));
let (req0, res0) = a_out.request([0], PsPacketType::Binary);
let (req1, res1) = a_out.request([1], PsPacketType::Binary);
let (req2, res2) = a_out.request([2], PsPacketType::Binary);
let send_all = req0.join3(req1, req2)
.and_then(|_| poll_fn(|| a_out.close()));
let receive_all = res0.join3(res1, res2)
.map(|(r0, r1, r2)| {
return r0.data() == &vec![0u8].into_boxed_slice() && r0.is_buffer_packet() &&
r1.data() == &vec![1u8].into_boxed_slice() &&
r1.is_buffer_packet() &&
r2.data() == &vec![2u8].into_boxed_slice() &&
r2.is_buffer_packet();
});
return echo.join4(consume_a, send_all, receive_all)
.map(|(_, _, _, worked)| worked)
.wait()
.unwrap();
}
#[test]
fn duplexes() {
let rng = StdGen::new(rand::thread_rng(), 20);
let mut quickcheck = QuickCheck::new().gen(rng).tests(100);
quickcheck.quickcheck(test_duplexes as
fn(usize,
usize,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>,
PartialWithErrors<GenInterruptedWouldBlock>)
-> bool);
}
fn test_duplexes(buf_size_a: usize,
buf_size_b: usize,
write_ops_a: PartialWithErrors<GenInterruptedWouldBlock>,
read_ops_a: PartialWithErrors<GenInterruptedWouldBlock>,
write_ops_b: PartialWithErrors<GenInterruptedWouldBlock>,
read_ops_b: PartialWithErrors<GenInterruptedWouldBlock>)
-> bool {
let (writer_a, reader_a) = ring_buffer(buf_size_a + 1);
let writer_a = PartialAsyncWrite::new(writer_a, write_ops_a);
let reader_a = PartialAsyncRead::new(reader_a, read_ops_a);
let (writer_b, reader_b) = ring_buffer(buf_size_b + 1);
let writer_b = PartialAsyncWrite::new(writer_b, write_ops_b);
let reader_b = PartialAsyncRead::new(reader_b, read_ops_b);
let (a_in, mut a_out) = packet_stream(reader_a, writer_b);
let (b_in, mut b_out) = packet_stream(reader_b, writer_a);
let echo =
b_in.for_each(|incoming_packet| match incoming_packet {
IncomingPacket::Request(_) => unreachable!(),
IncomingPacket::Duplex(_, sink, stream) => {
stream
.take_while(|packet| ok(!packet.is_end_packet()))
.map(|packet| {
(packet.into_data(), PsPacketType::Binary, false)
})
.forward(sink)
.map(|_| ())
}
})
.and_then(move |_| poll_fn(move || b_out.close()));
let consume_a = a_in.for_each(|_| ok(()));
let (sink0_a, stream0_a) = a_out.duplex();
let (sink1_a, stream1_a) = a_out.duplex();
let (sink2_a, stream2_a) = a_out.duplex();
let send_0 =
sink0_a.send_all(iter_ok::<_, io::Error>(vec![(vec![0], PsPacketType::Binary, false),
(vec![0], PsPacketType::Binary, false),
(vec![42], PsPacketType::Binary, true)]));
let send_1 =
sink1_a.send_all(iter_ok::<_, io::Error>(vec![(vec![1], PsPacketType::Binary, false),
(vec![1], PsPacketType::Binary, false),
(vec![1], PsPacketType::Binary, false),
(vec![43], PsPacketType::Binary, true)]));
let send_2 =
sink2_a.send_all(iter_ok::<_, io::Error>(vec![(vec![2], PsPacketType::Binary, false),
(vec![2], PsPacketType::Binary, false),
(vec![2], PsPacketType::Binary, false),
(vec![2], PsPacketType::Binary, false),
(vec![44], PsPacketType::Binary, true)]));
let send_all = send_0
.join3(send_1, send_2)
.and_then(move |_| poll_fn(move || a_out.close()));
let receive_0 =
stream0_a
.take(1)
.fold(false, |_, packet| {
ok::<_, io::Error>(packet.into_data() == vec![0].into_boxed_slice())
});
let receive_1 = stream1_a
.take(2)
.fold(true, |acc, packet| {
ok::<_, io::Error>(acc && packet.into_data() == vec![1].into_boxed_slice())
});
let receive_2 = stream2_a
.take(3)
.fold(true, |acc, packet| {
ok::<_, io::Error>(acc && packet.into_data() == vec![2].into_boxed_slice())
});
let receive_all = receive_0
.join3(receive_1, receive_2)
.map(|(a, b, c)| a && b && c);
return echo.join4(consume_a, send_all, receive_all)
.map(|(_, _, _, worked)| worked)
.wait()
.unwrap();
}
}