use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
use crate::Version;
use bytes::{BufMut, Bytes, BytesMut};
use futures::{io::IoSlice, prelude::*, ready};
use std::{
convert::TryFrom,
error::Error,
fmt, io,
pin::Pin,
task::{Context, Poll},
};
use unsigned_varint as uvi;
const MAX_PROTOCOLS: usize = 1000;
const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
const MSG_PROTOCOL_NA: &[u8] = b"na\n";
const MSG_LS: &[u8] = b"ls\n";
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub(crate) enum HeaderLine {
V1,
}
impl From<Version> for HeaderLine {
fn from(v: Version) -> HeaderLine {
match v {
Version::V1 | Version::V1Lazy => HeaderLine::V1,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub(crate) struct Protocol(String);
impl AsRef<str> for Protocol {
fn as_ref(&self) -> &str {
self.0.as_ref()
}
}
impl TryFrom<Bytes> for Protocol {
type Error = ProtocolError;
fn try_from(value: Bytes) -> Result<Self, Self::Error> {
if !value.as_ref().starts_with(b"/") {
return Err(ProtocolError::InvalidProtocol);
}
let protocol_as_string =
String::from_utf8(value.to_vec()).map_err(|_| ProtocolError::InvalidProtocol)?;
Ok(Protocol(protocol_as_string))
}
}
impl TryFrom<&[u8]> for Protocol {
type Error = ProtocolError;
fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(Bytes::copy_from_slice(value))
}
}
impl TryFrom<&str> for Protocol {
type Error = ProtocolError;
fn try_from(value: &str) -> Result<Self, Self::Error> {
if !value.starts_with('/') {
return Err(ProtocolError::InvalidProtocol);
}
Ok(Protocol(value.to_owned()))
}
}
impl fmt::Display for Protocol {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum Message {
Header(HeaderLine),
Protocol(Protocol),
ListProtocols,
Protocols(Vec<Protocol>),
NotAvailable,
}
impl Message {
fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
match self {
Message::Header(HeaderLine::V1) => {
dest.reserve(MSG_MULTISTREAM_1_0.len());
dest.put(MSG_MULTISTREAM_1_0);
Ok(())
}
Message::Protocol(p) => {
let len = p.as_ref().len() + 1; dest.reserve(len);
dest.put(p.0.as_ref());
dest.put_u8(b'\n');
Ok(())
}
Message::ListProtocols => {
dest.reserve(MSG_LS.len());
dest.put(MSG_LS);
Ok(())
}
Message::Protocols(ps) => {
let mut buf = uvi::encode::usize_buffer();
let mut encoded = Vec::with_capacity(ps.len());
for p in ps {
encoded.extend(uvi::encode::usize(p.as_ref().len() + 1, &mut buf)); encoded.extend_from_slice(p.0.as_ref());
encoded.push(b'\n')
}
encoded.push(b'\n');
dest.reserve(encoded.len());
dest.put(encoded.as_ref());
Ok(())
}
Message::NotAvailable => {
dest.reserve(MSG_PROTOCOL_NA.len());
dest.put(MSG_PROTOCOL_NA);
Ok(())
}
}
}
fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
if msg == MSG_MULTISTREAM_1_0 {
return Ok(Message::Header(HeaderLine::V1));
}
if msg == MSG_PROTOCOL_NA {
return Ok(Message::NotAvailable);
}
if msg == MSG_LS {
return Ok(Message::ListProtocols);
}
if msg.first() == Some(&b'/')
&& msg.last() == Some(&b'\n')
&& !msg[..msg.len() - 1].contains(&b'\n')
{
let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
return Ok(Message::Protocol(p));
}
let mut protocols = Vec::new();
let mut remaining: &[u8] = &msg;
loop {
if remaining == [b'\n'] {
break;
} else if protocols.len() == MAX_PROTOCOLS {
return Err(ProtocolError::TooManyProtocols);
}
let (len, tail) = uvi::decode::usize(remaining)?;
if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
return Err(ProtocolError::InvalidMessage);
}
let p = Protocol::try_from(Bytes::copy_from_slice(&tail[..len - 1]))?;
protocols.push(p);
remaining = &tail[len..];
}
Ok(Message::Protocols(protocols))
}
}
#[pin_project::pin_project]
pub(crate) struct MessageIO<R> {
#[pin]
inner: LengthDelimited<R>,
}
impl<R> MessageIO<R> {
pub(crate) fn new(inner: R) -> MessageIO<R>
where
R: AsyncRead + AsyncWrite,
{
Self {
inner: LengthDelimited::new(inner),
}
}
pub(crate) fn into_reader(self) -> MessageReader<R> {
MessageReader {
inner: self.inner.into_reader(),
}
}
pub(crate) fn into_inner(self) -> R {
self.inner.into_inner()
}
}
impl<R> Sink<Message> for MessageIO<R>
where
R: AsyncWrite,
{
type Error = ProtocolError;
fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_ready(cx).map_err(From::from)
}
fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
let mut buf = BytesMut::new();
item.encode(&mut buf)?;
self.project()
.inner
.start_send(buf.freeze())
.map_err(From::from)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_flush(cx).map_err(From::from)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.project().inner.poll_close(cx).map_err(From::from)
}
}
impl<R> Stream for MessageIO<R>
where
R: AsyncRead,
{
type Item = Result<Message, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
match poll_stream(self.project().inner, cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(None) => Poll::Ready(None),
Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
}
}
}
#[pin_project::pin_project]
#[derive(Debug)]
pub(crate) struct MessageReader<R> {
#[pin]
inner: LengthDelimitedReader<R>,
}
impl<R> MessageReader<R> {
pub(crate) fn into_inner(self) -> R {
self.inner.into_inner()
}
}
impl<R> Stream for MessageReader<R>
where
R: AsyncRead,
{
type Item = Result<Message, ProtocolError>;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
poll_stream(self.project().inner, cx)
}
}
impl<TInner> AsyncWrite for MessageReader<TInner>
where
TInner: AsyncWrite,
{
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
self.project().inner.poll_close(cx)
}
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
self.project().inner.poll_write_vectored(cx, bufs)
}
}
fn poll_stream<S>(
stream: Pin<&mut S>,
cx: &mut Context<'_>,
) -> Poll<Option<Result<Message, ProtocolError>>>
where
S: Stream<Item = Result<Bytes, io::Error>>,
{
let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
match Message::decode(msg) {
Ok(m) => m,
Err(err) => return Poll::Ready(Some(Err(err))),
}
} else {
return Poll::Ready(None);
};
log::trace!("Received message: {:?}", msg);
Poll::Ready(Some(Ok(msg)))
}
#[derive(Debug)]
pub enum ProtocolError {
IoError(io::Error),
InvalidMessage,
InvalidProtocol,
TooManyProtocols,
}
impl From<io::Error> for ProtocolError {
fn from(err: io::Error) -> ProtocolError {
ProtocolError::IoError(err)
}
}
impl From<ProtocolError> for io::Error {
fn from(err: ProtocolError) -> Self {
if let ProtocolError::IoError(e) = err {
return e;
}
io::ErrorKind::InvalidData.into()
}
}
impl From<uvi::decode::Error> for ProtocolError {
fn from(err: uvi::decode::Error) -> ProtocolError {
Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
}
}
impl Error for ProtocolError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match *self {
ProtocolError::IoError(ref err) => Some(err),
_ => None,
}
}
}
impl fmt::Display for ProtocolError {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ProtocolError::IoError(e) => write!(fmt, "I/O error: {e}"),
ProtocolError::InvalidMessage => write!(fmt, "Received an invalid message."),
ProtocolError::InvalidProtocol => write!(fmt, "A protocol (name) is invalid."),
ProtocolError::TooManyProtocols => write!(fmt, "Too many protocols received."),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use quickcheck::*;
use std::iter;
impl Arbitrary for Protocol {
fn arbitrary(g: &mut Gen) -> Protocol {
let n = g.gen_range(1..g.size());
let p: String = iter::repeat(())
.map(|()| char::arbitrary(g))
.filter(|&c| c.is_ascii_alphanumeric())
.take(n)
.collect();
Protocol(format!("/{p}"))
}
}
impl Arbitrary for Message {
fn arbitrary(g: &mut Gen) -> Message {
match g.gen_range(0..5u8) {
0 => Message::Header(HeaderLine::V1),
1 => Message::NotAvailable,
2 => Message::ListProtocols,
3 => Message::Protocol(Protocol::arbitrary(g)),
4 => Message::Protocols(Vec::arbitrary(g)),
_ => panic!(),
}
}
}
#[test]
fn encode_decode_message() {
fn prop(msg: Message) {
let mut buf = BytesMut::new();
msg.encode(&mut buf)
.unwrap_or_else(|_| panic!("Encoding message failed: {msg:?}"));
match Message::decode(buf.freeze()) {
Ok(m) => assert_eq!(m, msg),
Err(e) => panic!("Decoding failed: {e:?}"),
}
}
quickcheck(prop as fn(_))
}
}