use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
use bytes::{BytesMut, Buf};
use futures::{prelude::*, io::{IoSlice, IoSliceMut}, ready};
use pin_project::{pin_project, project};
use std::{error::Error, fmt, io, mem, pin::Pin, task::{Context, Poll}};
#[pin_project]
#[derive(Debug)]
pub struct Negotiated<TInner> {
#[pin]
state: State<TInner>
}
#[derive(Debug)]
pub struct NegotiatedComplete<TInner> {
inner: Option<Negotiated<TInner>>,
}
impl<TInner> Future for NegotiatedComplete<TInner>
where
TInner: AsyncRead + AsyncWrite + Unpin,
{
type Output = Result<Negotiated<TInner>, NegotiationError>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
match Negotiated::poll(Pin::new(&mut io), cx) {
Poll::Pending => {
self.inner = Some(io);
return Poll::Pending
},
Poll::Ready(Ok(())) => Poll::Ready(Ok(io)),
Poll::Ready(Err(err)) => {
self.inner = Some(io);
return Poll::Ready(Err(err));
}
}
}
}
impl<TInner> Negotiated<TInner> {
pub(crate) fn completed(io: TInner, remaining: BytesMut) -> Self {
Negotiated { state: State::Completed { io, remaining } }
}
pub(crate) fn expecting(io: MessageReader<TInner>, protocol: Protocol, version: Version) -> Self {
Negotiated { state: State::Expecting { io, protocol, version } }
}
#[project]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), NegotiationError>>
where
TInner: AsyncRead + AsyncWrite + Unpin
{
match self.as_mut().poll_flush(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(e)) => {
if e.kind() != io::ErrorKind::WriteZero {
return Poll::Ready(Err(e.into()))
}
}
}
let mut this = self.project();
#[project]
match this.state.as_mut().project() {
State::Completed { remaining, .. } => {
debug_assert!(remaining.is_empty());
return Poll::Ready(Ok(()))
}
_ => {}
}
loop {
match mem::replace(&mut *this.state, State::Invalid) {
State::Expecting { mut io, protocol, version } => {
let msg = match Pin::new(&mut io).poll_next(cx)? {
Poll::Ready(Some(msg)) => msg,
Poll::Pending => {
*this.state = State::Expecting { io, protocol, version };
return Poll::Pending
},
Poll::Ready(None) => {
return Poll::Ready(Err(ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()).into()));
}
};
if let Message::Header(v) = &msg {
if *v == version {
continue
}
}
if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() {
log::debug!("Negotiated: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
*this.state = State::Completed { io, remaining };
return Poll::Ready(Ok(()));
}
}
return Poll::Ready(Err(NegotiationError::Failed));
}
_ => panic!("Negotiated: Invalid state")
}
}
}
pub fn complete(self) -> NegotiatedComplete<TInner> {
NegotiatedComplete { inner: Some(self) }
}
}
#[pin_project]
#[derive(Debug)]
enum State<R> {
Expecting {
#[pin]
io: MessageReader<R>,
protocol: Protocol,
version: Version
},
Completed { #[pin] io: R, remaining: BytesMut },
Invalid,
}
impl<TInner> AsyncRead for Negotiated<TInner>
where
TInner: AsyncRead + AsyncWrite + Unpin
{
#[project]
fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8])
-> Poll<Result<usize, io::Error>>
{
loop {
#[project]
match self.as_mut().project().state.project() {
State::Completed { io, remaining } => {
if remaining.is_empty() {
return io.poll_read(cx, buf)
}
},
_ => {}
}
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
}
}
#[project]
fn poll_read_vectored(mut self: Pin<&mut Self>, cx: &mut Context, bufs: &mut [IoSliceMut])
-> Poll<Result<usize, io::Error>>
{
loop {
#[project]
match self.as_mut().project().state.project() {
State::Completed { io, remaining } => {
if remaining.is_empty() {
return io.poll_read_vectored(cx, bufs)
}
},
_ => {}
}
match self.as_mut().poll(cx) {
Poll::Ready(Ok(())) => {},
Poll::Pending => return Poll::Pending,
Poll::Ready(Err(err)) => return Poll::Ready(Err(From::from(err))),
}
}
}
}
impl<TInner> AsyncWrite for Negotiated<TInner>
where
TInner: AsyncWrite + AsyncRead + Unpin
{
#[project]
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write(cx, buf)
},
State::Expecting { io, .. } => io.poll_write(cx, buf),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_flush(cx)
},
State::Expecting { io, .. } => io.poll_flush(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
ready!(self.as_mut().poll(cx).map_err(Into::<io::Error>::into)?);
ready!(self.as_mut().poll_flush(cx).map_err(Into::<io::Error>::into)?);
#[project]
match self.project().state.project() {
State::Completed { io, .. } => io.poll_close(cx),
State::Expecting { io, .. } => io.poll_close(cx),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
#[project]
fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context, bufs: &[IoSlice])
-> Poll<Result<usize, io::Error>>
{
#[project]
match self.project().state.project() {
State::Completed { mut io, remaining } => {
while !remaining.is_empty() {
let n = ready!(io.as_mut().poll_write(cx, &remaining)?);
if n == 0 {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
remaining.advance(n);
}
io.poll_write_vectored(cx, bufs)
},
State::Expecting { io, .. } => io.poll_write_vectored(cx, bufs),
State::Invalid => panic!("Negotiated: Invalid state"),
}
}
}
#[derive(Debug)]
pub enum NegotiationError {
ProtocolError(ProtocolError),
Failed,
}
impl From<ProtocolError> for NegotiationError {
fn from(err: ProtocolError) -> NegotiationError {
NegotiationError::ProtocolError(err)
}
}
impl From<io::Error> for NegotiationError {
fn from(err: io::Error) -> NegotiationError {
ProtocolError::from(err).into()
}
}
impl From<NegotiationError> for io::Error {
fn from(err: NegotiationError) -> io::Error {
if let NegotiationError::ProtocolError(e) = err {
return e.into()
}
io::Error::new(io::ErrorKind::Other, err)
}
}
impl Error for NegotiationError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
NegotiationError::ProtocolError(err) => Some(err),
_ => None,
}
}
}
impl fmt::Display for NegotiationError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
NegotiationError::ProtocolError(p) =>
fmt.write_fmt(format_args!("Protocol error: {}", p)),
NegotiationError::Failed =>
fmt.write_str("Protocol negotiation failed.")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
use std::{io::Write, task::Poll};
struct Capped { buf: Vec<u8>, step: usize }
impl AsyncRead for Capped {
fn poll_read(self: Pin<&mut Self>, _: &mut Context, _: &mut [u8]) -> Poll<Result<usize, io::Error>> {
unreachable!()
}
}
impl AsyncWrite for Capped {
fn poll_write(mut self: Pin<&mut Self>, _: &mut Context, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
if self.buf.len() + buf.len() > self.buf.capacity() {
return Poll::Ready(Err(io::ErrorKind::WriteZero.into()))
}
let len = usize::min(self.step, buf.len());
let n = Write::write(&mut self.buf, &buf[.. len]).unwrap();
Poll::Ready(Ok(n))
}
fn poll_flush(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _: &mut Context) -> Poll<Result<(), io::Error>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn write_remaining() {
fn prop(rem: Vec<u8>, new: Vec<u8>, free: u8, step: u8) -> TestResult {
let cap = rem.len() + free as usize;
let step = u8::min(free, step) as usize + 1;
let buf = Capped { buf: Vec::with_capacity(cap), step };
let rem = BytesMut::from(&rem[..]);
let mut io = Negotiated::completed(buf, rem.clone());
let mut written = 0;
loop {
match future::poll_fn(|cx| Pin::new(&mut io).poll_write(cx, &new[written..])).now_or_never().unwrap() {
Ok(n) =>
if let State::Completed { remaining, .. } = &io.state {
assert!(remaining.is_empty());
written += n;
if written == new.len() {
return TestResult::passed()
}
} else {
return TestResult::failed()
}
Err(e) if e.kind() == io::ErrorKind::WriteZero => {
if let State::Completed { .. } = &io.state {
assert!(rem.len() + new.len() > cap);
return TestResult::passed()
} else {
return TestResult::failed()
}
}
Err(e) => panic!("Unexpected error: {:?}", e),
}
}
}
quickcheck(prop as fn(_,_,_,_) -> _)
}
}