multistream_select/
protocol.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21//! Multistream-select protocol messages an I/O operations for
22//! constructing protocol negotiation flows.
23//!
24//! A protocol negotiation flow is constructed by using the
25//! `Stream` and `Sink` implementations of `MessageIO` and
26//! `MessageReader`.
27
28use crate::Version;
29use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
30
31use bytes::{Bytes, BytesMut, BufMut};
32use futures::{prelude::*, io::IoSlice, ready};
33use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
34use unsigned_varint as uvi;
35
36/// The maximum number of supported protocols that can be processed.
37const MAX_PROTOCOLS: usize = 1000;
38
39/// The encoded form of a multistream-select 1.0.0 header message.
40const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
41/// The encoded form of a multistream-select 'na' message.
42const MSG_PROTOCOL_NA: &[u8] = b"na\n";
43/// The encoded form of a multistream-select 'ls' message.
44const MSG_LS: &[u8] = b"ls\n";
45
46/// The multistream-select header lines preceeding negotiation.
47///
48/// Every [`Version`] has a corresponding header line.
49#[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub enum HeaderLine {
51    /// The `/multistream/1.0.0` header line.
52    V1,
53}
54
55impl From<Version> for HeaderLine {
56    fn from(v: Version) -> HeaderLine {
57        match v {
58            Version::V1 | Version::V1Lazy => HeaderLine::V1,
59        }
60    }
61}
62
63/// A protocol (name) exchanged during protocol negotiation.
64#[derive(Clone, Debug, PartialEq, Eq)]
65pub struct Protocol(Bytes);
66
67impl AsRef<[u8]> for Protocol {
68    fn as_ref(&self) -> &[u8] {
69        self.0.as_ref()
70    }
71}
72
73impl TryFrom<Bytes> for Protocol {
74    type Error = ProtocolError;
75
76    fn try_from(value: Bytes) -> Result<Self, Self::Error> {
77        if !value.as_ref().starts_with(b"/") {
78            return Err(ProtocolError::InvalidProtocol)
79        }
80        Ok(Protocol(value))
81    }
82}
83
84impl TryFrom<&[u8]> for Protocol {
85    type Error = ProtocolError;
86
87    fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
88        Self::try_from(Bytes::copy_from_slice(value))
89    }
90}
91
92impl fmt::Display for Protocol {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        write!(f, "{}", String::from_utf8_lossy(&self.0))
95    }
96}
97
98/// A multistream-select protocol message.
99///
100/// Multistream-select protocol messages are exchanged with the goal
101/// of agreeing on a application-layer protocol to use on an I/O stream.
102#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum Message {
104    /// A header message identifies the multistream-select protocol
105    /// that the sender wishes to speak.
106    Header(HeaderLine),
107    /// A protocol message identifies a protocol request or acknowledgement.
108    Protocol(Protocol),
109    /// A message through which a peer requests the complete list of
110    /// supported protocols from the remote.
111    ListProtocols,
112    /// A message listing all supported protocols of a peer.
113    Protocols(Vec<Protocol>),
114    /// A message signaling that a requested protocol is not available.
115    NotAvailable,
116}
117
118impl Message {
119    /// Encodes a `Message` into its byte representation.
120    pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
121        match self {
122            Message::Header(HeaderLine::V1) => {
123                dest.reserve(MSG_MULTISTREAM_1_0.len());
124                dest.put(MSG_MULTISTREAM_1_0);
125                Ok(())
126            }
127            Message::Protocol(p) => {
128                let len = p.0.as_ref().len() + 1; // + 1 for \n
129                dest.reserve(len);
130                dest.put(p.0.as_ref());
131                dest.put_u8(b'\n');
132                Ok(())
133            }
134            Message::ListProtocols => {
135                dest.reserve(MSG_LS.len());
136                dest.put(MSG_LS);
137                Ok(())
138            }
139            Message::Protocols(ps) => {
140                let mut buf = uvi::encode::usize_buffer();
141                let mut encoded = Vec::with_capacity(ps.len());
142                for p in ps {
143                    encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); // +1 for '\n'
144                    encoded.extend_from_slice(p.0.as_ref());
145                    encoded.push(b'\n')
146                }
147                encoded.push(b'\n');
148                dest.reserve(encoded.len());
149                dest.put(encoded.as_ref());
150                Ok(())
151            }
152            Message::NotAvailable => {
153                dest.reserve(MSG_PROTOCOL_NA.len());
154                dest.put(MSG_PROTOCOL_NA);
155                Ok(())
156            }
157        }
158    }
159
160    /// Decodes a `Message` from its byte representation.
161    pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
162        if msg == MSG_MULTISTREAM_1_0 {
163            return Ok(Message::Header(HeaderLine::V1))
164        }
165
166        if msg == MSG_PROTOCOL_NA {
167            return Ok(Message::NotAvailable);
168        }
169
170        if msg == MSG_LS {
171            return Ok(Message::ListProtocols)
172        }
173
174        // If it starts with a `/`, ends with a line feed without any
175        // other line feeds in-between, it must be a protocol name.
176        if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') &&
177            !msg[.. msg.len() - 1].contains(&b'\n')
178        {
179            let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
180            return Ok(Message::Protocol(p));
181        }
182
183        // At this point, it must be an `ls` response, i.e. one or more
184        // length-prefixed, newline-delimited protocol names.
185        let mut protocols = Vec::new();
186        let mut remaining: &[u8] = &msg;
187        loop {
188            // A well-formed message must be terminated with a newline.
189            if remaining == [b'\n'] {
190                break
191            } else if protocols.len() == MAX_PROTOCOLS {
192                return Err(ProtocolError::TooManyProtocols)
193            }
194
195            // Decode the length of the next protocol name and check that
196            // it ends with a line feed.
197            let (len, tail) = uvi::decode::usize(remaining)?;
198            if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
199                return Err(ProtocolError::InvalidMessage)
200            }
201
202            // Parse the protocol name.
203            let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?;
204            protocols.push(p);
205
206            // Skip ahead to the next protocol.
207            remaining = &tail[len ..];
208        }
209
210        Ok(Message::Protocols(protocols))
211    }
212}
213
214/// A `MessageIO` implements a [`Stream`] and [`Sink`] of [`Message`]s.
215#[pin_project::pin_project]
216pub struct MessageIO<R> {
217    #[pin]
218    inner: LengthDelimited<R>,
219}
220
221impl<R> MessageIO<R> {
222    /// Constructs a new `MessageIO` resource wrapping the given I/O stream.
223    pub fn new(inner: R) -> MessageIO<R>
224    where
225        R: AsyncRead + AsyncWrite
226    {
227        Self { inner: LengthDelimited::new(inner) }
228    }
229
230    /// Converts the [`MessageIO`] into a [`MessageReader`], dropping the
231    /// [`Message`]-oriented `Sink` in favour of direct `AsyncWrite` access
232    /// to the underlying I/O stream.
233    ///
234    /// This is typically done if further negotiation messages are expected to be
235    /// received but no more messages are written, allowing the writing of
236    /// follow-up protocol data to commence.
237    pub fn into_reader(self) -> MessageReader<R> {
238        MessageReader { inner: self.inner.into_reader() }
239    }
240
241    /// Drops the [`MessageIO`] resource, yielding the underlying I/O stream.
242    ///
243    /// # Panics
244    ///
245    /// Panics if the read buffer or write buffer is not empty, meaning that an incoming
246    /// protocol negotiation frame has been partially read or an outgoing frame
247    /// has not yet been flushed. The read buffer is guaranteed to be empty whenever
248    /// `MessageIO::poll` returned a message. The write buffer is guaranteed to be empty
249    /// when the sink has been flushed.
250    pub fn into_inner(self) -> R {
251        self.inner.into_inner()
252    }
253}
254
255impl<R> Sink<Message> for MessageIO<R>
256where
257    R: AsyncWrite,
258{
259    type Error = ProtocolError;
260
261    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262        self.project().inner.poll_ready(cx).map_err(From::from)
263    }
264
265    fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
266        let mut buf = BytesMut::new();
267        item.encode(&mut buf)?;
268        self.project().inner.start_send(buf.freeze()).map_err(From::from)
269    }
270
271    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272        self.project().inner.poll_flush(cx).map_err(From::from)
273    }
274
275    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276        self.project().inner.poll_close(cx).map_err(From::from)
277    }
278}
279
280impl<R> Stream for MessageIO<R>
281where
282    R: AsyncRead
283{
284    type Item = Result<Message, ProtocolError>;
285
286    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287        match poll_stream(self.project().inner, cx) {
288            Poll::Pending => Poll::Pending,
289            Poll::Ready(None) => Poll::Ready(None),
290            Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
291            Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
292        }
293    }
294}
295
296/// A `MessageReader` implements a `Stream` of `Message`s on an underlying
297/// I/O resource combined with direct `AsyncWrite` access.
298#[pin_project::pin_project]
299#[derive(Debug)]
300pub struct MessageReader<R> {
301    #[pin]
302    inner: LengthDelimitedReader<R>
303}
304
305impl<R> MessageReader<R> {
306    /// Drops the `MessageReader` resource, yielding the underlying I/O stream
307    /// together with the remaining write buffer containing the protocol
308    /// negotiation frame data that has not yet been written to the I/O stream.
309    ///
310    /// # Panics
311    ///
312    /// Panics if the read buffer or write buffer is not empty, meaning that either
313    /// an incoming protocol negotiation frame has been partially read, or an
314    /// outgoing frame has not yet been flushed. The read buffer is guaranteed to
315    /// be empty whenever `MessageReader::poll` returned a message. The write
316    /// buffer is guaranteed to be empty whenever the sink has been flushed.
317    pub fn into_inner(self) -> R {
318        self.inner.into_inner()
319    }
320}
321
322impl<R> Stream for MessageReader<R>
323where
324    R: AsyncRead
325{
326    type Item = Result<Message, ProtocolError>;
327
328    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
329        poll_stream(self.project().inner, cx)
330    }
331}
332
333impl<TInner> AsyncWrite for MessageReader<TInner>
334where
335    TInner: AsyncWrite
336{
337    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
338        self.project().inner.poll_write(cx, buf)
339    }
340
341    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
342        self.project().inner.poll_flush(cx)
343    }
344
345    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
346        self.project().inner.poll_close(cx)
347    }
348
349    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, io::Error>> {
350        self.project().inner.poll_write_vectored(cx, bufs)
351    }
352}
353
354fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll<Option<Result<Message, ProtocolError>>>
355where
356    S: Stream<Item = Result<Bytes, io::Error>>,
357{
358    let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
359        match Message::decode(msg) {
360            Ok(m) => m,
361            Err(err) => return Poll::Ready(Some(Err(err))),
362        }
363    } else {
364        return Poll::Ready(None)
365    };
366
367    log::trace!("Received message: {:?}", msg);
368
369    Poll::Ready(Some(Ok(msg)))
370}
371
372/// A protocol error.
373#[derive(Debug)]
374pub enum ProtocolError {
375    /// I/O error.
376    IoError(io::Error),
377
378    /// Received an invalid message from the remote.
379    InvalidMessage,
380
381    /// A protocol (name) is invalid.
382    InvalidProtocol,
383
384    /// Too many protocols have been returned by the remote.
385    TooManyProtocols,
386}
387
388impl From<io::Error> for ProtocolError {
389    fn from(err: io::Error) -> ProtocolError {
390        ProtocolError::IoError(err)
391    }
392}
393
394impl Into<io::Error> for ProtocolError {
395    fn into(self) -> io::Error {
396        if let ProtocolError::IoError(e) = self {
397            return e
398        }
399        io::ErrorKind::InvalidData.into()
400    }
401}
402
403impl From<uvi::decode::Error> for ProtocolError {
404    fn from(err: uvi::decode::Error) -> ProtocolError {
405        Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
406    }
407}
408
409impl Error for ProtocolError {
410    fn source(&self) -> Option<&(dyn Error + 'static)> {
411        match *self {
412            ProtocolError::IoError(ref err) => Some(err),
413            _ => None,
414        }
415    }
416}
417
418impl fmt::Display for ProtocolError {
419    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
420        match self {
421            ProtocolError::IoError(e) =>
422                write!(fmt, "I/O error: {}", e),
423            ProtocolError::InvalidMessage =>
424                write!(fmt, "Received an invalid message."),
425            ProtocolError::InvalidProtocol =>
426                write!(fmt, "A protocol (name) is invalid."),
427            ProtocolError::TooManyProtocols =>
428                write!(fmt, "Too many protocols received.")
429        }
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436    use quickcheck::*;
437    use rand::Rng;
438    use rand::distributions::Alphanumeric;
439    use std::iter;
440
441    impl Arbitrary for Protocol {
442        fn arbitrary<G: Gen>(g: &mut G) -> Protocol {
443            let n = g.gen_range(1, g.size());
444            let p: String = iter::repeat(())
445                .map(|()| g.sample(Alphanumeric))
446                .take(n)
447                .collect();
448            Protocol(Bytes::from(format!("/{}", p)))
449        }
450    }
451
452    impl Arbitrary for Message {
453        fn arbitrary<G: Gen>(g: &mut G) -> Message {
454            match g.gen_range(0, 5) {
455                0 => Message::Header(HeaderLine::V1),
456                1 => Message::NotAvailable,
457                2 => Message::ListProtocols,
458                3 => Message::Protocol(Protocol::arbitrary(g)),
459                4 => Message::Protocols(Vec::arbitrary(g)),
460                _ => panic!()
461            }
462        }
463    }
464
465    #[test]
466    fn encode_decode_message() {
467        fn prop(msg: Message) {
468            let mut buf = BytesMut::new();
469            msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
470            match Message::decode(buf.freeze()) {
471                Ok(m) => assert_eq!(m, msg),
472                Err(e) => panic!("Decoding failed: {:?}", e)
473            }
474        }
475        quickcheck(prop as fn(_))
476    }
477}