use bytes::Bytes;
use futures::{future::Either, prelude::*, sink, stream::StreamFuture};
use crate::protocol::{Dialer, DialerFuture, DialerToListenerMessage, ListenerToDialerMessage};
use log::trace;
use std::mem;
use tokio_io::{AsyncRead, AsyncWrite};
use crate::ProtocolChoiceError;
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
#[inline]
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let iter = protocols.into_iter();
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::A(dialer_select_proto_serial(inner, iter))
} else {
Either::B(dialer_select_proto_parallel(inner, iter))
}
}
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I,) -> DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
DialerSelectSeq {
inner: DialerSelectSeqState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
}
pub struct DialerSelectSeq<R: AsyncRead + AsyncWrite, I: Iterator> {
inner: DialerSelectSeqState<R, I>
}
enum DialerSelectSeqState<R: AsyncRead + AsyncWrite, I: Iterator> {
AwaitDialer {
dialer_fut: DialerFuture<R>,
protocols: I
},
NextProtocol {
dialer: Dialer<R>,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
proto_name: I::Item,
protocols: I
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
proto_name: I::Item,
protocols: I
},
Undefined
}
impl<R, I> Future for DialerSelectSeq<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, DialerSelectSeqState::Undefined) {
DialerSelectSeqState::AwaitDialer { mut dialer_fut, protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectSeqState::AwaitDialer { dialer_fut, protocols };
return Ok(Async::NotReady)
}
};
self.inner = DialerSelectSeqState::NextProtocol { dialer, protocols }
}
DialerSelectSeqState::NextProtocol { dialer, mut protocols } => {
let proto_name =
protocols.next().ok_or(ProtocolChoiceError::NoProtocolFound)?;
let req = DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
};
trace!("sending {:?}", req);
let sender = dialer.send(req);
self.inner = DialerSelectSeqState::SendProtocol {
sender,
proto_name,
protocols
}
}
DialerSelectSeqState::SendProtocol { mut sender, proto_name, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectSeqState::SendProtocol {
sender,
proto_name,
protocols
};
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectSeqState::AwaitProtocol {
stream,
proto_name,
protocols
};
}
DialerSelectSeqState::AwaitProtocol { mut stream, proto_name, protocols } => {
let (m, r) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectSeqState::AwaitProtocol {
stream,
proto_name,
protocols
};
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("received {:?}", m);
match m.ok_or(ProtocolChoiceError::UnexpectedMessage)? {
ListenerToDialerMessage::ProtocolAck { ref name }
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
},
ListenerToDialerMessage::NotAvailable => {
self.inner = DialerSelectSeqState::NextProtocol { dialer: r, protocols }
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
}
DialerSelectSeqState::Undefined =>
panic!("DialerSelectSeqState::poll called after completion")
}
}
}
}
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite
{
DialerSelectPar {
inner: DialerSelectParState::AwaitDialer { dialer_fut: Dialer::new(inner), protocols }
}
}
pub struct DialerSelectPar<R: AsyncRead + AsyncWrite, I: Iterator> {
inner: DialerSelectParState<R, I>
}
enum DialerSelectParState<R: AsyncRead + AsyncWrite, I: Iterator> {
AwaitDialer {
dialer_fut: DialerFuture<R>,
protocols: I
},
SendRequest {
sender: sink::Send<Dialer<R>>,
protocols: I
},
AwaitResponse {
stream: StreamFuture<Dialer<R>>,
protocols: I
},
SendProtocol {
sender: sink::Send<Dialer<R>>,
proto_name: I::Item
},
AwaitProtocol {
stream: StreamFuture<Dialer<R>>,
proto_name: I::Item
},
Undefined
}
impl<R, I> Future for DialerSelectPar<R, I>
where
I: Iterator,
I::Item: AsRef<[u8]>,
R: AsyncRead + AsyncWrite,
{
type Item = (I::Item, R);
type Error = ProtocolChoiceError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.inner, DialerSelectParState::Undefined) {
DialerSelectParState::AwaitDialer { mut dialer_fut, protocols } => {
let dialer = match dialer_fut.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::AwaitDialer { dialer_fut, protocols };
return Ok(Async::NotReady)
}
};
trace!("requesting protocols list");
let sender = dialer.send(DialerToListenerMessage::ProtocolsListRequest);
self.inner = DialerSelectParState::SendRequest { sender, protocols };
}
DialerSelectParState::SendRequest { mut sender, protocols } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendRequest { sender, protocols };
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
}
DialerSelectParState::AwaitResponse { mut stream, protocols } => {
let (m, d) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitResponse { stream, protocols };
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("protocols list response: {:?}", m);
let list = match m {
Some(ListenerToDialerMessage::ProtocolsListResponse { list }) => list,
_ => return Err(ProtocolChoiceError::UnexpectedMessage),
};
let mut found = None;
for local_name in protocols {
for remote_name in &list {
if remote_name.as_ref() == local_name.as_ref() {
found = Some(local_name);
break;
}
}
if found.is_some() {
break;
}
}
let proto_name = found.ok_or(ProtocolChoiceError::NoProtocolFound)?;
trace!("sending {:?}", proto_name.as_ref());
let sender = d.send(DialerToListenerMessage::ProtocolRequest {
name: Bytes::from(proto_name.as_ref())
});
self.inner = DialerSelectParState::SendProtocol { sender, proto_name };
}
DialerSelectParState::SendProtocol { mut sender, proto_name } => {
let dialer = match sender.poll()? {
Async::Ready(d) => d,
Async::NotReady => {
self.inner = DialerSelectParState::SendProtocol {
sender,
proto_name
};
return Ok(Async::NotReady)
}
};
let stream = dialer.into_future();
self.inner = DialerSelectParState::AwaitProtocol {
stream,
proto_name
};
}
DialerSelectParState::AwaitProtocol { mut stream, proto_name } => {
let (m, r) = match stream.poll() {
Ok(Async::Ready(x)) => x,
Ok(Async::NotReady) => {
self.inner = DialerSelectParState::AwaitProtocol {
stream,
proto_name
};
return Ok(Async::NotReady)
}
Err((e, _)) => return Err(ProtocolChoiceError::from(e))
};
trace!("received {:?}", m);
match m {
Some(ListenerToDialerMessage::ProtocolAck { ref name })
if name.as_ref() == proto_name.as_ref() =>
{
return Ok(Async::Ready((proto_name, r.into_inner())))
}
_ => return Err(ProtocolChoiceError::UnexpectedMessage)
}
}
DialerSelectParState::Undefined =>
panic!("DialerSelectParState::poll called after completion")
}
}
}
}