use bytes::Bytes;
use futures::{prelude::*, sink, Async, AsyncSink, StartSend, try_ready};
use crate::length_delimited::LengthDelimited;
use crate::protocol::DialerToListenerMessage;
use crate::protocol::ListenerToDialerMessage;
use crate::protocol::MultistreamSelectError;
use crate::protocol::MULTISTREAM_PROTOCOL_WITH_LF;
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint::decode;
pub struct Dialer<R> {
inner: LengthDelimited<Bytes, R>,
handshake_finished: bool,
}
impl<R> Dialer<R>
where
R: AsyncRead + AsyncWrite,
{
pub fn new(inner: R) -> DialerFuture<R> {
let sender = LengthDelimited::new(inner);
DialerFuture {
inner: sender.send(Bytes::from(MULTISTREAM_PROTOCOL_WITH_LF))
}
}
#[inline]
pub fn into_inner(self) -> R {
self.inner.into_inner()
}
}
impl<R> Sink for Dialer<R>
where
R: AsyncRead + AsyncWrite,
{
type SinkItem = DialerToListenerMessage;
type SinkError = MultistreamSelectError;
fn start_send(&mut self, item: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
match item {
DialerToListenerMessage::ProtocolRequest { name } => {
if !name.starts_with(b"/") {
return Err(MultistreamSelectError::WrongProtocolName);
}
let mut protocol = Bytes::from(name);
protocol.extend_from_slice(&[b'\n']);
match self.inner.start_send(protocol) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(mut protocol)) => {
let protocol_len = protocol.len();
protocol.truncate(protocol_len - 1);
Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolRequest { name: protocol },
))
}
Err(err) => Err(err.into()),
}
}
DialerToListenerMessage::ProtocolsListRequest => {
match self.inner.start_send(Bytes::from(&b"ls\n"[..])) {
Ok(AsyncSink::Ready) => Ok(AsyncSink::Ready),
Ok(AsyncSink::NotReady(_)) => Ok(AsyncSink::NotReady(
DialerToListenerMessage::ProtocolsListRequest,
)),
Err(err) => Err(err.into()),
}
}
}
}
#[inline]
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?)
}
#[inline]
fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?)
}
}
impl<R> Stream for Dialer<R>
where
R: AsyncRead + AsyncWrite,
{
type Item = ListenerToDialerMessage;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
loop {
let mut frame = match self.inner.poll() {
Ok(Async::Ready(Some(frame))) => frame,
Ok(Async::Ready(None)) => return Ok(Async::Ready(None)),
Ok(Async::NotReady) => return Ok(Async::NotReady),
Err(err) => return Err(err.into()),
};
if !self.handshake_finished {
if frame == MULTISTREAM_PROTOCOL_WITH_LF {
self.handshake_finished = true;
continue;
} else {
return Err(MultistreamSelectError::FailedHandshake);
}
}
if frame.get(0) == Some(&b'/') && frame.last() == Some(&b'\n') {
let frame_len = frame.len();
let protocol = frame.split_to(frame_len - 1);
return Ok(Async::Ready(Some(ListenerToDialerMessage::ProtocolAck {
name: protocol,
})));
} else if frame == b"na\n"[..] {
return Ok(Async::Ready(Some(ListenerToDialerMessage::NotAvailable)));
} else {
let (num_protocols, mut remaining) = decode::usize(&frame)?;
if num_protocols > 1000 { return Err(MultistreamSelectError::VarintParseError("too many protocols".into()))
}
let mut out = Vec::with_capacity(num_protocols);
for _ in 0 .. num_protocols {
let (len, rem) = decode::usize(remaining)?;
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
return Err(MultistreamSelectError::UnknownMessage)
}
out.push(Bytes::from(&rem[.. len - 1]));
remaining = &rem[len ..]
}
return Ok(Async::Ready(Some(
ListenerToDialerMessage::ProtocolsListResponse { list: out },
)));
}
}
}
}
pub struct DialerFuture<T: AsyncWrite> {
inner: sink::Send<LengthDelimited<Bytes, T>>
}
impl<T: AsyncWrite> Future for DialerFuture<T> {
type Item = Dialer<T>;
type Error = MultistreamSelectError;
fn poll(&mut self) -> Poll<Self::Item, Self::Error> {
let inner = try_ready!(self.inner.poll());
Ok(Async::Ready(Dialer { inner, handshake_finished: false }))
}
}
#[cfg(test)]
mod tests {
use tokio::runtime::current_thread::Runtime;
use tokio_tcp::{TcpListener, TcpStream};
use bytes::Bytes;
use futures::Future;
use futures::{Sink, Stream};
use crate::protocol::{Dialer, DialerToListenerMessage, MultistreamSelectError};
#[test]
fn wrong_proto_name() {
let listener = TcpListener::bind(&"127.0.0.1:0".parse().unwrap()).unwrap();
let listener_addr = listener.local_addr().unwrap();
let server = listener
.incoming()
.into_future()
.map(|_| ())
.map_err(|(e, _)| e.into());
let client = TcpStream::connect(&listener_addr)
.from_err()
.and_then(move |stream| Dialer::new(stream))
.and_then(move |dialer| {
let p = Bytes::from("invalid_name");
dialer.send(DialerToListenerMessage::ProtocolRequest { name: p })
});
let mut rt = Runtime::new().unwrap();
match rt.block_on(server.join(client)) {
Err(MultistreamSelectError::WrongProtocolName) => (),
_ => panic!(),
}
}
}