use futures::{AsyncRead, AsyncWrite, Sink, Stream};
use crate::{
Negotiated, NegotiationError,
protocol::{Message, MessageIO, Protocol},
};
use std::{
iter, mem,
pin::Pin,
task::{Context, Poll},
};
#[pin_project::pin_project]
pub struct DialerSelectFuture<R, I: Iterator> {
protocols: iter::Peekable<I>,
state: State<R, I::Item>,
lazy: bool,
}
impl<R, I> DialerSelectFuture<R, I>
where
R: AsyncRead + AsyncWrite,
I: Iterator,
I::Item: AsRef<str>,
{
pub fn new(io: R, protocols: I) -> Self {
DialerSelectFuture {
protocols: protocols.peekable(),
state: State::Initial {
io: MessageIO::new(io),
},
lazy: false,
}
}
}
enum State<R, P> {
Initial { io: MessageIO<R> },
SendProtocol { io: MessageIO<R>, protocol: P },
FlushProtocol { io: MessageIO<R>, protocol: P },
AwaitProtocol { io: MessageIO<R>, protocol: P },
Done,
}
impl<R, I> Future for DialerSelectFuture<R, I>
where
R: AsyncRead + AsyncWrite + Unpin,
I: Iterator,
I::Item: AsRef<str>,
{
type Output = Result<(I::Item, Negotiated<R>), NegotiationError>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
loop {
match mem::replace(this.state, State::Done) {
State::Initial { mut io } => {
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = State::Initial { io };
return Poll::Pending;
}
};
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = State::SendProtocol { io, protocol };
}
State::SendProtocol { mut io, protocol } => {
tracing::trace!("Sending protocol: {}", protocol.as_ref());
match Pin::new(&mut io).poll_ready(cx)? {
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = State::SendProtocol { io, protocol };
return Poll::Pending;
}
};
let p = Protocol::try_from(protocol.as_ref())?;
if let Err(err) = Pin::new(&mut io).start_send(Message::Protocol(p.clone())) {
return Poll::Ready(Err(From::from(err)));
}
if this.protocols.peek().is_some() {
*this.state = State::FlushProtocol { io, protocol };
} else if *this.lazy {
tracing::trace!("Expecting protocol: {}", p.as_ref());
let io = Negotiated::expecting(io.into_reader(), p);
return Poll::Ready(Ok((protocol, io)));
} else {
*this.state = State::FlushProtocol { io, protocol };
}
}
State::FlushProtocol { mut io, protocol } => {
match Pin::new(&mut io).poll_flush(cx)? {
Poll::Ready(()) => {}
Poll::Pending => {
*this.state = State::FlushProtocol { io, protocol };
return Poll::Pending;
}
};
*this.state = State::AwaitProtocol { io, protocol };
}
State::AwaitProtocol { mut io, protocol } => {
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Ready(None) => {
tracing::debug!("No message received, connection closed");
return Poll::Ready(Err(NegotiationError::Failed));
}
Poll::Pending => {
*this.state = State::AwaitProtocol { io, protocol };
return Poll::Pending;
}
};
match msg {
Message::Protocol(p) if p.as_ref() == protocol.as_ref() => {
let io = Negotiated::completed(io.into_inner());
return Poll::Ready(Ok((protocol, io)));
}
Message::NotAvailable => {
tracing::debug!("Protocol not available, trying next protocol");
let protocol = this.protocols.next().ok_or(NegotiationError::Failed)?;
*this.state = State::SendProtocol { io, protocol }
}
_ => {
*this.state = State::Initial { io };
}
}
}
_ => panic!("Unexpected state in DialerSelectFuture"),
}
}
}
}