multistream_select/
length_delimited.rs

1// Copyright 2017 Parity Technologies (UK) Ltd.
2//
3// Permission is hereby granted, free of charge, to any person obtaining a
4// copy of this software and associated documentation files (the "Software"),
5// to deal in the Software without restriction, including without limitation
6// the rights to use, copy, modify, merge, publish, distribute, sublicense,
7// and/or sell copies of the Software, and to permit persons to whom the
8// Software is furnished to do so, subject to the following conditions:
9//
10// The above copyright notice and this permission notice shall be included in
11// all copies or substantial portions of the Software.
12//
13// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
14// OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
18// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
19// DEALINGS IN THE SOFTWARE.
20
21use bytes::{Bytes, BytesMut, Buf as _, BufMut as _};
22use futures::{prelude::*, io::IoSlice};
23use std::{convert::TryFrom as _, io, pin::Pin, task::{Poll, Context}, u16};
24
25const MAX_LEN_BYTES: u16 = 2;
26const MAX_FRAME_SIZE: u16 = (1 << (MAX_LEN_BYTES * 8 - MAX_LEN_BYTES)) - 1;
27const DEFAULT_BUFFER_SIZE: usize = 64;
28
29/// A `Stream` and `Sink` for unsigned-varint length-delimited frames,
30/// wrapping an underlying `AsyncRead + AsyncWrite` I/O resource.
31///
32/// We purposely only support a frame sizes up to 16KiB (2 bytes unsigned varint
33/// frame length). Frames mostly consist in a short protocol name, which is highly
34/// unlikely to be more than 16KiB long.
35#[pin_project::pin_project]
36#[derive(Debug)]
37pub struct LengthDelimited<R> {
38    /// The inner I/O resource.
39    #[pin]
40    inner: R,
41    /// Read buffer for a single incoming unsigned-varint length-delimited frame.
42    read_buffer: BytesMut,
43    /// Write buffer for outgoing unsigned-varint length-delimited frames.
44    write_buffer: BytesMut,
45    /// The current read state, alternating between reading a frame
46    /// length and reading a frame payload.
47    read_state: ReadState,
48}
49
50#[derive(Debug, Copy, Clone, PartialEq, Eq)]
51enum ReadState {
52    /// We are currently reading the length of the next frame of data.
53    ReadLength { buf: [u8; MAX_LEN_BYTES as usize], pos: usize },
54    /// We are currently reading the frame of data itself.
55    ReadData { len: u16, pos: usize },
56}
57
58impl Default for ReadState {
59    fn default() -> Self {
60        ReadState::ReadLength {
61            buf: [0; MAX_LEN_BYTES as usize],
62            pos: 0
63        }
64    }
65}
66
67impl<R> LengthDelimited<R> {
68    /// Creates a new I/O resource for reading and writing unsigned-varint
69    /// length delimited frames.
70    pub fn new(inner: R) -> LengthDelimited<R> {
71        LengthDelimited {
72            inner,
73            read_state: ReadState::default(),
74            read_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE),
75            write_buffer: BytesMut::with_capacity(DEFAULT_BUFFER_SIZE + MAX_LEN_BYTES as usize),
76        }
77    }
78
79    /// Drops the [`LengthDelimited`] resource, yielding the underlying I/O stream.
80    ///
81    /// # Panic
82    ///
83    /// Will panic if called while there is data in the read or write buffer.
84    /// The read buffer is guaranteed to be empty whenever `Stream::poll` yields
85    /// a new `Bytes` frame. The write buffer is guaranteed to be empty after
86    /// flushing.
87    pub fn into_inner(self) -> R {
88        assert!(self.read_buffer.is_empty());
89        assert!(self.write_buffer.is_empty());
90        self.inner
91    }
92
93    /// Converts the [`LengthDelimited`] into a [`LengthDelimitedReader`], dropping the
94    /// uvi-framed `Sink` in favour of direct `AsyncWrite` access to the underlying
95    /// I/O stream.
96    ///
97    /// This is typically done if further uvi-framed messages are expected to be
98    /// received but no more such messages are written, allowing the writing of
99    /// follow-up protocol data to commence.
100    pub fn into_reader(self) -> LengthDelimitedReader<R> {
101        LengthDelimitedReader { inner: self }
102    }
103
104    /// Writes all buffered frame data to the underlying I/O stream,
105    /// _without flushing it_.
106    ///
107    /// After this method returns `Poll::Ready`, the write buffer of frames
108    /// submitted to the `Sink` is guaranteed to be empty.
109    pub fn poll_write_buffer(self: Pin<&mut Self>, cx: &mut Context<'_>)
110        -> Poll<Result<(), io::Error>>
111    where
112        R: AsyncWrite
113    {
114        let mut this = self.project();
115
116        while !this.write_buffer.is_empty() {
117            match this.inner.as_mut().poll_write(cx, &this.write_buffer) {
118                Poll::Pending => return Poll::Pending,
119                Poll::Ready(Ok(0)) => {
120                    return Poll::Ready(Err(io::Error::new(
121                        io::ErrorKind::WriteZero,
122                        "Failed to write buffered frame.")))
123                }
124                Poll::Ready(Ok(n)) => this.write_buffer.advance(n),
125                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
126            }
127        }
128
129        Poll::Ready(Ok(()))
130    }
131}
132
133impl<R> Stream for LengthDelimited<R>
134where
135    R: AsyncRead
136{
137    type Item = Result<Bytes, io::Error>;
138
139    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
140        let mut this = self.project();
141
142        loop {
143            match this.read_state {
144                ReadState::ReadLength { buf, pos } => {
145                    match this.inner.as_mut().poll_read(cx, &mut buf[*pos .. *pos + 1]) {
146                        Poll::Ready(Ok(0)) => {
147                            if *pos == 0 {
148                                return Poll::Ready(None);
149                            } else {
150                                return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into())));
151                            }
152                        }
153                        Poll::Ready(Ok(n)) => {
154                            debug_assert_eq!(n, 1);
155                            *pos += n;
156                        }
157                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
158                        Poll::Pending => return Poll::Pending,
159                    };
160
161                    if (buf[*pos - 1] & 0x80) == 0 {
162                        // MSB is not set, indicating the end of the length prefix.
163                        let (len, _) = unsigned_varint::decode::u16(buf)
164                            .map_err(|e| {
165                                log::debug!("invalid length prefix: {}", e);
166                                io::Error::new(io::ErrorKind::InvalidData, "invalid length prefix")
167                            })?;
168
169                        if len >= 1 {
170                            *this.read_state = ReadState::ReadData { len, pos: 0 };
171                            this.read_buffer.resize(len as usize, 0);
172                        } else {
173                            debug_assert_eq!(len, 0);
174                            *this.read_state = ReadState::default();
175                            return Poll::Ready(Some(Ok(Bytes::new())));
176                        }
177                    } else if *pos == MAX_LEN_BYTES as usize {
178                        // MSB signals more length bytes but we have already read the maximum.
179                        // See the module documentation about the max frame len.
180                        return Poll::Ready(Some(Err(io::Error::new(
181                            io::ErrorKind::InvalidData,
182                            "Maximum frame length exceeded"))));
183                    }
184                }
185                ReadState::ReadData { len, pos } => {
186                    match this.inner.as_mut().poll_read(cx, &mut this.read_buffer[*pos..]) {
187                        Poll::Ready(Ok(0)) => return Poll::Ready(Some(Err(io::ErrorKind::UnexpectedEof.into()))),
188                        Poll::Ready(Ok(n)) => *pos += n,
189                        Poll::Pending => return Poll::Pending,
190                        Poll::Ready(Err(err)) => return Poll::Ready(Some(Err(err))),
191                    };
192
193                    if *pos == *len as usize {
194                        // Finished reading the frame.
195                        let frame = this.read_buffer.split_off(0).freeze();
196                        *this.read_state = ReadState::default();
197                        return Poll::Ready(Some(Ok(frame)));
198                    }
199                }
200            }
201        }
202    }
203}
204
205impl<R> Sink<Bytes> for LengthDelimited<R>
206where
207    R: AsyncWrite,
208{
209    type Error = io::Error;
210
211    fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
212        // Use the maximum frame length also as a (soft) upper limit
213        // for the entire write buffer. The actual (hard) limit is thus
214        // implied to be roughly 2 * MAX_FRAME_SIZE.
215        if self.as_mut().project().write_buffer.len() >= MAX_FRAME_SIZE as usize {
216            match self.as_mut().poll_write_buffer(cx) {
217                Poll::Ready(Ok(())) => {},
218                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
219                Poll::Pending => return Poll::Pending,
220            }
221
222            debug_assert!(self.as_mut().project().write_buffer.is_empty());
223        }
224
225        Poll::Ready(Ok(()))
226    }
227
228    fn start_send(self: Pin<&mut Self>, item: Bytes) -> Result<(), Self::Error> {
229        let this = self.project();
230
231        let len = match u16::try_from(item.len()) {
232            Ok(len) if len <= MAX_FRAME_SIZE => len,
233            _ => {
234                return Err(io::Error::new(
235                    io::ErrorKind::InvalidData,
236                    "Maximum frame size exceeded."))
237            }
238        };
239
240        let mut uvi_buf = unsigned_varint::encode::u16_buffer();
241        let uvi_len = unsigned_varint::encode::u16(len, &mut uvi_buf);
242        this.write_buffer.reserve(len as usize + uvi_len.len());
243        this.write_buffer.put(uvi_len);
244        this.write_buffer.put(item);
245
246        Ok(())
247    }
248
249    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
250        // Write all buffered frame data to the underlying I/O stream.
251        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
252            Poll::Ready(Ok(())) => {},
253            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
254            Poll::Pending => return Poll::Pending,
255        }
256
257        let this = self.project();
258        debug_assert!(this.write_buffer.is_empty());
259
260        // Flush the underlying I/O stream.
261        this.inner.poll_flush(cx)
262    }
263
264    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
265        // Write all buffered frame data to the underlying I/O stream.
266        match LengthDelimited::poll_write_buffer(self.as_mut(), cx) {
267            Poll::Ready(Ok(())) => {},
268            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
269            Poll::Pending => return Poll::Pending,
270        }
271
272        let this = self.project();
273        debug_assert!(this.write_buffer.is_empty());
274
275        // Close the underlying I/O stream.
276        this.inner.poll_close(cx)
277    }
278}
279
280/// A `LengthDelimitedReader` implements a `Stream` of uvi-length-delimited
281/// frames on an underlying I/O resource combined with direct `AsyncWrite` access.
282#[pin_project::pin_project]
283#[derive(Debug)]
284pub struct LengthDelimitedReader<R> {
285    #[pin]
286    inner: LengthDelimited<R>
287}
288
289impl<R> LengthDelimitedReader<R> {
290    /// Destroys the `LengthDelimitedReader` and returns the underlying I/O stream.
291    ///
292    /// This method is guaranteed not to drop any data read from or not yet
293    /// submitted to the underlying I/O stream.
294    ///
295    /// # Panic
296    ///
297    /// Will panic if called while there is data in the read or write buffer.
298    /// The read buffer is guaranteed to be empty whenever [`Stream::poll_next`]
299    /// yield a new `Message`. The write buffer is guaranteed to be empty whenever
300    /// [`LengthDelimited::poll_write_buffer`] yields [`Poll::Ready`] or after
301    /// the [`Sink`] has been completely flushed via [`Sink::poll_flush`].
302    pub fn into_inner(self) -> R {
303        self.inner.into_inner()
304    }
305}
306
307impl<R> Stream for LengthDelimitedReader<R>
308where
309    R: AsyncRead
310{
311    type Item = Result<Bytes, io::Error>;
312
313    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
314        self.project().inner.poll_next(cx)
315    }
316}
317
318impl<R> AsyncWrite for LengthDelimitedReader<R>
319where
320    R: AsyncWrite
321{
322    fn poll_write(self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8])
323        -> Poll<Result<usize, io::Error>>
324    {
325        // `this` here designates the `LengthDelimited`.
326        let mut this = self.project().inner;
327
328        // We need to flush any data previously written with the `LengthDelimited`.
329        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
330            Poll::Ready(Ok(())) => {},
331            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
332            Poll::Pending => return Poll::Pending,
333        }
334        debug_assert!(this.write_buffer.is_empty());
335
336        this.project().inner.poll_write(cx, buf)
337    }
338
339    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
340        self.project().inner.poll_flush(cx)
341    }
342
343    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
344        self.project().inner.poll_close(cx)
345    }
346
347    fn poll_write_vectored(self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>])
348        -> Poll<Result<usize, io::Error>>
349    {
350        // `this` here designates the `LengthDelimited`.
351        let mut this = self.project().inner;
352
353        // We need to flush any data previously written with the `LengthDelimited`.
354        match LengthDelimited::poll_write_buffer(this.as_mut(), cx) {
355            Poll::Ready(Ok(())) => {},
356            Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
357            Poll::Pending => return Poll::Pending,
358        }
359        debug_assert!(this.write_buffer.is_empty());
360
361        this.project().inner.poll_write_vectored(cx, bufs)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use crate::length_delimited::LengthDelimited;
368    use async_std::net::{TcpListener, TcpStream};
369    use futures::{prelude::*, io::Cursor};
370    use quickcheck::*;
371    use std::io::ErrorKind;
372
373    #[test]
374    fn basic_read() {
375        let data = vec![6, 9, 8, 7, 6, 5, 4];
376        let framed = LengthDelimited::new(Cursor::new(data));
377        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
378        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4]]);
379    }
380
381    #[test]
382    fn basic_read_two() {
383        let data = vec![6, 9, 8, 7, 6, 5, 4, 3, 9, 8, 7];
384        let framed = LengthDelimited::new(Cursor::new(data));
385        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
386        assert_eq!(recved, vec![vec![9, 8, 7, 6, 5, 4], vec![9, 8, 7]]);
387    }
388
389    #[test]
390    fn two_bytes_long_packet() {
391        let len = 5000u16;
392        assert!(len < (1 << 15));
393        let frame = (0..len).map(|n| (n & 0xff) as u8).collect::<Vec<_>>();
394        let mut data = vec![(len & 0x7f) as u8 | 0x80, (len >> 7) as u8];
395        data.extend(frame.clone().into_iter());
396        let mut framed = LengthDelimited::new(Cursor::new(data));
397        let recved = futures::executor::block_on(async move {
398            framed.next().await
399        }).unwrap();
400        assert_eq!(recved.unwrap(), frame);
401    }
402
403    #[test]
404    fn packet_len_too_long() {
405        let mut data = vec![0x81, 0x81, 0x1];
406        data.extend((0..16513).map(|_| 0));
407        let mut framed = LengthDelimited::new(Cursor::new(data));
408        let recved = futures::executor::block_on(async move {
409            framed.next().await.unwrap()
410        });
411
412        if let Err(io_err) = recved {
413            assert_eq!(io_err.kind(), ErrorKind::InvalidData)
414        } else {
415            panic!()
416        }
417    }
418
419    #[test]
420    fn empty_frames() {
421        let data = vec![0, 0, 6, 9, 8, 7, 6, 5, 4, 0, 3, 9, 8, 7];
422        let framed = LengthDelimited::new(Cursor::new(data));
423        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>()).unwrap();
424        assert_eq!(
425            recved,
426            vec![
427                vec![],
428                vec![],
429                vec![9, 8, 7, 6, 5, 4],
430                vec![],
431                vec![9, 8, 7],
432            ]
433        );
434    }
435
436    #[test]
437    fn unexpected_eof_in_len() {
438        let data = vec![0x89];
439        let framed = LengthDelimited::new(Cursor::new(data));
440        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
441        if let Err(io_err) = recved {
442            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
443        } else {
444            panic!()
445        }
446    }
447
448    #[test]
449    fn unexpected_eof_in_data() {
450        let data = vec![5];
451        let framed = LengthDelimited::new(Cursor::new(data));
452        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
453        if let Err(io_err) = recved {
454            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
455        } else {
456            panic!()
457        }
458    }
459
460    #[test]
461    fn unexpected_eof_in_data2() {
462        let data = vec![5, 9, 8, 7];
463        let framed = LengthDelimited::new(Cursor::new(data));
464        let recved = futures::executor::block_on(framed.try_collect::<Vec<_>>());
465        if let Err(io_err) = recved {
466            assert_eq!(io_err.kind(), ErrorKind::UnexpectedEof)
467        } else {
468            panic!()
469        }
470    }
471
472    #[test]
473    fn writing_reading() {
474        fn prop(frames: Vec<Vec<u8>>) -> TestResult {
475            async_std::task::block_on(async move {
476                let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
477                let listener_addr = listener.local_addr().unwrap();
478
479                let expected_frames = frames.clone();
480                let server = async_std::task::spawn(async move {
481                    let socket = listener.accept().await.unwrap().0;
482                    let mut connec = rw_stream_sink::RwStreamSink::new(LengthDelimited::new(socket));
483
484                    let mut buf = vec![0u8; 0];
485                    for expected in expected_frames {
486                        if expected.is_empty() {
487                            continue;
488                        }
489                        if buf.len() < expected.len() {
490                            buf.resize(expected.len(), 0);
491                        }
492                        let n = connec.read(&mut buf).await.unwrap();
493                        assert_eq!(&buf[..n], &expected[..]);
494                    }
495                });
496
497                let client = async_std::task::spawn(async move {
498                    let socket = TcpStream::connect(&listener_addr).await.unwrap();
499                    let mut connec = LengthDelimited::new(socket);
500                    for frame in frames {
501                        connec.send(From::from(frame)).await.unwrap();
502                    }
503                });
504
505                server.await;
506                client.await;
507            });
508
509            TestResult::passed()
510        }
511
512        quickcheck(prop as fn(_) -> _)
513    }
514}