use crate::protocol::{Protocol, ProtocolError, MessageIO, Message, Version};
use futures::{future::Either, prelude::*};
use log::debug;
use std::{io, iter, mem, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use crate::{Negotiated, NegotiationError};
pub fn dialer_select_proto<R, I>(inner: R, protocols: I) -> DialerSelectFuture<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let iter = protocols.into_iter();
if iter.size_hint().1.map(|n| n <= 3).unwrap_or(false) {
Either::A(dialer_select_proto_serial(inner, iter))
} else {
Either::B(dialer_select_proto_parallel(inner, iter))
}
}
pub type DialerSelectFuture<R, I> = Either<DialerSelectSeq<R, I>, DialerSelectPar<R, I>>;
pub fn dialer_select_proto_serial<R, I>(inner: R, protocols: I) -> DialerSelectSeq<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter().peekable();
DialerSelectSeq {
protocols,
state: SeqState::SendHeader {
io: MessageIO::new(inner)
}
}
}
pub fn dialer_select_proto_parallel<R, I>(inner: R, protocols: I) -> DialerSelectPar<R, I::IntoIter>
where
R: AsyncRead + AsyncWrite,
I: IntoIterator,
I::Item: AsRef<[u8]>
{
let protocols = protocols.into_iter();
DialerSelectPar {
protocols,
state: ParState::SendHeader {
io: MessageIO::new(inner)
}
}
}
pub struct DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
protocols: iter::Peekable<I>,
state: SeqState<R, I::Item>
}
enum SeqState<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
SendHeader { io: MessageIO<R>, },
SendProtocol { io: MessageIO<R>, protocol: N },
FlushProtocol { io: MessageIO<R>, protocol: N },
AwaitProtocol { io: MessageIO<R>, protocol: N },
Done
}
impl<R, I> Future for DialerSelectSeq<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
type Item = (I::Item, Negotiated<R>);
type Error = NegotiationError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, SeqState::Done) {
SeqState::SendHeader { mut io } => {
if io.start_send(Message::Header(Version::V1))?.is_not_ready() {
self.state = SeqState::SendHeader { io };
return Ok(Async::NotReady)
}
let protocol = self.protocols.next().ok_or(NegotiationError::Failed)?;
self.state = SeqState::SendProtocol { io, protocol };
}
SeqState::SendProtocol { mut io, protocol } => {
let p = Protocol::try_from(protocol.as_ref())?;
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
self.state = SeqState::SendProtocol { io, protocol };
return Ok(Async::NotReady)
}
debug!("Dialer: Proposed protocol: {}", p);
if self.protocols.peek().is_some() {
self.state = SeqState::FlushProtocol { io, protocol }
} else {
debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p);
return Ok(Async::Ready((protocol, io)))
}
}
SeqState::FlushProtocol { mut io, protocol } => {
if io.poll_complete()?.is_not_ready() {
self.state = SeqState::FlushProtocol { io, protocol };
return Ok(Async::NotReady)
}
self.state = SeqState::AwaitProtocol { io, protocol }
}
SeqState::AwaitProtocol { mut io, protocol } => {
let msg = match io.poll()? {
Async::NotReady => {
self.state = SeqState::AwaitProtocol { io, protocol };
return Ok(Async::NotReady)
}
Async::Ready(None) =>
return Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))),
Async::Ready(Some(msg)) => msg,
};
match msg {
Message::Header(Version::V1) => {
self.state = SeqState::AwaitProtocol { io, protocol };
}
Message::Protocol(ref p) if p.as_ref() == protocol.as_ref() => {
debug!("Dialer: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
let io = Negotiated::completed(io, remaining);
return Ok(Async::Ready((protocol, io)))
}
Message::NotAvailable => {
debug!("Dialer: Received rejection of protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
let protocol = self.protocols.next()
.ok_or(NegotiationError::Failed)?;
self.state = SeqState::SendProtocol { io, protocol }
}
_ => return Err(ProtocolError::InvalidMessage.into())
}
}
SeqState::Done => panic!("SeqState::poll called after completion")
}
}
}
}
pub struct DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
protocols: I,
state: ParState<R, I::Item>
}
enum ParState<R, N>
where
R: AsyncRead + AsyncWrite,
N: AsRef<[u8]>
{
SendHeader { io: MessageIO<R> },
SendProtocolsRequest { io: MessageIO<R> },
Flush { io: MessageIO<R> },
RecvProtocols { io: MessageIO<R> },
SendProtocol { io: MessageIO<R>, protocol: N },
Done
}
impl<R, I> Future for DialerSelectPar<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<[u8]>
{
type Item = (I::Item, Negotiated<R>);
type Error = NegotiationError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
loop {
match mem::replace(&mut self.state, ParState::Done) {
ParState::SendHeader { mut io } => {
if io.start_send(Message::Header(Version::V1))?.is_not_ready() {
self.state = ParState::SendHeader { io };
return Ok(Async::NotReady)
}
self.state = ParState::SendProtocolsRequest { io };
}
ParState::SendProtocolsRequest { mut io } => {
if io.start_send(Message::ListProtocols)?.is_not_ready() {
self.state = ParState::SendProtocolsRequest { io };
return Ok(Async::NotReady)
}
debug!("Dialer: Requested supported protocols.");
self.state = ParState::Flush { io }
}
ParState::Flush { mut io } => {
if io.poll_complete()?.is_not_ready() {
self.state = ParState::Flush { io };
return Ok(Async::NotReady)
}
self.state = ParState::RecvProtocols { io }
}
ParState::RecvProtocols { mut io } => {
let msg = match io.poll()? {
Async::NotReady => {
self.state = ParState::RecvProtocols { io };
return Ok(Async::NotReady)
}
Async::Ready(None) =>
return Err(NegotiationError::from(
io::Error::from(io::ErrorKind::UnexpectedEof))),
Async::Ready(Some(msg)) => msg,
};
match &msg {
Message::Header(Version::V1) => {
self.state = ParState::RecvProtocols { io }
}
Message::Protocols(supported) => {
let protocol = self.protocols.by_ref()
.find(|p| supported.iter().any(|s|
s.as_ref() == p.as_ref()))
.ok_or(NegotiationError::Failed)?;
debug!("Dialer: Found supported protocol: {}",
String::from_utf8_lossy(protocol.as_ref()));
self.state = ParState::SendProtocol { io, protocol };
}
_ => return Err(ProtocolError::InvalidMessage.into())
}
}
ParState::SendProtocol { mut io, protocol } => {
let p = Protocol::try_from(protocol.as_ref())?;
if io.start_send(Message::Protocol(p.clone()))?.is_not_ready() {
self.state = ParState::SendProtocol { io, protocol };
return Ok(Async::NotReady)
}
debug!("Dialer: Expecting proposed protocol: {}", p);
let io = Negotiated::expecting(io.into_reader(), p);
return Ok(Async::Ready((protocol, io)))
}
ParState::Done => panic!("ParState::poll called after completion")
}
}
}
}