1use crate::Version;
29use crate::length_delimited::{LengthDelimited, LengthDelimitedReader};
30
31use bytes::{Bytes, BytesMut, BufMut};
32use futures::{prelude::*, io::IoSlice, ready};
33use std::{convert::TryFrom, io, fmt, error::Error, pin::Pin, task::{Context, Poll}};
34use unsigned_varint as uvi;
35
36const MAX_PROTOCOLS: usize = 1000;
38
39const MSG_MULTISTREAM_1_0: &[u8] = b"/multistream/1.0.0\n";
41const MSG_PROTOCOL_NA: &[u8] = b"na\n";
43const MSG_LS: &[u8] = b"ls\n";
45
46#[derive(Copy, Clone, Debug, PartialEq, Eq)]
50pub enum HeaderLine {
51 V1,
53}
54
55impl From<Version> for HeaderLine {
56 fn from(v: Version) -> HeaderLine {
57 match v {
58 Version::V1 | Version::V1Lazy => HeaderLine::V1,
59 }
60 }
61}
62
63#[derive(Clone, Debug, PartialEq, Eq)]
65pub struct Protocol(Bytes);
66
67impl AsRef<[u8]> for Protocol {
68 fn as_ref(&self) -> &[u8] {
69 self.0.as_ref()
70 }
71}
72
73impl TryFrom<Bytes> for Protocol {
74 type Error = ProtocolError;
75
76 fn try_from(value: Bytes) -> Result<Self, Self::Error> {
77 if !value.as_ref().starts_with(b"/") {
78 return Err(ProtocolError::InvalidProtocol)
79 }
80 Ok(Protocol(value))
81 }
82}
83
84impl TryFrom<&[u8]> for Protocol {
85 type Error = ProtocolError;
86
87 fn try_from(value: &[u8]) -> Result<Self, Self::Error> {
88 Self::try_from(Bytes::copy_from_slice(value))
89 }
90}
91
92impl fmt::Display for Protocol {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 write!(f, "{}", String::from_utf8_lossy(&self.0))
95 }
96}
97
98#[derive(Debug, Clone, PartialEq, Eq)]
103pub enum Message {
104 Header(HeaderLine),
107 Protocol(Protocol),
109 ListProtocols,
112 Protocols(Vec<Protocol>),
114 NotAvailable,
116}
117
118impl Message {
119 pub fn encode(&self, dest: &mut BytesMut) -> Result<(), ProtocolError> {
121 match self {
122 Message::Header(HeaderLine::V1) => {
123 dest.reserve(MSG_MULTISTREAM_1_0.len());
124 dest.put(MSG_MULTISTREAM_1_0);
125 Ok(())
126 }
127 Message::Protocol(p) => {
128 let len = p.0.as_ref().len() + 1; dest.reserve(len);
130 dest.put(p.0.as_ref());
131 dest.put_u8(b'\n');
132 Ok(())
133 }
134 Message::ListProtocols => {
135 dest.reserve(MSG_LS.len());
136 dest.put(MSG_LS);
137 Ok(())
138 }
139 Message::Protocols(ps) => {
140 let mut buf = uvi::encode::usize_buffer();
141 let mut encoded = Vec::with_capacity(ps.len());
142 for p in ps {
143 encoded.extend(uvi::encode::usize(p.0.as_ref().len() + 1, &mut buf)); encoded.extend_from_slice(p.0.as_ref());
145 encoded.push(b'\n')
146 }
147 encoded.push(b'\n');
148 dest.reserve(encoded.len());
149 dest.put(encoded.as_ref());
150 Ok(())
151 }
152 Message::NotAvailable => {
153 dest.reserve(MSG_PROTOCOL_NA.len());
154 dest.put(MSG_PROTOCOL_NA);
155 Ok(())
156 }
157 }
158 }
159
160 pub fn decode(mut msg: Bytes) -> Result<Message, ProtocolError> {
162 if msg == MSG_MULTISTREAM_1_0 {
163 return Ok(Message::Header(HeaderLine::V1))
164 }
165
166 if msg == MSG_PROTOCOL_NA {
167 return Ok(Message::NotAvailable);
168 }
169
170 if msg == MSG_LS {
171 return Ok(Message::ListProtocols)
172 }
173
174 if msg.get(0) == Some(&b'/') && msg.last() == Some(&b'\n') &&
177 !msg[.. msg.len() - 1].contains(&b'\n')
178 {
179 let p = Protocol::try_from(msg.split_to(msg.len() - 1))?;
180 return Ok(Message::Protocol(p));
181 }
182
183 let mut protocols = Vec::new();
186 let mut remaining: &[u8] = &msg;
187 loop {
188 if remaining == [b'\n'] {
190 break
191 } else if protocols.len() == MAX_PROTOCOLS {
192 return Err(ProtocolError::TooManyProtocols)
193 }
194
195 let (len, tail) = uvi::decode::usize(remaining)?;
198 if len == 0 || len > tail.len() || tail[len - 1] != b'\n' {
199 return Err(ProtocolError::InvalidMessage)
200 }
201
202 let p = Protocol::try_from(Bytes::copy_from_slice(&tail[.. len - 1]))?;
204 protocols.push(p);
205
206 remaining = &tail[len ..];
208 }
209
210 Ok(Message::Protocols(protocols))
211 }
212}
213
214#[pin_project::pin_project]
216pub struct MessageIO<R> {
217 #[pin]
218 inner: LengthDelimited<R>,
219}
220
221impl<R> MessageIO<R> {
222 pub fn new(inner: R) -> MessageIO<R>
224 where
225 R: AsyncRead + AsyncWrite
226 {
227 Self { inner: LengthDelimited::new(inner) }
228 }
229
230 pub fn into_reader(self) -> MessageReader<R> {
238 MessageReader { inner: self.inner.into_reader() }
239 }
240
241 pub fn into_inner(self) -> R {
251 self.inner.into_inner()
252 }
253}
254
255impl<R> Sink<Message> for MessageIO<R>
256where
257 R: AsyncWrite,
258{
259 type Error = ProtocolError;
260
261 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
262 self.project().inner.poll_ready(cx).map_err(From::from)
263 }
264
265 fn start_send(self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
266 let mut buf = BytesMut::new();
267 item.encode(&mut buf)?;
268 self.project().inner.start_send(buf.freeze()).map_err(From::from)
269 }
270
271 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
272 self.project().inner.poll_flush(cx).map_err(From::from)
273 }
274
275 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
276 self.project().inner.poll_close(cx).map_err(From::from)
277 }
278}
279
280impl<R> Stream for MessageIO<R>
281where
282 R: AsyncRead
283{
284 type Item = Result<Message, ProtocolError>;
285
286 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
287 match poll_stream(self.project().inner, cx) {
288 Poll::Pending => Poll::Pending,
289 Poll::Ready(None) => Poll::Ready(None),
290 Poll::Ready(Some(Ok(m))) => Poll::Ready(Some(Ok(m))),
291 Poll::Ready(Some(Err(err))) => Poll::Ready(Some(Err(err))),
292 }
293 }
294}
295
296#[pin_project::pin_project]
299#[derive(Debug)]
300pub struct MessageReader<R> {
301 #[pin]
302 inner: LengthDelimitedReader<R>
303}
304
305impl<R> MessageReader<R> {
306 pub fn into_inner(self) -> R {
318 self.inner.into_inner()
319 }
320}
321
322impl<R> Stream for MessageReader<R>
323where
324 R: AsyncRead
325{
326 type Item = Result<Message, ProtocolError>;
327
328 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
329 poll_stream(self.project().inner, cx)
330 }
331}
332
333impl<TInner> AsyncWrite for MessageReader<TInner>
334where
335 TInner: AsyncWrite
336{
337 fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8]) -> Poll<Result<usize, io::Error>> {
338 self.project().inner.poll_write(cx, buf)
339 }
340
341 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
342 self.project().inner.poll_flush(cx)
343 }
344
345 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
346 self.project().inner.poll_close(cx)
347 }
348
349 fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>]) -> Poll<Result<usize, io::Error>> {
350 self.project().inner.poll_write_vectored(cx, bufs)
351 }
352}
353
354fn poll_stream<S>(stream: Pin<&mut S>, cx: &mut Context<'_>) -> Poll<Option<Result<Message, ProtocolError>>>
355where
356 S: Stream<Item = Result<Bytes, io::Error>>,
357{
358 let msg = if let Some(msg) = ready!(stream.poll_next(cx)?) {
359 match Message::decode(msg) {
360 Ok(m) => m,
361 Err(err) => return Poll::Ready(Some(Err(err))),
362 }
363 } else {
364 return Poll::Ready(None)
365 };
366
367 log::trace!("Received message: {:?}", msg);
368
369 Poll::Ready(Some(Ok(msg)))
370}
371
372#[derive(Debug)]
374pub enum ProtocolError {
375 IoError(io::Error),
377
378 InvalidMessage,
380
381 InvalidProtocol,
383
384 TooManyProtocols,
386}
387
388impl From<io::Error> for ProtocolError {
389 fn from(err: io::Error) -> ProtocolError {
390 ProtocolError::IoError(err)
391 }
392}
393
394impl Into<io::Error> for ProtocolError {
395 fn into(self) -> io::Error {
396 if let ProtocolError::IoError(e) = self {
397 return e
398 }
399 io::ErrorKind::InvalidData.into()
400 }
401}
402
403impl From<uvi::decode::Error> for ProtocolError {
404 fn from(err: uvi::decode::Error) -> ProtocolError {
405 Self::from(io::Error::new(io::ErrorKind::InvalidData, err.to_string()))
406 }
407}
408
409impl Error for ProtocolError {
410 fn source(&self) -> Option<&(dyn Error + 'static)> {
411 match *self {
412 ProtocolError::IoError(ref err) => Some(err),
413 _ => None,
414 }
415 }
416}
417
418impl fmt::Display for ProtocolError {
419 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
420 match self {
421 ProtocolError::IoError(e) =>
422 write!(fmt, "I/O error: {}", e),
423 ProtocolError::InvalidMessage =>
424 write!(fmt, "Received an invalid message."),
425 ProtocolError::InvalidProtocol =>
426 write!(fmt, "A protocol (name) is invalid."),
427 ProtocolError::TooManyProtocols =>
428 write!(fmt, "Too many protocols received.")
429 }
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436 use quickcheck::*;
437 use rand::Rng;
438 use rand::distributions::Alphanumeric;
439 use std::iter;
440
441 impl Arbitrary for Protocol {
442 fn arbitrary<G: Gen>(g: &mut G) -> Protocol {
443 let n = g.gen_range(1, g.size());
444 let p: String = iter::repeat(())
445 .map(|()| g.sample(Alphanumeric))
446 .take(n)
447 .collect();
448 Protocol(Bytes::from(format!("/{}", p)))
449 }
450 }
451
452 impl Arbitrary for Message {
453 fn arbitrary<G: Gen>(g: &mut G) -> Message {
454 match g.gen_range(0, 5) {
455 0 => Message::Header(HeaderLine::V1),
456 1 => Message::NotAvailable,
457 2 => Message::ListProtocols,
458 3 => Message::Protocol(Protocol::arbitrary(g)),
459 4 => Message::Protocols(Vec::arbitrary(g)),
460 _ => panic!()
461 }
462 }
463 }
464
465 #[test]
466 fn encode_decode_message() {
467 fn prop(msg: Message) {
468 let mut buf = BytesMut::new();
469 msg.encode(&mut buf).expect(&format!("Encoding message failed: {:?}", msg));
470 match Message::decode(buf.freeze()) {
471 Ok(m) => assert_eq!(m, msg),
472 Err(e) => panic!("Decoding failed: {:?}", e)
473 }
474 }
475 quickcheck(prop as fn(_))
476 }
477}