use bytes::{Bytes, BytesMut, BufMut};
use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use futures::{prelude::*, try_ready};
use log::trace;
use std::{io, fmt, error::Error, convert::TryFrom};
use tokio_io::{AsyncRead, AsyncWrite};
use unsigned_varint as uvi;
const MAX_PROTOCOLS: usize = 1000;
const MAX_PROTOCOL_LEN: usize = 140;
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
const MSG_MULTISTREAM_2_0: &[u8] = b"/multistream/2.0.0\n";
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
const MSG_LS: &[u8] = b"ls\n";
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Version {
V1,
V2,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Protocol(Bytes);
impl AsRef<[u8]> for Protocol {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl TryFrom<Bytes> for Protocol {
type Error = ProtocolError;
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
if !value.as_ref().starts_with(b"/") || value.len() > MAX_PROTOCOL_LEN {
return Err(ProtocolError::InvalidProtocol)
}
Ok(Protocol(value))
}
}
impl TryFrom<&[u8]> for Protocol {
type Error = ProtocolError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(Bytes::from(value))
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", String::from_utf8_lossy(&self.0))
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Message {
Header(Version),
Protocol(Protocol),
ListProtocols,
Protocols(Vec<Protocol>),
NotAvailable,
}
impl Message {
pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
match self {
Message::Header(Version::V1) => {
dest.reserve(MSG_MULTISTREAM_1_0.len());
dest.put(MSG_MULTISTREAM_1_0);
Ok(())
}
Message::Header(Version::V2) => {
dest.reserve(MSG_MULTISTREAM_2_0.len());
dest.put(MSG_MULTISTREAM_2_0);
Ok(())
}
Message::Protocol(p) => {
let len = p.0.as_ref().len() + 1; dest.reserve(len);
dest.put(p.0.as_ref());
dest.put(&b"\n"[..]);
Ok(())
}
Message::ListProtocols => {
dest.reserve(MSG_LS.len());
dest.put(MSG_LS);
Ok(())
}
Message::Protocols(ps) => {
let mut buf = uvi::encode::usize_buffer();
let mut out_msg = Vec::from(uvi::encode::usize(ps.len(), &mut buf));
for p in ps {
out_msg.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); out_msg.extend_from_slice(p.0.as_ref());
out_msg.push(b'\n')
}
dest.reserve(out_msg.len());
dest.put(out_msg);
Ok(())
}
Message::NotAvailable => {
dest.reserve(MSG_PROTOCOL_NA.len());
dest.put(MSG_PROTOCOL_NA);
Ok(())
}
}
}
pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
if msg == MSG_MULTISTREAM_1_0 {
return Ok(Message::Header(Version::V1))
}
if msg == MSG_MULTISTREAM_2_0 {
return Ok(Message::Header(Version::V2))
}
if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') && msg.len() <= MAX_PROTOCOL_LEN {
let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
return Ok(Message::Protocol(p));
}
if msg == MSG_PROTOCOL_NA {
return Ok(Message::NotAvailable);
}
if msg == MSG_LS {
return Ok(Message::ListProtocols)
}
let (num_protocols, mut remaining) = uvi::decode::usize(&msg)?;
if num_protocols > MAX_PROTOCOLS {
return Err(ProtocolError::TooManyProtocols)
}
let mut protocols = Vec::with_capacity(num_protocols);
for _ in 0 .. num_protocols {
let (len, rem) = uvi::decode::usize(remaining)?;
if len == 0 || len > rem.len() || rem[len - 1] != b'\n' {
return Err(ProtocolError::InvalidMessage)
}
let p = Protocol::try_from(Bytes::from(&rem[.. len - 1]))?;
protocols.push(p);
remaining = &rem[len ..]
}
return Ok(Message::Protocols(protocols));
}
}
pub struct MessageIO<R> {
inner: LengthDelimited<R>,
}
impl<R> MessageIO<R> {
pub fn new(inner: R) -> MessageIO<R>
where
R: AsyncRead + AsyncWrite
{
Self { inner: LengthDelimited::new(inner) }
}
pub fn into_reader(self) -> MessageReader<R> {
MessageReader { inner: self.inner.into_reader() }
}
pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner()
}
}
impl<R> Sink for MessageIO<R>
where
R: AsyncWrite,
{
type SinkItem = Message;
type SinkError = ProtocolError;
fn start_send(&mut self, msg: Self::SinkItem) -> StartSend<Self::SinkItem, Self::SinkError> {
let mut buf = BytesMut::new();
msg.encode(&mut buf)?;
match self.inner.start_send(buf.freeze())? {
AsyncSink::NotReady(_) => Ok(AsyncSink::NotReady(msg)),
AsyncSink::Ready => Ok(AsyncSink::Ready),
}
}
fn poll_complete(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.poll_complete()?)
}
fn close(&mut self) -> Poll<(), Self::SinkError> {
Ok(self.inner.close()?)
}
}
impl<R> Stream for MessageIO<R>
where
R: AsyncRead
{
type Item = Message;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
poll_stream(&mut self.inner)
}
}
pub struct MessageReader<R> {
inner: LengthDelimitedReader<R>
}
impl<R> MessageReader<R> {
pub fn into_inner(self) -> (R, BytesMut) {
self.inner.into_inner()
}
pub fn inner_ref(&self) -> &R {
self.inner.inner_ref()
}
}
impl<R> Stream for MessageReader<R>
where
R: AsyncRead
{
type Item = Message;
type Error = ProtocolError;
fn poll(&mut self) -> Poll<Option<Self::Item>, Self::Error> {
poll_stream(&mut self.inner)
}
}
impl<R> io::Write for MessageReader<R>
where
R: AsyncWrite
{
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.inner.write(buf)
}
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<TInner> AsyncWrite for MessageReader<TInner>
where
TInner: AsyncWrite
{
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.inner.shutdown()
}
}
fn poll_stream<S>(stream: &mut S) -> Poll<Option<Message>, ProtocolError>
where
S: Stream<Item = Bytes, Error = io::Error>,
{
let msg = if let Some(msg) = try_ready!(stream.poll()) {
Message::decode(msg)?
} else {
return Ok(Async::Ready(None))
};
trace!("Received message: {:?}", msg);
Ok(Async::Ready(Some(msg)))
}
#[derive(Debug)]
pub enum ProtocolError {
IoError(io::Error),
InvalidMessage,
InvalidProtocol,
TooManyProtocols,
}
impl From<io::Error> for ProtocolError {
fn from(err: io::Error) -> ProtocolError {
ProtocolError::IoError(err)
}
}
impl Into<io::Error> for ProtocolError {
fn into(self) -> io::Error {
if let ProtocolError::IoError(e) = self {
return e
}
return io::ErrorKind::InvalidData.into()
}
}
impl From<uvi::decode::Error> for ProtocolError {
fn from(err: uvi::decode::Error) -> ProtocolError {
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
}
}
impl Error for ProtocolError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
ProtocolError::IoError(ref err) => Some(err),
_ => None,
}
}
}
impl fmt::Display for ProtocolError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ProtocolError::IoError(e) =>
write!(fmt, "I/O error: {}", e),
ProtocolError::InvalidMessage =>
write!(fmt, "Received an invalid message."),
ProtocolError::InvalidProtocol =>
write!(fmt, "A protocol (name) is invalid."),
ProtocolError::TooManyProtocols =>
write!(fmt, "Too many protocols received.")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
use rand::Rng;
use rand::distributions::Alphanumeric;
use std::iter;
impl Arbitrary for Protocol {
fn arbitrary<G: Gen>(g: &mut G) -> Protocol {
let n = g.gen_range(1, g.size());
let p: String = iter::repeat(())
.map(|()| g.sample(Alphanumeric))
.take(n)
.collect();
Protocol(Bytes::from(format!("/{}", p)))
}
}
impl Arbitrary for Message {
fn arbitrary<G: Gen>(g: &mut G) -> Message {
match g.gen_range(0, 5) {
0 => Message::Header(Version::V1),
1 => Message::NotAvailable,
2 => Message::ListProtocols,
3 => Message::Protocol(Protocol::arbitrary(g)),
4 => Message::Protocols(Vec::arbitrary(g)),
_ => panic!()
}
}
}
#[test]
fn encode_decode_message() {
fn prop(msg: Message) {
let mut buf = BytesMut::new();
msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
match Message::decode(buf.freeze()) {
Ok(m) => assert_eq!(m, msg),
Err(e) => panic!("Decoding failed: {:?}", e)
}
}
quickcheck(prop as fn(_))
}
}