use bytes::BytesMut;
use crate::protocol::{Protocol, MessageReader, Message, Version, ProtocolError};
use futures::{prelude::*, Async, try_ready};
use log::debug;
use tokio_io::{AsyncRead, AsyncWrite};
use std::{mem, io, fmt, error::Error};
pub struct Negotiated<TInner> {
state: State<TInner>
}
pub struct NegotiatedComplete<TInner> {
inner: Option<Negotiated<TInner>>
}
impl<TInner: AsyncRead + AsyncWrite> Future for NegotiatedComplete<TInner> {
type Item = Negotiated<TInner>;
type Error = NegotiationError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let mut io = self.inner.take().expect("NegotiatedFuture called after completion.");
if io.poll()?.is_not_ready() {
self.inner = Some(io);
return Ok(Async::NotReady)
}
return Ok(Async::Ready(io))
}
}
impl<TInner> Negotiated<TInner> {
pub(crate) fn completed(io: TInner, remaining: BytesMut) -> Self {
Negotiated { state: State::Completed { io, remaining } }
}
pub(crate) fn expecting(io: MessageReader<TInner>, protocol: Protocol) -> Self {
Negotiated { state: State::Expecting { io, protocol } }
}
fn poll(&mut self) -> Poll<(), NegotiationError>
where
TInner: AsyncRead + AsyncWrite
{
match self.poll_flush() {
Ok(Async::Ready(())) => {},
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(e) => {
if e.kind() != io::ErrorKind::WriteZero {
return Err(e.into())
}
}
}
if let State::Completed { remaining, .. } = &mut self.state {
let _ = remaining.take(); return Ok(Async::Ready(()))
}
loop {
match mem::replace(&mut self.state, State::Invalid) {
State::Expecting { mut io, protocol } => {
let msg = match io.poll() {
Ok(Async::Ready(Some(msg))) => msg,
Ok(Async::NotReady) => {
self.state = State::Expecting { io, protocol };
return Ok(Async::NotReady)
}
Ok(Async::Ready(None)) => {
self.state = State::Expecting { io, protocol };
return Err(ProtocolError::IoError(
io::ErrorKind::UnexpectedEof.into()).into())
}
Err(err) => {
self.state = State::Expecting { io, protocol };
return Err(err.into())
}
};
if let Message::Header(Version::V1) = &msg {
self.state = State::Expecting { io, protocol };
continue
}
if let Message::Protocol(p) = &msg {
if p.as_ref() == protocol.as_ref() {
debug!("Negotiated: Received confirmation for protocol: {}", p);
let (io, remaining) = io.into_inner();
self.state = State::Completed { io, remaining };
return Ok(Async::Ready(()))
}
}
return Err(NegotiationError::Failed)
}
_ => panic!("Negotiated: Invalid state")
}
}
}
pub fn complete(self) -> NegotiatedComplete<TInner> {
NegotiatedComplete { inner: Some(self) }
}
}
enum State<R> {
Expecting { io: MessageReader<R>, protocol: Protocol },
Completed { io: R, remaining: BytesMut },
Invalid,
}
impl<R> io::Read for Negotiated<R>
where
R: AsyncRead + AsyncWrite
{
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
if let State::Completed { io, remaining } = &mut self.state {
if remaining.is_empty() {
return io.read(buf)
}
}
let result = self.poll();
if let Ok(Async::NotReady) = result {
return Err(io::ErrorKind::WouldBlock.into())
}
if let Err(err) = result {
return Err(err.into())
}
}
}
}
impl<TInner> AsyncRead for Negotiated<TInner>
where
TInner: AsyncRead + AsyncWrite
{
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
match &self.state {
State::Completed { io, .. } =>
io.prepare_uninitialized_buffer(buf),
State::Expecting { io, .. } =>
io.inner_ref().prepare_uninitialized_buffer(buf),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
}
impl<TInner> io::Write for Negotiated<TInner>
where
TInner: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
match &mut self.state {
State::Completed { io, ref mut remaining } => {
if !remaining.is_empty() {
let remaining_len = remaining.len();
remaining.extend_from_slice(buf);
match io.write(&remaining) {
Err(e) => {
remaining.split_off(remaining_len);
Err(e)
}
Ok(n) => {
remaining.split_to(n);
if !remaining.is_empty() {
let written = if n < buf.len() {
remaining.split_off(remaining_len);
n
} else {
buf.len()
};
debug_assert!(remaining.len() <= remaining_len);
Ok(written)
} else {
Ok(buf.len())
}
}
}
} else {
io.write(buf)
}
},
State::Expecting { io, .. } => io.write(buf),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
fn flush(&mut self) -> io::Result<()> {
match &mut self.state {
State::Completed { io, ref mut remaining } => {
while !remaining.is_empty() {
let n = io.write(remaining)?;
if n == 0 {
return Err(io::Error::new(
io::ErrorKind::WriteZero,
"Failed to write remaining buffer."))
}
remaining.split_to(n);
}
io.flush()
},
State::Expecting { io, .. } => io.flush(),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
}
impl<TInner> AsyncWrite for Negotiated<TInner>
where
TInner: AsyncWrite + AsyncRead
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
try_ready!(self.poll().map_err(Into::<io::Error>::into));
match &mut self.state {
State::Completed { io, .. } => io.shutdown(),
State::Expecting { io, .. } => io.shutdown(),
State::Invalid => panic!("Negotiated: Invalid state")
}
}
}
#[derive(Debug)]
pub enum NegotiationError {
ProtocolError(ProtocolError),
Failed,
}
impl From<ProtocolError> for NegotiationError {
fn from(err: ProtocolError) -> NegotiationError {
NegotiationError::ProtocolError(err)
}
}
impl From<io::Error> for NegotiationError {
fn from(err: io::Error) -> NegotiationError {
ProtocolError::from(err).into()
}
}
impl Into<io::Error> for NegotiationError {
fn into(self) -> io::Error {
if let NegotiationError::ProtocolError(e) = self {
return e.into()
}
io::Error::new(io::ErrorKind::Other, self)
}
}
impl Error for NegotiationError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
NegotiationError::ProtocolError(err) => Some(err),
_ => None,
}
}
}
impl fmt::Display for NegotiationError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
write!(fmt, "{}", Error::description(self))
}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
use std::io::Write;
struct Capped { buf: Vec<u8>, step: usize }
impl io::Write for Capped {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
if self.buf.len() + buf.len() > self.buf.capacity() {
return Err(io::ErrorKind::WriteZero.into())
}
self.buf.write(&buf[.. usize::min(self.step, buf.len())])
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl AsyncWrite for Capped {
fn shutdown(&mut self) -> Poll<(), io::Error> {
Ok(().into())
}
}
#[test]
fn write_remaining() {
fn prop(rem: Vec<u8>, new: Vec<u8>, free: u8) -> TestResult {
let cap = rem.len() + free as usize;
let buf = Capped { buf: Vec::with_capacity(cap), step: free as usize };
let mut rem = BytesMut::from(rem);
let mut io = Negotiated::completed(buf, rem.clone());
let mut written = 0;
loop {
match io.write(&new[written..]) {
Ok(n) =>
if let State::Completed { remaining, .. } = &io.state {
if n == rem.len() + new[written..].len() {
assert!(remaining.is_empty())
} else {
assert!(remaining.len() <= rem.len());
}
written += n;
if written == new.len() {
return TestResult::passed()
}
rem = remaining.clone();
} else {
return TestResult::failed()
}
Err(_) =>
if let State::Completed { remaining, .. } = &io.state {
assert!(rem.len() + new[written..].len() > cap);
assert_eq!(remaining, &rem);
return TestResult::passed()
} else {
return TestResult::failed()
}
}
}
}
quickcheck(prop as fn(_,_,_) -> _)
}
}